From 07f16673c98ab0952fe09a58e362714f58998ebb Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 7 Feb 2019 00:56:53 -0600 Subject: [PATCH] add lds double buffer for cnhw implicit gemm --- driver/conv.cu | 7 +- ...icit_gemm_convolution_2_cnhw_csrk_knhw.cuh | 78 +++- ...mm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh | 180 --------- src/include/blockwise_gemm.cuh | 7 +- ...icit_gemm_convolution_2_cnhw_csrk_knhw.cuh | 52 ++- ...on_2_cnhw_csrk_knhw_lds_double_buffer.cuh} | 102 +++-- ...volution_2_cnhw_csrk_knhw_lds_pipeline.cuh | 351 ------------------ 7 files changed, 180 insertions(+), 597 deletions(-) delete mode 100644 driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh rename src/include/{gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh => gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.cuh} (84%) delete mode 100644 src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh diff --git a/driver/conv.cu b/driver/conv.cu index cbb783d344..0941bf2733 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -14,7 +14,6 @@ #include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh" #include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh" #include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh" -#include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh" //#include "device_winograd_convolution.cuh" struct GeneratorTensor_1 @@ -392,7 +391,7 @@ int main() constexpr unsigned HPad = 0; constexpr unsigned WPad = 0; -#elif 0 +#elif 1 // 3x3, 34x34 constexpr unsigned N = 64; constexpr unsigned C = 256; @@ -592,10 +591,8 @@ int main() device_implicit_gemm_convolution_1_chwn_csrk_khwn #elif 0 device_implicit_gemm_convolution_2_cnhw_srck_knhw -#elif 0 - device_implicit_gemm_convolution_2_cnhw_csrk_knhw #elif 1 - device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2 + device_implicit_gemm_convolution_2_cnhw_csrk_knhw #endif (in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat); diff --git a/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh b/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh index a3d66b8a94..de439b54bf 100644 --- a/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh +++ b/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh @@ -1,6 +1,6 @@ #pragma once #include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh" -#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh" +#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.cuh" #include template @@ -67,17 +67,24 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, Tensor out_knhw(make_TensorDescriptor(out_knhw_desc)); -#if 0 +#if 1 // 3x3, 34x34 constexpr unsigned BPerBlock = 128; constexpr unsigned KPerBlock = 64; constexpr unsigned CPerBlock = 4; - constexpr unsigned BPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; + constexpr unsigned BPerThread = 8; + constexpr unsigned KPerThread = 8; - constexpr unsigned GemmThreadPerColumnPerCluster = 4; + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 8; + constexpr unsigned GemmNLevel0Cluster = 2; + constexpr unsigned GemmMLevel1Cluster = 1; + constexpr unsigned GemmNLevel1Cluster = 8; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned GemmThreadPerColumnPerCluster = 8; constexpr unsigned GemmThreadPerRowPerCluster = 8; constexpr unsigned InBlockCopyThreadPerDim0 = 4; @@ -90,17 +97,24 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr unsigned BlockSize = 128; -#elif 1 +#elif 0 // 1x1, 28x28 constexpr unsigned BPerBlock = 64; constexpr unsigned KPerBlock = 64; constexpr unsigned CPerBlock = 8; - constexpr unsigned BPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; + constexpr unsigned BPerThread = 8; + constexpr unsigned KPerThread = 8; - constexpr unsigned GemmThreadPerColumnPerCluster = 4; + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 8; + constexpr unsigned GemmNLevel0Cluster = 2; + constexpr unsigned GemmMLevel1Cluster = 1; + constexpr unsigned GemmNLevel1Cluster = 4; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned GemmThreadPerColumnPerCluster = 8; constexpr unsigned GemmThreadPerRowPerCluster = 8; constexpr unsigned InBlockCopyThreadPerDim0 = 4; @@ -113,6 +127,36 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr unsigned BlockSize = 64; +#elif 1 + // 1x1, 28x28 try + constexpr unsigned BPerBlock = 128; + constexpr unsigned KPerBlock = 128; + constexpr unsigned CPerBlock = 8; + + constexpr unsigned BPerThread = 8; + constexpr unsigned KPerThread = 8; + + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 8; + constexpr unsigned GemmNLevel0Cluster = 4; + constexpr unsigned GemmMLevel1Cluster = 2; + constexpr unsigned GemmNLevel1Cluster = 4; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned GemmThreadPerColumnPerCluster = 8; + constexpr unsigned GemmThreadPerRowPerCluster = 8; + + constexpr unsigned InBlockCopyThreadPerDim0 = 4; + constexpr unsigned InBlockCopyThreadPerDim1 = 16; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; + + constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned BlockSize = 256; #endif constexpr unsigned GridSize = @@ -143,8 +187,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, #if 1 gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw -#elif 1 - gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline +#else + gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer #endif - -template -void device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcsr, - OutDesc, - Tensor& out_nkhw, - unsigned nrepeat) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcsr_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr unsigned N = in_nchw_desc.GetLength(I0); - constexpr unsigned Hi = in_nchw_desc.GetLength(I2); - constexpr unsigned Wi = in_nchw_desc.GetLength(I3); - - constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); - constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); - - constexpr unsigned K = wei_kcsr_desc.GetLength(I0); - constexpr unsigned C = wei_kcsr_desc.GetLength(I1); - constexpr unsigned S = wei_kcsr_desc.GetLength(I2); - constexpr unsigned R = wei_kcsr_desc.GetLength(I3); - - constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1); - - // convert in_nchw to in_cnhw - auto in_cnhw_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(in_cnhw_desc, std::cout << "in_cnhw_desc: "); - - Tensor in_cnhw(make_TensorDescriptor(in_cnhw_desc)); - - auto f_reorder_nchw2cnhw = [&](auto n, auto c, auto hi, auto wi) { - in_cnhw(c, n, hi, wi) = in_nchw(n, c, hi, wi); - }; - - make_ParallelTensorFunctor(f_reorder_nchw2cnhw, N, C, Hi, Wi)( - std::thread::hardware_concurrency()); - - // convert wei_kcsr to wei_csrk - auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: "); - - Tensor wei_csrk(make_TensorDescriptor(wei_csrk_desc)); - - auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto s, auto r) { - wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r); - }; - - make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, S, R)( - std::thread::hardware_concurrency()); - - // conver out_nkhw to out_knhw - auto out_knhw_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(out_knhw_desc, std::cout << "out_knhw_desc: "); - - Tensor out_knhw(make_TensorDescriptor(out_knhw_desc)); - -#if 1 - // 1x1, 28x28 - constexpr unsigned BPerBlock = 64; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 8; - - constexpr unsigned BPerThread = 8; - constexpr unsigned KPerThread = 8; - - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 8; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 1; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; - - constexpr unsigned GemmThreadPerColumnPerCluster = 8; - constexpr unsigned GemmThreadPerRowPerCluster = 8; - - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; - - constexpr unsigned InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; - - constexpr unsigned BlockSize = 64; -#endif - - constexpr unsigned GridSize = - ((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); - - dim3 block_dim(BlockSize); - dim3 grid_dim(GridSize); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - // mem - std::size_t data_sz = sizeof(T); - DeviceMem in_cnhw_device_buf(data_sz * (in_cnhw.mDesc.GetElementSpace() + BGhostRead + - BPerBlock)); // reserve extra space for BGhostRead - DeviceMem wei_csrk_device_buf(data_sz * wei_csrk.mDesc.GetElementSpace()); - DeviceMem out_knhw_device_buf(data_sz * out_knhw.mDesc.GetElementSpace()); - - in_cnhw_device_buf.ToDevice(in_cnhw.mData.data()); - wei_csrk_device_buf.ToDevice(wei_csrk.mData.data()); - out_knhw_device_buf.ToDevice(out_knhw.mData.data()); - - for(unsigned i = 0; i < nrepeat; ++i) - { - cudaEvent_t start, stop; - float elapsedTime; - cudaEventCreate(&start); - cudaEventRecord(start, 0); - - gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2 - <<>>(in_cnhw_desc, - static_cast(in_cnhw_device_buf.GetDeviceBuffer()), - wei_csrk_desc, - static_cast(wei_csrk_device_buf.GetDeviceBuffer()), - out_knhw_desc, - static_cast(out_knhw_device_buf.GetDeviceBuffer())); - - cudaEventCreate(&stop); - cudaEventRecord(stop, 0); - cudaEventSynchronize(stop); - - cudaEventElapsedTime(&elapsedTime, start, stop); - printf("Elapsed time : %f ms\n", elapsedTime); - - usleep(std::min(elapsedTime * 1000, float(10000))); - } - - checkCudaErrors(cudaGetLastError()); - out_knhw_device_buf.FromDevice(out_knhw.mData.data()); - - // convert out_knhw to out_nkhw - auto f_reorder_knhw2nkhw = [&](auto n, auto k, auto ho, auto wo) { - out_nkhw(n, k, ho, wo) = out_knhw(k, n, ho, wo); - }; - - make_ParallelTensorFunctor(f_reorder_knhw2nkhw, N, K, Ho, Wo)( - std::thread::hardware_concurrency()); -} diff --git a/src/include/blockwise_gemm.cuh b/src/include/blockwise_gemm.cuh index 29c0cf7695..dfcd1c4c88 100644 --- a/src/include/blockwise_gemm.cuh +++ b/src/include/blockwise_gemm.cuh @@ -680,6 +680,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 Number{}, Number{}); // constexpr doesn't compile + // register FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()]; FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()]; @@ -687,7 +688,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()]; constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; @@ -717,7 +717,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 bool even_loop = true; #pragma unroll - for(unsigned k_begin = 0; k_begin + 1 < K; + for(unsigned k_begin = 0; k_begin + KPerThreadLoop < K; k_begin += KPerThreadLoop, even_loop = !even_loop) { // loop over k FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1; @@ -727,7 +727,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0; // preload next A, B - #pragma unroll for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { // copy A-sub to form A @@ -767,8 +766,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 // last loop { - even_loop = !even_loop; - FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1; FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1; diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh index 92b549bb37..d4bfd47396 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh @@ -19,9 +19,15 @@ template {}, Number{}); // constexpr doesn't compile +#if 0 const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC{}; +#else + const auto blockwise_gemm = + BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2{}; +#endif // LDS: be careful of alignment constexpr unsigned in_block_size = @@ -237,27 +258,25 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc, { auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; -#if 1 +#if 0 blockwise_gemm.Run #else blockwise_gemm.Run_RegisterDoubleBuffer #endif - (p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), - p_in_block + s * Wi + r, - p_out_thread, - f_accum); + (p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), + p_in_block + s * Wi + r, + p_out_thread, + f_accum); } } } // output: register to global mem, - const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + const auto c_thread_mtx_begin = + blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - const unsigned k_thread_data_begin = matrix_c_index.row; - const unsigned b_thread_data_begin = matrix_c_index.col; - - const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; - const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; + const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; + const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; #if 0 if(get_block_1d_id() == 0) @@ -277,8 +296,11 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc, { for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) { - unsigned k_data = k_data_begin + k; - unsigned b_data = b_data_begin + b; + const auto c_thread_mtx_distance = + blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); + + unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row; + unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col; unsigned n_data = b_data / (Hi * Wi); unsigned itmp = b_data - n_data * (Hi * Wi); diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.cuh similarity index 84% rename from src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh rename to src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.cuh index 3574381026..7140de8c18 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.cuh @@ -34,13 +34,13 @@ template -__global__ void -gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InGlobalDesc, - Float* const __restrict__ p_in_global, - WeiGlobalDesc, - Float* const __restrict__ p_wei_global, - OutGlobalDesc, - Float* __restrict__ p_out_global) +__global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer( + InGlobalDesc, + Float* const __restrict__ p_in_global, + WeiGlobalDesc, + Float* const __restrict__ p_wei_global, + OutGlobalDesc, + Float* __restrict__ p_out_global) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -123,7 +123,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InGlobalDesc, decltype(in_cb_global_desc), decltype(in_cb_block_desc), decltype(in_cb_block_desc.GetLengths())>{}; -#elif 0 +#elif 1 const auto blockwise_in_copy = Blockwise2dTensorCopy2{}; -#elif 0 +#elif 1 const auto blockwise_wei_copy = Blockwise2dTensorCopy2 -__global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline( - InGlobalDesc, - Float* const __restrict__ p_in_global, - WeiGlobalDesc, - Float* const __restrict__ p_wei_global, - OutGlobalDesc, - Float* __restrict__ p_out_global) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_cnhw_global_desc = InGlobalDesc{}; - constexpr auto wei_csrk_global_desc = WeiGlobalDesc{}; - constexpr auto out_knhw_global_desc = OutGlobalDesc{}; - - constexpr unsigned C = in_cnhw_global_desc.GetLength(I0); - constexpr unsigned N = in_cnhw_global_desc.GetLength(I1); - constexpr unsigned Hi = in_cnhw_global_desc.GetLength(I2); - constexpr unsigned Wi = in_cnhw_global_desc.GetLength(I3); - - constexpr unsigned K = out_knhw_global_desc.GetLength(I0); - constexpr unsigned Ho = out_knhw_global_desc.GetLength(I2); - constexpr unsigned Wo = out_knhw_global_desc.GetLength(I3); - - constexpr unsigned S = wei_csrk_global_desc.GetLength(I1); - constexpr unsigned R = wei_csrk_global_desc.GetLength(I2); - - constexpr unsigned B = N * Hi * Wi; - constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1); - - // divide block work by 2d: [K, B] - constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock; - - const unsigned k_block_work_id = get_block_1d_id() / BBlockWork; - const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; - - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned b_block_data_begin = b_block_work_id * BPerBlock; - -#if 0 - if(get_thread_local_1d_id() == 0) - { - printf("K %u B %u, BGhostRead %u\n", K, B, BGhostRead); - - printf("%u %u, KBlockWork %u BBlockWork %u, k_block_data_begin %u b_block_data_begin %u\n", - get_block_1d_id(), - get_thread_local_1d_id(), - KBlockWork, - BBlockWork, - k_block_data_begin, - b_block_data_begin); - } -#endif - - // flattend (2d) tensor view of gridwise input - constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); - - constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of blockwise input and weight - constexpr auto in_cb_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - constexpr auto wei_ek_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - constexpr auto wei_csrk_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of threadwise output in register - constexpr auto out_kb_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc"); - print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc"); - print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc"); - - printf("KPerBlock %u\n", KPerBlock); - } -#endif - - // blockwise in copy - // formmat is [CPerBlock,BPerBlock + BGhostRead] -#if 0 - const auto blockwise_in_copy = - Blockwise2dTensorCopy1{}; -#elif 0 - const auto blockwise_in_copy = Blockwise2dTensorCopy2{}; -#elif 1 - const auto blockwise_in_copy = Blockwise2dTensorCopy3{}; -#endif - - // blockwise wei copy - // format is [CPerBlock*S*R,KPerBlock] -#if 0 - const auto blockwise_wei_copy = - Blockwise2dTensorCopy1{}; -#elif 0 - const auto blockwise_wei_copy = Blockwise2dTensorCopy2{}; -#elif 1 - const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}; -#endif - - // a series of blockwise GEMM - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx and b_mtx saved in LDS, c_mtx saved in register - // a_mtx[C,K] is a sub-matrix of wei_block[S,R,C,K] - // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] - // c_mtx[K,B] is out_block[K,B] - const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile - - const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, - Number{}, - Number{}); // constexpr doesn't compile - - const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC{}; - - // LDS - constexpr unsigned in_block_size = in_cb_block_desc.GetElementSpace(); - constexpr unsigned wei_block_size = wei_csrk_block_desc.GetElementSpace(); - - // LDS double buffer - __shared__ Float p_in_block_0[in_block_size]; - __shared__ Float p_wei_block_0[wei_block_size]; - - __shared__ Float p_in_block_1[in_block_size]; - __shared__ Float p_wei_block_1[wei_block_size]; - - // register - Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; - - Float* p_in_global_block_offset = - p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); - - Float* p_wei_global_block_offset = - p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); - - // prelog : preload data - // input: global mem to LDS, - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_0); - - // weight: global mem to LDS, - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_0); - - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0); - - p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I2); - - // set threadwise output tensor to 0 - threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); - - bool even_loop = true; - - for(unsigned c_block_data_begin = CPerBlock; c_block_data_begin < C; - c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I2), - even_loop = !even_loop) - { - __syncthreads(); - - Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; - Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; - - Float* p_in_block_next = even_loop ? p_in_block_1 : p_in_block_0; - Float* p_wei_block_next = even_loop ? p_wei_block_1 : p_wei_block_0; - - // input: global mem to LDS, - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next); - - // weight: global mem to LDS, - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next); - - // a series of GEMM - for(unsigned s = 0; s < S; ++s) - { - for(unsigned r = 0; r < R; ++r) - { - auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; - - blockwise_gemm.Run(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), - p_in_block_now + s * Wi + r, - p_out_thread, - f_accum); - } - } - } - - // last computation - { - __syncthreads(); - - Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; - Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; - - // a series of GEMM - for(unsigned s = 0; s < S; ++s) - { - for(unsigned r = 0; r < R; ++r) - { - auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; - - blockwise_gemm.Run(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(s, r, 0, 0), - p_in_block_now + s * Wi + r, - p_out_thread, - f_accum); - } - } - } - - // output: register to global mem, - const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const unsigned k_thread_data_begin = matrix_c_index.row; - const unsigned b_thread_data_begin = matrix_c_index.col; - - const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; - const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; - -#if 0 - if(get_block_1d_id() == 0) - { - printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", - get_block_1d_id(), - get_thread_local_1d_id(), - matrix_c_index.row, - matrix_c_index.col, - k_data_begin, - b_data_begin, - p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); - } -#endif - - for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) - { - for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) - { - unsigned k_data = k_data_begin + k; - unsigned b_data = b_data_begin + b; - - unsigned n_data = b_data / (Hi * Wi); - unsigned itmp = b_data - n_data * (Hi * Wi); - unsigned h_data = itmp / Wi; - unsigned w_data = itmp - h_data * Wi; - -#if 0 - if(get_block_1d_id() == 0) - { - printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n", - get_block_1d_id(), - get_thread_local_1d_id(), - k, - b, - k_data, - n_data, - h_data, - w_data, - p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]); - } -#endif - if(n_data < N && h_data < Ho && w_data < Wo) - { -#if 1 - p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] = - p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; -#endif - } - } - } -}