This commit is contained in:
Chao Liu
2019-05-31 01:25:14 -05:00
parent 9ef124cc15
commit 3a6044aa84
6 changed files with 135 additions and 60 deletions

View File

@@ -66,14 +66,20 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopySubLengths_E_K = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_E_K = Sequence<8, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [E, K]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [E, K]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [K, E]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 4;
@@ -114,10 +120,16 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
GemmDataPerReadB,
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2,
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>{};

View File

@@ -410,7 +410,7 @@ int main(int argc, char* argv[])
{
#if 0
constexpr index_t N = 8;
constexpr index_t C = 8;
constexpr index_t C = 16;
constexpr index_t HI = 3;
constexpr index_t WI = 18;
constexpr index_t K = 128;