improve implicit gemm NCHW, SRCK, NKHW, and tuned

This commit is contained in:
Chao Liu
2019-01-24 22:24:30 -06:00
parent 1de6fd0753
commit 8bd6ea1a97
4 changed files with 60 additions and 22 deletions

View File

@@ -361,7 +361,7 @@ int main()
constexpr unsigned K = 1;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#elif 1
#elif 0
// 3x3, 34x34
constexpr unsigned N = 64;
constexpr unsigned C = 256;
@@ -370,15 +370,6 @@ int main()
constexpr unsigned K = 64;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#elif 0
// 3x3, 54x54
constexpr unsigned N = 64;
constexpr unsigned C = 64;
constexpr unsigned HI = 54;
constexpr unsigned WI = 54;
constexpr unsigned K = 64;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#elif 0
// 3x3, 56x56
constexpr unsigned N = 64;
@@ -415,6 +406,15 @@ int main()
constexpr unsigned K = 64;
constexpr unsigned S = 7;
constexpr unsigned R = 7;
#elif 1
// 3x3, 58x58
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 256;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#endif
auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{});
@@ -449,7 +449,7 @@ int main()
device_direct_convolution_2
#elif 0
device_implicit_gemm_convolution_1_nchw_kcsr
#elif 1
#elif 0
device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 1
device_implicit_gemm_convolution_1_chwn_csrk_khwn

View File

@@ -87,8 +87,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 8;
#elif 1
// for 3x3, 34x34 | 3x3 58x58
#elif 0
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
@@ -101,6 +101,21 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 1
// 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 0
// for 5x5, 36x36

View File

@@ -65,7 +65,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
constexpr unsigned WoPerThread = 2;
constexpr unsigned BlockSize = 16;
#elif 1
#elif 0
// for 3x3, 34x34
constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 64;
@@ -73,6 +73,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 2;
@@ -80,16 +81,32 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
constexpr unsigned BlockSize = 128;
#elif 0
// for 3x3, 34x34
constexpr unsigned NPerBlock = 2;
// for 3x3, 58x58
constexpr unsigned NPerBlock = 4;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 8;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 2;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 1
// for 3x3, 56x56
constexpr unsigned NPerBlock = 32;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
@@ -123,6 +140,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
CPerThread,
HoPerThread,