From 9f2e8f8bb48897c3669b5de1855cea67b07abea8 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 15 Feb 2019 22:51:51 -0600 Subject: [PATCH] 2-type implicit gemm using chwn --- ...icit_gemm_convolution_2_chwn_csrk_khwn.hpp | 263 ++++++++++++ driver/driver.hip.cpp | 7 +- ...2_chwn_csrk_khwn_lds_double_buffer.hip.hpp | 378 ++++++++++++++++++ 3 files changed, 646 insertions(+), 2 deletions(-) create mode 100644 driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp create mode 100644 src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp diff --git a/driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp new file mode 100644 index 0000000000..1736a5e874 --- /dev/null +++ b/driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp @@ -0,0 +1,263 @@ +#pragma once +#include +#include "device.hpp" +#include "gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp" + +template +void device_implicit_gemm_convolution_2_chwn_csrk_khwn(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcsr, + OutDesc, + Tensor& out_nkhw, + unsigned nrepeat) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_nchw_desc = InDesc{}; + constexpr auto wei_kcsr_desc = WeiDesc{}; + constexpr auto out_nkhw_desc = OutDesc{}; + + constexpr unsigned N = in_nchw_desc.GetLength(I0); + constexpr unsigned Hi = in_nchw_desc.GetLength(I2); + constexpr unsigned Wi = in_nchw_desc.GetLength(I3); + + constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); + constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); + + constexpr unsigned K = wei_kcsr_desc.GetLength(I0); + constexpr unsigned C = wei_kcsr_desc.GetLength(I1); + constexpr unsigned S = wei_kcsr_desc.GetLength(I2); + constexpr unsigned R = wei_kcsr_desc.GetLength(I3); + + constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1); + + // convert in_nchw to in_cnhw + auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: "); + + Tensor in_chwn(make_TensorDescriptor(in_chwn_desc)); + + make_ParallelTensorFunctor( + [&](auto n, auto c, auto hi, auto wi) { in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); }, + N, + C, + Hi, + Wi)(std::thread::hardware_concurrency()); + + // convert wei_kcsr to wei_csrk + auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: "); + + Tensor wei_csrk(make_TensorDescriptor(wei_csrk_desc)); + + make_ParallelTensorFunctor( + [&](auto k, auto c, auto s, auto r) { wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r); }, + K, + C, + S, + R)(std::thread::hardware_concurrency()); + + // conver out_nkhw to out_knhw + auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: "); + + Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); + +#if 0 + // 3x3, 34x34 + // need to use register double buffer for GEMM + constexpr unsigned BPerBlock = 128; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 4; + + constexpr unsigned BPerThread = 8; + constexpr unsigned KPerThread = 8; + + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 4; + constexpr unsigned GemmNLevel0Cluster = 2; + constexpr unsigned GemmMLevel1Cluster = 2; + constexpr unsigned GemmNLevel1Cluster = 8; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned GemmThreadPerColumnPerCluster = 8; + constexpr unsigned GemmThreadPerRowPerCluster = 8; + + constexpr unsigned InBlockCopyThreadPerDim0 = 4; + constexpr unsigned InBlockCopyThreadPerDim1 = 16; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; + + constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned BlockSize = 128; +#elif 0 + // 1x1, 28x28, 64 threads + constexpr unsigned BPerBlock = 64; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 8; + + constexpr unsigned BPerThread = 8; + constexpr unsigned KPerThread = 8; + + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 4; + constexpr unsigned GemmNLevel0Cluster = 2; + constexpr unsigned GemmMLevel1Cluster = 2; + constexpr unsigned GemmNLevel1Cluster = 4; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned GemmThreadPerColumnPerCluster = 8; + constexpr unsigned GemmThreadPerRowPerCluster = 8; + + constexpr unsigned InBlockCopyThreadPerDim0 = 4; + constexpr unsigned InBlockCopyThreadPerDim1 = 16; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; + + constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned BlockSize = 64; +#elif 1 + // 1x1, 28x28, 128 threads, no lds-double-buffer + // 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128 + constexpr unsigned BPerBlock = 64; + constexpr unsigned KPerBlock = 128; + constexpr unsigned CPerBlock = 8; + + constexpr unsigned BPerThread = 8; + constexpr unsigned KPerThread = 8; + + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 4; + constexpr unsigned GemmNLevel0Cluster = 2; + constexpr unsigned GemmMLevel1Cluster = 4; + constexpr unsigned GemmNLevel1Cluster = 4; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned GemmThreadPerColumnPerCluster = 8; + constexpr unsigned GemmThreadPerRowPerCluster = 8; + + constexpr unsigned InBlockCopyThreadPerDim0 = 4; + constexpr unsigned InBlockCopyThreadPerDim1 = 16; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; + + constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned BlockSize = 128; +#elif 1 + // 1x1, 28x28, 256 thread + constexpr unsigned BPerBlock = 128; + constexpr unsigned KPerBlock = 128; + constexpr unsigned CPerBlock = 8; + + constexpr unsigned BPerThread = 8; + constexpr unsigned KPerThread = 8; + + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 4; + constexpr unsigned GemmNLevel0Cluster = 4; + constexpr unsigned GemmMLevel1Cluster = 4; + constexpr unsigned GemmNLevel1Cluster = 4; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned GemmThreadPerColumnPerCluster = 8; + constexpr unsigned GemmThreadPerRowPerCluster = 8; + + constexpr unsigned InBlockCopyThreadPerDim0 = 4; + constexpr unsigned InBlockCopyThreadPerDim1 = 16; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; + + constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned BlockSize = 256; +#endif + + constexpr unsigned GridSize = + ((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + // mem + std::size_t data_sz = sizeof(T); + DeviceMem in_chwn_device_buf(data_sz * (in_chwn.mDesc.GetElementSpace() + BGhostRead + + BPerBlock)); // reserve extra space for BGhostRead + DeviceMem wei_csrk_device_buf(data_sz * wei_csrk.mDesc.GetElementSpace()); + DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace()); + + in_chwn_device_buf.ToDevice(in_chwn.mData.data()); + wei_csrk_device_buf.ToDevice(wei_csrk.mData.data()); + out_khwn_device_buf.ToDevice(out_khwn.mData.data()); + + for(unsigned i = 0; i < nrepeat; ++i) + { + float time = launch_kernel( +#if 0 + gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn +#else + gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer +#endif + , + dim3(GridSize), + dim3(BlockSize), + static_cast(in_chwn_device_buf.GetDeviceBuffer()), + static_cast(wei_csrk_device_buf.GetDeviceBuffer()), + static_cast(out_khwn_device_buf.GetDeviceBuffer())); + + printf("Elapsed time : %f ms\n", time); + usleep(std::min(time * 1000, float(10000))); + } + + out_khwn_device_buf.FromDevice(out_khwn.mData.data()); + + // convert out_khwn to out_nkhw + make_ParallelTensorFunctor( + [&](auto n, auto k, auto ho, auto wo) { out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); }, + N, + K, + Ho, + Wo)(std::thread::hardware_concurrency()); +} diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index 17f5a6447f..d4a84e34e5 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -13,6 +13,7 @@ #include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp" #include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp" #include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw.hpp" +#include "device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp" //#include "device_winograd_convolution.hip.hpp" struct GeneratorTensor_1 @@ -594,8 +595,10 @@ int main() device_implicit_gemm_convolution_1_nchw_srck_nkhw #elif 0 device_implicit_gemm_convolution_1_chwn_csrk_khwn -#elif 1 +#elif 0 device_implicit_gemm_convolution_2_cnhw_csrk_knhw +#elif 1 + device_implicit_gemm_convolution_2_chwn_csrk_khwn #endif (in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat); @@ -611,7 +614,7 @@ int main() nrepeat); #endif -#if 0 +#if 1 if(S == 3 && R == 3) { host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp new file mode 100644 index 0000000000..15ce27dddb --- /dev/null +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp @@ -0,0 +1,378 @@ +#pragma once +#include "common.hip.hpp" +#include "ConstantTensorDescriptor.hip.hpp" +#include "ConstantMatrixDescriptor.hip.hpp" +#include "blockwise_4d_tensor_op.hip.hpp" +#include "blockwise_2d_tensor_op.hip.hpp" +#include "threadwise_2d_tensor_op.hip.hpp" +#include "blockwise_gemm.hip.hpp" + +// define B = flatten(N, Hi, Wi) +template +__global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer( + const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_chwn_global_desc = InGlobalDesc{}; + constexpr auto wei_csrk_global_desc = WeiGlobalDesc{}; + constexpr auto out_khwn_global_desc = OutGlobalDesc{}; + + constexpr unsigned C = in_chwn_global_desc.GetLength(I0); + constexpr unsigned Hi = in_chwn_global_desc.GetLength(I1); + constexpr unsigned Wi = in_chwn_global_desc.GetLength(I2); + constexpr unsigned N = in_chwn_global_desc.GetLength(I3); + + constexpr unsigned K = out_khwn_global_desc.GetLength(I0); + constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); + constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); + + constexpr unsigned S = wei_csrk_global_desc.GetLength(I1); + constexpr unsigned R = wei_csrk_global_desc.GetLength(I2); + + constexpr unsigned B = N * Hi * Wi; + constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1); + + // divide block work by 2d: [K, B] + constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock; + + const unsigned k_block_work_id = get_block_1d_id() / BBlockWork; + const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; + + const unsigned k_block_data_begin = k_block_work_id * KPerBlock; + const unsigned b_block_data_begin = b_block_work_id * BPerBlock; + + // flattend (2d) tensor view of gridwise input + constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); + constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); + + // tensor view of blockwise input and weight + // be careful of alignment + constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + // tensor view of threadwise output in register + constexpr auto out_kb_thread_desc = + make_ConstantTensorDescriptor(Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc"); + print_ConstantTensorDescriptor(wei_csrk_global_desc, "wei_csrk_global_desc"); + print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); + + print_ConstantTensorDescriptor(in_cb_global_desc, "in_cb_global_desc"); + print_ConstantTensorDescriptor(wei_ek_global_desc, "wei_ek_global_desc"); + + print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc"); + print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc"); + print_ConstantTensorDescriptor(wei_ek_block_desc, "wei_ek_block_desc"); + print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc"); + + printf("KPerBlock %u\n", KPerBlock); + } +#endif + +// blockwise in copy +// formmat is [CPerBlock,BPerBlock + BGhostRead] +#if 0 + const auto blockwise_in_copy = + Blockwise2dTensorCopy1{}; +#elif 0 + const auto blockwise_in_copy = Blockwise2dTensorCopy2{}; +#elif 1 + const auto blockwise_in_copy = Blockwise2dTensorCopy3{}; +#endif + +// blockwise wei copy +// format is [CPerBlock*S*R,KPerBlock] +#if 0 + const auto blockwise_wei_copy = + Blockwise2dTensorCopy1{}; +#elif 0 + const auto blockwise_wei_copy = Blockwise2dTensorCopy2{}; +#elif 1 + const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}; +#endif + + // a series of blockwise GEMM + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx and b_mtx saved in LDS, c_mtx saved in register + // a_mtx[C,K] is a sub-matrix of wei_block[C,S,R,K] + // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] + // c_mtx[K,B] is out_block[K,B] + constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto c_kxb_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); + +#if 0 + const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC{}; +#else + const auto blockwise_gemm = + BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2{}; +#endif + + // LDS: be careful of alignment + constexpr unsigned in_block_size = + in_cb_block_desc.GetElementSpace(Number{}); + + constexpr unsigned wei_block_size = + wei_csrk_block_desc.GetElementSpace(Number{}); + + constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead + ? InBlockCopyDataPerRead + : WeiBlockCopyDataPerRead; + + // LDS double buffer + __shared__ Float p_in_block_0[max_align * ((in_block_size + max_align - 1) / max_align)]; + __shared__ Float p_wei_block_0[max_align * ((wei_block_size + max_align - 1) / max_align)]; + + __shared__ Float p_in_block_1[max_align * ((in_block_size + max_align - 1) / max_align)]; + __shared__ Float p_wei_block_1[max_align * ((wei_block_size + max_align - 1) / max_align)]; + + const Float* p_in_global_block_offset = + p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); + + const Float* p_wei_global_block_offset = + p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); + + // preload data into LDS + blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_0); + blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_0); + + p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0); + p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0); + + // register + Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; + + // set threadwise output tensor to 0 + threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); + + bool even_loop = true; + + for(unsigned c_block_data_begin = 0; c_block_data_begin + CPerBlock < C; + c_block_data_begin += CPerBlock, + p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), + p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0), + even_loop = !even_loop) + { + Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; + Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; + + Float* p_in_block_next = even_loop ? p_in_block_1 : p_in_block_0; + Float* p_wei_block_next = even_loop ? p_wei_block_1 : p_wei_block_0; + + __syncthreads(); + +// load next data +#if 0 + blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next); + blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next); +#elif 0 + blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next); + + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); +#elif 1 + Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); + + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); +#endif + + // compute on current data + // a series of GEMM + for(unsigned s = 0; s < S; ++s) + { + for(unsigned r = 0; r < R; ++r) + { + auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; +#if 1 + blockwise_gemm.Run +#else + blockwise_gemm.Run_RegisterDoubleBuffer +#endif + (p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), + p_in_block_now + s * Wi + r, + p_out_thread, + f_accum); + } + } + +#if 0 + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next); +#elif 1 + blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next); +#endif + } + + // last computation + { + Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; + Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; + + __syncthreads(); + + for(unsigned s = 0; s < S; ++s) + { + for(unsigned r = 0; r < R; ++r) + { + auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; +#if 0 + blockwise_gemm.Run +#else + blockwise_gemm.Run_RegisterDoubleBuffer +#endif + (p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), + p_in_block_now + s * Wi + r, + p_out_thread, + f_accum); + } + } + } + + // output: register to global mem, + const auto c_thread_mtx_begin = + blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; + const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; + +#if 0 + if(get_block_1d_id() == 0) + { + printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", + get_block_1d_id(), + get_thread_local_1d_id(), + matrix_c_index.row, + matrix_c_index.col, + k_data_begin, + b_data_begin, + p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); + } +#endif + + for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) + { + for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) + { + const auto c_thread_mtx_distance = + blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); + + unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row; + unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col; + + unsigned h_data = b_data / (Wi * N); + unsigned itmp = b_data - h_data * (Wi * N); + unsigned w_data = itmp / N; + unsigned n_data = itmp - w_data * N; + + if(n_data < N && h_data < Ho && w_data < Wo) + { + p_out_global[out_khwn_global_desc.Get1dIndex(k_data, h_data, w_data, n_data)] = + p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; + } + } + } +}