diff --git a/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp b/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp index 26fa9c8ca8..e5c20994f2 100644 --- a/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp +++ b/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp @@ -95,7 +95,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW constexpr index_t InBlockReorderDataPerWrite_N = 2; - using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used + using WeiBlockCopyClusterLengths = void; constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t OutThreadCopyDataPerWrite_N = 2; @@ -130,7 +130,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW constexpr index_t InBlockReorderDataPerWrite_N = 2; - using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used + using WeiBlockCopyClusterLengths = void; constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t OutThreadCopyDataPerWrite_N = 2; @@ -200,7 +200,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW constexpr index_t InBlockReorderDataPerWrite_N = 1; - using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used + using WeiBlockCopyClusterLengths = void; constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t OutThreadCopyDataPerWrite_N = 2; diff --git a/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp b/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp index f74b05e750..f366f6664c 100644 --- a/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp +++ b/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp @@ -3,6 +3,7 @@ #include "device.hpp" #include "gridwise_convolution_wrapper.hip.hpp" #include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp" +#include "gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp" template void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, @@ -92,7 +93,42 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, constexpr index_t OutThreadCopyDataPerWrite_W = 2; #elif 0 - // for 3x3, 34x34, v1r3, Vega 20 + // for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 32 + constexpr index_t BlockSize = 256; + + constexpr index_t NPerBlock = 1; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 4; + constexpr index_t WoPerBlock = 32; + + constexpr index_t NPerThread = 1; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockReorderSrcSubLengths_NCHW = Sequence<1, 2, 2, 1>; + using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 4, 2, 32>; + using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; + constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW + constexpr index_t InBlockReorderDataPerWrite_N = 1; + + using WeiBlockCopyClusterLengths = void; + constexpr index_t WeiBlockCopyDataPerRead_K = 4; + + constexpr index_t OutThreadCopyDataPerWrite_W = 4; +#elif 0 + // for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 16 constexpr index_t BlockSize = 256; constexpr index_t NPerBlock = 2; @@ -125,9 +161,9 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, using WeiBlockCopyClusterLengths = void; constexpr index_t WeiBlockCopyDataPerRead_K = 4; - constexpr index_t OutThreadCopyDataPerWrite_W = 4; + constexpr index_t OutThreadCopyDataPerWrite_W = 2; #elif 1 - // for 3x3, 34x34, v1r3, Vega 20, try + // for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 8 constexpr index_t BlockSize = 256; constexpr index_t NPerBlock = 4; @@ -160,7 +196,77 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, using WeiBlockCopyClusterLengths = void; constexpr index_t WeiBlockCopyDataPerRead_K = 4; - constexpr index_t OutThreadCopyDataPerWrite_W = 2; + constexpr index_t OutThreadCopyDataPerWrite_W = 1; +#elif 0 + // for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 4 + constexpr index_t BlockSize = 256; + + constexpr index_t NPerBlock = 8; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 4; + constexpr index_t WoPerBlock = 4; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>; + using InBlockReorderSrcClusterLengths_NCHW = Sequence<2, 8, 4, 4>; + using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; + constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW + constexpr index_t InBlockReorderDataPerWrite_N = 4; + + using WeiBlockCopyClusterLengths = void; + constexpr index_t WeiBlockCopyDataPerRead_K = 4; + + constexpr index_t OutThreadCopyDataPerWrite_W = 1; +#elif 0 + // for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 2 + constexpr index_t BlockSize = 256; + + constexpr index_t NPerBlock = 32; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>; + using InBlockReorderSrcClusterLengths_NCHW = Sequence<8, 8, 2, 2>; + using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; + constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW + constexpr index_t InBlockReorderDataPerWrite_N = 4; + + using WeiBlockCopyClusterLengths = void; + constexpr index_t WeiBlockCopyDataPerRead_K = 4; + + constexpr index_t OutThreadCopyDataPerWrite_W = 1; #elif 1 // for 3x3, 28x28, v1r3, Pascal constexpr index_t BlockSize = 128; @@ -206,39 +312,44 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, for(index_t i = 0; i < nrepeat; ++i) { - constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw< - GridSize, - BlockSize, - T, - decltype(in_nchw_desc), - decltype(wei_cyxk_desc), - decltype(out_nkhw_desc), - NPerBlock, - KPerBlock, - CPerBlock, - HoPerBlock, - WoPerBlock, - NPerThread, - KPerThread, - HoPerThread, - WoPerThread, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmDataPerReadA, - GemmDataPerReadB, - InBlockReorderSrcSubLengths_NCHW, - InBlockReorderSrcClusterLengths_NCHW, - InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW, - InBlockReorderDataPerRead_W, - InBlockReorderDataPerWrite_N, - WeiBlockCopyClusterLengths, - WeiBlockCopyDataPerRead_K, - OutThreadCopyDataPerWrite_W>{}; + constexpr auto gridwise_conv = +#if 0 + GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw +#else + GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw +#endif + {}; float time = launch_kernel(run_gridwise_convolution, dim3(GridSize), diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index f25b99ca27..ed7fa09d1d 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -371,7 +371,7 @@ void host_winograd_3x3_convolution(const Tensor& in_nchw, std::size_t ho = HoPerTile * htile + j; for(int i = 0; i < WoPerTile; ++i) { - std::size_t wo = WoPerTile * wtile + i; + std::size_t wo = WoPerTile * wtile + i; out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i); } } @@ -413,13 +413,13 @@ int main(int argc, char* argv[]) { #if 1 // 3x3, 34x34 - constexpr index_t N = 64; - constexpr index_t C = 256; + constexpr index_t N = 64; + constexpr index_t C = 256; constexpr index_t HI = 34; constexpr index_t WI = 34; - constexpr index_t K = 128; - constexpr index_t Y = 3; - constexpr index_t X = 3; + constexpr index_t K = 128; + constexpr index_t Y = 3; + constexpr index_t X = 3; constexpr index_t HPad = 0; constexpr index_t WPad = 0; diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp index ab67f3afff..5f9c7a75bc 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp @@ -74,22 +74,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, "wrong! cannot evenly divide work for workgroup "); - constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; - constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; - constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; + constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock); + constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock); + constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock); + constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock); - const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); - index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); - const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork); - itmp -= h_block_work_id * (WBlockWork * NBlockWork); - const index_t w_block_work_id = itmp / NBlockWork; - const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; + constexpr auto block_work_desc = make_ConstantTensorDescriptor( + Sequence{}); - const index_t k_block_data_begin = k_block_work_id * KPerBlock; - const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; - const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; - const index_t n_block_data_begin = n_block_work_id * NPerBlock; + const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id()); + + const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock; + const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock; + const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock; + const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock; const index_t hi_block_data_begin = ho_block_data_begin; const index_t wi_block_data_begin = wo_block_data_begin; @@ -185,7 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn // choose GEMM implementation here const auto run_blockwise_batch_gemm = [&](auto... Xs) { -#if 1 +#if 0 return blockwise_batch_gemm.Run(Xs...); #elif 0 return blockwise_batch_gemm.Run_asm(Xs...); diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp index ac96fff9fc..732443adf3 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp @@ -81,22 +81,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, "wrong! cannot evenly divide work for workgroup "); - constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; - constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; - constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; + constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock); + constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock); + constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock); + constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock); - const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); - index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); - const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork); - itmp -= h_block_work_id * (WBlockWork * NBlockWork); - const index_t w_block_work_id = itmp / NBlockWork; - const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; + constexpr auto block_work_desc = make_ConstantTensorDescriptor( + Sequence{}); - const index_t k_block_data_begin = k_block_work_id * KPerBlock; - const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; - const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; - const index_t n_block_data_begin = n_block_work_id * NPerBlock; + const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id()); + + const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock; + const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock; + const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock; + const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock; const index_t hi_block_data_begin = ho_block_data_begin; const index_t wi_block_data_begin = wo_block_data_begin; diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp new file mode 100644 index 0000000000..a2dee6a01f --- /dev/null +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp @@ -0,0 +1,475 @@ +#pragma once +#include "common.hip.hpp" +#include "ConstantTensorDescriptor.hip.hpp" +#include "ConstantMatrixDescriptor.hip.hpp" +#include "blockwise_2d_tensor_op.hip.hpp" +#include "blockwise_nd_tensor_op.hip.hpp" +#include "threadwise_nd_tensor_op.hip.hpp" +#include "threadwise_4d_tensor_op.hip.hpp" +#include "blockwise_batched_gemm.hip.hpp" + +template +struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw +{ + __device__ void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) const + { + // be careful of this assertion + static_assert( + NPerBlock % NPerThread == 0 && + ((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) || + (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0)), + "wrong!"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; + constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; + constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; + + constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); + + constexpr index_t N = out_n_k_h_w_global_desc.GetLength(I0); + constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1); + constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); + constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3); + + constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); + constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); + + // assert for LDS double buffer + static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided"); + + // divide block work: [K, Ho, Wo, N] + static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && + Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, + "wrong! cannot evenly divide work for workgroup "); + + constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock); + constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock); + constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock); + constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock); + + constexpr auto block_work_desc = make_ConstantTensorDescriptor( + Sequence{}); + + const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id()); + + const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock; + const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock; + const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock; + const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock; + + const index_t hi_block_data_begin = ho_block_data_begin; + const index_t wi_block_data_begin = wo_block_data_begin; + + // global tensor view + constexpr auto wei_c_k_global_desc = + make_ConstantTensorDescriptor(Sequence{}, Sequence{}); + + // LDS tensor view + // be careful of alignment + constexpr index_t max_align = mod_conv::max(InBlockReorderDataPerWrite_N, + WeiBlockCopyDataPerRead_K, + GemmDataPerReadA, + GemmDataPerReadB); + + constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, + Number{}); + + // this check is ad-hoc + // TODO: need to properly implement tensor descriptor with alignment + static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, + "GemmDataPerReadB alignment requirement is not meet"); + + constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, + Number{}); + + // tensor view of threadwise output in register + constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + + // blockwise copy + // input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N] + constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{}; + + const auto blockwise_in_copy_reorder = + BlockwiseNdTensorCopyReorder_v3, + InBlockReorderSrcSubLengths_NCHW, + InBlockReorderSrcClusterLengths_NCHW, + decltype(map_chwn2nchw), + InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW, + InBlockReorderDataPerRead_W, + InBlockReorderDataPerWrite_N>{}; + + // blockwise wei copy + // format is [CPerBlock, KPerBlock] + const auto blockwise_wei_copy = + Blockwise2dTensorCopy3{}; + + // a series of blockwise batched GEMM + // C_matrix += transpose(A_matrix) * B_matrix + // A_matrix and B_matrix saved in LDS, C_matrix saved in register + // A_matrix[C,K] is a sub-matrix of wei_block[C,K] + // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] + // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] + constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_c_wn_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + constexpr auto c_k_wn_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + const auto blockwise_batch_gemm = + BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< + BlockSize, + decltype(a_c_k_block_mtx_desc), + decltype(b_c_wn_block_mtx_desc), + decltype(c_k_wn_thread_mtx_desc), + 0, + in_c_h_w_n_block_desc.GetStride(I1), + out_k_h_w_n_thread_desc.GetStride(I1), + HoPerBlock, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + HoPerThread, + GemmDataPerReadA, + GemmDataPerReadB>{}; + + // choose GEMM implementation here + const auto run_blockwise_batch_gemm = [&](auto... Xs) { +#if 0 + return blockwise_batch_gemm.Run(Xs...); +#elif 0 + return blockwise_batch_gemm.Run_asm(Xs...); +#else + return blockwise_batch_gemm.Run_asm_v2(Xs...); +#endif + }; + + // LDS: be careful of alignment + constexpr index_t in_block_space = + in_c_h_w_n_block_desc.GetElementSpace(Number{}); + constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number{}); + + // LDS double buffer + __shared__ Float p_in_block_double[2 * in_block_space]; + __shared__ Float p_wei_block_double[2 * wei_block_space]; + + // register + // C++ lambda doesn't capture array, use pointer instead + Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()]; + Float* const p_out_thread = p_out_thread_data; + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc"); + print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc"); + + print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc"); + print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc"); + + printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space); + } +#endif + + // set threadwise output tensor to 0 + threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); + + for(index_t y = 0; y < Y; ++y) + { + for(index_t x = 0; x < X; ++x) + { + const Float* p_in_global_block_offset = + p_in_global + + in_n_c_h_w_global_desc.Get1dIndex( + n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x); + + const Float* p_wei_global_block_offset = + p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin); + + // LDS double buffer: preload data into LDS + { + Float p_in_register_clipboard[blockwise_in_copy_reorder + .GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); + + blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard, + p_in_block_double); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, + p_wei_block_double); + } + + // LDS double buffer: main body + for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C; + c_block_data_begin += 2 * CPerBlock) + { +#pragma unroll + for(index_t iloop = 0; iloop < 2; ++iloop) + { + const bool even_loop = (iloop % 2 == 0); + + Float* p_in_block_now = + even_loop ? p_in_block_double : p_in_block_double + in_block_space; + Float* p_wei_block_now = + even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; + + Float* p_in_block_next = + even_loop ? p_in_block_double + in_block_space : p_in_block_double; + Float* p_wei_block_next = + even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; + + Float p_in_register_clipboard[blockwise_in_copy_reorder + .GetRegisterClipboardSize()]; + Float + p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + p_in_global_block_offset += + CPerBlock * in_n_c_h_w_global_desc.GetStride(I1); + p_wei_global_block_offset += + CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); + + // LDS double buffer: GEMM on current data + run_blockwise_batch_gemm(p_wei_block_now, p_in_block_now, p_out_thread); + + // LDS double buffer: store next data to LDS + blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard, + p_in_block_next); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, + p_wei_block_next); + } + } + + // LDS double buffer: tail + { + Float p_in_register_clipboard[blockwise_in_copy_reorder + .GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + // even iteration + p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1); + p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); + + // LDS double buffer: GEMM on current data + run_blockwise_batch_gemm(p_wei_block_double, p_in_block_double, p_out_thread); + + // LDS double buffer: store next data to LDS + blockwise_in_copy_reorder.RunStoreRegisterClipboard( + p_in_register_clipboard, p_in_block_double + in_block_space); + blockwise_wei_copy.RunStoreRegisterClipboard( + p_wei_register_clipboard, p_wei_block_double + wei_block_space); + + // odd iteration + __syncthreads(); + + // LDS double buffer: GEMM on current data + run_blockwise_batch_gemm(p_wei_block_double + wei_block_space, + p_in_block_double + in_block_space, + p_out_thread); + } + } + } + + // output: register to global mem, + const auto c_thread_mtx_begin = + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const index_t k_thread_data_begin = c_thread_mtx_begin.row; + const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; + const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; + const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; + + static_if{}([&](auto f_dummy) { // f_dummy do nothing but + // perfect forwarding. + // Using this trick to + // make this lambda a generic lambda, so it won't be compiled until + // instantiated + static_assert( + (f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), + "wrong!"); + + // output is a 10d tensor + constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N1 = NPerBlock / N2; + + constexpr index_t W2 = + (GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC); + constexpr index_t W1 = WoPerBlock / W2; + + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + } +#endif + + constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{}; + + threadwise_nd_tensor_copy_reorder_given_dst2src_v2( + out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin), + out_10d_thread_desc.GetLengths(), + map_out_global2thread); + // Number{}); + }).else_([&](auto f_dummy) { + static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0, + "wrong!"); + + // output is a 10d tensor + constexpr index_t N1 = NPerBlock; + + constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; + constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; + constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); + + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( + Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + + } +#endif + + constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{}; + + threadwise_nd_tensor_copy_reorder_given_dst2src_v2( + out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin), + out_10d_thread_desc.GetLengths(), + map_out_global2thread); + // Number{}); + }); + } +}; diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp index 2a85725a50..ffda830d67 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp @@ -193,7 +193,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw // choose GEMM implementation here const auto run_blockwise_batch_gemm = [&](auto... Xs) { -#if 1 +#if 0 return blockwise_batch_gemm.Run(Xs...); #elif 0 return blockwise_batch_gemm.Run_asm(Xs...); @@ -340,39 +340,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; - static_if{}( - [&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to - // make this lambda a generic lambda, so it won't be compiled until - // instantiated - static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock && - NPerBlock % GemmNPerThreadSubC == 0), - "wrong!"); + static_if{}([&](auto f_dummy) { // f_dummy do nothing but + // perfect forwarding. + // Using this trick to + // make this lambda a generic lambda, so it won't be compiled until + // instantiated + static_assert( + (f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), + "wrong!"); - // output is a 10d tensor - constexpr index_t N2 = GemmNPerThreadSubC; - constexpr index_t N1 = NPerBlock / N2; + // output is a 10d tensor + constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N1 = NPerBlock / N2; - constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / - f_dummy(NPerBlock / GemmNPerThreadSubC); - constexpr index_t W1 = WoPerBlock / W2; + constexpr index_t W2 = + (GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC); + constexpr index_t W1 = WoPerBlock / W2; - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; - constexpr auto out_10d_global_desc = - make_ConstantTensorDescriptor(Sequence{}); + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); #if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) @@ -387,51 +388,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw } #endif - constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{}; + constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{}; - threadwise_nd_tensor_copy_reorder_given_dst2src_v2( - out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + - out_n_k_h_w_global_desc.Get1dIndex( - n_block_data_begin + n_thread_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), - out_10d_thread_desc.GetLengths(), - map_out_global2thread); - // Number{}); - }) - .else_([&](auto f_dummy) { - static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && - GemmNPerThreadSubC % NPerThread == 0, - "wrong!"); + threadwise_nd_tensor_copy_reorder_given_dst2src_v2( + out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin), + out_10d_thread_desc.GetLengths(), + map_out_global2thread); + // Number{}); + }).else_([&](auto f_dummy) { + static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0, + "wrong!"); - // output is a 10d tensor - constexpr index_t N1 = NPerBlock; + // output is a 10d tensor + constexpr index_t N1 = NPerBlock; - constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; - constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; - constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); + constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; + constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; + constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; - constexpr auto out_10d_global_desc = - make_ConstantTensorDescriptor(Sequence{}); + constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( + Sequence{}); - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); #if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) @@ -447,21 +437,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw } #endif - constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{}; + constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{}; - threadwise_nd_tensor_copy_reorder_given_dst2src_v2( - out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + - out_n_k_h_w_global_desc.Get1dIndex( - n_block_data_begin + n_thread_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), - out_10d_thread_desc.GetLengths(), - map_out_global2thread); - // Number{}); - }); + threadwise_nd_tensor_copy_reorder_given_dst2src_v2( + out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin), + out_10d_thread_desc.GetLengths(), + map_out_global2thread); + // Number{}); + }); } };