From 3439e4b5b72c0533cd60bb06ff076df6e4b004f5 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 25 Jan 2019 02:50:28 -0600 Subject: [PATCH] padding works (sort of), but code looks ugly. Tuned some resnet configs --- driver/conv.cu | 127 +++++++-- ...volution_1_chwn_csrk_khwn_with_padding.cuh | 245 ++++++++++++++++ src/include/blockwise_4d_tensor_op.cuh | 127 +++++++++ src/include/conv_common.cuh | 41 ++- ...volution_1_chwn_csrk_khwn_with_padding.cuh | 265 ++++++++++++++++++ 5 files changed, 785 insertions(+), 20 deletions(-) create mode 100644 driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh create mode 100644 src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh diff --git a/driver/conv.cu b/driver/conv.cu index 25c42ec611..de5996c7aa 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -11,6 +11,7 @@ #include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh" #include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh" #include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh" +#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh" #include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh" //#include "device_winograd_convolution.cuh" @@ -107,20 +108,31 @@ auto make_TensorDescriptor(TConstTensorDesc) return TensorDescriptor(lengths, strides); } -template -void host_direct_convolution(const Tensor& in_nchw, const Tensor& wei_kcsr, Tensor& out) +template +void host_direct_convolution( + const Tensor& in_nchw, const Tensor& wei_kcsr, Tensor& out, LowerPads, UpperPads) { + unsigned h_pad_low = LowerPads{}.Get(Number<0>{}); + unsigned w_pad_low = LowerPads{}.Get(Number<1>{}); + + unsigned h_pad_up = UpperPads{}.Get(Number<0>{}); + unsigned w_pad_up = UpperPads{}.Get(Number<1>{}); + auto f = [&](auto n, auto k, auto ho, auto wo) { double v = 0; for(int c = 0; c < wei_kcsr.mDesc.GetLengths()[1]; ++c) { for(int y = 0; y < wei_kcsr.mDesc.GetLengths()[2]; ++y) { - int hi = ho + y; + int hi = ho + y - h_pad_low; for(int x = 0; x < wei_kcsr.mDesc.GetLengths()[3]; ++x) { - int wi = wo + x; - v += in_nchw(n, c, hi, wi) * wei_kcsr(k, c, y, x); + int wi = wo + x - w_pad_low; + if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && + wi < in_nchw.mDesc.GetLengths()[3]) + { + v += in_nchw(n, c, hi, wi) * wei_kcsr(k, c, y, x); + } } } } @@ -136,10 +148,9 @@ void host_direct_convolution(const Tensor& in_nchw, const Tensor& wei_kcsr f_par(std::thread::hardware_concurrency()); } -template -void host_winograd_3x3_convolution(const Tensor& in_nchw, - const Tensor& wei_kcsr, - Tensor& out) +template +void host_winograd_3x3_convolution( + const Tensor& in_nchw, const Tensor& wei_kcsr, Tensor& out, LowerPads, UpperPads) { constexpr std::size_t OutTileSizeH = 2; constexpr std::size_t OutTileSizeW = 2; @@ -156,6 +167,12 @@ void host_winograd_3x3_convolution(const Tensor& in_nchw, std::size_t HO = out.mDesc.GetLengths()[2]; std::size_t WO = out.mDesc.GetLengths()[3]; + unsigned h_pad_low = LowerPads{}.Get(Number<0>{}); + unsigned w_pad_low = LowerPads{}.Get(Number<1>{}); + + unsigned h_pad_up = UpperPads{}.Get(Number<0>{}); + unsigned w_pad_up = UpperPads{}.Get(Number<1>{}); + std::size_t InTileSizeH = OutTileSizeH + S - 1; std::size_t InTileSizeW = OutTileSizeW + R - 1; @@ -171,11 +188,20 @@ void host_winograd_3x3_convolution(const Tensor& in_nchw, auto f_in_hold = [&](auto n, auto c, auto y, auto x) { for(int j = 0; j < InTileSizeH; ++j) { - std::size_t hi = OutTileSizeH * y + j; + int hi = OutTileSizeH * y + j - h_pad_low; for(int i = 0; i < InTileSizeW; ++i) { - std::size_t wi = OutTileSizeW * x + i; - in_hold(n, c, y, x, j, i) = in_nchw(n, c, hi, wi); + int wi = OutTileSizeW * x + i - w_pad_low; + + if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && + wi < in_nchw.mDesc.GetLengths()[3]) + { + in_hold(n, c, y, x, j, i) = in_nchw(n, c, hi, wi); + } + else + { + in_hold(n, c, y, x, j, i) = T(0); + } } } }; @@ -406,7 +432,7 @@ int main() constexpr unsigned K = 64; constexpr unsigned S = 7; constexpr unsigned R = 7; -#elif 1 +#elif 0 // 3x3, 58x58 constexpr unsigned N = 16; constexpr unsigned C = 128; @@ -415,12 +441,63 @@ int main() constexpr unsigned K = 256; constexpr unsigned S = 3; constexpr unsigned R = 3; +#elif 0 + // 3x3 filter, 58x58 image, 0x0 padding + constexpr unsigned N = 16; + constexpr unsigned C = 128; + constexpr unsigned HI = 58; + constexpr unsigned WI = 58; + constexpr unsigned K = 256; + constexpr unsigned S = 3; + constexpr unsigned R = 3; + + constexpr unsigned HPad = 0; + constexpr unsigned WPad = 0; +#elif 1 + // 3x3 filter, 56x56 image, 1x1 padding + constexpr unsigned N = 16; + constexpr unsigned C = 128; + constexpr unsigned HI = 56; + constexpr unsigned WI = 56; + constexpr unsigned K = 256; + constexpr unsigned S = 3; + constexpr unsigned R = 3; + + constexpr unsigned HPad = 1; + constexpr unsigned WPad = 1; +#elif 0 + // 3x3 filter, 28x28 image, 1x1 padding + constexpr unsigned N = 16; + constexpr unsigned C = 256; + constexpr unsigned HI = 28; + constexpr unsigned WI = 28; + constexpr unsigned K = 512; + constexpr unsigned S = 3; + constexpr unsigned R = 3; + + constexpr unsigned HPad = 1; + constexpr unsigned WPad = 1; +#elif 0 + // 3x3 filter, 20x84 image, 1x1 padding + constexpr unsigned N = 16; + constexpr unsigned C = 256; + constexpr unsigned HI = 20; + constexpr unsigned WI = 84; + constexpr unsigned K = 256; + constexpr unsigned S = 3; + constexpr unsigned R = 3; + + constexpr unsigned HPad = 1; + constexpr unsigned WPad = 1; #endif + auto lower_pads = Sequence{}; + auto upper_pads = Sequence{}; + auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence{}); auto wei_kcsr_desc = make_ConstantTensorDescriptor(Sequence{}); - auto out_nkhw_desc = - get_convolution_output_default_4d_tensor_descriptor(in_nchw_desc, wei_kcsr_desc); + auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor( + in_nchw_desc, wei_kcsr_desc, lower_pads, upper_pads); ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_ConstantTensorDescriptor(wei_kcsr_desc, std::cout << "wei_kcsr_desc: "); @@ -443,6 +520,7 @@ int main() unsigned nrepeat = 50; +#if 0 #if 0 device_direct_convolution_1 #elif 0 @@ -451,7 +529,7 @@ int main() device_implicit_gemm_convolution_1_nchw_kcsr #elif 0 device_implicit_gemm_convolution_1_nchw_srck_nkhw -#elif 1 +#elif 0 device_implicit_gemm_convolution_1_chwn_csrk_khwn #elif 0 device_implicit_gemm_convolution_2_cnhw_srck_knhw @@ -459,15 +537,28 @@ int main() device_winograd_convolution #endif (in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat); +#endif + +#if 1 + device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding(in_nchw_desc, + in_nchw, + wei_kcsr_desc, + wei_kcsr, + out_nkhw_desc, + out_nkhw_device, + lower_pads, + upper_pads, + nrepeat); +#endif #if 1 if(S == 3 && R == 3) { - host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host); + host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); } else { - host_direct_convolution(in_nchw, wei_kcsr, out_nkhw_host); + host_direct_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); } check_error(out_nkhw_host, out_nkhw_device); #endif diff --git a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh new file mode 100644 index 0000000000..94da496f5a --- /dev/null +++ b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh @@ -0,0 +1,245 @@ +#pragma once +#include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh" +#include + +template +void device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcsr, + OutDesc, + Tensor& out_nkhw, + LowerPads, + UpperPads, + unsigned nrepeat) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_nchw_desc = InDesc{}; + constexpr auto wei_kcsr_desc = WeiDesc{}; + constexpr auto out_nkhw_desc = OutDesc{}; + + constexpr unsigned Hi = in_nchw_desc.GetLength(I2); + constexpr unsigned Wi = in_nchw_desc.GetLength(I3); + + constexpr unsigned N = out_nkhw_desc.GetLength(I0); + constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); + constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); + + constexpr unsigned K = wei_kcsr_desc.GetLength(I0); + constexpr unsigned C = wei_kcsr_desc.GetLength(I1); + constexpr unsigned S = wei_kcsr_desc.GetLength(I2); + constexpr unsigned R = wei_kcsr_desc.GetLength(I3); + + // reorder weight + auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: "); + + Tensor wei_csrk(make_TensorDescriptor(wei_csrk_desc)); + + auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto s, auto r) { + wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r); + }; + + make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, S, R)( + std::thread::hardware_concurrency()); + + // reorder input + auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: "); + + Tensor in_chwn(make_TensorDescriptor(in_chwn_desc)); + + auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) { + in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); + }; + + make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)( + std::thread::hardware_concurrency()); + + // output + auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: "); + + Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); + + std::size_t data_sz = sizeof(T); + DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace()); + DeviceMem wei_csrk_device_buf(data_sz * wei_csrk.mDesc.GetElementSpace()); + DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace()); + + in_chwn_device_buf.ToDevice(in_chwn.mData.data()); + wei_csrk_device_buf.ToDevice(wei_csrk.mData.data()); + out_khwn_device_buf.ToDevice(out_khwn.mData.data()); + +#if 0 + constexpr unsigned NPerBlock = 1; + constexpr unsigned KPerBlock = 1; + constexpr unsigned CPerBlock = 1; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 1; + constexpr unsigned KPerThread = 1; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 8; +#elif 0 + // for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256 + constexpr unsigned NPerBlock = 16; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 4; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 128; +#elif 0 + // 3x3 58x58, NKC = 16,256,128 + constexpr unsigned NPerBlock = 8; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 2; + constexpr unsigned HoPerBlock = 4; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 128; +#elif 0 + // for 5x5, 36x36 + constexpr unsigned NPerBlock = 16; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 2; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 128; +#elif 0 + // for 7x7, 38x38 + constexpr unsigned NPerBlock = 8; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 2; + constexpr unsigned HoPerBlock = 4; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 128; +#elif 0 + // for 3x3, 56x56 + constexpr unsigned NPerBlock = 32; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 4; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 2; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 128; +#elif 1 + // 3x3 56x56, NKC = 16,256,128, with padding + // 3x3 28x28, NKC = 16,512,256, with padding + // 3x3 20x84, NKC = 16,256,256, with padding + constexpr unsigned NPerBlock = 16; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 2; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 128; +#endif + + constexpr unsigned GridSize = + ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * + ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); + + dim3 block_dim(BlockSize); + dim3 grid_dim(GridSize); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + for(unsigned i = 0; i < nrepeat; ++i) + { + cudaEvent_t start, stop; + float elapsedTime; + + cudaEventCreate(&start); + cudaEventRecord(start, 0); + + gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding + <<>>(static_cast(in_chwn_device_buf.GetDeviceBuffer()), + static_cast(wei_csrk_device_buf.GetDeviceBuffer()), + static_cast(out_khwn_device_buf.GetDeviceBuffer())); + + cudaEventCreate(&stop); + cudaEventRecord(stop, 0); + cudaEventSynchronize(stop); + + cudaEventElapsedTime(&elapsedTime, start, stop); + printf("Elapsed time : %f ms\n", elapsedTime); + + usleep(10000); + } + + checkCudaErrors(cudaGetLastError()); + out_khwn_device_buf.FromDevice(out_khwn.mData.data()); + + // reorder output + auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) { + out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); + }; + + make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)( + std::thread::hardware_concurrency()); +} diff --git a/src/include/blockwise_4d_tensor_op.cuh b/src/include/blockwise_4d_tensor_op.cuh index 4619140b05..eca267e2e3 100644 --- a/src/include/blockwise_4d_tensor_op.cuh +++ b/src/include/blockwise_4d_tensor_op.cuh @@ -211,6 +211,133 @@ struct blockwise_4d_tensor_copy_1 } }; +template +struct blockwise_chwn_tensor_copy_with_padding +{ + __device__ void run(Float* const __restrict__ p_src, + unsigned c_block_data_begin, + unsigned ho_block_data_begin, + unsigned wo_block_data_begin, + unsigned n_block_data_begin, + Float* __restrict__ p_dst, + unsigned h_block_pad_low, + unsigned w_block_pad_low, + unsigned h_block_pad_up, + unsigned w_block_pad_up) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto src_desc = SrcDesc{}; + constexpr auto dst_desc = DstDesc{}; + constexpr auto ref_desc = make_ConstantTensorDescriptor(DstOpLengths{}); + + constexpr auto h_global_pad_low = GlobalLowerPads{}.Get(I0); + constexpr auto w_global_pad_low = GlobalLowerPads{}.Get(I1); + + constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize; + + Float* const p_src_tmp = + p_src + src_desc.Get1dIndex(c_block_data_begin, + (ho_block_data_begin + h_block_pad_low) - h_global_pad_low, + (wo_block_data_begin + w_block_pad_low) - w_global_pad_low, + n_block_data_begin); + +#if 0 + if(get_thread_local_1d_id() == 0) + { + print_ConstantTensorDescriptor(src_desc, "src_desc: "); + print_ConstantTensorDescriptor(dst_desc, "dst_desc: "); + print_ConstantTensorDescriptor(ref_desc, "ref_desc: "); + + printf("%u %u, \t" + "h_global_pad_low %u w_global_pad_low %u \t" + "h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u \t" + "\n", + get_block_1d_id(), + get_thread_local_1d_id(), + h_global_pad_low, + w_global_pad_low, + h_block_pad_low, + w_block_pad_low, + h_block_pad_up, + w_block_pad_up); + } +#endif + + for(unsigned iloop = 0; iloop < NLoop; ++iloop) + { + unsigned is = threadIdx.x + iloop * BlockSize; + + unsigned did[4]; + + did[0] = is / ref_desc.GetStride(I0); + + is -= did[0] * ref_desc.GetStride(I0); + + did[1] = is / ref_desc.GetStride(I1); + + is -= did[1] * ref_desc.GetStride(I1); + + did[2] = is / ref_desc.GetStride(I2); + + is -= did[2] * ref_desc.GetStride(I2); + + did[3] = is / ref_desc.GetStride(I3); + + const unsigned bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]); + + p_dst[bindex] = + (did[1] < h_block_pad_low || did[1] + h_block_pad_up >= ref_desc.GetLength(I1) || + did[2] < w_block_pad_low || did[2] + w_block_pad_up >= ref_desc.GetLength(I2)) + ? Float(0) + : p_src_tmp[src_desc.Get1dIndex(did[0], did[1], did[2], did[3])]; + } + + constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize); + + if(has_tail) + { + unsigned is = threadIdx.x + NLoop * BlockSize; + + if(is < ref_desc.GetElementSize()) + { + unsigned did[4]; + + did[0] = is / ref_desc.GetStride(I0); + + is -= did[0] * ref_desc.GetStride(I0); + + did[1] = is / ref_desc.GetStride(I1); + + is -= did[1] * ref_desc.GetStride(I1); + + did[2] = is / ref_desc.GetStride(I2); + + is -= did[2] * ref_desc.GetStride(I2); + + did[3] = is / ref_desc.GetStride(I3); + + const unsigned bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]); + + p_dst[bindex] = + (did[1] < h_block_pad_low || + did[1] + h_block_pad_up >= ref_desc.GetLength(I1) || + did[2] < w_block_pad_low || did[2] + w_block_pad_up >= ref_desc.GetLength(I2)) + ? Float(0) + : p_src_tmp[src_desc.Get1dIndex(did[0], did[1], did[2], did[3])]; + } + } + } +}; + template struct blockwise_4d_tensor_copy_dummy { diff --git a/src/include/conv_common.cuh b/src/include/conv_common.cuh index f1e2b2c9f9..1b1c655b6f 100644 --- a/src/include/conv_common.cuh +++ b/src/include/conv_common.cuh @@ -27,8 +27,45 @@ __host__ __device__ constexpr auto get_convolution_output_default_4d_tensor_desc constexpr auto S = wei_desc.GetLength(I2); constexpr auto R = wei_desc.GetLength(I3); - constexpr auto HO = HI - S + 1; - constexpr auto WO = WI - R + 1; + constexpr auto HO = HI + 1 - S; + constexpr auto WO = WI + 1 - R; + + return make_ConstantTensorDescriptor(Sequence{}); +} + +template +__host__ __device__ constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor( + InDesc, WeiDesc, LowerPads, UpperPads) +{ + constexpr auto in_desc = InDesc{}; + constexpr auto wei_desc = WeiDesc{}; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + static_assert(in_desc.GetDimension() == 4, "input nDim is not 4"); + static_assert(wei_desc.GetDimension() == 4, "weight nDim is not 4"); + static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1), + "input & weight dimension not consistent"); + + constexpr auto N = in_desc.GetLength(I0); + constexpr auto HI = in_desc.GetLength(I2); + constexpr auto WI = in_desc.GetLength(I3); + + constexpr auto K = wei_desc.GetLength(I0); + constexpr auto S = wei_desc.GetLength(I2); + constexpr auto R = wei_desc.GetLength(I3); + + constexpr auto HPadLow = LowerPads{}.Get(I0); + constexpr auto WPadLow = LowerPads{}.Get(I1); + + constexpr auto HPadUp = UpperPads{}.Get(I0); + constexpr auto WPadUp = UpperPads{}.Get(I1); + + constexpr auto HO = HI + HPadLow + HPadUp + 1 - S; + constexpr auto WO = WI + WPadLow + WPadUp + 1 - R; return make_ConstantTensorDescriptor(Sequence{}); } diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh new file mode 100644 index 0000000000..6234e5d0de --- /dev/null +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh @@ -0,0 +1,265 @@ +#pragma once +#include "common.cuh" +#include "ConstantTensorDescriptor.cuh" +#include "ConstantMatrixDescriptor.cuh" +#include "blockwise_4d_tensor_op.cuh" +#include "threadwise_4d_tensor_op.cuh" +#include "gemm.cuh" + +template +__global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding( + Float* const __restrict__ p_in_global, + Float* const __restrict__ p_wei_global, + Float* __restrict__ p_out_global) +{ + // NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N] + // for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N" + // if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock + static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0"); + static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock, + "wrong!"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_chwn_global_desc = InGlobalDesc{}; + constexpr auto wei_csrk_global_desc = WeiGlobalDesc{}; + constexpr auto out_khwn_global_desc = OutGlobalDesc{}; + + constexpr unsigned C = in_chwn_global_desc.GetLength(I0); + + constexpr unsigned K = out_khwn_global_desc.GetLength(I0); + constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); + constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); + constexpr unsigned N = out_khwn_global_desc.GetLength(I3); + + constexpr unsigned S = wei_csrk_global_desc.GetLength(I1); + constexpr unsigned R = wei_csrk_global_desc.GetLength(I2); + + constexpr unsigned HPadLow = LowerPads{}.Get(I0); + constexpr unsigned WPadLow = LowerPads{}.Get(I1); + + constexpr unsigned HPadUp = UpperPads{}.Get(I0); + constexpr unsigned WPadUp = UpperPads{}.Get(I1); + + constexpr unsigned HiPerBlock = HoPerBlock + S - 1; + constexpr unsigned WiPerBlock = WoPerBlock + R - 1; + + // divide block work: [K, Ho, Wo, N] + constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; + constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; + constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock; + + const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); + unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); + const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork); + itmp -= h_block_work_id * (WBlockWork * NBlockWork); + const unsigned w_block_work_id = itmp / NBlockWork; + const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork; + + const unsigned k_block_data_begin = k_block_work_id * KPerBlock; + const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; + const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; + const unsigned n_block_data_begin = n_block_work_id * NPerBlock; + + // tensor view of blockwise input and weight in LDS + constexpr auto in_chwn_block_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto wei_csrk_block_desc = + make_ConstantTensorDescriptor(Sequence{}); + + // tensor view of threadwise output in register + constexpr auto out_hkwn_thread_desc = + make_ConstantTensorDescriptor(Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc"); + print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc"); + print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc"); + } +#endif + + // blockwise copy + // input: format is [C, Hi, Wi, N] + const unsigned h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0; + const unsigned w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0; + + const unsigned h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0; + const unsigned w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0; + +#if 0 + if(get_thread_local_1d_id() == 0) + ; + { + printf( + "%u %u, h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u\n", + get_block_1d_id(), + get_thread_local_1d_id(), + h_block_pad_low, + w_block_pad_low, + h_block_pad_up, + w_block_pad_up); + } +#endif + + constexpr auto blockwise_in_copy = + blockwise_chwn_tensor_copy_with_padding{}; + + // weight: format is [S,R,C,K] + constexpr auto blockwise_wei_copy = + blockwise_4d_tensor_copy_1{}; + + // a series of blockwise batched GEMM + // C_matrix += transpose(A_matrix) * B_matrix + // A_matrix and B_matrix saved in LDS, C_matrix saved in register + // A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K] + // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] + // C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N] + const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, + Number{}, + Number{}); // constexpr doesn't compile + + const auto b_cxwn_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, + Number{}, + Number{}); // constexpr doesn't compile + + const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + + const auto blockwise_batch_gemm = + blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c{}; + + // LDS + constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); + constexpr unsigned wei_block_size = wei_csrk_block_desc.GetElementSpace(); + + __shared__ Float p_in_block[in_block_size]; + __shared__ Float p_wei_block[wei_block_size]; + + // register + Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()]; + + // set threadwise output tensor to 0 + threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread); + + for(unsigned c_block_data_begin = 0; c_block_data_begin < C; + c_block_data_begin += CPerBlock, __syncthreads()) + { +#if 1 + // input: global mem to LDS, + blockwise_in_copy.run(p_in_global, + c_block_data_begin, + ho_block_data_begin, + wo_block_data_begin, + n_block_data_begin, + p_in_block, + h_block_pad_low, + w_block_pad_low, + h_block_pad_up, + w_block_pad_up); +#endif + +#if 1 + // weight: global mem to LDS, + blockwise_wei_copy.run(p_wei_global + wei_csrk_global_desc.Get1dIndex( + c_block_data_begin, 0, 0, k_block_data_begin), + p_wei_block); +#endif + + __syncthreads(); + + // a series of batched GEMM + for(unsigned s = 0; s < S; ++s) + { + for(unsigned r = 0; r < R; ++r) + { + auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; + + blockwise_batch_gemm.run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), + p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0), + p_out_thread, + f_accum); + } + } + } + + const auto matrix_c_index = + blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + + const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; + const unsigned k_thread_data_begin = matrix_c_index.row_begin; + const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock; + const unsigned n_thread_data_begin = + matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock; + +#if 0 + printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n", + get_block_1d_id(), get_thread_local_1d_id(), + ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin, + ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin, + p_out_thread[0]); +#endif + + // output: register to global mem, + // convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N] + constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{}; + + threadwise_4d_tensor_copy_reorder_by_get_dst_from_src( + out_hkwn_thread_desc, + p_out_thread, + out_khwn_global_desc, + p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_hkwn_thread_desc.GetLengths(), + reorder_khwn_from_hkwn); +}