diff --git a/driver/conv.cu b/driver/conv.cu index 0941bf2733..9826230284 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -391,7 +391,7 @@ int main() constexpr unsigned HPad = 0; constexpr unsigned WPad = 0; -#elif 1 +#elif 0 // 3x3, 34x34 constexpr unsigned N = 64; constexpr unsigned C = 256; @@ -587,11 +587,11 @@ int main() device_implicit_gemm_convolution_1_nchw_kcsr #elif 0 device_implicit_gemm_convolution_1_nchw_srck_nkhw -#elif 0 +#elif 1 device_implicit_gemm_convolution_1_chwn_csrk_khwn #elif 0 device_implicit_gemm_convolution_2_cnhw_srck_knhw -#elif 1 +#elif 0 device_implicit_gemm_convolution_2_cnhw_csrk_knhw #endif (in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat); @@ -608,7 +608,7 @@ int main() nrepeat); #endif -#if 1 +#if 0 if(S == 3 && R == 3) { host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); diff --git a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh index 7bf43cf2a7..bf87dc1cf3 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh +++ b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh @@ -87,7 +87,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned WoPerThread = 1; constexpr unsigned BlockSize = 8; -#elif 1 +#elif 0 // for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256 constexpr unsigned NPerBlock = 16; constexpr unsigned KPerBlock = 64; @@ -101,6 +101,12 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned HoPerThread = 1; constexpr unsigned WoPerThread = 1; + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + + constexpr unsigned InBlockCopyDataPerRead = 2; // not used, yet + constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr unsigned BlockSize = 128; #elif 0 // 3x3 58x58, NKC = 16,256,128 @@ -162,7 +168,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned WoPerThread = 1; constexpr unsigned BlockSize = 128; -#elif 0 +#elif 1 // for 1x1, 28x28 constexpr unsigned NPerBlock = 16; constexpr unsigned KPerBlock = 128; @@ -176,6 +182,12 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned HoPerThread = 1; constexpr unsigned WoPerThread = 1; + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + + constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet + constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr unsigned BlockSize = 128; #endif @@ -211,7 +223,11 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, KPerThread, CPerThread, HoPerThread, - WoPerThread> + WoPerThread, + WeiBlockCopyThreadPerDim0, + WeiBlockCopyThreadPerDim1, + InBlockCopyDataPerRead, + WeiBlockCopyDataPerRead> <<>>(in_chwn_desc, static_cast(in_chwn_device_buf.GetDeviceBuffer()), wei_csrk_desc, diff --git a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh index d1699d1fbb..f4cc40c71c 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh +++ b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh @@ -108,6 +108,9 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc, constexpr unsigned HoPerThread = 1; constexpr unsigned WoPerThread = 1; + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + constexpr unsigned BlockSize = 128; #elif 0 // 3x3 58x58, NKC = 16,256,128 diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh index c25b98a801..30a8dda2a2 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh @@ -3,6 +3,7 @@ #include "ConstantTensorDescriptor.cuh" #include "ConstantMatrixDescriptor.cuh" #include "blockwise_4d_tensor_op.cuh" +#include "blockwise_2d_tensor_op.cuh" #include "threadwise_4d_tensor_op.cuh" #include "blockwise_gemm.cuh" @@ -21,7 +22,11 @@ template + unsigned WoPerThread, + unsigned WeiBlockCopyThreadPerDim0, + unsigned WeiBlockCopyThreadPerDim1, + unsigned InBlockCopyDataPerRead, + unsigned WeiBlockCopyDataPerRead> __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc, Float* const __restrict__ p_in_global, @@ -80,12 +85,19 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc, const unsigned hi_block_data_begin = ho_block_data_begin; const unsigned wi_block_data_begin = wo_block_data_begin; + // flattend (2d) tensor view of gridwise weight + constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); + // tensor view of blockwise input and weight in LDS + // be careful of alignment constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor(Sequence{}); - constexpr auto wei_csrk_block_desc = - make_ConstantTensorDescriptor(Sequence{}); + constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); // tensor view of threadwise output in register constexpr auto out_hkwn_thread_desc = @@ -112,13 +124,31 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc, decltype(in_chwn_block_desc), decltype(in_chwn_block_desc.GetLengths())>{}; - // weight: format is [S,R,C,K] - constexpr auto blockwise_wei_copy = - Blockwise4dTensorCopy1{}; + decltype(wei_ek_global_desc), + decltype(wei_ek_block_desc), + decltype(wei_ek_block_desc.GetLengths())>{}; +#elif 0 + const auto blockwise_wei_copy = Blockwise2dTensorCopy2{}; +#elif 1 + const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}; +#endif // a series of blockwise batched GEMM // C_matrix += transpose(A_matrix) * B_matrix @@ -155,12 +185,17 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc, CPerThread, true>{}; - // LDS - constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); - constexpr unsigned wei_block_size = wei_csrk_block_desc.GetElementSpace(); + // LDS: be careful of alignment + constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); + constexpr unsigned wei_block_size = + wei_csrk_block_desc.GetElementSpace(Number{}); - __shared__ Float p_in_block[in_block_size]; - __shared__ Float p_wei_block[wei_block_size]; + constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead + ? InBlockCopyDataPerRead + : WeiBlockCopyDataPerRead; + + __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; + __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; // register Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()];