From c9af4dece0632d688c45e31581e2d389491a8f6b Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 24 Jan 2019 14:28:46 -0600 Subject: [PATCH] implicit gemm: LDS double buffer --- driver/conv.cu | 8 +- driver/device_direct_convolution_1.cuh | 77 +++-- driver/device_direct_convolution_2.cuh | 77 +++-- ...icit_gemm_convolution_2_cnhw_srck_knhw.cuh | 1 + src/include/blockwise_2d_tensor_op.cuh | 47 +++ src/include/blockwise_4d_tensor_op.cuh | 43 ++- src/include/blockwise_direct_convolution.cuh | 4 +- src/include/gridwise_direct_convolution_1.cuh | 55 +-- src/include/gridwise_direct_convolution_2.cuh | 39 ++- ...icit_gemm_convolution_2_cnhw_srck_knhw.cuh | 37 +- ...icit_gemm_convolution_3_cnhw_srck_knhw.cuh | 323 ++++++++++++++++++ 11 files changed, 581 insertions(+), 130 deletions(-) create mode 100644 src/include/gridwise_implicit_gemm_convolution_3_cnhw_srck_knhw.cuh diff --git a/driver/conv.cu b/driver/conv.cu index b09e595e07..d669f5b12c 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -8,8 +8,8 @@ #include "conv_common.cuh" #include "device_direct_convolution_1.cuh" #include "device_direct_convolution_2.cuh" -#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh" -#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh" +//#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh" +//#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh" #include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh" //#include "device_winograd_convolution.cuh" @@ -410,7 +410,7 @@ int main() wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); #endif - unsigned nrepeat = 100; + unsigned nrepeat = 50; #if 0 device_direct_convolution_1 @@ -427,7 +427,7 @@ int main() #endif (in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat); -#if 0 +#if 1 host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host); check_error(out_nkhw_host, out_nkhw_device); #elif 0 diff --git a/driver/device_direct_convolution_1.cuh b/driver/device_direct_convolution_1.cuh index 68dc000173..1029026e67 100644 --- a/driver/device_direct_convolution_1.cuh +++ b/driver/device_direct_convolution_1.cuh @@ -1,9 +1,15 @@ #pragma once #include "gridwise_direct_convolution_1.cuh" +#include template -void device_direct_convolution_1( - InDesc, const Tensor& in, WeiDesc, const Tensor& wei, OutDesc, Tensor& out) +void device_direct_convolution_1(InDesc, + const Tensor& in, + WeiDesc, + const Tensor& wei, + OutDesc, + Tensor& out, + unsigned nrepeat) { std::size_t data_sz = sizeof(T); DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace()); @@ -65,41 +71,46 @@ void device_direct_convolution_1( printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - cudaEvent_t start, stop; - float elapsedTime; + for(unsigned i = 0; i < nrepeat; ++i) + { + cudaEvent_t start, stop; + float elapsedTime; - cudaEventCreate(&start); - cudaEventRecord(start, 0); + cudaEventCreate(&start); + cudaEventRecord(start, 0); - gridwise_direct_convolution_1 - <<>>(InDesc{}, - static_cast(in_device_buf.GetDeviceBuffer()), - WeiDesc{}, - static_cast(wei_device_buf.GetDeviceBuffer()), - OutDesc{}, - static_cast(out_device_buf.GetDeviceBuffer())); + gridwise_direct_convolution_1 + <<>>(InDesc{}, + static_cast(in_device_buf.GetDeviceBuffer()), + WeiDesc{}, + static_cast(wei_device_buf.GetDeviceBuffer()), + OutDesc{}, + static_cast(out_device_buf.GetDeviceBuffer())); - cudaEventCreate(&stop); - cudaEventRecord(stop, 0); - cudaEventSynchronize(stop); + cudaEventCreate(&stop); + cudaEventRecord(stop, 0); + cudaEventSynchronize(stop); - cudaEventElapsedTime(&elapsedTime, start, stop); - printf("Elapsed time : %f ms\n", elapsedTime); + cudaEventElapsedTime(&elapsedTime, start, stop); + printf("Elapsed time : %f ms\n", elapsedTime); + + usleep(10000); + } checkCudaErrors(cudaGetLastError()); out_device_buf.FromDevice(out.mData.data()); diff --git a/driver/device_direct_convolution_2.cuh b/driver/device_direct_convolution_2.cuh index 9e3d8b2d92..95b8fedf1f 100644 --- a/driver/device_direct_convolution_2.cuh +++ b/driver/device_direct_convolution_2.cuh @@ -1,9 +1,15 @@ #pragma once #include "gridwise_direct_convolution_2.cuh" +#include template -void device_direct_convolution_2( - InDesc, const Tensor& in, WeiDesc, const Tensor& wei, OutDesc, Tensor& out) +void device_direct_convolution_2(InDesc, + const Tensor& in, + WeiDesc, + const Tensor& wei, + OutDesc, + Tensor& out, + unsigned nrepeat) { std::size_t data_sz = sizeof(T); DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace()); @@ -79,41 +85,46 @@ void device_direct_convolution_2( printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - cudaEvent_t start, stop; - float elapsedTime; + for(unsigned i = 0; i < nrepeat; ++i) + { + cudaEvent_t start, stop; + float elapsedTime; - cudaEventCreate(&start); - cudaEventRecord(start, 0); + cudaEventCreate(&start); + cudaEventRecord(start, 0); - gridwise_direct_convolution_2 - <<>>(InDesc{}, - static_cast(in_device_buf.GetDeviceBuffer()), - WeiDesc{}, - static_cast(wei_device_buf.GetDeviceBuffer()), - OutDesc{}, - static_cast(out_device_buf.GetDeviceBuffer())); + gridwise_direct_convolution_2 + <<>>(InDesc{}, + static_cast(in_device_buf.GetDeviceBuffer()), + WeiDesc{}, + static_cast(wei_device_buf.GetDeviceBuffer()), + OutDesc{}, + static_cast(out_device_buf.GetDeviceBuffer())); - cudaEventCreate(&stop); - cudaEventRecord(stop, 0); - cudaEventSynchronize(stop); + cudaEventCreate(&stop); + cudaEventRecord(stop, 0); + cudaEventSynchronize(stop); - cudaEventElapsedTime(&elapsedTime, start, stop); - printf("Elapsed time : %f ms\n", elapsedTime); + cudaEventElapsedTime(&elapsedTime, start, stop); + printf("Elapsed time : %f ms\n", elapsedTime); + + usleep(10000); + } checkCudaErrors(cudaGetLastError()); out_device_buf.FromDevice(out.mData.data()); diff --git a/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh b/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh index dd7006d33c..892eff2083 100644 --- a/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh +++ b/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh @@ -1,5 +1,6 @@ #pragma once #include "gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh" +#include "gridwise_implicit_gemm_convolution_3_cnhw_srck_knhw.cuh" #include template diff --git a/src/include/blockwise_2d_tensor_op.cuh b/src/include/blockwise_2d_tensor_op.cuh index 5ce3fad347..5d347de8ef 100644 --- a/src/include/blockwise_2d_tensor_op.cuh +++ b/src/include/blockwise_2d_tensor_op.cuh @@ -347,3 +347,50 @@ struct blockwise_2d_tensor_copy_2 } } }; + +template +struct blockwise_2d_tensor_copy_dummy_1 +{ + unsigned mBegin; + + __device__ blockwise_2d_tensor_copy_dummy_1() + { + constexpr unsigned n_total = + make_ConstantTensorDescriptor(SrcOpLengths{}).GetElementSpace(); + + constexpr unsigned n_per_thread = n_total / BlockSize; + + mBegin = n_per_thread * get_thread_local_1d_id(); + } + + __device__ void run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const + { + constexpr unsigned n_total = + make_ConstantTensorDescriptor(SrcOpLengths{}).GetElementSpace(); + + constexpr unsigned n_per_thread = n_total / BlockSize; + + for(unsigned i = 0; i < n_per_thread; ++i) + { + p_dst[mBegin + i] = p_src[mBegin + i]; + } + } +}; + +template +struct blockwise_2d_tensor_copy_dummy_2 +{ + __device__ void run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const + { + constexpr unsigned n_total = + make_ConstantTensorDescriptor(SrcOpLengths{}).GetElementSpace(); + + constexpr unsigned n_per_thread = n_total / BlockSize; + + for(unsigned i = 0; i < n_per_thread; ++i) + { + unsigned index = get_thread_local_1d_id() + BlockSize * i; + p_dst[index] = p_src[index]; + } + } +}; diff --git a/src/include/blockwise_4d_tensor_op.cuh b/src/include/blockwise_4d_tensor_op.cuh index 13e2093333..4619140b05 100644 --- a/src/include/blockwise_4d_tensor_op.cuh +++ b/src/include/blockwise_4d_tensor_op.cuh @@ -200,11 +200,42 @@ blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, } template -__device__ void blockwise_4d_tensor_copy( - SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths) +struct blockwise_4d_tensor_copy_1 { - constexpr auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{}; + __device__ void run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const + { + constexpr auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{}; - blockwise_4d_tensor_copy_reorder_by_get_dst_from_src( - SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder); -} + blockwise_4d_tensor_copy_reorder_by_get_dst_from_src( + SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder); + } +}; + +template +struct blockwise_4d_tensor_copy_dummy +{ + unsigned mBegin; + + __device__ blockwise_4d_tensor_copy_dummy() + { + constexpr unsigned n_total = + make_ConstantTensorDescriptor(SrcOpLengths{}).GetElementSpace(); + + constexpr unsigned n_per_thread = n_total / BlockSize; + + mBegin = n_per_thread * get_thread_local_1d_id(); + } + + __device__ void run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const + { + constexpr unsigned n_total = + make_ConstantTensorDescriptor(SrcOpLengths{}).GetElementSpace(); + + constexpr unsigned n_per_thread = n_total / BlockSize; + + for(unsigned i = 0; i < n_per_thread; ++i) + { + p_dst[mBegin + i] = p_src[mBegin + i]; + } + } +}; diff --git a/src/include/blockwise_direct_convolution.cuh b/src/include/blockwise_direct_convolution.cuh index be2e58c350..b60c458af3 100644 --- a/src/include/blockwise_direct_convolution.cuh +++ b/src/include/blockwise_direct_convolution.cuh @@ -101,7 +101,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc, wo_thread_data_begin), out_thread_desc, p_out_thread, - out_thread_desc); + out_thread_desc.GetLengths()); for(unsigned c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1); c_thread_data_begin += CPerThread) @@ -128,6 +128,6 @@ __device__ void blockwise_direct_convolution(InBlockDesc, k_thread_data_begin, ho_thread_data_begin, wo_thread_data_begin), - out_thread_desc); + out_thread_desc.GetLengths()); } } diff --git a/src/include/gridwise_direct_convolution_1.cuh b/src/include/gridwise_direct_convolution_1.cuh index d4358ac3c6..e8a99536bb 100644 --- a/src/include/gridwise_direct_convolution_1.cuh +++ b/src/include/gridwise_direct_convolution_1.cuh @@ -121,6 +121,27 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc, } #endif + constexpr auto blockwise_in_copy = + blockwise_4d_tensor_copy_1{}; + + constexpr auto blockwise_wei_copy = + blockwise_4d_tensor_copy_1{}; + + constexpr auto blockwise_out_copy = + blockwise_4d_tensor_copy_1{}; + // set output tensor in LDS to 0 blockwise_4d_tensor_set_zero(out_block_desc, p_out_block); @@ -128,23 +149,16 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc, c_block_work_begin += CPerBlock) { // copy input tensor to LDS - blockwise_4d_tensor_copy(in_block_global_desc, - p_in_global + - in_global_desc.Get1dIndex(n_block_work_begin, - c_block_work_begin, - hi_block_work_begin, - wi_block_work_begin), - in_block_desc, - p_in_block, - in_block_desc); + blockwise_in_copy.run(p_in_global + in_global_desc.Get1dIndex(n_block_work_begin, + c_block_work_begin, + hi_block_work_begin, + wi_block_work_begin), + p_in_block); // copy weight tensor to LDS - blockwise_4d_tensor_copy( - wei_block_global_desc, + blockwise_wei_copy.run( p_wei_global + wei_global_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0), - wei_block_desc, - p_wei_block, - wei_block_desc); + p_wei_block); __syncthreads(); @@ -165,12 +179,9 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc, } // copy output tensor from LDS to device mem - blockwise_4d_tensor_copy( - out_block_desc, - p_out_block, - out_block_global_desc, - p_out_global + - out_global_desc.Get1dIndex( - n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin), - out_block_desc); + blockwise_out_copy.run(p_out_block, + p_out_global + out_global_desc.Get1dIndex(n_block_work_begin, + k_block_work_begin, + ho_block_work_begin, + wo_block_work_begin)); } diff --git a/src/include/gridwise_direct_convolution_2.cuh b/src/include/gridwise_direct_convolution_2.cuh index 744496f6cf..aeafbacca3 100644 --- a/src/include/gridwise_direct_convolution_2.cuh +++ b/src/include/gridwise_direct_convolution_2.cuh @@ -144,6 +144,20 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc, wo_thread_data_begin); #endif + constexpr auto blockwise_in_copy = + blockwise_4d_tensor_copy_1{}; + + constexpr auto blockwise_wei_copy = + blockwise_4d_tensor_copy_1{}; + // set threadwise output tensor to 0 threadwise_4d_tensor_set_zero(out_thread_desc, p_out_thread); @@ -151,30 +165,23 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc, c_block_data_begin += CPerBlock, __syncthreads()) { // copy input tensor to LDS - blockwise_4d_tensor_copy(in_global_desc, - p_in_global + - in_global_desc.Get1dIndex(n_block_data_begin, - c_block_data_begin, - hi_block_data_begin, - wi_block_data_begin), - in_block_desc, - p_in_block, - in_block_desc.GetLengths()); + blockwise_in_copy.run(p_in_global + in_global_desc.Get1dIndex(n_block_data_begin, + c_block_data_begin, + hi_block_data_begin, + wi_block_data_begin), + p_in_block); // copy weight tensor to LDS - blockwise_4d_tensor_copy( - wei_global_desc, + blockwise_wei_copy.run( p_wei_global + wei_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0), - wei_block_desc, - p_wei_block, - wei_block_desc.GetLengths()); + p_wei_block); __syncthreads(); for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread) { // threadwise convolution -#if 0 +#if 1 threadwise_direct_convolution_2( in_thread_block_desc, p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin, @@ -185,7 +192,7 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc, p_wei_block + wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), out_thread_desc, p_out_thread); -#elif 1 +#elif 0 threadwise_direct_convolution_3( in_thread_block_desc, p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin, diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh index 0fc4381a9b..43d47c2ceb 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh @@ -106,17 +106,17 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, } #endif + // blockwise in copy + // formmat is [CPerBlock,BPerBlock + BGhostRead] #if 1 - // blockwise 2d copy - const auto blockwise_2d_copy = + const auto blockwise_in_copy = blockwise_2d_tensor_copy_1{}; #elif 0 - // blockwise 2d copy - const auto blockwise_2d_copy = + const auto blockwise_in_copy = blockwise_2d_tensor_copy_2{}; #endif + // blockwise wei copy + // format is [S,R,CPerBlock,KPerBlock] +#if 1 + const auto blockwise_wei_copy = + blockwise_4d_tensor_copy_1{}; +#endif + // a series of blockwise GEMM // c_mtx += transpose(a_mtx) * b_mtx // a_mtx and b_mtx saved in LDS, c_mtx saved in register @@ -172,21 +183,19 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, __syncthreads()) { +#if 1 // input: global mem to LDS, - // formmat is [CPerBlock,BPerBlock + BGhostRead] - blockwise_2d_copy.run( + blockwise_in_copy.run( p_in_global + in_cb_global_desc.Get1dIndex(c_block_data_begin, b_block_data_begin), p_in_block); +#endif +#if 1 // weight: global mem to LDS, - // format is [S,R,CPerBlock,KPerBlock] - blockwise_4d_tensor_copy( - wei_srck_global_desc, - p_wei_global + - wei_srck_global_desc.Get1dIndex(0, 0, c_block_data_begin, k_block_data_begin), - wei_srck_block_desc, - p_wei_block, - wei_srck_block_desc.GetLengths()); + blockwise_wei_copy.run(p_wei_global + wei_srck_global_desc.Get1dIndex( + 0, 0, c_block_data_begin, k_block_data_begin), + p_wei_block); +#endif __syncthreads(); diff --git a/src/include/gridwise_implicit_gemm_convolution_3_cnhw_srck_knhw.cuh b/src/include/gridwise_implicit_gemm_convolution_3_cnhw_srck_knhw.cuh new file mode 100644 index 0000000000..d0d5797790 --- /dev/null +++ b/src/include/gridwise_implicit_gemm_convolution_3_cnhw_srck_knhw.cuh @@ -0,0 +1,323 @@ +#pragma once +#include "common.cuh" +#include "ConstantTensorDescriptor.cuh" +#include "ConstantMatrixDescriptor.cuh" +#include "blockwise_4d_tensor_op.cuh" +#include "blockwise_2d_tensor_op.cuh" +#include "threadwise_2d_tensor_op.cuh" +#include "gemm.cuh" + +// define B = N*Hi*Wi +template +__global__ void +gridwise_implicit_gemm_convolution_3_cnhw_srck_knhw(InGlobalDesc, + Float* const __restrict__ p_in_global, + WeiGlobalDesc, + Float* const __restrict__ p_wei_global, + OutGlobalDesc, + Float* __restrict__ p_out_global) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_cnhw_global_desc = InGlobalDesc{}; + constexpr auto wei_srck_global_desc = WeiGlobalDesc{}; + constexpr auto out_knhw_global_desc = OutGlobalDesc{}; + + constexpr unsigned C = in_cnhw_global_desc.GetLength(I0); + constexpr unsigned N = in_cnhw_global_desc.GetLength(I1); + constexpr unsigned Hi = in_cnhw_global_desc.GetLength(I2); + constexpr unsigned Wi = in_cnhw_global_desc.GetLength(I3); + + constexpr unsigned K = out_knhw_global_desc.GetLength(I0); + constexpr unsigned Ho = out_knhw_global_desc.GetLength(I2); + constexpr unsigned Wo = out_knhw_global_desc.GetLength(I3); + + constexpr unsigned S = wei_srck_global_desc.GetLength(I0); + constexpr unsigned R = wei_srck_global_desc.GetLength(I1); + + constexpr unsigned B = N * Hi * Wi; + constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1); + + // divide block work by 2d: [K, B] + constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock; + + const unsigned k_block_work_id = get_block_1d_id() / BBlockWork; + const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; + + const unsigned k_block_data_begin = k_block_work_id * KPerBlock; + const unsigned b_block_data_begin = b_block_work_id * BPerBlock; + +#if 0 + if(get_thread_local_1d_id() == 0) + { + printf("K %u B %u, BGhostRead %u\n", K, B, BGhostRead); + + printf("%u %u, KBlockWork %u BBlockWork %u, k_block_data_begin %u b_block_data_begin %u\n", + get_block_1d_id(), + get_thread_local_1d_id(), + KBlockWork, + BBlockWork, + k_block_data_begin, + b_block_data_begin); + } +#endif + + // flattend (2d) tensor view of gridwise input + constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); + + // tensor view of blockwise input and weight + constexpr auto in_cb_block_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto wei_srck_block_desc = + make_ConstantTensorDescriptor(Sequence{}); + + // tensor view of threadwise output in register + constexpr auto out_kb_thread_desc = + make_ConstantTensorDescriptor(Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc"); + print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc"); + print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc"); + + printf("KPerBlock %u\n", KPerBlock); + } +#endif + + // in: global mem to LDS + // formmat is [CPerBlock,BPerBlock + BGhostRead] +#if 1 + const auto blockwise_in_copy = + blockwise_2d_tensor_copy_1{}; +#elif 0 + const auto blockwise_in_copy = + blockwise_2d_tensor_copy_2{}; +#elif 0 + const auto blockwise_in_copy = + blockwise_2d_tensor_copy_dummy_2{}; +#endif + + // weight: global mem to LDS, + // format is [S,R,CPerBlock,KPerBlock] +#if 1 + const auto blockwise_wei_copy = + blockwise_4d_tensor_copy_1{}; +#else + const auto blockwise_wei_copy = + blockwise_4d_tensor_copy_dummy{}; +#endif + + // a series of blockwise GEMM + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx and b_mtx saved in LDS, c_mtx saved in register + // a_mtx[C,K] is a sub-matrix of wei_block[S,R,C,K] + // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] + // c_mtx[K,B] is out_block[K,B] + const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + + const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, + Number{}, + Number{}); // constexpr doesn't compile + + const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + + const auto blockwise_gemm = + blockwise_gemm_block_a_block_b_thread_c{}; + + // LDS + constexpr unsigned in_block_size = in_cb_block_desc.GetElementSpace(); + constexpr unsigned wei_block_size = wei_srck_block_desc.GetElementSpace(); + + // double buffer + __shared__ Float p_in_block_0[in_block_size]; + __shared__ Float p_wei_block_0[wei_block_size]; + + __shared__ Float p_in_block_1[in_block_size]; + __shared__ Float p_wei_block_1[wei_block_size]; + + // register + Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; + + // set threadwise output tensor to 0 + threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); + + // prelog: load data +#if 1 + // input: global mem to LDS, + blockwise_in_copy.run(p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin), + p_in_block_0); +#endif + +#if 1 + // weight: global mem to LDS, + blockwise_wei_copy.run( + p_wei_global + wei_srck_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin), p_wei_block_0); +#endif + + unsigned cloop = 0; + + for(unsigned c_block_data_begin = 0; c_block_data_begin + CPerBlock < C; + c_block_data_begin += CPerBlock, ++cloop) + { + __syncthreads(); + + Float* p_in_block_now = (cloop % 2 == 0) ? p_in_block_0 : p_in_block_1; + Float* p_wei_block_now = (cloop % 2 == 0) ? p_wei_block_0 : p_wei_block_1; + + Float* p_in_block_next = (cloop % 2 == 0) ? p_in_block_1 : p_in_block_0; + Float* p_wei_block_next = (cloop % 2 == 0) ? p_wei_block_1 : p_wei_block_0; + +#if 1 + // preload next data + // input: global mem to LDS, + blockwise_in_copy.run(p_in_global + in_cb_global_desc.Get1dIndex( + c_block_data_begin + CPerBlock, b_block_data_begin), + p_in_block_next); +#endif + +#if 1 + // weight: global mem to LDS, + blockwise_wei_copy.run(p_wei_global + + wei_srck_global_desc.Get1dIndex( + 0, 0, c_block_data_begin + CPerBlock, k_block_data_begin), + p_wei_block_next); +#endif + + // a series of GEMM + for(unsigned s = 0; s < S; ++s) + { + for(unsigned r = 0; r < R; ++r) + { + auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; + + blockwise_gemm.run(p_wei_block_now + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), + p_in_block_now + s * Wi + r, + p_out_thread, + f_accum); + } + } + } + + { + // last cloop + __syncthreads(); + + Float* p_in_block_now = (cloop % 2 == 0) ? p_in_block_0 : p_in_block_1; + Float* p_wei_block_now = (cloop % 2 == 0) ? p_wei_block_0 : p_wei_block_1; + + // a series of GEMM + for(unsigned s = 0; s < S; ++s) + { + for(unsigned r = 0; r < R; ++r) + { + auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; + + blockwise_gemm.run(p_wei_block_now + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), + p_in_block_now + s * Wi + r, + p_out_thread, + f_accum); + } + } + } + + // output: register to global mem, + const auto matrix_c_index = + blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + + const unsigned k_thread_data_begin = matrix_c_index.row_begin; + const unsigned b_thread_data_begin = matrix_c_index.col_begin; + + const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; + const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; + +#if 0 + if(get_block_1d_id() == 0) + { + printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", + get_block_1d_id(), + get_thread_local_1d_id(), + matrix_c_index.row_begin, + matrix_c_index.col_begin, + k_data_begin, + b_data_begin, + p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); + } +#endif + + for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) + { + for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) + { + unsigned k_data = k_data_begin + k; + unsigned b_data = b_data_begin + b; + + unsigned n_data = b_data / (Hi * Wi); + unsigned itmp = b_data - n_data * (Hi * Wi); + unsigned h_data = itmp / Wi; + unsigned w_data = itmp - h_data * Wi; + + if(n_data < N && h_data < Ho && w_data < Wo) + { + p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] = + p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; + } + } + } +}