diff --git a/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp b/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp index fb36afa4db..613b55a81e 100644 --- a/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp +++ b/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp @@ -3,8 +3,10 @@ #include "device.hpp" #include "gridwise_convolution_wrapper.hip.hpp" #include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp" -#include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp" +#include "gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp" #include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp" +#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp" +#include "gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp" template void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, @@ -94,9 +96,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t InBlockCopy_ThreadPerDimH = 4; constexpr index_t InBlockCopy_ThreadPerDimW = 2; constexpr index_t InBlockCopy_ThreadPerDimN = 4; - constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead_N = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4; @@ -108,10 +110,10 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadB = 4; - constexpr index_t OutThreadCopyDataPerWrite = 2; + constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t BlockSize = 128; -#elif 1 +#elif 0 // for 3x3, 34x34, v1r2, Pascal, in-block-copy1 constexpr index_t NPerBlock = 4; constexpr index_t KPerBlock = 64; @@ -128,9 +130,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t InBlockCopy_ThreadPerDimH = 4; constexpr index_t InBlockCopy_ThreadPerDimW = 2; constexpr index_t InBlockCopy_ThreadPerDimN = 1; - constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead_N = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4; @@ -142,7 +144,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadB = 4; - constexpr index_t OutThreadCopyDataPerWrite = 2; + constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t BlockSize = 128; #elif 0 @@ -172,14 +174,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t InBlockCopy_ThreadPerDimH = 4; constexpr index_t InBlockCopy_ThreadPerDimW = 2; constexpr index_t InBlockCopy_ThreadPerDimN = 8; - constexpr index_t InBlockCopyDataPerRead = 2; + constexpr index_t InBlockCopyDataPerRead_N = 2; - constexpr index_t WeiBlockCopyDataPerRead = 2; - constexpr index_t OutThreadCopyDataPerWrite = 4; + constexpr index_t WeiBlockCopyDataPerRead_K = 2; + constexpr index_t OutThreadCopyDataPerWrite_N = 4; constexpr index_t BlockSize = 256; #elif 0 - // for 3x3, 56x56, v1, Pascal + // for 3x3, 56x56, v1r1, Pascal constexpr index_t NPerBlock = 32; constexpr index_t KPerBlock = 64; constexpr index_t CPerBlock = 4; @@ -195,9 +197,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t InBlockCopy_ThreadPerDimH = 4; constexpr index_t InBlockCopy_ThreadPerDimW = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 8; - constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead_N = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4; @@ -207,7 +209,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t OutThreadCopyDataPerWrite = 2; + constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t BlockSize = 128; #elif 0 @@ -237,13 +239,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t InBlockCopy_ThreadPerDimH = 2; constexpr index_t InBlockCopy_ThreadPerDimW = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 4; - constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead_N = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr index_t OutThreadCopyDataPerWrite = 4; + constexpr index_t WeiBlockCopyDataPerRead_K = 4; + constexpr index_t OutThreadCopyDataPerWrite_N = 4; constexpr index_t BlockSize = 128; -#elif 1 +#elif 0 // for 3x3, 28x28, v1r1, Pacal constexpr index_t NPerBlock = 32; constexpr index_t KPerBlock = 64; @@ -260,9 +262,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t InBlockCopy_ThreadPerDimH = 4; constexpr index_t InBlockCopy_ThreadPerDimW = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 8; - constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead_N = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4; @@ -274,11 +276,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadB = 4; - constexpr index_t OutThreadCopyDataPerWrite = 2; + constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t BlockSize = 128; -#elif 1 +#elif 0 // for 3x3, 28x28, v1r2, Pascal + constexpr index_t BlockSize = 128; + constexpr index_t NPerBlock = 16; constexpr index_t KPerBlock = 128; constexpr index_t CPerBlock = 8; @@ -290,13 +294,37 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t HoPerThread = 1; constexpr index_t WoPerThread = 2; - constexpr index_t InBlockCopy_ThreadPerDimC = 4; - constexpr index_t InBlockCopy_ThreadPerDimH = 2; - constexpr index_t InBlockCopy_ThreadPerDimW = 4; - constexpr index_t InBlockCopy_ThreadPerDimN = 4; - constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; + using InBlockCopyClusterLengths_CHWN = Sequence<4, 2, 4, 4>; + constexpr index_t InBlockCopyDataPerRead_N = 4; + + constexpr index_t WeiBlockCopyDataPerRead_K = 4; + + constexpr index_t OutThreadCopyDataPerWrite_N = 2; +#elif 1 + // for 3x3, 28x28, v1r3, Pascal + // for 3x3, 14x14, v1r3, Pascal + constexpr index_t BlockSize = 128; + + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4; @@ -308,11 +336,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadB = 4; - constexpr index_t OutThreadCopyDataPerWrite = 2; + using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>; + constexpr index_t InBlockCopyDataPerRead_N = 4; - constexpr index_t BlockSize = 128; + constexpr index_t WeiBlockCopyDataPerRead_K = 4; + + constexpr index_t OutThreadCopyDataPerWrite_N = 2; #elif 0 - // for 1x1, 28x28 + // for 1x1, 28x28, v1r1, Pascal constexpr index_t NPerBlock = 16; constexpr index_t KPerBlock = 128; constexpr index_t CPerBlock = 8; @@ -329,9 +360,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t InBlockCopy_ThreadPerDimH = 2; constexpr index_t InBlockCopy_ThreadPerDimW = 2; constexpr index_t InBlockCopy_ThreadPerDimN = 4; - constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead_N = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4; @@ -341,11 +372,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t OutThreadCopyDataPerWrite = 2; + constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t BlockSize = 128; -#elif 1 - // for 1x1, 14x14, Pascal +#elif 0 + // for 1x1, 14x14, v1r1, Pascal constexpr index_t NPerBlock = 16; constexpr index_t KPerBlock = 128; constexpr index_t CPerBlock = 8; @@ -369,10 +400,10 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t InBlockCopy_ThreadPerDimH = 2; constexpr index_t InBlockCopy_ThreadPerDimW = 2; constexpr index_t InBlockCopy_ThreadPerDimN = 4; - constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead_N = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr index_t OutThreadCopyDataPerWrite = 2; + constexpr index_t WeiBlockCopyDataPerRead_K = 4; + constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t BlockSize = 128; #endif @@ -386,12 +417,16 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, for(index_t i = 0; i < nrepeat; ++i) { constexpr auto gridwise_conv = -#if 1 +#if 0 GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn #elif 0 - GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer -#elif 1 + GridwiseConvolutionImplicitGemm_v1r1_lds_double_buffer_chwn_cyxk_khwn +#elif 0 GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn +#elif 0 + GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn +#elif 1 + GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn #endif , - InBlockCopyDataPerRead, - WeiBlockCopyDataPerRead, - OutThreadCopyDataPerWrite>{}; + InBlockCopyClusterLengths_CHWN, + InBlockCopyDataPerRead_N, + WeiBlockCopyDataPerRead_K, + OutThreadCopyDataPerWrite_N>{}; float time = launch_kernel(run_gridwise_convolution, dim3(GridSize), diff --git a/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp b/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp index 2447ec0c13..6dcd7eabbb 100644 --- a/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp +++ b/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp @@ -87,13 +87,13 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadB = 4; - using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 2>; - using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>; - using InBlockReorderMapThreadCluster2SrcCluster = Sequence<1, 2, 3, 0>; - constexpr index_t InBlockReorderDataPerRead_W = 2; - constexpr index_t InBlockReorderDataPerWrite_N = 4; + using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 2>; + using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>; + using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; + constexpr index_t InBlockReorderDataPerRead_W = 2; + constexpr index_t InBlockReorderDataPerWrite_N = 4; - using WeiBlockCopyClusterLengths = Sequence<4, 1, 32>; + using WeiBlockCopyClusterLengths_CXK = Sequence<4, 1, 32>; constexpr index_t WeiBlockCopyDataPerRead_C = 4; constexpr index_t OutThreadCopyDataPerWrite_N = 2; @@ -137,10 +137,10 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, GemmDataPerReadB, InBlockReorderSrcSubLengths_NCHW, InBlockReorderSrcClusterLengths_NCHW, - InBlockReorderMapThreadCluster2SrcCluster, + InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW, InBlockReorderDataPerRead_W, InBlockReorderDataPerWrite_N, - WeiBlockCopyClusterLengths, + WeiBlockCopyClusterLengths_CXK, WeiBlockCopyDataPerRead_C, OutThreadCopyDataPerWrite_N>{}; diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index 5937190e66..aa0c610e6d 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -451,60 +451,6 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 0 - // 3x3, 58x58 - constexpr index_t N = 64; - constexpr index_t C = 64; - constexpr index_t HI = 58; - constexpr index_t WI = 58; - constexpr index_t K = 64; - constexpr index_t Y = 3; - constexpr index_t X = 3; -#elif 0 - // 3x3, 58x58 - constexpr index_t N = 16; - constexpr index_t C = 128; - constexpr index_t HI = 58; - constexpr index_t WI = 58; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; -#elif 0 - // 3x3 filter, 58x58 image, 0x0 padding - constexpr index_t N = 16; - constexpr index_t C = 128; - constexpr index_t HI = 58; - constexpr index_t WI = 58; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; -#elif 0 - // 3x3 filter, 56x56 image, 1x1 padding - constexpr index_t N = 16; - constexpr index_t C = 128; - constexpr index_t HI = 56; - constexpr index_t WI = 56; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - constexpr index_t HPad = 1; - constexpr index_t WPad = 1; -#elif 0 - // 3x3 filter, 28x28 image, 1x1 padding - constexpr index_t N = 16; - constexpr index_t C = 256; - constexpr index_t HI = 28; - constexpr index_t WI = 28; - constexpr index_t K = 512; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - constexpr index_t HPad = 1; - constexpr index_t WPad = 1; #elif 1 // 3x3 filter, 28x28 image constexpr index_t N = 128; @@ -578,31 +524,19 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 2; constexpr index_t WPad = 2; #elif 0 - // 1x1 filter, 32x32 image - constexpr index_t N = 64; - constexpr index_t C = 256; - constexpr index_t HI = 32; - constexpr index_t WI = 32; - constexpr index_t K = 512; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; -#elif 1 - // 1x1 filter, 14x14 image, C = 2048 + // 3x3 filter, 14x14 image constexpr index_t N = 128; - constexpr index_t C = 2048; + constexpr index_t C = 256; constexpr index_t HI = 14; constexpr index_t WI = 14; - constexpr index_t K = 512; - constexpr index_t Y = 1; - constexpr index_t X = 1; + constexpr index_t K = 128; + constexpr index_t Y = 3; + constexpr index_t X = 3; constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 1 - // 1x1 filter, 14x14 image, C = 512 +#elif 0 + // 1x1 filter, 14x14 image constexpr index_t N = 128; constexpr index_t C = 512; constexpr index_t HI = 14; @@ -673,9 +607,9 @@ int main(int argc, char* argv[]) device_direct_convolution_2_nchw_kcyx_nkhw #elif 0 device_direct_convolution_2_vectorized_nchw_kcyx_nkhw -#elif 0 - device_convolution_implicit_gemm_v1_chwn_cyxk_khwn #elif 1 + device_convolution_implicit_gemm_v1_chwn_cyxk_khwn +#elif 0 device_convolution_implicit_gemm_v1_nchw_cyxk_khwn #elif 0 device_convolution_implicit_gemm_v2_chwn_cyxk_khwn diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp similarity index 99% rename from src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp rename to src/include/gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp index 3a024bbaaa..34cb38822e 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp @@ -36,7 +36,7 @@ template -struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer +struct GridwiseConvolutionImplicitGemm_v1r1_lds_double_buffer_chwn_cyxk_khwn { __device__ void Run(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global, diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp index 0c7a455fc9..74c0e5b4b6 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp @@ -213,6 +213,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn // set threadwise output tensor to 0 threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); +#if 1 const Float* p_in_global_block_offset = p_in_global + in_c_h_w_n_global_desc.Get1dIndex( 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); @@ -247,6 +248,43 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn __syncthreads(); } } +#else + // this use much more register, haven't figure out why? + for(index_t y = 0; y < Y; ++y) + { + const Float* p_in_global_block_offset = + p_in_global + + in_c_h_w_n_global_desc.Get1dIndex( + 0, hi_block_data_begin + y, wi_block_data_begin, n_block_data_begin); + + const Float* p_wei_global_block_offset = + p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, k_block_data_begin); + + for(index_t c_block_data_begin = 0; c_block_data_begin < C; + c_block_data_begin += CPerBlock, + p_in_global_block_offset += + CPerBlock * in_c_h_w_n_global_desc.GetStride(I0), + p_wei_global_block_offset += + CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) + { + blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); + + blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); + + __syncthreads(); + + for(index_t x = 0; x < X; ++x) + { + blockwise_batch_gemm.Run(p_wei_block + wei_c_x_k_block_desc.Get1dIndex(0, x, 0), + p_in_block + + in_c_h_w_n_block_desc.Get1dIndex(0, 0, x, 0), + p_out_thread); + } + + __syncthreads(); + } + } +#endif // output: register to global mem, const auto c_thread_mtx_begin = diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp index 550246e38d..e7d8dee565 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp @@ -35,10 +35,10 @@ template struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn @@ -122,7 +122,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn // blockwise copy // input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N] - auto map_chwn2nchw = Sequence<1, 2, 3, 0>{}; + constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{}; + const auto blockwise_in_copy_reorder = Blockwise4dTensorCopyReorder3{}; @@ -144,7 +145,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn decltype(wei_c_x_k_global_desc), decltype(wei_c_x_k_block_desc), decltype(wei_c_x_k_block_desc.GetLengths()), - Sequence<4, 1, 32>, + WeiBlockCopyClusterLengths_CXK, WeiBlockCopyDataPerRead_C>{}; // a series of blockwise batched GEMM diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp new file mode 100644 index 0000000000..7ce9039a20 --- /dev/null +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp @@ -0,0 +1,331 @@ +#pragma once +#include "common.hip.hpp" +#include "ConstantTensorDescriptor.hip.hpp" +#include "ConstantMatrixDescriptor.hip.hpp" +#include "blockwise_2d_tensor_op.hip.hpp" +#include "blockwise_4d_tensor_op.hip.hpp" +#include "threadwise_nd_tensor_op.hip.hpp" +#include "threadwise_4d_tensor_op.hip.hpp" +#include "blockwise_batched_gemm.hip.hpp" + +template +struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn +{ + __device__ void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) const + { + // be careful of this assertion + static_assert( + NPerThread <= NPerBlock && NPerBlock % NPerThread == 0, + "wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{}; + constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; + constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{}; + + constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0); + + constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0); + constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1); + constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2); + constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3); + + constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); + constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); + + constexpr index_t HiPerBlock = HoPerBlock + Y - 1; + constexpr index_t WiPerBlock = WoPerBlock + X - 1; + + // divide block work: [K, Ho, Wo, N] + static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && + Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, + "wrong! cannot evenly divide work for workgroup "); + + constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; + constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; + constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; + + const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); + index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); + const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork); + itmp -= h_block_work_id * (WBlockWork * NBlockWork); + const index_t w_block_work_id = itmp / NBlockWork; + const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; + + const index_t k_block_data_begin = k_block_work_id * KPerBlock; + const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; + const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; + const index_t n_block_data_begin = n_block_work_id * NPerBlock; + + const index_t hi_block_data_begin = ho_block_data_begin; + const index_t wi_block_data_begin = wo_block_data_begin; + + // global tensor view + constexpr auto wei_c_k_global_desc = + make_ConstantTensorDescriptor(Sequence{}, Sequence{}); + + // LDS tensor view + // be careful of alignment + constexpr index_t max_align = mod_conv::max(InBlockCopyDataPerRead_N, + WeiBlockCopyDataPerRead_K, + GemmDataPerReadA, + GemmDataPerReadB); + + constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + // tensor view of threadwise output in register + constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + + // blockwise copy + // input: format is [C, Hi, Wi, N] +#if 0 + const auto blockwise_in_copy = + Blockwise4dTensorCopy1{}; +#else + const auto blockwise_in_copy = + Blockwise4dTensorCopy3{}; +#endif + + // blockwise wei copy + // format is [CPerBlock, X * KPerBlock] + const auto blockwise_wei_copy = + Blockwise2dTensorCopy1{}; + + // a series of blockwise batched GEMM + // C_matrix += transpose(A_matrix) * B_matrix + // A_matrix and B_matrix saved in LDS, C_matrix saved in register + // A_matrix[C,K] is a sub-matrix of wei_block[C,K] + // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] + // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] + constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_c_wn_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + constexpr auto c_k_wn_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + const auto blockwise_batch_gemm = + BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< + BlockSize, + decltype(a_c_k_block_mtx_desc), + decltype(b_c_wn_block_mtx_desc), + decltype(c_k_wn_thread_mtx_desc), + 0, + in_c_h_w_n_block_desc.GetStride(I1), + out_k_h_w_n_thread_desc.GetStride(I1), + HoPerBlock, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + HoPerThread, + GemmDataPerReadA, + GemmDataPerReadB>{}; + + // LDS: be careful of alignment + constexpr index_t in_block_space = + in_c_h_w_n_block_desc.GetElementSpace(Number{}); + constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number{}); + + __shared__ Float p_in_block[in_block_space]; + __shared__ Float p_wei_block[wei_block_space]; + + // register + Float p_out_thread[out_k_h_w_n_thread_desc.GetElementSpace()]; + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc"); + print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc"); + + print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc"); + print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc"); + + printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space); + } +#endif + + // set threadwise output tensor to 0 + threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); + +#if 1 + const Float* p_in_global_block_offset = + p_in_global + in_c_h_w_n_global_desc.Get1dIndex( + 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); + + const Float* p_wei_global_block_offset = + p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); + + for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, + p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0), + p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) + { + for(index_t y = 0; y < Y; ++y) + { + for(index_t x = 0; x < X; ++x) + { + blockwise_in_copy.Run(p_in_global_block_offset + + in_c_h_w_n_global_desc.Get1dIndex(0, y, x, 0), + p_in_block); + + blockwise_wei_copy.Run(p_wei_global_block_offset + + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, 0), + p_wei_block); + + __syncthreads(); + + blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); + + __syncthreads(); + } + } + } +#else + for(index_t y = 0; y < Y; ++y) + { + for(index_t x = 0; x < X; ++x) + { + const Float* p_in_global_block_offset = + p_in_global + + in_c_h_w_n_global_desc.Get1dIndex( + 0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin); + + const Float* p_wei_global_block_offset = + p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin); + + for(index_t c_block_data_begin = 0; c_block_data_begin < C; + c_block_data_begin += CPerBlock, + p_in_global_block_offset += + CPerBlock * in_c_h_w_n_global_desc.GetStride(I0), + p_wei_global_block_offset += + CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) + { + blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); + + blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); + + __syncthreads(); + + blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); + + __syncthreads(); + } + } + } +#endif + + // output: register to global mem, + const auto c_thread_mtx_begin = + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const index_t k_thread_data_begin = c_thread_mtx_begin.row; + const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; + const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; + const index_t n_thread_data_begin = + c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; + + // output is a 10d tensor + constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N1 = NPerBlock / N2; + + constexpr index_t W2 = + (GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC); + constexpr index_t W1 = WoPerBlock / W2; + + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( + Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + + print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + } +#endif + + threadwise_10d_tensor_copy(out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + out_k_h_w_n_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); + } +}; diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp new file mode 100644 index 0000000000..a2c7dba241 --- /dev/null +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp @@ -0,0 +1,297 @@ +#pragma once +#include "common.hip.hpp" +#include "ConstantTensorDescriptor.hip.hpp" +#include "ConstantMatrixDescriptor.hip.hpp" +#include "blockwise_2d_tensor_op.hip.hpp" +#include "blockwise_4d_tensor_op.hip.hpp" +#include "threadwise_nd_tensor_op.hip.hpp" +#include "threadwise_4d_tensor_op.hip.hpp" +#include "blockwise_batched_gemm.hip.hpp" + +template +struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn +{ + __device__ void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) const + { + // be careful of this assertion + static_assert( + NPerThread <= NPerBlock && NPerBlock % NPerThread == 0, + "wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{}; + constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; + constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{}; + + constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0); + + constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0); + constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1); + constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2); + constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3); + + constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); + constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); + + constexpr index_t HiPerBlock = HoPerBlock + Y - 1; + constexpr index_t WiPerBlock = WoPerBlock + X - 1; + + // divide block work: [K, Ho, Wo, N] + static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && + Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, + "wrong! cannot evenly divide work for workgroup "); + + constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; + constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; + constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; + + const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); + index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); + const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork); + itmp -= h_block_work_id * (WBlockWork * NBlockWork); + const index_t w_block_work_id = itmp / NBlockWork; + const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; + + const index_t k_block_data_begin = k_block_work_id * KPerBlock; + const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; + const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; + const index_t n_block_data_begin = n_block_work_id * NPerBlock; + + const index_t hi_block_data_begin = ho_block_data_begin; + const index_t wi_block_data_begin = wo_block_data_begin; + + // global tensor view + constexpr auto wei_c_k_global_desc = + make_ConstantTensorDescriptor(Sequence{}, Sequence{}); + + // LDS tensor view + // be careful of alignment + constexpr index_t max_align = mod_conv::max(InBlockCopyDataPerRead_N, + WeiBlockCopyDataPerRead_K, + GemmDataPerReadA, + GemmDataPerReadB); + + constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + // tensor view of threadwise output in register + constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + + // blockwise copy + // input: format is [C, Hi, Wi, N] +#if 0 + const auto blockwise_in_copy = + Blockwise4dTensorCopy1{}; +#else + const auto blockwise_in_copy = + Blockwise4dTensorCopy3{}; +#endif + + // blockwise wei copy + // format is [CPerBlock, X * KPerBlock] + const auto blockwise_wei_copy = + Blockwise2dTensorCopy1{}; + + // a series of blockwise batched GEMM + // C_matrix += transpose(A_matrix) * B_matrix + // A_matrix and B_matrix saved in LDS, C_matrix saved in register + // A_matrix[C,K] is a sub-matrix of wei_block[C,K] + // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] + // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] + constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_c_wn_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + constexpr auto c_k_wn_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + const auto blockwise_batch_gemm = + BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< + BlockSize, + decltype(a_c_k_block_mtx_desc), + decltype(b_c_wn_block_mtx_desc), + decltype(c_k_wn_thread_mtx_desc), + 0, + in_c_h_w_n_block_desc.GetStride(I1), + out_k_h_w_n_thread_desc.GetStride(I1), + HoPerBlock, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + HoPerThread, + GemmDataPerReadA, + GemmDataPerReadB>{}; + + // LDS: be careful of alignment + constexpr index_t in_block_space = + in_c_h_w_n_block_desc.GetElementSpace(Number{}); + constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number{}); + + __shared__ Float p_in_block[in_block_space]; + __shared__ Float p_wei_block[wei_block_space]; + + // register + Float p_out_thread[out_k_h_w_n_thread_desc.GetElementSpace()]; + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc"); + print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc"); + + print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc"); + print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc"); + + printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space); + } +#endif + + // set threadwise output tensor to 0 + threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); + + for(index_t y = 0; y < Y; ++y) + { + for(index_t x = 0; x < X; ++x) + { + const Float* p_in_global_block_offset = + p_in_global + + in_c_h_w_n_global_desc.Get1dIndex( + 0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin); + + const Float* p_wei_global_block_offset = + p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin); + + for(index_t c_block_data_begin = 0; c_block_data_begin < C; + c_block_data_begin += CPerBlock, + p_in_global_block_offset += + CPerBlock * in_c_h_w_n_global_desc.GetStride(I0), + p_wei_global_block_offset += + CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) + { + blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); + + blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); + + __syncthreads(); + + blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); + + __syncthreads(); + } + } + } + + // output: register to global mem, + const auto c_thread_mtx_begin = + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const index_t k_thread_data_begin = c_thread_mtx_begin.row; + const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; + const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; + const index_t n_thread_data_begin = + c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; + + // output is a 10d tensor + constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N1 = NPerBlock / N2; + + constexpr index_t W2 = + (GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC); + constexpr index_t W1 = WoPerBlock / W2; + + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( + Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + + print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + } +#endif + + threadwise_10d_tensor_copy(out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + out_k_h_w_n_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); + } +};