From 1cb98850580a51969cdd96e00c5b6d85299768b3 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 17 Feb 2019 01:50:57 -0600 Subject: [PATCH] add anther verision of batch gemm --- ...icit_gemm_convolution_1_chwn_csrk_khwn.hpp | 56 +++- driver/driver.hip.cpp | 6 +- src/include/blockwise_2d_tensor_op.hip.hpp | 90 ++--- src/include/blockwise_4d_tensor_op.hip.hpp | 172 ++++++++++ src/include/blockwise_gemm.hip.hpp | 308 +++++++++++++++++- src/include/common.hip.hpp | 17 + ..._gemm_convolution_1_chwn_csrk_khwn.hip.hpp | 107 ++++-- ...2_chwn_csrk_khwn_lds_double_buffer.hip.hpp | 11 +- 8 files changed, 657 insertions(+), 110 deletions(-) diff --git a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp index fb0ae4a8cd..3e92a157eb 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp @@ -75,6 +75,39 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, out_khwn_device_buf.ToDevice(out_khwn.mData.data()); #if 1 + // for 3x3, 34x34, try + constexpr unsigned NPerBlock = 16; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 4; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 8; + constexpr unsigned KPerThread = 8; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + + constexpr unsigned InBlockCopy_ThreadPerDimC = 4; + constexpr unsigned InBlockCopy_ThreadPerDimH = 4; + constexpr unsigned InBlockCopy_ThreadPerDimW = 2; + constexpr unsigned InBlockCopy_ThreadPerDimN = 4; + constexpr unsigned InBlockCopyDataPerRead = 4; + + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 4; + constexpr unsigned GemmNLevel0Cluster = 2; + constexpr unsigned GemmMLevel1Cluster = 2; + constexpr unsigned GemmNLevel1Cluster = 4; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned BlockSize = 128; +#elif 0 // for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256 constexpr unsigned NPerBlock = 16; constexpr unsigned KPerBlock = 64; @@ -131,7 +164,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr unsigned BlockSize = 128; -#elif 1 +#elif 0 // for 7x7, 38x38 constexpr unsigned NPerBlock = 8; constexpr unsigned KPerBlock = 64; @@ -184,7 +217,12 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; - constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet + constexpr unsigned InBlockCopy_ThreadPerDimC = 8; + constexpr unsigned InBlockCopy_ThreadPerDimH = 2; + constexpr unsigned InBlockCopy_ThreadPerDimW = 2; + constexpr unsigned InBlockCopy_ThreadPerDimN = 4; + constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr unsigned BlockSize = 128; @@ -212,13 +250,23 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, WoPerBlock, NPerThread, KPerThread, - CPerThread, HoPerThread, WoPerThread, WeiBlockCopyThreadPerDim0, WeiBlockCopyThreadPerDim1, + Sequence, InBlockCopyDataPerRead, - WeiBlockCopyDataPerRead>, + WeiBlockCopyDataPerRead, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop>, dim3(GridSize), dim3(BlockSize), static_cast(in_chwn_device_buf.GetDeviceBuffer()), diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index d4a84e34e5..4fc78491bb 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -391,7 +391,7 @@ int main() constexpr unsigned HPad = 0; constexpr unsigned WPad = 0; -#elif 0 +#elif 1 // 3x3, 34x34 constexpr unsigned N = 64; constexpr unsigned C = 256; @@ -593,11 +593,11 @@ int main() device_implicit_gemm_convolution_1_nchw_kcsr_nkhw #elif 0 device_implicit_gemm_convolution_1_nchw_srck_nkhw -#elif 0 +#elif 1 device_implicit_gemm_convolution_1_chwn_csrk_khwn #elif 0 device_implicit_gemm_convolution_2_cnhw_csrk_knhw -#elif 1 +#elif 0 device_implicit_gemm_convolution_2_chwn_csrk_khwn #endif (in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat); diff --git a/src/include/blockwise_2d_tensor_op.hip.hpp b/src/include/blockwise_2d_tensor_op.hip.hpp index a90007e246..a178b5dade 100644 --- a/src/include/blockwise_2d_tensor_op.hip.hpp +++ b/src/include/blockwise_2d_tensor_op.hip.hpp @@ -453,8 +453,7 @@ struct Blockwise2dTensorCopy3 constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0; constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0; - for(unsigned iloop = 0; iloop < nloop_d0; ++iloop) - { + auto f_copy = [&](unsigned iloop) { if(DataPerRead == 1) { p_dst[mDstMyThreadOffset + iloop * dst_loop_stride] = @@ -476,6 +475,11 @@ struct Blockwise2dTensorCopy3 { assert(false); } + }; + + for(unsigned iloop = 0; iloop < nloop_d0; ++iloop) + { + f_copy(iloop); } constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0); @@ -486,29 +490,7 @@ struct Blockwise2dTensorCopy3 if(get_thread_local_1d_id() < tail_d0 * thread_per_d1) { - if(DataPerRead == 1) - { - p_dst[mDstMyThreadOffset + nloop_d0 * dst_loop_stride] = - p_src[mSrcMyThreadOffset + nloop_d0 * src_loop_stride]; - } - else if(DataPerRead == 2) - { - *(reinterpret_cast(p_dst + mDstMyThreadOffset + - nloop_d0 * dst_loop_stride)) = - *(reinterpret_cast(p_src + mSrcMyThreadOffset + - nloop_d0 * src_loop_stride)); - } - else if(DataPerRead == 4) - { - *(reinterpret_cast(p_dst + mDstMyThreadOffset + - nloop_d0 * dst_loop_stride)) = - *(reinterpret_cast(p_src + mSrcMyThreadOffset + - nloop_d0 * src_loop_stride)); - } - else - { - assert(false); - } + f_copy(nloop_d0); } } } @@ -561,8 +543,7 @@ struct Blockwise2dTensorCopy3 constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0; constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0; - for(unsigned iloop = 0; iloop < nloop_d0; ++iloop) - { + auto f_copy = [&](unsigned iloop) { if(DataPerRead == 1) { p_clipboard[iloop] = p_src[mSrcMyThreadOffset + iloop * src_loop_stride]; @@ -583,6 +564,11 @@ struct Blockwise2dTensorCopy3 { assert(false); } + }; + + for(unsigned iloop = 0; iloop < nloop_d0; ++iloop) + { + f_copy(iloop); } constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0); @@ -593,26 +579,7 @@ struct Blockwise2dTensorCopy3 if(get_thread_local_1d_id() < tail_d0 * thread_per_d1) { - if(DataPerRead == 1) - { - p_clipboard[nloop_d0] = p_src[mSrcMyThreadOffset + nloop_d0 * src_loop_stride]; - } - else if(DataPerRead == 2) - { - *(reinterpret_cast(p_clipboard + nloop_d0 * 2)) = - *(reinterpret_cast(p_src + mSrcMyThreadOffset + - nloop_d0 * src_loop_stride)); - } - else if(DataPerRead == 4) - { - *(reinterpret_cast(p_clipboard + nloop_d0 * 4)) = - *(reinterpret_cast(p_src + mSrcMyThreadOffset + - nloop_d0 * src_loop_stride)); - } - else - { - assert(false); - } + f_copy(nloop_d0); } } } @@ -649,8 +616,7 @@ struct Blockwise2dTensorCopy3 constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0; constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0; - for(unsigned iloop = 0; iloop < nloop_d0; ++iloop) - { + auto f_copy = [&](unsigned iloop) { if(DataPerRead == 1) { p_dst[mDstMyThreadOffset + iloop * dst_loop_stride] = p_clipboard[iloop]; @@ -669,6 +635,11 @@ struct Blockwise2dTensorCopy3 { assert(false); } + }; + + for(unsigned iloop = 0; iloop < nloop_d0; ++iloop) + { + f_copy(iloop); } constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0); @@ -679,26 +650,7 @@ struct Blockwise2dTensorCopy3 if(get_thread_local_1d_id() < tail_d0 * thread_per_d1) { - if(DataPerRead == 1) - { - p_dst[mDstMyThreadOffset + nloop_d0 * dst_loop_stride] = p_clipboard[nloop_d0]; - } - else if(DataPerRead == 2) - { - *(reinterpret_cast(p_dst + mDstMyThreadOffset + - nloop_d0 * dst_loop_stride)) = - *(reinterpret_cast(p_clipboard + nloop_d0 * 2)); - } - else if(DataPerRead == 4) - { - *(reinterpret_cast(p_dst + mDstMyThreadOffset + - nloop_d0 * dst_loop_stride)) = - *(reinterpret_cast(p_clipboard + nloop_d0 * 4)); - } - else - { - assert(false); - } + f_copy(nloop_d0); } } } diff --git a/src/include/blockwise_4d_tensor_op.hip.hpp b/src/include/blockwise_4d_tensor_op.hip.hpp index b81063fed5..1b2f5e5d15 100644 --- a/src/include/blockwise_4d_tensor_op.hip.hpp +++ b/src/include/blockwise_4d_tensor_op.hip.hpp @@ -337,3 +337,175 @@ struct BlockwiseChwnTensorCopyPadded } } }; + +// starting point need to be aligned to float4 or float2 or float +// stride3 need to be 1 for both source and destination +template +struct Blockwise4dTensorCopy3 +{ + unsigned mSrcMyThreadOffset; + unsigned mDstMyThreadOffset; + + __device__ Blockwise4dTensorCopy3() + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + static_assert(SrcDesc{}.GetStride(I3) == 1 && DstDesc{}.GetStride(I3) == 1, + "wrong! only support stride3 == 1!\n"); + + static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4, + "wrong! only support DataPerRead == 1, 2 or 4!\n"); + + static_assert( + SrcDesc{}.GetStride(I2) % DataPerRead == 0 && + DstDesc{}.GetStride(I2) % DataPerRead == 0, + "wrong! src and dst stride should be multiple of DataPerRead to keep alignment"); + + constexpr unsigned L0 = CopyLengths{}.Get(I0); + constexpr unsigned L1 = CopyLengths{}.Get(I1); + constexpr unsigned L2 = CopyLengths{}.Get(I2); + constexpr unsigned L3 = CopyLengths{}.Get(I3); + + constexpr unsigned thread_per_d0 = ThreadPerDims{}.Get(I0); + constexpr unsigned thread_per_d1 = ThreadPerDims{}.Get(I1); + constexpr unsigned thread_per_d2 = ThreadPerDims{}.Get(I2); + constexpr unsigned thread_per_d3 = ThreadPerDims{}.Get(I3); + + // we allow out-of-bound read from src in D3 dimension, + // but we need to make sure dst stride is big enough, + // so that the out-of-bound write won't contaminate next line in dst + constexpr unsigned nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); + + static_assert(nloop_d3 * thread_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2), + "wrong! out-of-bound write will contaminate next line!\n"); + + static_assert(L0 % thread_per_d0 == 0 && L1 % thread_per_d1 == 0 && L2 % thread_per_d2 == 0, + "wrong! L0, L1, L2 should be divided evenly!\n"); + + static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3, + "wrrong! BlockSize is not big enough for ThreadPerDims!"); + + constexpr unsigned num_active_thread = + thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3; + + if(BlockSize > num_active_thread) + { + if(get_thread_local_1d_id() >= num_active_thread) + { + return; + } + } + + const unsigned thread_id_d0 = + get_thread_local_1d_id() / (thread_per_d1 * thread_per_d2 * thread_per_d3); + unsigned itmp = get_thread_local_1d_id() - + thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3); + const unsigned thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3); + itmp -= thread_id_d1 * (thread_per_d2 * thread_per_d3); + const unsigned thread_id_d2 = itmp / thread_per_d3; + const unsigned thread_id_d3 = itmp - thread_id_d2 * thread_per_d3; + + mSrcMyThreadOffset = SrcDesc{}.Get1dIndex( + thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead); + mDstMyThreadOffset = DstDesc{}.Get1dIndex( + thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead); + } + + __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const + { + static_assert(is_same::value, "wrong! only support float!\n"); + + using Float2 = float2; + using Float4 = float4; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr unsigned L0 = CopyLengths{}.Get(I0); + constexpr unsigned L1 = CopyLengths{}.Get(I1); + constexpr unsigned L2 = CopyLengths{}.Get(I2); + constexpr unsigned L3 = CopyLengths{}.Get(I3); + + constexpr unsigned thread_per_d0 = ThreadPerDims{}.Get(I0); + constexpr unsigned thread_per_d1 = ThreadPerDims{}.Get(I1); + constexpr unsigned thread_per_d2 = ThreadPerDims{}.Get(I2); + constexpr unsigned thread_per_d3 = ThreadPerDims{}.Get(I3); + + constexpr unsigned num_active_thread = + thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3; + + if(BlockSize > num_active_thread) + { + if(get_thread_local_1d_id() >= num_active_thread) + { + return; + } + } + + constexpr unsigned nloop_d0 = L0 / thread_per_d0; + constexpr unsigned nloop_d1 = L1 / thread_per_d1; + constexpr unsigned nloop_d2 = L2 / thread_per_d2; + constexpr unsigned nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); + +#pragma unroll + for(unsigned iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0) + { +#pragma unroll + for(unsigned iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1) + { +#pragma unroll + for(unsigned iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2) + { +#pragma unroll + for(unsigned iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3) + { + const unsigned src_offset = + SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, + iloop_d1 * thread_per_d1, + iloop_d2 * thread_per_d2, + iloop_d3 * thread_per_d3 * DataPerRead); + + const unsigned dst_offset = + DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, + iloop_d1 * thread_per_d1, + iloop_d2 * thread_per_d2, + iloop_d3 * thread_per_d3 * DataPerRead); + + if(DataPerRead == 1) + { + p_dst[dst_offset + mDstMyThreadOffset] = + p_src[src_offset + mSrcMyThreadOffset]; + } + else if(DataPerRead == 2) + { + *(reinterpret_cast(p_dst + dst_offset + mDstMyThreadOffset)) = + *(reinterpret_cast(p_src + src_offset + + mSrcMyThreadOffset)); + } + else if(DataPerRead == 4) + { + *(reinterpret_cast(p_dst + dst_offset + mDstMyThreadOffset)) = + *(reinterpret_cast(p_src + src_offset + + mSrcMyThreadOffset)); + } + else + { + assert(false); + } + } + } + } + } + } +}; diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index 2df5caec82..46773388ed 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -116,6 +116,13 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC } } + // this should be optimized away if input is known + __device__ static MatrixIndex + GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c) + { + return MatrixIndex{batch_in_c, m_in_c, n_in_c}; + } + template __device__ void Run(const FloatA* __restrict__ p_a_block, const FloatB* __restrict__ p_b_block, @@ -219,6 +226,306 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC } }; +template +struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 +{ + unsigned mMyThreadOffsetA = 0; + unsigned mMyThreadOffsetB = 0; + + struct MatrixIndex + { + unsigned batch; + unsigned row; + unsigned col; + }; + + __device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2() + { + static_assert(BatchSize % BatchPerThread == 0, + "wrong! BatchSize is not dividable by BatchPerThread"); + + constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread; + + constexpr unsigned ThreadPerLevel1Cluster = + MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; + + static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster, + "wrong! wrong blocksize\n"); + + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), + "wrong! K dimension not consistent\n"); + + constexpr unsigned M = a_block_mtx.NCol(); // A is transposed + constexpr unsigned N = b_block_mtx.NCol(); + constexpr unsigned K = a_block_mtx.NRow(); + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0), + "wrong! Cannot evenly divide thread work among repeat \n"); + + constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; + constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + + static_assert((M % MRepeat == 0) && (N % NRepeat == 0), + "wrong! Cannot evenly divide work among repeat\n"); + + constexpr unsigned MPerLevel1Cluster = M / MRepeat; + constexpr unsigned NPerLevel1Cluster = N / NRepeat; + + static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) && + (NPerLevel1Cluster % NLevel1Cluster == 0), + "wrong! Cannot evenly divide work among Level1Cluster\n"); + + constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster; + constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster; + + static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) && + (NPerLevel0Cluster % NLevel0Cluster == 0), + "wrong! Cannot evenly divide work among Level0Cluster\n"); + + static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) && + (NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster), + "wrong! thread work size is wrong\n"); + + const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA + + a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row); + + mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB + + b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: "); + print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: "); + print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: "); + + printf("%u %u, %u %u %u, %u %u\n", + get_block_1d_id(), + get_thread_local_1d_id(), + c_thread_mtx_index.batch, + c_thread_mtx_index.row, + c_thread_mtx_index.col, + mMyThreadOffsetA, + mMyThreadOffsetB); + } +#endif + } + + __device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const + { + constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread; + + constexpr unsigned ThreadPerLevel1Cluster = + MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; + + constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster; + + unsigned batch_work_id = thread_id / ThreadPerLevel1Cluster; + unsigned cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster; + + unsigned level1_id = cluster_id / ThreadPerLevel0Cluster; + unsigned level1_m_id = level1_id / NLevel1Cluster; + unsigned level1_n_id = level1_id % NLevel1Cluster; + + unsigned level0_id = cluster_id % ThreadPerLevel0Cluster; + unsigned level0_m_id = level0_id / NLevel0Cluster; + unsigned level0_n_id = level0_id % NLevel0Cluster; + + constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster; + constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster; + + return MatrixIndex{batch_work_id * BatchPerThread, + level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC, + level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; + } + + // this should be optimized away if input is known + __device__ static MatrixIndex + GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c) + { + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; + constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + + constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + + unsigned m_repeat = m_in_c / MPerThreadSubC; + unsigned n_repeat = n_in_c / NPerThreadSubC; + + unsigned m_in_sub_c = m_in_c % MPerThreadSubC; + unsigned n_in_sub_c = n_in_c % NPerThreadSubC; + + return MatrixIndex{batch_in_c, + m_repeat * MPerLevel1Cluster + m_in_sub_c, + n_repeat * NPerLevel1Cluster + n_in_sub_c}; + } + + template + __device__ void Run(const FloatA* __restrict__ p_a_block, + const FloatB* __restrict__ p_b_block, + FloatC* __restrict__ p_c_thread, + Accumulator f_accum) const + { + constexpr auto True = integral_constant{}; + constexpr auto False = integral_constant{}; + + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + // thread A, B for GEMM + // A is transposed, b is not + constexpr auto a_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + constexpr auto b_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + // thread A-sub, B-sub for copy + constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; + FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; + + constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + + constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; + constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + + // loop over k +#pragma unroll + for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) + { + // read first batch of A, B + // copy A-sub to form A +#pragma unroll + for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + { + threadwise_matrix_copy( + a_block_mtx, + p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + + mMyThreadOffsetA, + a_thread_mtx, + p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), + a_thread_sub_mtx.GetLengths()); + } + + // copy B-sub to form B +#pragma unroll + for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + { + threadwise_matrix_copy( + b_block_mtx, + p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) + + mMyThreadOffsetB, + b_thread_mtx, + p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), + b_thread_sub_mtx.GetLengths()); + } + + // loop over batch +#pragma unroll + for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) + { + // do current batch of gemm + threadwise_gemm(a_thread_mtx, + True, + p_a_thread, + b_thread_mtx, + False, + p_b_thread, + c_thread_mtx, + False, + p_c_thread + ib * ThreadMatrixStrideC, + f_accum); + + // read next batch of a, b + if(BlockMatrixStrideA != 0) + { +#pragma unroll + for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + { + threadwise_matrix_copy( + a_block_mtx, + p_a_block + + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + + (ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA, + a_thread_mtx, + p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), + a_thread_sub_mtx.GetLengths()); + } + } + + if(BlockMatrixStrideB != 0) + { +#pragma unroll + for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + { + threadwise_matrix_copy( + b_block_mtx, + p_b_block + + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) + + (ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB, + b_thread_mtx, + p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), + b_thread_sub_mtx.GetLengths()); + } + } + } + + // do last batch of gemm + threadwise_gemm(a_thread_mtx, + True, + p_a_thread, + b_thread_mtx, + False, + p_b_thread, + c_thread_mtx, + False, + p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC, + f_accum); + } + } +}; template +__host__ __device__ constexpr T max(T a, T b) +{ + return a > b ? a : b; +} + +template +__host__ __device__ constexpr T min(T a, T b) +{ + return a < b ? a : b; +} + +__host__ __device__ constexpr unsigned integer_divide_ceil(unsigned a, unsigned b) +{ + return (a + b - 1) / b; +} diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp index 84a414147f..edde97b893 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp @@ -20,13 +20,20 @@ template + unsigned WeiBlockCopyDataPerRead, + unsigned GemmMPerThreadSubC, + unsigned GemmNPerThreadSubC, + unsigned GemmMLevel0Cluster, + unsigned GemmNLevel0Cluster, + unsigned GemmMLevel1Cluster, + unsigned GemmNLevel1Cluster, + unsigned GemmKPerThreadLoop> __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global, @@ -114,12 +121,22 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric // blockwise copy // input: format is [C, Hi, Wi, N] +#if 0 constexpr auto blockwise_in_copy = Blockwise4dTensorCopy1{}; +#elif 1 + const auto blockwise_in_copy = Blockwise4dTensorCopy3{}; +#endif // blockwise wei copy // format is [CPerBlock*S*R,KPerBlock] @@ -131,7 +148,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric decltype(wei_ek_block_desc), decltype(wei_ek_block_desc.GetLengths())>{}; #elif 0 - const auto blockwise_wei_copy = Blockwise2dTensorCopy2{}, Number{}); +#if 0 const auto blockwise_batch_gemm = Blockwise1dStridedBatchedGemmBlockABlockBThreadC{}; +#else + const auto blockwise_batch_gemm = BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< + BlockSize, + decltype(a_cxk_block_mtx_desc), + decltype(b_cxwn_block_mtx_desc), + decltype(c_kxwn_thread_mtx_desc), + 0, + in_chwn_block_desc.GetStride(I1), + out_hkwn_thread_desc.GetStride(I0), + HoPerBlock, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + HoPerThread>{}; +#endif // LDS: be careful of alignment constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); @@ -210,10 +247,10 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric p_wei_global_block_begin += CPerBlock * wei_csrk_global_desc.GetStride(I0), __syncthreads()) { - // input: global mem to LDS, + // input: global mem to LDS blockwise_in_copy.Run(p_in_global_block_begin, p_in_block); - // weight: global mem to LDS, + // weight: global mem to LDS blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block); __syncthreads(); @@ -223,34 +260,26 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric { for(unsigned r = 0; r < R; ++r) { - auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; - blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0), p_out_thread, - f_accum); + [](auto& acc, const auto&& v) { acc += v; }); } } } - const auto matrix_c_index = + const auto c_thread_mtx_begin = blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - const unsigned ho_thread_data_begin = matrix_c_index.batch; - const unsigned k_thread_data_begin = matrix_c_index.row; - const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock; - const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock; - -#if 0 - printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n", - get_block_1d_id(), get_thread_local_1d_id(), - ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin, - ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin, - p_out_thread[0]); -#endif - // output: register to global mem, // convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N] +#if 0 + // for v1 batch-gemm + const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch; + const unsigned k_thread_data_begin = c_thread_mtx_begin.row; + const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; + const unsigned n_thread_data_begin = c_thread_mtx_begin.col - wo_thread_data_begin * NPerBlock; + constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{}; threadwise_4d_tensor_copy_reorder_by_get_dst_from_src( @@ -263,4 +292,36 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric n_block_data_begin + n_thread_data_begin), out_hkwn_thread_desc.GetLengths(), reorder_khwn_from_hkwn); +#else + for(unsigned ho = 0; ho < out_hkwn_thread_desc.GetLength(I0); ++ho) + { + for(unsigned k = 0; k < out_hkwn_thread_desc.GetLength(I1); ++k) + { + for(unsigned wo = 0; wo < out_hkwn_thread_desc.GetLength(I2); ++wo) + { + for(unsigned n = 0; n < out_hkwn_thread_desc.GetLength(I3); ++n) + { + const unsigned b = out_hkwn_thread_desc.Get1dIndex(0, 0, wo, n); + + const auto c_thread_mtx_distance = + blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b); + + const unsigned ho_thread = + c_thread_mtx_begin.batch + c_thread_mtx_distance.batch; + const unsigned k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row; + const unsigned b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col; + + const unsigned wo_thread = b_thread / NPerBlock; + const unsigned n_thread = b_thread - NPerBlock * wo_thread; + + p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread, + ho_block_data_begin + ho_thread, + wo_block_data_begin + wo_thread, + n_block_data_begin + n_thread)] = + p_out_thread[out_hkwn_thread_desc.Get1dIndex(ho, k, wo, n)]; + } + } + } + } +#endif } diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp index 15ce27dddb..e7070ae978 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp @@ -259,16 +259,9 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b __syncthreads(); // load next data -#if 0 +#if 1 blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next); blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next); -#elif 0 - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next); - - Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; - - blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, - p_wei_register_clipboard); #elif 1 Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; @@ -300,8 +293,6 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b } #if 0 - blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next); -#elif 1 blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next); blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next); #endif