diff --git a/driver/conv.cu b/driver/conv.cu index 82d711e447..25c42ec611 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -361,7 +361,7 @@ int main() constexpr unsigned K = 1; constexpr unsigned S = 3; constexpr unsigned R = 3; -#elif 1 +#elif 0 // 3x3, 34x34 constexpr unsigned N = 64; constexpr unsigned C = 256; @@ -370,15 +370,6 @@ int main() constexpr unsigned K = 64; constexpr unsigned S = 3; constexpr unsigned R = 3; -#elif 0 - // 3x3, 54x54 - constexpr unsigned N = 64; - constexpr unsigned C = 64; - constexpr unsigned HI = 54; - constexpr unsigned WI = 54; - constexpr unsigned K = 64; - constexpr unsigned S = 3; - constexpr unsigned R = 3; #elif 0 // 3x3, 56x56 constexpr unsigned N = 64; @@ -415,6 +406,15 @@ int main() constexpr unsigned K = 64; constexpr unsigned S = 7; constexpr unsigned R = 7; +#elif 1 + // 3x3, 58x58 + constexpr unsigned N = 16; + constexpr unsigned C = 128; + constexpr unsigned HI = 58; + constexpr unsigned WI = 58; + constexpr unsigned K = 256; + constexpr unsigned S = 3; + constexpr unsigned R = 3; #endif auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence{}); @@ -449,7 +449,7 @@ int main() device_direct_convolution_2 #elif 0 device_implicit_gemm_convolution_1_nchw_kcsr -#elif 1 +#elif 0 device_implicit_gemm_convolution_1_nchw_srck_nkhw #elif 1 device_implicit_gemm_convolution_1_chwn_csrk_khwn diff --git a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh index 0dd5309255..c8996f1ad7 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh +++ b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh @@ -87,8 +87,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned WoPerThread = 1; constexpr unsigned BlockSize = 8; -#elif 1 - // for 3x3, 34x34 | 3x3 58x58 +#elif 0 + // for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256 constexpr unsigned NPerBlock = 16; constexpr unsigned KPerBlock = 64; constexpr unsigned CPerBlock = 4; @@ -101,6 +101,21 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned HoPerThread = 1; constexpr unsigned WoPerThread = 1; + constexpr unsigned BlockSize = 128; +#elif 1 + // 3x3 58x58, NKC = 16,256,128 + constexpr unsigned NPerBlock = 8; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 2; + constexpr unsigned HoPerBlock = 4; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + constexpr unsigned BlockSize = 128; #elif 0 // for 5x5, 36x36 diff --git a/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh b/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh index f948f5205e..f6f9ccdbc1 100644 --- a/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh +++ b/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh @@ -65,7 +65,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, constexpr unsigned WoPerThread = 2; constexpr unsigned BlockSize = 16; -#elif 1 +#elif 0 // for 3x3, 34x34 constexpr unsigned NPerBlock = 1; constexpr unsigned KPerBlock = 64; @@ -73,6 +73,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, constexpr unsigned HoPerBlock = 4; constexpr unsigned WoPerBlock = 32; + constexpr unsigned NPerThread = 1; constexpr unsigned KPerThread = 16; constexpr unsigned CPerThread = 1; constexpr unsigned HoPerThread = 2; @@ -80,16 +81,32 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, constexpr unsigned BlockSize = 128; #elif 0 - // for 3x3, 34x34 - constexpr unsigned NPerBlock = 2; + // for 3x3, 58x58 + constexpr unsigned NPerBlock = 4; constexpr unsigned KPerBlock = 64; constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 32; + constexpr unsigned HoPerBlock = 4; + constexpr unsigned WoPerBlock = 8; + constexpr unsigned NPerThread = 4; constexpr unsigned KPerThread = 16; constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 2; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 128; +#elif 1 + // for 3x3, 56x56 + constexpr unsigned NPerBlock = 32; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 4; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 2; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; constexpr unsigned WoPerThread = 1; constexpr unsigned BlockSize = 128; @@ -123,6 +140,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, CPerBlock, HoPerBlock, WoPerBlock, + NPerThread, KPerThread, CPerThread, HoPerThread, diff --git a/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh b/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh index e0a416ebf2..a5aa0feaa4 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh @@ -17,6 +17,7 @@ template {}; constexpr auto I1 = Number<1>{}; @@ -207,7 +210,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; const unsigned k_thread_data_begin = matrix_c_index.row_begin; - const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread; + const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock; + const unsigned n_thread_data_begin = + matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock; // output: register to global mem, // convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo] @@ -217,7 +222,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, out_hkwn_thread_desc, p_out_thread, out_nkhw_global_desc, - p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin, + p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, k_block_data_begin + k_thread_data_begin, ho_block_data_begin + ho_thread_data_begin, wo_block_data_begin + wo_thread_data_begin),