From 1b323316a8448499ad835d495ffbe197ef97761a Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 6 Feb 2019 23:10:08 -0600 Subject: [PATCH] add another blockwise gemm --- driver/conv.cu | 11 +- ...icit_gemm_convolution_2_cnhw_csrk_knhw.cuh | 27 +- ...mm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh | 210 +++++++++++ src/include/blockwise_gemm.cuh | 269 ++++++++++++-- ...icit_gemm_convolution_1_chwn_csrk_khwn.cuh | 11 +- ...mm_convolution_1_chwn_csrk_khwn_padded.cuh | 11 +- ...n_1_chwn_csrk_khwn_padded_lds_pipeline.cuh | 11 +- ..._implicit_gemm_convolution_1_nchw_kcsr.cuh | 10 +- ...icit_gemm_convolution_1_nchw_srck_nkhw.cuh | 11 +- ...icit_gemm_convolution_2_cnhw_csrk_knhw.cuh | 14 +- ...mm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh | 327 ++++++++++++++++++ ...volution_2_cnhw_csrk_knhw_lds_pipeline.cuh | 13 +- ...icit_gemm_convolution_2_cnhw_srck_knhw.cuh | 13 +- ...volution_2_cnhw_srck_knhw_lds_pipeline.cuh | 13 +- 14 files changed, 840 insertions(+), 111 deletions(-) create mode 100644 driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh create mode 100644 src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh diff --git a/driver/conv.cu b/driver/conv.cu index 0265730949..cecf4737e8 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -14,6 +14,7 @@ #include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh" #include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh" #include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh" +#include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh" //#include "device_winograd_convolution.cuh" struct GeneratorTensor_1 @@ -391,7 +392,7 @@ int main() constexpr unsigned HPad = 0; constexpr unsigned WPad = 0; -#elif 1 +#elif 0 // 3x3, 34x34 constexpr unsigned N = 64; constexpr unsigned C = 256; @@ -484,7 +485,7 @@ int main() constexpr unsigned HPad = 1; constexpr unsigned WPad = 1; -#elif 0 +#elif 1 // 1x1 filter, 28x28 image constexpr unsigned N = 16; constexpr unsigned C = 256; @@ -591,8 +592,10 @@ int main() device_implicit_gemm_convolution_1_chwn_csrk_khwn #elif 0 device_implicit_gemm_convolution_2_cnhw_srck_knhw -#elif 1 +#elif 0 device_implicit_gemm_convolution_2_cnhw_csrk_knhw +#elif 1 + device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2 #endif (in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat); @@ -608,7 +611,7 @@ int main() nrepeat); #endif -#if 1 +#if 0 if(S == 3 && R == 3) { host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); diff --git a/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh b/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh index c765f8aa58..a3d66b8a94 100644 --- a/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh +++ b/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh @@ -67,7 +67,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, Tensor out_knhw(make_TensorDescriptor(out_knhw_desc)); -#if 1 +#if 0 // 3x3, 34x34 constexpr unsigned BPerBlock = 128; constexpr unsigned KPerBlock = 64; @@ -90,31 +90,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr unsigned BlockSize = 128; -#elif 0 - // 1x1, 28x28 - constexpr unsigned BPerBlock = 64; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 8; - - constexpr unsigned BPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - - constexpr unsigned GemmThreadPerColumnPerCluster = 4; - constexpr unsigned GemmThreadPerRowPerCluster = 8; - - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; - - constexpr unsigned InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; - - constexpr unsigned BlockSize = 64; #elif 1 - // 1x1, 28x28 try + // 1x1, 28x28 constexpr unsigned BPerBlock = 64; constexpr unsigned KPerBlock = 64; constexpr unsigned CPerBlock = 8; diff --git a/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh b/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh new file mode 100644 index 0000000000..4bf88d9edd --- /dev/null +++ b/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh @@ -0,0 +1,210 @@ +#pragma once +#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh" +#include + +template +void device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcsr, + OutDesc, + Tensor& out_nkhw, + unsigned nrepeat) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_nchw_desc = InDesc{}; + constexpr auto wei_kcsr_desc = WeiDesc{}; + constexpr auto out_nkhw_desc = OutDesc{}; + + constexpr unsigned N = in_nchw_desc.GetLength(I0); + constexpr unsigned Hi = in_nchw_desc.GetLength(I2); + constexpr unsigned Wi = in_nchw_desc.GetLength(I3); + + constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); + constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); + + constexpr unsigned K = wei_kcsr_desc.GetLength(I0); + constexpr unsigned C = wei_kcsr_desc.GetLength(I1); + constexpr unsigned S = wei_kcsr_desc.GetLength(I2); + constexpr unsigned R = wei_kcsr_desc.GetLength(I3); + + constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1); + + // convert in_nchw to in_cnhw + auto in_cnhw_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(in_cnhw_desc, std::cout << "in_cnhw_desc: "); + + Tensor in_cnhw(make_TensorDescriptor(in_cnhw_desc)); + + auto f_reorder_nchw2cnhw = [&](auto n, auto c, auto hi, auto wi) { + in_cnhw(c, n, hi, wi) = in_nchw(n, c, hi, wi); + }; + + make_ParallelTensorFunctor(f_reorder_nchw2cnhw, N, C, Hi, Wi)( + std::thread::hardware_concurrency()); + + // convert wei_kcsr to wei_csrk + auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: "); + + Tensor wei_csrk(make_TensorDescriptor(wei_csrk_desc)); + + auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto s, auto r) { + wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r); + }; + + make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, S, R)( + std::thread::hardware_concurrency()); + + // conver out_nkhw to out_knhw + auto out_knhw_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(out_knhw_desc, std::cout << "out_knhw_desc: "); + + Tensor out_knhw(make_TensorDescriptor(out_knhw_desc)); + +#if 0 + // 1x1, 28x28 + constexpr unsigned BPerBlock = 64; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 8; + + constexpr unsigned BPerThread = 4; + constexpr unsigned KPerThread = 16; + + constexpr unsigned GemmMPerThreadSubC = 16; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 4; + constexpr unsigned GemmNLevel0Cluster = 8; + constexpr unsigned GemmMLevel1Cluster = 1; + constexpr unsigned GemmNLevel1Cluster = 2; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned GemmThreadPerColumnPerCluster = 4; + constexpr unsigned GemmThreadPerRowPerCluster = 8; + + constexpr unsigned InBlockCopyThreadPerDim0 = 4; + constexpr unsigned InBlockCopyThreadPerDim1 = 16; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; + + constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned BlockSize = 64; +#elif 1 + // 1x1, 28x28 try + constexpr unsigned BPerBlock = 64; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 8; + + constexpr unsigned BPerThread = 8; + constexpr unsigned KPerThread = 8; + + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 8; + constexpr unsigned GemmNLevel0Cluster = 2; + constexpr unsigned GemmMLevel1Cluster = 1; + constexpr unsigned GemmNLevel1Cluster = 4; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned GemmThreadPerColumnPerCluster = 8; + constexpr unsigned GemmThreadPerRowPerCluster = 8; + + constexpr unsigned InBlockCopyThreadPerDim0 = 4; + constexpr unsigned InBlockCopyThreadPerDim1 = 16; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; + + constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned BlockSize = 64; +#endif + + constexpr unsigned GridSize = + ((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); + + dim3 block_dim(BlockSize); + dim3 grid_dim(GridSize); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + // mem + std::size_t data_sz = sizeof(T); + DeviceMem in_cnhw_device_buf(data_sz * (in_cnhw.mDesc.GetElementSpace() + BGhostRead + + BPerBlock)); // reserve extra space for BGhostRead + DeviceMem wei_csrk_device_buf(data_sz * wei_csrk.mDesc.GetElementSpace()); + DeviceMem out_knhw_device_buf(data_sz * out_knhw.mDesc.GetElementSpace()); + + in_cnhw_device_buf.ToDevice(in_cnhw.mData.data()); + wei_csrk_device_buf.ToDevice(wei_csrk.mData.data()); + out_knhw_device_buf.ToDevice(out_knhw.mData.data()); + + for(unsigned i = 0; i < nrepeat; ++i) + { + cudaEvent_t start, stop; + float elapsedTime; + cudaEventCreate(&start); + cudaEventRecord(start, 0); + + gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2 + <<>>(in_cnhw_desc, + static_cast(in_cnhw_device_buf.GetDeviceBuffer()), + wei_csrk_desc, + static_cast(wei_csrk_device_buf.GetDeviceBuffer()), + out_knhw_desc, + static_cast(out_knhw_device_buf.GetDeviceBuffer())); + + cudaEventCreate(&stop); + cudaEventRecord(stop, 0); + cudaEventSynchronize(stop); + + cudaEventElapsedTime(&elapsedTime, start, stop); + printf("Elapsed time : %f ms\n", elapsedTime); + + usleep(std::min(elapsedTime * 1000, float(10000))); + } + + checkCudaErrors(cudaGetLastError()); + out_knhw_device_buf.FromDevice(out_knhw.mData.data()); + + // convert out_knhw to out_nkhw + auto f_reorder_knhw2nkhw = [&](auto n, auto k, auto ho, auto wo) { + out_nkhw(n, k, ho, wo) = out_knhw(k, n, ho, wo); + }; + + make_ParallelTensorFunctor(f_reorder_knhw2nkhw, N, K, Ho, Wo)( + std::thread::hardware_concurrency()); +} diff --git a/src/include/blockwise_gemm.cuh b/src/include/blockwise_gemm.cuh index 7d99a96eab..49ceeec168 100644 --- a/src/include/blockwise_gemm.cuh +++ b/src/include/blockwise_gemm.cuh @@ -22,9 +22,9 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC struct MatrixIndex { - unsigned batch_begin; - unsigned row_begin; - unsigned col_begin; + unsigned batch; + unsigned row; + unsigned col; }; __device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC() @@ -32,15 +32,15 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile - const auto c_thread_mtx_index = CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - mMyThreadOffsetA = c_thread_mtx_index.batch_begin * BlockMatrixStrideA + - ((!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row_begin, 0) - : a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row_begin)); + mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA + + ((!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row, 0) + : a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row)); - mMyThreadOffsetB = c_thread_mtx_index.batch_begin * BlockMatrixStrideB + - ((!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col_begin) - : b_block_mtx.Get1dIndex(c_thread_mtx_index.col_begin, 0)); + mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB + + ((!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col) + : b_block_mtx.Get1dIndex(c_thread_mtx_index.col, 0)); #if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) @@ -52,16 +52,16 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC printf("%u %u, %u %u %u, %u %u\n", get_block_1d_id(), get_thread_local_1d_id(), - c_thread_mtx_index.batch_begin, - c_thread_mtx_index.row_begin, - c_thread_mtx_index.col_begin, + c_thread_mtx_index.batch, + c_thread_mtx_index.row, + c_thread_mtx_index.col, mMyThreadOffsetA, mMyThreadOffsetB); } #endif } - __device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const + __device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const { if(TransA && (!TransB) && (!TransC)) @@ -237,8 +237,8 @@ struct BlockwiseGemmBlockABlockBThreadC struct MatrixIndex { - unsigned row_begin; - unsigned col_begin; + unsigned row; + unsigned col; }; __device__ BlockwiseGemmBlockABlockBThreadC() @@ -246,13 +246,13 @@ struct BlockwiseGemmBlockABlockBThreadC const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile - const auto c_thread_mtx_index = CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - mMyThreadOffsetA = (!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row_begin, 0) - : a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row_begin); + mMyThreadOffsetA = (!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row, 0) + : a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row); - mMyThreadOffsetB = (!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col_begin) - : b_block_mtx.Get1dIndex(c_thread_mtx_index.col_begin, 0); + mMyThreadOffsetB = (!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col) + : b_block_mtx.Get1dIndex(c_thread_mtx_index.col, 0); #if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) @@ -264,16 +264,16 @@ struct BlockwiseGemmBlockABlockBThreadC printf("%u %u, %u %u %u, %u %u\n", get_block_1d_id(), get_thread_local_1d_id(), - c_thread_mtx_index.batch_begin, - c_thread_mtx_index.row_begin, - c_thread_mtx_index.col_begin, + c_thread_mtx_index.batch, + c_thread_mtx_index.row, + c_thread_mtx_index.col, mMyThreadOffsetA, mMyThreadOffsetB); } #endif } - __device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const + __device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const { if(TransA && (!TransB) && (!TransC)) @@ -359,6 +359,13 @@ struct BlockwiseGemmBlockABlockBThreadC } } + // this should be optimized away if input is known + __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c, + unsigned n_in_c) + { + return MatrixIndex{m_in_c, n_in_c}; + } + template __device__ void Run(FloatA* const p_a_block, FloatB* const p_b_block, @@ -420,3 +427,215 @@ struct BlockwiseGemmBlockABlockBThreadC } } }; + +// if following number are power of 2, index calculation shall be greatly reduced: +// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster +template +struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 +{ + struct MatrixIndex + { + unsigned row; + unsigned col; + }; + + unsigned mMyThreadOffsetA; + unsigned mMyThreadOffsetB; + + __device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2() + { + constexpr unsigned ThreadPerLevel1Cluster = + MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; + + static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n"); + + const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile + const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile + const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile + + static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), + "wrong! K dimension not consistent\n"); + + constexpr unsigned M = a_block_mtx.NCol(); // A is transposed + constexpr unsigned N = b_block_mtx.NCol(); + constexpr unsigned K = a_block_mtx.NRow(); + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0), + "wrong! Cannot evenly divide thread work among repeat \n"); + + constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; + constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + + static_assert((M % MRepeat == 0) && (N % NRepeat == 0), + "wrong! Cannot evenly divide work among repeat\n"); + + constexpr unsigned MPerLevel1Cluster = M / MRepeat; + constexpr unsigned NPerLevel1Cluster = N / NRepeat; + + static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) && + (NPerLevel1Cluster % NLevel1Cluster == 0), + "wrong! Cannot evenly divide work among Level1Cluster\n"); + + constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster; + constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster; + + static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) && + (NPerLevel0Cluster % NLevel0Cluster == 0), + "wrong! Cannot evenly divide work among Level0Cluster\n"); + + static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) && + (NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster), + "wrong! thread work size is wrong\n"); + + auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + mMyThreadOffsetA = a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row); + mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col); + } + + __device__ static MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) + { + constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster; + + unsigned level1_id = thread_id / ThreadPerLevel0Cluster; + unsigned level1_m_id = level1_id / NLevel1Cluster; + unsigned level1_n_id = level1_id % NLevel1Cluster; + + unsigned level0_id = thread_id % ThreadPerLevel0Cluster; + unsigned level0_m_id = level0_id / NLevel0Cluster; + unsigned level0_n_id = level0_id % NLevel0Cluster; + + constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster; + constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster; + + return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC, + level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; + } + + // this should be optimized away if input is known + __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c, + unsigned n_in_c) + { + const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; + constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + + constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + + unsigned m_repeat = m_in_c / MPerThreadSubC; + unsigned n_repeat = n_in_c / NPerThreadSubC; + + unsigned m_in_sub_c = m_in_c % MPerThreadSubC; + unsigned n_in_sub_c = n_in_c % NPerThreadSubC; + + return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c, + n_repeat * NPerLevel1Cluster + n_in_sub_c}; + } + + template + __device__ void Run(FloatA* const p_a_block, + FloatB* const p_b_block, + FloatC* p_c_thread, + Accumulator f_accum) const + { + constexpr auto True = Constant{}; + constexpr auto False = Constant{}; + + const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile + const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile + const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile + + constexpr unsigned M = a_block_mtx.NCol(); + constexpr unsigned N = b_block_mtx.NCol(); + constexpr unsigned K = a_block_mtx.NRow(); + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + // thread A, B for GEMM + const auto a_thread_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + + const auto b_thread_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + + // thread A-sub, B-sub for copy + const auto a_thread_sub_mtx = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); // constexpr doesn't compile + + const auto b_thread_sub_mtx = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); // constexpr doesn't compile + + FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; + FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; + + constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + + constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + + constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; + constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + + // loop over k + for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) + { + // copy A-sub to form A + for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + { + threadwise_matrix_copy(a_block_mtx, + p_a_block + mMyThreadOffsetA + + k_begin * a_block_mtx.RowStride() + + m_repeat * MPerLevel1Cluster, + a_thread_sub_mtx, + p_a_thread + m_repeat * MPerThreadSubC, + a_thread_sub_mtx.GetLengths()); + } + + // copy B-sub to form B + for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + { + threadwise_matrix_copy(b_block_mtx, + p_b_block + mMyThreadOffsetB + + k_begin * b_block_mtx.RowStride() + + n_repeat * NPerLevel1Cluster, + b_thread_sub_mtx, + p_b_thread + n_repeat * NPerThreadSubC, + b_thread_sub_mtx.GetLengths()); + } + + // C = A * B + threadwise_gemm(a_thread_mtx, + True, + p_a_thread, + b_thread_mtx, + False, + p_b_thread, + c_thread_mtx, + False, + p_c_thread, + f_accum); + } + } +}; diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh index 0f32ca42ec..c25b98a801 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh @@ -208,13 +208,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc, } const auto matrix_c_index = - blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; - const unsigned k_thread_data_begin = matrix_c_index.row_begin; - const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock; - const unsigned n_thread_data_begin = - matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock; + const unsigned ho_thread_data_begin = matrix_c_index.batch; + const unsigned k_thread_data_begin = matrix_c_index.row; + const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock; + const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock; #if 0 printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n", diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh index 0c58bcc4de..91830ac076 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh @@ -262,13 +262,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri } const auto matrix_c_index = - blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; - const unsigned k_thread_data_begin = matrix_c_index.row_begin; - const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock; - const unsigned n_thread_data_begin = - matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock; + const unsigned ho_thread_data_begin = matrix_c_index.batch; + const unsigned k_thread_data_begin = matrix_c_index.row; + const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock; + const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock; #if 0 printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n", diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_pipeline.cuh b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_pipeline.cuh index b1be45968a..3726f12f1c 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_pipeline.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_pipeline.cuh @@ -318,13 +318,12 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p } const auto matrix_c_index = - blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; - const unsigned k_thread_data_begin = matrix_c_index.row_begin; - const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock; - const unsigned n_thread_data_begin = - matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock; + const unsigned ho_thread_data_begin = matrix_c_index.batch; + const unsigned k_thread_data_begin = matrix_c_index.row; + const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock; + const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock; #if 0 printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n", diff --git a/src/include/gridwise_implicit_gemm_convolution_1_nchw_kcsr.cuh b/src/include/gridwise_implicit_gemm_convolution_1_nchw_kcsr.cuh index a7e6810883..62d2fef543 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_nchw_kcsr.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_1_nchw_kcsr.cuh @@ -228,15 +228,15 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc, } const auto matrix_c_index = - blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); #if 0 - printf("%u %u, %u %u %u\n",get_block_1d_id(), get_thread_local_1d_id(), matrix_c_index.batch_begin, matrix_c_index.row_begin, matrix_c_index.col_begin); + printf("%u %u, %u %u %u\n",get_block_1d_id(), get_thread_local_1d_id(), matrix_c_index.batch, matrix_c_index.row, matrix_c_index.col); #endif - const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; - const unsigned k_thread_data_begin = matrix_c_index.row_begin; - const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread; + const unsigned ho_thread_data_begin = matrix_c_index.batch; + const unsigned k_thread_data_begin = matrix_c_index.row; + const unsigned wo_thread_data_begin = matrix_c_index.col / NPerThread; #if 1 // output: register to global mem, diff --git a/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh b/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh index abd143c02d..9c761bbc0f 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh @@ -205,13 +205,12 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, } const auto matrix_c_index = - blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; - const unsigned k_thread_data_begin = matrix_c_index.row_begin; - const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock; - const unsigned n_thread_data_begin = - matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock; + const unsigned ho_thread_data_begin = matrix_c_index.batch; + const unsigned k_thread_data_begin = matrix_c_index.row; + const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock; + const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock; // output: register to global mem, // convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo] diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh index 877c595ba5..5d2013ea56 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh @@ -75,6 +75,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc, constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); // tensor view of blockwise input and weight + // be careful of alignment constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( Sequence{}, Number{}); @@ -245,11 +246,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc, } // output: register to global mem, - const auto matrix_c_index = - blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - const unsigned k_thread_data_begin = matrix_c_index.row_begin; - const unsigned b_thread_data_begin = matrix_c_index.col_begin; + const unsigned k_thread_data_begin = matrix_c_index.row; + const unsigned b_thread_data_begin = matrix_c_index.col; const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; @@ -257,11 +257,11 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc, #if 0 if(get_block_1d_id() == 0) { - printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", + printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", get_block_1d_id(), get_thread_local_1d_id(), - matrix_c_index.row_begin, - matrix_c_index.col_begin, + matrix_c_index.row, + matrix_c_index.col, k_data_begin, b_data_begin, p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh new file mode 100644 index 0000000000..3574381026 --- /dev/null +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh @@ -0,0 +1,327 @@ +#pragma once +#include "common.cuh" +#include "ConstantTensorDescriptor.cuh" +#include "ConstantMatrixDescriptor.cuh" +#include "blockwise_4d_tensor_op.cuh" +#include "blockwise_2d_tensor_op.cuh" +#include "threadwise_2d_tensor_op.cuh" +#include "blockwise_gemm.cuh" + +// define B = flatten(N, Hi, Wi) +template +__global__ void +gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InGlobalDesc, + Float* const __restrict__ p_in_global, + WeiGlobalDesc, + Float* const __restrict__ p_wei_global, + OutGlobalDesc, + Float* __restrict__ p_out_global) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_cnhw_global_desc = InGlobalDesc{}; + constexpr auto wei_csrk_global_desc = WeiGlobalDesc{}; + constexpr auto out_knhw_global_desc = OutGlobalDesc{}; + + constexpr unsigned C = in_cnhw_global_desc.GetLength(I0); + constexpr unsigned N = in_cnhw_global_desc.GetLength(I1); + constexpr unsigned Hi = in_cnhw_global_desc.GetLength(I2); + constexpr unsigned Wi = in_cnhw_global_desc.GetLength(I3); + + constexpr unsigned K = out_knhw_global_desc.GetLength(I0); + constexpr unsigned Ho = out_knhw_global_desc.GetLength(I2); + constexpr unsigned Wo = out_knhw_global_desc.GetLength(I3); + + constexpr unsigned S = wei_csrk_global_desc.GetLength(I1); + constexpr unsigned R = wei_csrk_global_desc.GetLength(I2); + + constexpr unsigned B = N * Hi * Wi; + constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1); + + // divide block work by 2d: [K, B] + constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock; + + const unsigned k_block_work_id = get_block_1d_id() / BBlockWork; + const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; + + const unsigned k_block_data_begin = k_block_work_id * KPerBlock; + const unsigned b_block_data_begin = b_block_work_id * BPerBlock; + + // flattend (2d) tensor view of gridwise input + constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); + constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); + + // tensor view of blockwise input and weight + // be careful of alignment + constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + // tensor view of threadwise output in register + constexpr auto out_kb_thread_desc = + make_ConstantTensorDescriptor(Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_cnhw_global_desc, "in_cnhw_global_desc"); + print_ConstantTensorDescriptor(wei_csrk_global_desc, "wei_csrk_global_desc"); + print_ConstantTensorDescriptor(out_knhw_global_desc, "out_knhw_global_desc"); + + print_ConstantTensorDescriptor(in_cb_global_desc, "in_cb_global_desc"); + print_ConstantTensorDescriptor(wei_ek_global_desc, "wei_ek_global_desc"); + + print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc"); + print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc"); + print_ConstantTensorDescriptor(wei_ek_block_desc, "wei_ek_block_desc"); + print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc"); + + printf("KPerBlock %u\n", KPerBlock); + } +#endif + + // blockwise in copy + // formmat is [CPerBlock,BPerBlock + BGhostRead] +#if 0 + const auto blockwise_in_copy = + Blockwise2dTensorCopy1{}; +#elif 0 + const auto blockwise_in_copy = Blockwise2dTensorCopy2{}; +#elif 1 + const auto blockwise_in_copy = Blockwise2dTensorCopy3{}; +#endif + + // blockwise wei copy + // format is [CPerBlock*S*R,KPerBlock] +#if 0 + const auto blockwise_wei_copy = + Blockwise2dTensorCopy1{}; +#elif 0 + const auto blockwise_wei_copy = Blockwise2dTensorCopy2{}; +#elif 1 + const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}; +#endif + + // a series of blockwise GEMM + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx and b_mtx saved in LDS, c_mtx saved in register + // a_mtx[C,K] is a sub-matrix of wei_block[C,S,R,K] + // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] + // c_mtx[K,B] is out_block[K,B] + const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, + Number{}, + Number{}); // constexpr doesn't compile + + const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, + Number{}, + Number{}); // constexpr doesn't compile + + const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + +#if 0 + const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC{}; +#else + const auto blockwise_gemm = + BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2{}; +#endif + + // LDS: be careful of alignment + constexpr unsigned in_block_size = + in_cb_block_desc.GetElementSpace(Number{}); + + constexpr unsigned wei_block_size = + wei_csrk_block_desc.GetElementSpace(Number{}); + + constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead + ? InBlockCopyDataPerRead + : WeiBlockCopyDataPerRead; + + __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; + __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; + + // register + Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; + + // set threadwise output tensor to 0 + threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); + + Float* p_in_global_block_offset = + p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); + + Float* p_wei_global_block_offset = + p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); + + for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, + p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), + p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0), + __syncthreads()) + { + // input: global mem to LDS, + blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); + + // weight: global mem to LDS, + blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); + + __syncthreads(); + + // a series of GEMM + for(unsigned s = 0; s < S; ++s) + { + for(unsigned r = 0; r < R; ++r) + { + auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; + + blockwise_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), + p_in_block + s * Wi + r, + p_out_thread, + f_accum); + } + } + } + + // output: register to global mem, + const auto c_thread_mtx_begin = + blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; + const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; + +#if 0 + if(get_block_1d_id() == 0) + { + printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", + get_block_1d_id(), + get_thread_local_1d_id(), + matrix_c_index.row, + matrix_c_index.col, + k_data_begin, + b_data_begin, + p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); + } +#endif + + for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) + { + for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) + { + const auto c_thread_mtx_distance = + blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); + + unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row; + unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col; + + unsigned n_data = b_data / (Hi * Wi); + unsigned itmp = b_data - n_data * (Hi * Wi); + unsigned h_data = itmp / Wi; + unsigned w_data = itmp - h_data * Wi; + +#if 0 + if(get_block_1d_id() == 0) + { + printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n", + get_block_1d_id(), + get_thread_local_1d_id(), + k, + b, + k_data, + n_data, + h_data, + w_data, + p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]); + } +#endif + if(n_data < N && h_data < Ho && w_data < Wo) + { + p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] = + p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; + } + } + } +} diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh index 70cb8a465f..26ab71fa26 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh @@ -290,11 +290,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline } // output: register to global mem, - const auto matrix_c_index = - blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - const unsigned k_thread_data_begin = matrix_c_index.row_begin; - const unsigned b_thread_data_begin = matrix_c_index.col_begin; + const unsigned k_thread_data_begin = matrix_c_index.row; + const unsigned b_thread_data_begin = matrix_c_index.col; const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; @@ -302,11 +301,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline #if 0 if(get_block_1d_id() == 0) { - printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", + printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", get_block_1d_id(), get_thread_local_1d_id(), - matrix_c_index.row_begin, - matrix_c_index.col_begin, + matrix_c_index.row, + matrix_c_index.col, k_data_begin, b_data_begin, p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh index c60108e003..1f8cb152eb 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh @@ -217,11 +217,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, } // output: register to global mem, - const auto matrix_c_index = - blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - const unsigned k_thread_data_begin = matrix_c_index.row_begin; - const unsigned b_thread_data_begin = matrix_c_index.col_begin; + const unsigned k_thread_data_begin = matrix_c_index.row; + const unsigned b_thread_data_begin = matrix_c_index.col; const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; @@ -229,11 +228,11 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, #if 0 if(get_block_1d_id() == 0) { - printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", + printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", get_block_1d_id(), get_thread_local_1d_id(), - matrix_c_index.row_begin, - matrix_c_index.col_begin, + matrix_c_index.row, + matrix_c_index.col, k_data_begin, b_data_begin, p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.cuh index c54919cdf4..d9de0da7b0 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.cuh @@ -276,11 +276,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline } // output: register to global mem, - const auto matrix_c_index = - blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - const unsigned k_thread_data_begin = matrix_c_index.row_begin; - const unsigned b_thread_data_begin = matrix_c_index.col_begin; + const unsigned k_thread_data_begin = matrix_c_index.row; + const unsigned b_thread_data_begin = matrix_c_index.col; const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; @@ -288,11 +287,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline #if 0 if(get_block_1d_id() == 0) { - printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", + printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", get_block_1d_id(), get_thread_local_1d_id(), - matrix_c_index.row_begin, - matrix_c_index.col_begin, + matrix_c_index.row, + matrix_c_index.col, k_data_begin, b_data_begin, p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);