From 5ce19234a4538d52e18837a84ebe7c1fef224c71 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 19 Apr 2019 14:22:02 -0500 Subject: [PATCH] added GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn --- ...lution_implicit_gemm_v1_chwn_cyxk_khwn.hpp | 2 +- ...lution_implicit_gemm_v1_nchw_cyxk_khwn.hpp | 300 ++---------------- driver/driver.hip.cpp | 2 +- src/include/blockwise_3d_tensor_op.hip.hpp | 2 +- src/include/blockwise_4d_tensor_op.hip.hpp | 23 +- ..._implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp | 134 +++----- src/include/threadwise_4d_tensor_op.hip.hpp | 78 ++++- 7 files changed, 135 insertions(+), 406 deletions(-) diff --git a/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp b/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp index 6f5e29410c..fb36afa4db 100644 --- a/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp +++ b/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp @@ -111,7 +111,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t OutThreadCopyDataPerWrite = 2; constexpr index_t BlockSize = 128; -#elif 0 +#elif 1 // for 3x3, 34x34, v1r2, Pascal, in-block-copy1 constexpr index_t NPerBlock = 4; constexpr index_t KPerBlock = 64; diff --git a/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp b/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp index 6762bb1d2a..2447ec0c13 100644 --- a/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp +++ b/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp @@ -62,208 +62,10 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); out_khwn_device_buf.ToDevice(out_khwn.mData.data()); -#if 0 - // for 3x3, 34x34, v1r1, Pascal - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 64; - constexpr index_t CPerBlock = 4; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 4; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t InBlockCopy_ThreadPerDimC = 4; - constexpr index_t InBlockCopy_ThreadPerDimH = 4; - constexpr index_t InBlockCopy_ThreadPerDimW = 2; - constexpr index_t InBlockCopy_ThreadPerDimN = 4; - constexpr index_t InBlockCopyDataPerRead = 4; - - constexpr index_t WeiBlockCopyDataPerRead = 4; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - constexpr index_t OutThreadCopyDataPerWrite = 2; - - constexpr index_t BlockSize = 128; -#elif 0 - // for 3x3, 34x34, v1r2, Pascal, in-block-copy1 - constexpr index_t NPerBlock = 4; - constexpr index_t KPerBlock = 64; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 4; - constexpr index_t WoPerBlock = 8; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t InBlockCopy_ThreadPerDimC = 4; - constexpr index_t InBlockCopy_ThreadPerDimH = 4; - constexpr index_t InBlockCopy_ThreadPerDimW = 2; - constexpr index_t InBlockCopy_ThreadPerDimN = 1; - constexpr index_t InBlockCopyDataPerRead = 4; - - constexpr index_t WeiBlockCopyDataPerRead = 4; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - constexpr index_t OutThreadCopyDataPerWrite = 2; - - constexpr index_t BlockSize = 128; -#elif 0 - // for 3x3, 34x34, v1r1, Vega 20 - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 4; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 4; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - constexpr index_t InBlockCopy_ThreadPerDimC = 4; - constexpr index_t InBlockCopy_ThreadPerDimH = 4; - constexpr index_t InBlockCopy_ThreadPerDimW = 2; - constexpr index_t InBlockCopy_ThreadPerDimN = 8; - constexpr index_t InBlockCopyDataPerRead = 2; - - constexpr index_t WeiBlockCopyDataPerRead = 2; - constexpr index_t OutThreadCopyDataPerWrite = 4; - - constexpr index_t BlockSize = 256; -#elif 0 - // for 3x3, 56x56, v1, Pascal - constexpr index_t NPerBlock = 32; - constexpr index_t KPerBlock = 64; - constexpr index_t CPerBlock = 4; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t InBlockCopy_ThreadPerDimC = 1; - constexpr index_t InBlockCopy_ThreadPerDimH = 4; - constexpr index_t InBlockCopy_ThreadPerDimW = 4; - constexpr index_t InBlockCopy_ThreadPerDimN = 8; - constexpr index_t InBlockCopyDataPerRead = 4; - - constexpr index_t WeiBlockCopyDataPerRead = 4; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - - constexpr index_t OutThreadCopyDataPerWrite = 2; - - constexpr index_t BlockSize = 128; -#elif 0 - // for 3x3, 56x56, v1r2, Pascal - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 1; - constexpr index_t GemmDataPerReadB = 1; - - constexpr index_t InBlockCopy_ThreadPerDimC = 1; - constexpr index_t InBlockCopy_ThreadPerDimH = 2; - constexpr index_t InBlockCopy_ThreadPerDimW = 4; - constexpr index_t InBlockCopy_ThreadPerDimN = 4; - constexpr index_t InBlockCopyDataPerRead = 4; - - constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr index_t OutThreadCopyDataPerWrite = 4; - - constexpr index_t BlockSize = 128; -#elif 0 - // for 3x3, 28x28, v1r1, Pacal - constexpr index_t NPerBlock = 32; - constexpr index_t KPerBlock = 64; - constexpr index_t CPerBlock = 4; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t InBlockCopy_ThreadPerDimC = 1; - constexpr index_t InBlockCopy_ThreadPerDimH = 4; - constexpr index_t InBlockCopy_ThreadPerDimW = 4; - constexpr index_t InBlockCopy_ThreadPerDimN = 8; - constexpr index_t InBlockCopyDataPerRead = 4; - - constexpr index_t WeiBlockCopyDataPerRead = 4; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - constexpr index_t OutThreadCopyDataPerWrite = 2; - - constexpr index_t BlockSize = 128; -#elif 1 +#if 1 // for 3x3, 28x28, v1r2, Pascal + constexpr index_t BlockSize = 128; + constexpr index_t NPerBlock = 16; constexpr index_t KPerBlock = 128; constexpr index_t CPerBlock = 8; @@ -275,14 +77,6 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, constexpr index_t HoPerThread = 1; constexpr index_t WoPerThread = 2; - constexpr index_t InBlockCopy_ThreadPerDimN = 4; - constexpr index_t InBlockCopy_ThreadPerDimC = 8; - constexpr index_t InBlockCopy_ThreadPerDimH = 2; - constexpr index_t InBlockCopy_ThreadPerDimW = 2; - constexpr index_t InBlockCopyDataPerRead = 2; - - constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmMLevel0Cluster = 4; @@ -293,73 +87,16 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadB = 4; - constexpr index_t OutThreadCopyDataPerWrite = 2; + using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 2>; + using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>; + using InBlockReorderMapThreadCluster2SrcCluster = Sequence<1, 2, 3, 0>; + constexpr index_t InBlockReorderDataPerRead_W = 2; + constexpr index_t InBlockReorderDataPerWrite_N = 4; - constexpr index_t BlockSize = 128; -#elif 0 - // for 1x1, 28x28 - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; + using WeiBlockCopyClusterLengths = Sequence<4, 1, 32>; + constexpr index_t WeiBlockCopyDataPerRead_C = 4; - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 16; - constexpr index_t CPerThread = 1; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 1; - - constexpr index_t InBlockCopy_ThreadPerDimC = 8; - constexpr index_t InBlockCopy_ThreadPerDimH = 2; - constexpr index_t InBlockCopy_ThreadPerDimW = 2; - constexpr index_t InBlockCopy_ThreadPerDimN = 4; - constexpr index_t InBlockCopyDataPerRead = 4; - - constexpr index_t WeiBlockCopyDataPerRead = 4; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - - constexpr index_t OutThreadCopyDataPerWrite = 2; - - constexpr index_t BlockSize = 128; -#elif 1 - // for 1x1, 14x14, Pascal - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 8; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 1; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - - constexpr index_t InBlockCopy_ThreadPerDimC = 8; - constexpr index_t InBlockCopy_ThreadPerDimH = 2; - constexpr index_t InBlockCopy_ThreadPerDimW = 2; - constexpr index_t InBlockCopy_ThreadPerDimN = 4; - constexpr index_t InBlockCopyDataPerRead = 4; - - constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr index_t OutThreadCopyDataPerWrite = 2; - - constexpr index_t BlockSize = 128; + constexpr index_t OutThreadCopyDataPerWrite_N = 2; #endif constexpr index_t GridSize = @@ -398,13 +135,14 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, GemmKPerThreadLoop, GemmDataPerReadA, GemmDataPerReadB, - Sequence, - InBlockCopyDataPerRead, - WeiBlockCopyDataPerRead, - OutThreadCopyDataPerWrite>{}; + InBlockReorderSrcSubLengths_NCHW, + InBlockReorderSrcClusterLengths_NCHW, + InBlockReorderMapThreadCluster2SrcCluster, + InBlockReorderDataPerRead_W, + InBlockReorderDataPerWrite_N, + WeiBlockCopyClusterLengths, + WeiBlockCopyDataPerRead_C, + OutThreadCopyDataPerWrite_N>{}; float time = launch_kernel(run_gridwise_convolution, dim3(GridSize), diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index 89106fe85f..5937190e66 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -673,7 +673,7 @@ int main(int argc, char* argv[]) device_direct_convolution_2_nchw_kcyx_nkhw #elif 0 device_direct_convolution_2_vectorized_nchw_kcyx_nkhw -#elif 1 +#elif 0 device_convolution_implicit_gemm_v1_chwn_cyxk_khwn #elif 1 device_convolution_implicit_gemm_v1_nchw_cyxk_khwn diff --git a/src/include/blockwise_3d_tensor_op.hip.hpp b/src/include/blockwise_3d_tensor_op.hip.hpp index 6a88757075..9e73be106d 100644 --- a/src/include/blockwise_3d_tensor_op.hip.hpp +++ b/src/include/blockwise_3d_tensor_op.hip.hpp @@ -231,7 +231,7 @@ struct Blockwise3dTensorCopy3 } } - __device__ constexpr index_t GetRegisterClipboardSize() const + __device__ static constexpr index_t GetRegisterClipboardSize() { static_assert(is_same::value, "wrong! only support float!\n"); diff --git a/src/include/blockwise_4d_tensor_op.hip.hpp b/src/include/blockwise_4d_tensor_op.hip.hpp index eea37a2b2e..e301631e46 100644 --- a/src/include/blockwise_4d_tensor_op.hip.hpp +++ b/src/include/blockwise_4d_tensor_op.hip.hpp @@ -761,7 +761,6 @@ struct Blockwise4dTensorCopyReorder1 } }; -#if 1 template + class InBlockReorderSrcSubLengths_NCHW, + class InBlockReorderSrcClusterLengths_NCHW, + class InBlockReorderMapThreadCluster2SrcCluster, + index_t InBlockReorderDataPerRead_W, + index_t InBlockReorderDataPerWrite_N, + class WeiBlockCopyClusterLengths_KXC, + index_t WeiBlockCopyDataPerRead_C, + index_t OutThreadCopyDataPerWrite_N> struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn { __device__ void Run(const Float* const __restrict__ p_in_global, @@ -101,8 +105,10 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn // LDS tensor view // be careful of alignment - constexpr index_t max_align = mod_conv::max( - InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB); + constexpr index_t max_align = mod_conv::max(InBlockReorderDataPerWrite_N, + WeiBlockCopyDataPerRead_C, + GemmDataPerReadA, + GemmDataPerReadB); constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( Sequence{}, Number{}); @@ -117,68 +123,38 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn // blockwise copy // input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N] auto map_chwn2nchw = Sequence<1, 2, 3, 0>{}; -#if 0 - const auto blockwise_in_copy_reorder = - Blockwise4dTensorCopyReorder1, - decltype(map_chwn2nchw)>{}; -#else - auto map_thread_cluster_2_src_cluster = Sequence<1, 2, 0, 3>{}; - const auto blockwise_in_copy_reorder = Blockwise4dTensorCopyReorder3, - Sequence<4, 1, 1, 2>, - Sequence<4, 8, 2, 2>, + InBlockReorderSrcSubLengths_NCHW, + InBlockReorderSrcClusterLengths_NCHW, decltype(map_chwn2nchw), - decltype(map_thread_cluster_2_src_cluster), - 2, - 4>{}; - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - printf("size %u\n", blockwise_in_copy_reorder.GetRegisterClipboardSize()); - } -#endif -#endif + InBlockReorderMapThreadCluster2SrcCluster, + InBlockReorderDataPerRead_W, + InBlockReorderDataPerWrite_N>{}; // blockwise wei copy // format is [CPerBlock, X * KPerBlock] const auto blockwise_wei_copy = -#if 0 - Blockwise3dTensorCopy1{}; -#else Blockwise3dTensorCopy3, - WeiBlockCopyDataPerRead>{}; -#endif + WeiBlockCopyDataPerRead_C>{}; - // a series of blockwise batched GEMM - // C_matrix += transpose(A_matrix) * B_matrix - // A_matrix and B_matrix saved in LDS, C_matrix saved in register - // A_matrix[C,K] is a sub-matrix of wei_block[C,K] - // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] - // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] - constexpr auto a_c_k_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); + // a series of blockwise batched GEMM + // C_matrix += transpose(A_matrix) * B_matrix + // A_matrix and B_matrix saved in LDS, C_matrix saved in register + // A_matrix[C,K] is a sub-matrix of wei_block[C,K] + // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] + // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] + constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); constexpr auto b_c_wn_block_mtx_desc = make_ConstantMatrixDescriptor(Number{}, @@ -252,6 +228,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn { for(index_t y = 0; y < Y; ++y) { +#if 1 blockwise_in_copy_reorder.Run(p_in_global_block_offset + in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, 0), p_in_block); @@ -259,6 +236,23 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn blockwise_wei_copy.Run(p_wei_global_block_offset + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, 0), p_wei_block); +#else + Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()]; + Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + blockwise_in_copy_reorder.RunLoadRegisterClipboard( + p_in_global_block_offset + in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, 0), + p_in_clipboard); + + blockwise_wei_copy.RunLoadRegisterClipboard( + p_wei_global_block_offset + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, 0), + p_wei_clipboard); + + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block); + + blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block); + +#endif __syncthreads(); @@ -274,42 +268,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn } } -// output: register to global mem, -#if 0 - const auto c_thread_mtx_begin = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k) - { - for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho) - { - for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo) - { - for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n) - { - const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n); - - const auto c_thread_mtx_distance = - blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b); - - const index_t ho_thread = - c_thread_mtx_begin.batch + c_thread_mtx_distance.batch; - const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row; - const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col; - - const index_t wo_thread = b_thread / NPerBlock; - const index_t n_thread = b_thread % NPerBlock; - - p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread, - ho_block_data_begin + ho_thread, - wo_block_data_begin + wo_thread, - n_block_data_begin + n_thread)] = - p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)]; - } - } - } - } -#elif 1 + // output: register to global mem, const auto c_thread_mtx_begin = blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); @@ -356,7 +315,6 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn wo_block_data_begin + wo_thread_data_begin, n_block_data_begin + n_thread_data_begin), out_10d_thread_desc.GetLengths(), - Number{}); -#endif + Number{}); } }; diff --git a/src/include/threadwise_4d_tensor_op.hip.hpp b/src/include/threadwise_4d_tensor_op.hip.hpp index 05894d434f..37427c0b8b 100644 --- a/src/include/threadwise_4d_tensor_op.hip.hpp +++ b/src/include/threadwise_4d_tensor_op.hip.hpp @@ -44,7 +44,7 @@ template -__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( +__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_given_dst2src( SrcDesc, const SrcData* __restrict__ p_src, DstDesc, @@ -82,9 +82,9 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); -#if 1 f(p_src[aindex], p_dst[bindex]); -#else + +#if 0 if(get_block_1d_id() == 0) { printf("tid %5u, " @@ -126,17 +126,16 @@ template -__device__ void -threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, - const SrcData* __restrict__ p_src, - DstDesc, - DstData* __restrict__ p_dst, - SrcOpLengths, - MapDst2Src) +__device__ void threadwise_4d_tensor_copy_reorder_given_dst2src(SrcDesc, + const SrcData* __restrict__ p_src, + DstDesc, + DstData* __restrict__ p_dst, + SrcOpLengths, + MapDst2Src) { auto f_copy = [](const SrcData& src, DstData& dst) { dst = static_cast(src); }; - threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( + threadwise_4d_tensor_pointwise_operation_binary_reorder_given_dst2src( SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy); } @@ -146,7 +145,7 @@ __device__ void threadwise_4d_tensor_copy( { auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{}; - threadwise_4d_tensor_copy_reorder_by_get_dst_from_src( + threadwise_4d_tensor_copy_reorder_given_dst2src( SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder); } @@ -212,6 +211,61 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc, } } +template +__device__ void +threadwise_4d_tensor_copy_reorder_given_dst2src_v2(SrcDesc, + const SrcData* __restrict__ p_src, + DstDesc, + DstData* __restrict__ p_dst, + SrcOpLengths, + MapDst2Src) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr index_t IR0 = MapDst2Src{}.Get(I0); + constexpr index_t IR1 = MapDst2Src{}.Get(I1); + constexpr index_t IR2 = MapDst2Src{}.Get(I2); + constexpr index_t IR3 = MapDst2Src{}.Get(I3); + + constexpr auto src_desc = SrcDesc{}; + constexpr auto dst_desc = DstDesc{}; + + // ref_desc has dst_desc's ordering + constexpr auto ref_desc = + make_ConstantTensorDescriptor(SrcOpLengths{}.ReorderGivenNew2Old(MapDst2Src{})); + + for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) + { + for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) + { + for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) + { + for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) + { + const auto dst_multi_id = Array{did0, did1, did2, did3}; + + const auto src_multi_id = + reorder_array_given_old2new(dst_multi_id, MapDst2Src{}); + + const index_t dst_index = dst_desc.Get1dIndex(dst_multi_id); + + const index_t src_index = src_desc.Get1dIndex(src_multi_id); + + p_dst[dst_index] = p_src[src_index]; + } + } + } + } +} + template __device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift) {