From c2d246696f30bec91eaa402fef626af5b793339e Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sat, 3 Aug 2019 00:19:19 -0500 Subject: [PATCH] added implicit gemm v4r4 and double buffer --- ..._v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp | 380 ++++++++++++++++++ ...tion_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 73 ++-- driver/src/driver.cpp | 155 +------ 3 files changed, 420 insertions(+), 188 deletions(-) create mode 100644 composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp new file mode 100644 index 0000000000..00d98cf7cf --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -0,0 +1,380 @@ +#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP +#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP + +#include "common_header.hpp" +#include "ConstantTensorDescriptor.hpp" +#include "ConstantMergedTensorDescriptor.hpp" +#include "ConstantMatrixDescriptor.hpp" +#include "blockwise_generic_tensor_slice_copy.hpp" +#include "blockwise_gemm.hpp" +#include "threadwise_generic_tensor_slice_copy.hpp" + +namespace ck { + +// B = merge(N, Ho, Wo) +template +struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer +{ + __device__ void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I5 = Number<5>{}; + + constexpr auto True = integral_constant{}; + + constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; + constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; + constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; + + constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0]; + constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1]; + + constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1]; + constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2]; + constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3]; + + constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; + constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; + + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + constexpr index_t E = C * Y * X; + constexpr index_t B = N * Ho * Wo; + + static_assert((X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0), + "wrong! aligment requirement for vectorized global load of input tensor will " + "be violated"); + + // divide block work by [K, B] + static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, + "wrong! cannot divide work evenly among block"); + + constexpr index_t KBlockWork = K / KPerBlock; + constexpr index_t BBlockWork = B / BPerBlock; + + constexpr auto block_work_desc = + make_ConstantTensorDescriptor_packed(Sequence{}); + + const auto block_work_multi_id = + block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); + + const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; + const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; + + // input tensor + // tensor descriptor in device memory [N, Ho, Wo] + constexpr auto in_n_ho_wo_global_desc = + in_n_c_h_w_global_desc.Extract(I0, I2, I3) + .StridedSlice(I1, Number{}, Number{}) + .StridedSlice(I2, Number{}, Number{}); + + // batch descritpor for device memory + constexpr auto in_c_y_x_global_desc = + in_n_c_h_w_global_desc.StridedSlice(I2, Number{}, Number{}) + .StridedSlice(I3, Number{}, Number{}) + .Extract(Sequence<1, 2, 3>{}); + + // merged tensor descriptor in device memory [E, B], src of blockwise copy + constexpr auto in_e_b_global_desc = + make_ConstantMergedTensorDescriptor(in_c_y_x_global_desc.Embed(in_n_ho_wo_global_desc), + Sequence<0, 1, 2>{}, + Sequence<3, 4, 5>{}); + + // memory layout descriptor in LDS [E, B], dst of blockwise copy + // be careful of LDS alignment + constexpr auto in_e_b_block_desc = + make_ConstantTensorDescriptor_packed(Sequence{}); + + // input blockwise copy + // slice a merged tensor, reorder and copy to a normal tensor + // this copy operator already has blockwise offset built-in + auto blockwise_in_copy = + BlockwiseGenericTensorSliceCopy_v2, + NormalTensorCoordinate, + decltype(in_e_b_block_desc.GetLengths()), + InBlockCopySubLengths_E_B, + InBlockCopyClusterLengths_E_B, + InBlockCopyThreadClusterArrangeOrder>( + {0, b_block_data_on_global}, {0, 0}); + + // weight tensor + // tensor descriptor in device memory, src of blockwise copy + constexpr auto wei_e_k_global_desc = + wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); + + // tensor descriptor in LDS, dst of blockwise copy + // be careful of LDS alignment + constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, + Number{}); + + // operator for blockwise copy of weight into LDS + // slice a tensor, and copy it into another tensor + // this copy operator already have blockwise offset built-in + auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< + BlockSize, + Float, + decltype(wei_e_k_global_desc), + decltype(wei_e_k_block_desc), + NormalTensorCoordinate, + NormalTensorCoordinate, + decltype(wei_e_k_block_desc.GetLengths()), + WeiBlockCopySubLengths_E_K, + WeiBlockCopyClusterLengths_E_K, + WeiBlockCopyThreadClusterArrangeOrder>({0, k_block_data_on_global}, {0, 0}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[EPerBlock, KPerBlock] is in LDS + // b_mtx[EPerBlocl, BPerBlock] is in LDS + // c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in + // register + constexpr auto a_e_k_block_mtx_desc = + make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(wei_e_k_block_desc); + + constexpr auto b_e_b_block_mtx_desc = + make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(in_e_b_block_desc); + + // sanity check + static_assert( + KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 && + BPerBlock % (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0, + "wrong!"); + + constexpr index_t GemmMRepeat = + KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); + + constexpr index_t GemmNRepeat = + BPerBlock / (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster); + + // c_thread_mtx definition: this is a mess + // TODO:: more elegent way of defining c_thread_mtx + constexpr auto c_k0k1_b0b1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( + Number{}, Number{}); + + const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< + BlockSize, + decltype(a_e_k_block_mtx_desc), + decltype(b_e_b_block_mtx_desc), + decltype(c_k0k1_b0b1_thread_mtx_desc), + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + GemmDataPerReadA, + GemmDataPerReadB>{}; + + // LDS allocation for input and weight: be careful of alignment + constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B, + WeiBlockCopyDstDataPerWrite_K, + GemmDataPerReadA, + GemmDataPerReadB); + + constexpr index_t in_block_space = + math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align); + + constexpr index_t wei_block_space = + math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align); + + __shared__ Float p_in_block_double[2 * in_block_space]; + __shared__ Float p_wei_block_double[2 * wei_block_space]; + + // register allocation for output + Float p_out_thread[c_k0k1_b0b1_thread_mtx_desc.GetElementSpace()]; + + // zero out threadwise output + threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread); + + const Float* p_wei_block_on_global = p_wei_global; + + // LDS double buffer: preload data into LDS + { + blockwise_in_copy.Run(p_in_global, p_in_block_double); + blockwise_wei_copy.Run(p_wei_global, p_wei_block_double); + } + + // LDS double buffer: main body + for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E; + e_block_data_begin += 2 * EPerBlock) + { +#pragma unroll + for(index_t iloop = 0; iloop < 2; ++iloop) + { + const bool even_loop = (iloop % 2 == 0); + + Float* p_in_block_now = + even_loop ? p_in_block_double : p_in_block_double + in_block_space; + Float* p_wei_block_now = + even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; + + Float* p_in_block_next = + even_loop ? p_in_block_double + in_block_space : p_in_block_double; + Float* p_wei_block_next = + even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; + + Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; + Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; + + blockwise_in_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); + blockwise_wei_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); + blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, + p_wei_register_buffer); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); + + // LDS double buffer: store next data to LDS + blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); + blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); + } + } + + // LDS double buffer: tail + { + Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; + Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; + + // even iteration + blockwise_in_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); + blockwise_wei_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); + blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); + + // LDS double buffer: store next data to LDS + blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, + p_in_block_double + in_block_space); + blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, + p_wei_block_double + wei_block_space); + + // odd iteration + __syncthreads(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_wei_block_double + wei_block_space, + p_in_block_double + in_block_space, + p_out_thread); + } + + // copy output: register to global memory + { + constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster; + constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster; + + // define tensor descriptor for threadwise copy + // output global descriptor, for calculating origin of thread tensor + // in global memory + constexpr auto out_k_b_global_desc = make_ConstantMergedTensorDescriptor( + out_n_k_h_w_global_desc, Sequence<1>{}, Sequence<0, 2, 3>{}); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const index_t k_thread_data_on_global = + k_block_data_on_global + c_thread_mtx_on_block.row; + + const index_t b_thread_data_on_global = + b_block_data_on_global + c_thread_mtx_on_block.col; + + // This is a hack, because slicing a merged dimension is not supported yet. + // This should be replaced with logic above, once slicing a merged dimension support + // become available + // dst descriptor + constexpr auto out_k0_k1_b_global_desc = + make_ConstantMergedTensorDescriptor(out_n_k_h_w_global_desc.Fold(I1, Number{}), + Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3, 4>{}); + + // src descriptor + constexpr auto out_k0_k1_b_thread_desc = make_ConstantTensorDescriptor_packed( + Sequence{}); + + using OutThreadCopySliceLengths = + Sequence; + + auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2< + Float, + decltype(out_k0_k1_b_thread_desc), + decltype(out_k0_k1_b_global_desc), + NormalTensorCoordinate, + MergedTensorCoordinate, + OutThreadCopySliceLengths>({0, 0, 0}, + {k_thread_data_on_global / K1, + k_thread_data_on_global % K1, + b_thread_data_on_global}); + + for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat) + { + threadwise_out_copy.Run(p_out_thread, p_out_global); + + threadwise_out_copy.MoveSrcSlicingWindow({0, 0, GemmNPerThreadSubC}, true); + threadwise_out_copy.MoveDstSlicingWindow({0, 0, B1}, true); + } + } + } +}; + +} // namespace ck +#endif diff --git a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 529e51378c..e1f950739a 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -4,7 +4,7 @@ #include "tensor.hpp" #include "gridwise_convolution_kernel_wrapper.hpp" #include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" -//#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp" +#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp" using namespace ck; @@ -132,39 +132,44 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); constexpr auto gridwise_conv = - GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw{}; +#if 0 + GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw +#else + GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer +#endif + {}; for(index_t i = 0; i < nrepeat; ++i) { diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index c9488b211a..4a75628952 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -101,159 +101,6 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; #elif 0 - // 3x3, 56x56 - constexpr index_t N = 64; - constexpr index_t C = 64; - constexpr index_t HI = 56; - constexpr index_t WI = 56; - constexpr index_t K = 128; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; -#elif 0 - // 3x3 filter, 28x28 image - constexpr index_t N = 128; - constexpr index_t C = 256; - constexpr index_t HI = 28; - constexpr index_t WI = 28; - constexpr index_t K = 128; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; -#elif 0 - // 1x1 filter, 28x28 image - constexpr index_t N = 128; - constexpr index_t C = 512; - constexpr index_t HI = 28; - constexpr index_t WI = 28; - constexpr index_t K = 512; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; -#elif 0 - // 3x3 filter, 20x84 image, 1x1 padding - constexpr index_t N = 16; - constexpr index_t C = 256; - constexpr index_t HI = 20; - constexpr index_t WI = 84; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - constexpr index_t HPad = 1; - constexpr index_t WPad = 1; -#elif 0 - // 3x3 filter, 112x112 image, 1x1 padding - constexpr index_t N = 16; - constexpr index_t C = 64; - constexpr index_t HI = 112; - constexpr index_t WI = 112; - constexpr index_t K = 128; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - constexpr index_t HPad = 1; - constexpr index_t WPad = 1; -#elif 0 - // 5x5 filter, 20x86 image - constexpr index_t N = 16; - constexpr index_t C = 256; - constexpr index_t HI = 20; - constexpr index_t WI = 86; - constexpr index_t K = 512; - constexpr index_t Y = 5; - constexpr index_t X = 5; - - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; -#elif 0 - // 5x5 filter, 20x86 image, 1x1 padding - constexpr index_t N = 16; - constexpr index_t C = 256; - constexpr index_t HI = 20; - constexpr index_t WI = 86; - constexpr index_t K = 512; - constexpr index_t Y = 5; - constexpr index_t X = 5; - - constexpr index_t HPad = 1; - constexpr index_t WPad = 1; -#elif 0 - // 5x5 filter, 28x28 image, 2x2 padding - constexpr index_t N = 16; - constexpr index_t C = 192; - constexpr index_t HI = 28; - constexpr index_t WI = 28; - constexpr index_t K = 32; - constexpr index_t Y = 5; - constexpr index_t X = 5; - - constexpr index_t HPad = 2; - constexpr index_t WPad = 2; -#elif 0 - // 3x3 filter, 14x14 image - constexpr index_t N = 128; - constexpr index_t C = 256; - constexpr index_t HI = 14; - constexpr index_t WI = 14; - constexpr index_t K = 128; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; -#elif 0 - // 1x1 filter, 14x14 image - constexpr index_t N = 128; - constexpr index_t C = 512; - constexpr index_t HI = 14; - constexpr index_t WI = 14; - constexpr index_t K = 512; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; -#elif 0 - // 1x1 filter, 7x7 image - constexpr index_t N = 128; - constexpr index_t C = 512; - constexpr index_t HI = 7; - constexpr index_t WI = 7; - constexpr index_t K = 2048; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; -#elif 0 - // 1x1 filter, 73x73 image - constexpr index_t N = 128; - constexpr index_t C = 512; - constexpr index_t HI = 73; - constexpr index_t WI = 73; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; -#elif 1 // 1x1 filter, 8x8 image // cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42% constexpr index_t N = 64; @@ -532,7 +379,7 @@ int main(int argc, char* argv[]) #elif 0 device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); -#elif 1 +#elif 0 device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, in_nchw, wei_kcyx_desc,