From fd8de384170d6100a837b19e37139665c89e2054 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sat, 16 Mar 2019 10:50:46 -0500 Subject: [PATCH] refactor --- ...e_direct_convolution_2_nchw_kcyx_nkhw.hpp} | 48 +-- ...icit_gemm_convolution_2_chwn_cyxk_khwn.hpp | 71 +++-- driver/driver.hip.cpp | 19 +- ...rect_convolution_2_nchw_kcyx_nkhw.hip.hpp} | 7 +- ..._gemm_convolution_2_chwn_cyxk_khwn.hip.hpp | 281 ++++++++++++++++++ ...2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp | 4 +- 6 files changed, 359 insertions(+), 71 deletions(-) rename driver/{device_direct_convolution_2.hpp => device_direct_convolution_2_nchw_kcyx_nkhw.hpp} (77%) rename src/include/{gridwise_direct_convolution_2.hip.hpp => gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp} (96%) create mode 100644 src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp diff --git a/driver/device_direct_convolution_2.hpp b/driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp similarity index 77% rename from driver/device_direct_convolution_2.hpp rename to driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp index 1baedafc46..602702949e 100644 --- a/driver/device_direct_convolution_2.hpp +++ b/driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp @@ -1,16 +1,16 @@ #pragma once #include #include "device.hpp" -#include "gridwise_direct_convolution_2.hip.hpp" +#include "gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp" template -void device_direct_convolution_2(InDesc, - const Tensor& in, - WeiDesc, - const Tensor& wei, - OutDesc, - Tensor& out, - unsigned nrepeat) +void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc, + const Tensor& in, + WeiDesc, + const Tensor& wei, + OutDesc, + Tensor& out, + unsigned nrepeat) { std::size_t data_sz = sizeof(T); DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace()); @@ -57,22 +57,22 @@ void device_direct_convolution_2(InDesc, for(unsigned i = 0; i < nrepeat; ++i) { - float time = launch_kernel(gridwise_direct_convolution_2, + float time = launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw, dim3(GridSize), dim3(BlockSize), static_cast(in_device_buf.GetDeviceBuffer()), diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index a657949f35..c885894165 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -1,6 +1,7 @@ #pragma once #include #include "device.hpp" +#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp" #include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp" template @@ -209,39 +210,43 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, for(unsigned i = 0; i < nrepeat; ++i) { - float time = - launch_kernel(gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer< - GridSize, - BlockSize, - T, - decltype(in_chwn_desc), - decltype(wei_cyxk_desc), - decltype(out_khwn_desc), - BPerBlock, - KPerBlock, - CPerBlock, - BPerThread, - KPerThread, - GemmThreadPerColumnPerCluster, - GemmThreadPerRowPerCluster, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - InBlockCopyThreadPerDim0, - InBlockCopyThreadPerDim1, - WeiBlockCopyThreadPerDim0, - WeiBlockCopyThreadPerDim1, - InBlockCopyDataPerRead, - WeiBlockCopyDataPerRead>, - dim3(GridSize), - dim3(BlockSize), - static_cast(in_chwn_device_buf.GetDeviceBuffer()), - static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), - static_cast(out_khwn_device_buf.GetDeviceBuffer())); + float time = launch_kernel( +#if 1 + gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn +#else + gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer +#endif + , + dim3(GridSize), + dim3(BlockSize), + static_cast(in_chwn_device_buf.GetDeviceBuffer()), + static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), + static_cast(out_khwn_device_buf.GetDeviceBuffer())); printf("Elapsed time : %f ms\n", time); usleep(std::min(time * 1000, float(10000))); diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index aca345acfd..6cd75afd79 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -8,7 +8,7 @@ #include "ConstantTensorDescriptor.hip.hpp" #include "conv_common.hip.hpp" #include "device_direct_convolution_1.hpp" -#include "device_direct_convolution_2.hpp" +#include "device_direct_convolution_2_nchw_kcyx_nkhw.hpp" #include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp" #include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp" #include "device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp" @@ -503,7 +503,7 @@ int main(int argc, char* argv[]) constexpr unsigned HPad = 1; constexpr unsigned WPad = 1; -#elif 0 +#elif 1 // 1x1 filter, 28x28 image constexpr unsigned N = 16; constexpr unsigned C = 256; @@ -577,10 +577,11 @@ int main(int argc, char* argv[]) ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); - Tensor in_nchw(make_TensorDescriptor(in_nchw_desc)); - Tensor wei_kcyx(make_TensorDescriptor(wei_kcyx_desc)); - Tensor out_nkhw_host(make_TensorDescriptor(out_nkhw_desc)); - Tensor out_nkhw_device(make_TensorDescriptor(out_nkhw_desc)); + using Float = float; + Tensor in_nchw(make_TensorDescriptor(in_nchw_desc)); + Tensor wei_kcyx(make_TensorDescriptor(wei_kcyx_desc)); + Tensor out_nkhw_host(make_TensorDescriptor(out_nkhw_desc)); + Tensor out_nkhw_device(make_TensorDescriptor(out_nkhw_desc)); std::size_t num_thread = std::thread::hardware_concurrency(); @@ -610,9 +611,9 @@ int main(int argc, char* argv[]) #if 1 #if 0 device_direct_convolution_1 -#elif 0 - device_direct_convolution_2 #elif 1 + device_direct_convolution_2_nchw_kcyx_nkhw +#elif 0 device_implicit_gemm_convolution_1_chwn_cyxk_khwn #elif 0 device_implicit_gemm_convolution_2_chwn_cyxk_khwn @@ -633,7 +634,7 @@ int main(int argc, char* argv[]) if(do_verification) { -#if 0 +#if 1 if(Y == 3 && X == 3) { host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads); diff --git a/src/include/gridwise_direct_convolution_2.hip.hpp b/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp similarity index 96% rename from src/include/gridwise_direct_convolution_2.hip.hpp rename to src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp index 13f9e6cf1d..322d5fd9c2 100644 --- a/src/include/gridwise_direct_convolution_2.hip.hpp +++ b/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp @@ -22,9 +22,10 @@ template -__global__ void gridwise_direct_convolution_2(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) +__global__ void +gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp new file mode 100644 index 0000000000..afa3d3ee90 --- /dev/null +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp @@ -0,0 +1,281 @@ +#pragma once +#include "common.hip.hpp" +#include "ConstantTensorDescriptor.hip.hpp" +#include "ConstantMatrixDescriptor.hip.hpp" +#include "blockwise_4d_tensor_op.hip.hpp" +#include "blockwise_2d_tensor_op.hip.hpp" +#include "threadwise_2d_tensor_op.hip.hpp" +#include "blockwise_gemm.hip.hpp" + +// define B = flatten(N, Hi, Wi) +template +__global__ void +gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_chwn_global_desc = InGlobalDesc{}; + constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; + constexpr auto out_khwn_global_desc = OutGlobalDesc{}; + + constexpr unsigned C = in_chwn_global_desc.GetLength(I0); + constexpr unsigned Hi = in_chwn_global_desc.GetLength(I1); + constexpr unsigned Wi = in_chwn_global_desc.GetLength(I2); + constexpr unsigned N = in_chwn_global_desc.GetLength(I3); + + constexpr unsigned K = out_khwn_global_desc.GetLength(I0); + constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); + constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); + + constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1); + constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2); + + constexpr unsigned B = N * Hi * Wi; + constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1); + + // divide block work by 2d: [K, B] + constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock; + + const unsigned k_block_work_id = get_block_1d_id() / BBlockWork; + const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; + + const unsigned k_block_data_begin = k_block_work_id * KPerBlock; + const unsigned b_block_data_begin = b_block_work_id * BPerBlock; + + // flattend (2d) tensor view of gridwise input + constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); + constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); + + // tensor view of blockwise input and weight + // be careful of alignment + constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + // tensor view of threadwise output in register + constexpr auto out_kb_thread_desc = + make_ConstantTensorDescriptor(Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc"); + print_ConstantTensorDescriptor(wei_cyxk_global_desc, "wei_cyxk_global_desc"); + print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); + + print_ConstantTensorDescriptor(in_cb_global_desc, "in_cb_global_desc"); + print_ConstantTensorDescriptor(wei_ek_global_desc, "wei_ek_global_desc"); + + print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc"); + print_ConstantTensorDescriptor(wei_cyxk_block_desc, "wei_cyxk_block_desc"); + print_ConstantTensorDescriptor(wei_ek_block_desc, "wei_ek_block_desc"); + print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc"); + + printf("KPerBlock %u\n", KPerBlock); + } +#endif + +// blockwise in copy +// formmat is [CPerBlock,BPerBlock + BGhostRead] +#if 0 + const auto blockwise_in_copy = + Blockwise2dTensorCopy1{}; +#elif 0 + const auto blockwise_in_copy = Blockwise2dTensorCopy2{}; +#elif 1 + const auto blockwise_in_copy = Blockwise2dTensorCopy3{}; +#endif + +// blockwise wei copy +// format is [CPerBlock*Y*X,KPerBlock] +#if 0 + const auto blockwise_wei_copy = + Blockwise2dTensorCopy1{}; +#elif 0 + const auto blockwise_wei_copy = Blockwise2dTensorCopy2{}; +#elif 1 + const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}; +#endif + + // a series of blockwise GEMM + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx and b_mtx saved in LDS, c_mtx saved in register + // a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K] + // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] + // c_mtx[K,B] is out_block[K,B] + constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto c_kxb_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + const auto blockwise_gemm = + BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2{}; + + // LDS: be careful of alignment + constexpr unsigned in_block_size = + in_cb_block_desc.GetElementSpace(Number{}); + + constexpr unsigned wei_block_size = + wei_cyxk_block_desc.GetElementSpace(Number{}); + + constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead + ? InBlockCopyDataPerRead + : WeiBlockCopyDataPerRead; + + // LDS + __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; + __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; + + const Float* p_in_global_block_offset = + p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); + + const Float* p_wei_global_block_offset = + p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); + + // register + Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; + + // set threadwise output tensor to 0 + threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); + + for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, + p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), + p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), + __syncthreads()) + { + // load data + blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); + blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); + + __syncthreads(); + + // compute on current data + // a series of GEMM + for(unsigned y = 0; y < Y; ++y) + { + for(unsigned x = 0; x < X; ++x) + { + auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; +#if 1 + blockwise_gemm.Run +#else + blockwise_gemm.Run_RegisterDoubleBuffer +#endif + (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block + y * Wi + x, + p_out_thread, + f_accum); + } + } + } + + // output: register to global mem, + const auto c_thread_mtx_begin = + blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; + const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; + + for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) + { + for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) + { + const auto c_thread_mtx_distance = + blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); + + unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row; + unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col; + + unsigned h_data = b_data / (Wi * N); + unsigned itmp = b_data - h_data * (Wi * N); + unsigned w_data = itmp / N; + unsigned n_data = itmp - w_data * N; + + if(n_data < N && h_data < Ho && w_data < Wo) + { + p_out_global[out_khwn_global_desc.Get1dIndex(k_data, h_data, w_data, n_data)] = + p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; + } + } + } +} diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp index 7c802266d8..60d827293b 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp @@ -259,7 +259,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_b __syncthreads(); // load next data -#if 1 +#if 0 blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next); blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next); #elif 1 @@ -292,7 +292,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_b } } -#if 0 +#if 1 blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next); blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next); #endif