diff --git a/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.hpp b/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.hpp deleted file mode 100644 index 2ede247ff0..0000000000 --- a/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.hpp +++ /dev/null @@ -1,183 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.hip.hpp" -#include "gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.hip.hpp" - -template -void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcsr, - OutDesc, - Tensor& out_nkhw, - unsigned nrepeat) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcsr_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr unsigned N = in_nchw_desc.GetLength(I0); - constexpr unsigned Hi = in_nchw_desc.GetLength(I2); - constexpr unsigned Wi = in_nchw_desc.GetLength(I3); - - constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); - constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); - - constexpr unsigned K = wei_kcsr_desc.GetLength(I0); - constexpr unsigned C = wei_kcsr_desc.GetLength(I1); - constexpr unsigned S = wei_kcsr_desc.GetLength(I2); - constexpr unsigned R = wei_kcsr_desc.GetLength(I3); - - constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1); - - // convert in_nchw to in_cnhw - auto in_cnhw_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(in_cnhw_desc, std::cout << "in_cnhw_desc: "); - - Tensor in_cnhw(make_TensorDescriptor(in_cnhw_desc)); - - auto f_reorder_nchw2cnhw = [&](auto n, auto c, auto hi, auto wi) { - in_cnhw(c, n, hi, wi) = in_nchw(n, c, hi, wi); - }; - - make_ParallelTensorFunctor(f_reorder_nchw2cnhw, N, C, Hi, Wi)( - std::thread::hardware_concurrency()); - - // convert wei_kcsr to wei_srck - auto wei_srck_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(wei_srck_desc, std::cout << "wei_srck_desc: "); - - Tensor wei_srck(make_TensorDescriptor(wei_srck_desc)); - - auto f_reorder_kcsr2srck = [&](auto k, auto c, auto s, auto r) { - wei_srck(s, r, c, k) = wei_kcsr(k, c, s, r); - }; - - make_ParallelTensorFunctor(f_reorder_kcsr2srck, K, C, S, R)( - std::thread::hardware_concurrency()); - - // conver out_nkhw to out_knhw - auto out_knhw_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(out_knhw_desc, std::cout << "out_knhw_desc: "); - - Tensor out_knhw(make_TensorDescriptor(out_knhw_desc)); - -#if 0 - constexpr unsigned BPerBlock = 128; - constexpr unsigned KPerBlock = 1; - constexpr unsigned CPerBlock = 1; - - constexpr unsigned BPerThread = 4; - constexpr unsigned KPerThread = 1; - constexpr unsigned CPerThread = 1; - - constexpr unsigned GemmThreadPerColumnPerCluster = 1; - constexpr unsigned GemmThreadPerRowPerCluster = 1; - - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; - - constexpr unsigned BlockSize = 32; -#elif 1 - // 3x3, 34x34 - constexpr unsigned BPerBlock = 128; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; - - constexpr unsigned BPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - - constexpr unsigned GemmThreadPerColumnPerCluster = 4; - constexpr unsigned GemmThreadPerRowPerCluster = 8; - - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; - - constexpr unsigned BlockSize = 128; -#elif 0 - // 1x1, 28x28 - constexpr unsigned BPerBlock = 64; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 8; - - constexpr unsigned BPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - - constexpr unsigned GemmThreadPerColumnPerCluster = 4; - constexpr unsigned GemmThreadPerRowPerCluster = 8; - - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; - - constexpr unsigned BlockSize = 64; -#endif - - constexpr unsigned GridSize = - ((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - // mem - std::size_t data_sz = sizeof(T); - DeviceMem in_cnhw_device_buf(data_sz * (in_cnhw.mDesc.GetElementSpace() + BGhostRead + - BPerBlock)); // reserve extra space for BGhostRead - DeviceMem wei_srck_device_buf(data_sz * wei_srck.mDesc.GetElementSpace()); - DeviceMem out_knhw_device_buf(data_sz * out_knhw.mDesc.GetElementSpace()); - - in_cnhw_device_buf.ToDevice(in_cnhw.mData.data()); - wei_srck_device_buf.ToDevice(wei_srck.mData.data()); - out_knhw_device_buf.ToDevice(out_knhw.mData.data()); - - for(unsigned i = 0; i < nrepeat; ++i) - { - float time = launch_kernel( -#if 1 - gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw -#else - gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline -#endif - , - dim3(GridSize), - dim3(BlockSize), - - static_cast(in_cnhw_device_buf.GetDeviceBuffer()), - static_cast(wei_srck_device_buf.GetDeviceBuffer()), - static_cast(out_knhw_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms\n", time); - usleep(std::min(time * 1000, float(10000))); - } - - out_knhw_device_buf.FromDevice(out_knhw.mData.data()); - - // convert out_knhw to out_nkhw - auto f_reorder_knhw2nkhw = [&](auto n, auto k, auto ho, auto wo) { - out_nkhw(n, k, ho, wo) = out_knhw(k, n, ho, wo); - }; - - make_ParallelTensorFunctor(f_reorder_knhw2nkhw, N, K, Ho, Wo)( - std::thread::hardware_concurrency()); -} diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index 7b669f47d0..17f5a6447f 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -12,7 +12,6 @@ #include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.hpp" #include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp" #include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp" -#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.hpp" #include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw.hpp" //#include "device_winograd_convolution.hip.hpp" @@ -595,8 +594,6 @@ int main() device_implicit_gemm_convolution_1_nchw_srck_nkhw #elif 0 device_implicit_gemm_convolution_1_chwn_csrk_khwn -#elif 0 - device_implicit_gemm_convolution_2_cnhw_srck_knhw #elif 1 device_implicit_gemm_convolution_2_cnhw_csrk_knhw #endif @@ -614,7 +611,7 @@ int main() nrepeat); #endif -#if 1 +#if 0 if(S == 3 && R == 3) { host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.hip.hpp deleted file mode 100644 index 44abde0336..0000000000 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.hip.hpp +++ /dev/null @@ -1,273 +0,0 @@ -#pragma once -#include "common.hip.hpp" -#include "ConstantTensorDescriptor.hip.hpp" -#include "ConstantMatrixDescriptor.hip.hpp" -#include "blockwise_4d_tensor_op.hip.hpp" -#include "blockwise_2d_tensor_op.hip.hpp" -#include "threadwise_2d_tensor_op.hip.hpp" -#include "blockwise_gemm.hip.hpp" - -// define B = flatten(N, Hi, Wi) -template -__global__ void -gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_cnhw_global_desc = InGlobalDesc{}; - constexpr auto wei_srck_global_desc = WeiGlobalDesc{}; - constexpr auto out_knhw_global_desc = OutGlobalDesc{}; - - constexpr unsigned C = in_cnhw_global_desc.GetLength(I0); - constexpr unsigned N = in_cnhw_global_desc.GetLength(I1); - constexpr unsigned Hi = in_cnhw_global_desc.GetLength(I2); - constexpr unsigned Wi = in_cnhw_global_desc.GetLength(I3); - - constexpr unsigned K = out_knhw_global_desc.GetLength(I0); - constexpr unsigned Ho = out_knhw_global_desc.GetLength(I2); - constexpr unsigned Wo = out_knhw_global_desc.GetLength(I3); - - constexpr unsigned S = wei_srck_global_desc.GetLength(I0); - constexpr unsigned R = wei_srck_global_desc.GetLength(I1); - - constexpr unsigned B = N * Hi * Wi; - constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1); - - // divide block work by 2d: [K, B] - constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock; - - const unsigned k_block_work_id = get_block_1d_id() / BBlockWork; - const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; - - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned b_block_data_begin = b_block_work_id * BPerBlock; - -#if 0 - if(get_thread_local_1d_id() == 0) - { - printf("K %u B %u, BGhostRead %u\n", K, B, BGhostRead); - - printf("%u %u, KBlockWork %u BBlockWork %u, k_block_data_begin %u b_block_data_begin %u\n", - get_block_1d_id(), - get_thread_local_1d_id(), - KBlockWork, - BBlockWork, - k_block_data_begin, - b_block_data_begin); - } -#endif - - // flattend (2d) tensor view of gridwise input - constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of blockwise input and weight - constexpr auto in_cb_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - constexpr auto wei_srck_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of threadwise output in register - constexpr auto out_kb_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc"); - print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc"); - print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc"); - - printf("KPerBlock %u\n", KPerBlock); - } -#endif - -// blockwise in copy -// formmat is [CPerBlock,BPerBlock + BGhostRead] -#if 0 - const auto blockwise_in_copy = - Blockwise2dTensorCopy1{}; -#elif 1 - const auto blockwise_in_copy = Blockwise2dTensorCopy2{}; -#endif - - // blockwise wei copy - // format is [S,R,CPerBlock,KPerBlock] - const auto blockwise_wei_copy = - Blockwise4dTensorCopy1{}; - - // a series of blockwise GEMM - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx and b_mtx saved in LDS, c_mtx saved in register - // a_mtx[C,K] is a sub-matrix of wei_block[S,R,C,K] - // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] - // c_mtx[K,B] is out_block[K,B] - constexpr auto a_cxk_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto c_kxb_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC{}; - - // LDS - constexpr unsigned in_block_size = in_cb_block_desc.GetElementSpace(); - constexpr unsigned wei_block_size = wei_srck_block_desc.GetElementSpace(); - - __shared__ Float p_in_block[in_block_size]; - __shared__ Float p_wei_block[wei_block_size]; - - // register - Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; - - // set threadwise output tensor to 0 - threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); - - const Float* p_in_global_block_offset = - p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); - - const Float* p_wei_global_block_offset = - p_wei_global + wei_srck_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); - - for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_srck_global_desc.GetStride(I2), - __syncthreads()) - { -#if 1 - // input: global mem to LDS, - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); -#endif - -#if 1 - // weight: global mem to LDS, - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); -#endif - - __syncthreads(); - -#if 1 - // a series of GEMM - for(unsigned s = 0; s < S; ++s) - { - for(unsigned r = 0; r < R; ++r) - { - auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; - - blockwise_gemm.Run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), - p_in_block + s * Wi + r, - p_out_thread, - f_accum); - } - } -#endif - } - - // output: register to global mem, - const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const unsigned k_thread_data_begin = matrix_c_index.row; - const unsigned b_thread_data_begin = matrix_c_index.col; - - const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; - const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; - -#if 0 - if(get_block_1d_id() == 0) - { - printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", - get_block_1d_id(), - get_thread_local_1d_id(), - matrix_c_index.row, - matrix_c_index.col, - k_data_begin, - b_data_begin, - p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); - } -#endif - - for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) - { - for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) - { - unsigned k_data = k_data_begin + k; - unsigned b_data = b_data_begin + b; - - unsigned n_data = b_data / (Hi * Wi); - unsigned itmp = b_data - n_data * (Hi * Wi); - unsigned h_data = itmp / Wi; - unsigned w_data = itmp - h_data * Wi; - -#if 0 - if(get_block_1d_id() == 0) - { - printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n", - get_block_1d_id(), - get_thread_local_1d_id(), - k, - b, - k_data, - n_data, - h_data, - w_data, - p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]); - } -#endif - if(n_data < N && h_data < Ho && w_data < Wo) - { -#if 1 - p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] = - p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; -#endif - } - } - } -} diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.hip.hpp deleted file mode 100644 index 0c6577f4a4..0000000000 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.hip.hpp +++ /dev/null @@ -1,315 +0,0 @@ -#pragma once -#include "common.hip.hpp" -#include "ConstantTensorDescriptor.hip.hpp" -#include "ConstantMatrixDescriptor.hip.hpp" -#include "blockwise_4d_tensor_op.hip.hpp" -#include "blockwise_2d_tensor_op.hip.hpp" -#include "threadwise_2d_tensor_op.hip.hpp" -#include "blockwise_gemm.hip.hpp" - -// define B = N*Hi*Wi -template -__global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline( - const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_cnhw_global_desc = InGlobalDesc{}; - constexpr auto wei_srck_global_desc = WeiGlobalDesc{}; - constexpr auto out_knhw_global_desc = OutGlobalDesc{}; - - constexpr unsigned C = in_cnhw_global_desc.GetLength(I0); - constexpr unsigned N = in_cnhw_global_desc.GetLength(I1); - constexpr unsigned Hi = in_cnhw_global_desc.GetLength(I2); - constexpr unsigned Wi = in_cnhw_global_desc.GetLength(I3); - - constexpr unsigned K = out_knhw_global_desc.GetLength(I0); - constexpr unsigned Ho = out_knhw_global_desc.GetLength(I2); - constexpr unsigned Wo = out_knhw_global_desc.GetLength(I3); - - constexpr unsigned S = wei_srck_global_desc.GetLength(I0); - constexpr unsigned R = wei_srck_global_desc.GetLength(I1); - - constexpr unsigned B = N * Hi * Wi; - constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1); - - // divide block work by 2d: [K, B] - constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock; - - const unsigned k_block_work_id = get_block_1d_id() / BBlockWork; - const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; - - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned b_block_data_begin = b_block_work_id * BPerBlock; - -#if 0 - if(get_thread_local_1d_id() == 0) - { - printf("K %u B %u, BGhostRead %u\n", K, B, BGhostRead); - - printf("%u %u, KBlockWork %u BBlockWork %u, k_block_data_begin %u b_block_data_begin %u\n", - get_block_1d_id(), - get_thread_local_1d_id(), - KBlockWork, - BBlockWork, - k_block_data_begin, - b_block_data_begin); - } -#endif - - // flattend (2d) tensor view of gridwise input - constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of blockwise input and weight - constexpr auto in_cb_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - constexpr auto wei_srck_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of threadwise output in register - constexpr auto out_kb_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc"); - print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc"); - print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc"); - - printf("KPerBlock %u\n", KPerBlock); - } -#endif - -// in: global mem to LDS -// formmat is [CPerBlock,BPerBlock + BGhostRead] -#if 1 - const auto blockwise_in_copy = - Blockwise2dTensorCopy1{}; -#elif 1 - const auto blockwise_in_copy = Blockwise2dTensorCopy2{}; -#elif 0 - const auto blockwise_in_copy = - blockwise_2d_tensor_copy_dummy_2{}; -#endif - -// weight: global mem to LDS, -// format is [S,R,CPerBlock,KPerBlock] -#if 1 - const auto blockwise_wei_copy = - Blockwise4dTensorCopy1{}; -#else - const auto blockwise_wei_copy = - blockwise_4d_tensor_copy_dummy{}; -#endif - - // a series of blockwise GEMM - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx and b_mtx saved in LDS, c_mtx saved in register - // a_mtx[C,K] is a sub-matrix of wei_block[S,R,C,K] - // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] - // c_mtx[K,B] is out_block[K,B] - constexpr auto a_cxk_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto c_kxb_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC{}; - - // LDS - constexpr unsigned in_block_size = in_cb_block_desc.GetElementSpace(); - constexpr unsigned wei_block_size = wei_srck_block_desc.GetElementSpace(); - - // double buffer - __shared__ Float p_in_block_0[in_block_size]; - __shared__ Float p_wei_block_0[wei_block_size]; - - __shared__ Float p_in_block_1[in_block_size]; - __shared__ Float p_wei_block_1[wei_block_size]; - - // register - Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; - - // set threadwise output tensor to 0 - threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); - -// prelog: load data -#if 1 - // input: global mem to LDS, - blockwise_in_copy.Run(p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin), - p_in_block_0); -#endif - -#if 1 - // weight: global mem to LDS, - blockwise_wei_copy.Run( - p_wei_global + wei_srck_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin), p_wei_block_0); -#endif - - unsigned cloop = 0; - - for(unsigned c_block_data_begin = 0; c_block_data_begin + CPerBlock < C; - c_block_data_begin += CPerBlock, ++cloop) - { - __syncthreads(); - - Float* p_in_block_now = (cloop % 2 == 0) ? p_in_block_0 : p_in_block_1; - Float* p_wei_block_now = (cloop % 2 == 0) ? p_wei_block_0 : p_wei_block_1; - - Float* p_in_block_next = (cloop % 2 == 0) ? p_in_block_1 : p_in_block_0; - Float* p_wei_block_next = (cloop % 2 == 0) ? p_wei_block_1 : p_wei_block_0; - -#if 1 - // preload next data - // input: global mem to LDS, - blockwise_in_copy.Run(p_in_global + in_cb_global_desc.Get1dIndex( - c_block_data_begin + CPerBlock, b_block_data_begin), - p_in_block_next); -#endif - -#if 1 - // weight: global mem to LDS, - blockwise_wei_copy.Run(p_wei_global + - wei_srck_global_desc.Get1dIndex( - 0, 0, c_block_data_begin + CPerBlock, k_block_data_begin), - p_wei_block_next); -#endif - - // a series of GEMM - for(unsigned s = 0; s < S; ++s) - { - for(unsigned r = 0; r < R; ++r) - { - auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; - - blockwise_gemm.Run(p_wei_block_now + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), - p_in_block_now + s * Wi + r, - p_out_thread, - f_accum); - } - } - } - - { - // last cloop - __syncthreads(); - - Float* p_in_block_now = (cloop % 2 == 0) ? p_in_block_0 : p_in_block_1; - Float* p_wei_block_now = (cloop % 2 == 0) ? p_wei_block_0 : p_wei_block_1; - - // a series of GEMM - for(unsigned s = 0; s < S; ++s) - { - for(unsigned r = 0; r < R; ++r) - { - auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; - - blockwise_gemm.Run(p_wei_block_now + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), - p_in_block_now + s * Wi + r, - p_out_thread, - f_accum); - } - } - } - - // output: register to global mem, - const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const unsigned k_thread_data_begin = matrix_c_index.row; - const unsigned b_thread_data_begin = matrix_c_index.col; - - const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; - const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; - -#if 0 - if(get_block_1d_id() == 0) - { - printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", - get_block_1d_id(), - get_thread_local_1d_id(), - matrix_c_index.row, - matrix_c_index.col, - k_data_begin, - b_data_begin, - p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); - } -#endif - - for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) - { - for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) - { - unsigned k_data = k_data_begin + k; - unsigned b_data = b_data_begin + b; - - unsigned n_data = b_data / (Hi * Wi); - unsigned itmp = b_data - n_data * (Hi * Wi); - unsigned h_data = itmp / Wi; - unsigned w_data = itmp - h_data * Wi; - - if(n_data < N && h_data < Ho && w_data < Wo) - { - p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] = - p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; - } - } - } -}