diff --git a/driver/device_direct_convolution_1.hpp b/driver/device_direct_convolution_1.hpp index 99b184ec7c..57abc62432 100644 --- a/driver/device_direct_convolution_1.hpp +++ b/driver/device_direct_convolution_1.hpp @@ -34,25 +34,24 @@ void device_direct_convolution_1(InDesc, #if 1 // 3x3, 34x34 - constexpr unsigned OutTileSizeH = 2; - constexpr unsigned OutTileSizeW = 2; - constexpr unsigned NPerBlock = 2; - constexpr unsigned KPerBlock = 16; - constexpr unsigned CPerBlock = 2; - constexpr unsigned YPerBlock = 2; - constexpr unsigned XPerBlock = 16; + constexpr unsigned NPerBlock = 2; + constexpr unsigned KPerBlock = 16; + constexpr unsigned CPerBlock = 2; + constexpr unsigned HoPerBlock = 4; + constexpr unsigned WoPerBlock = 32; - constexpr unsigned NPerThread = 2; - constexpr unsigned KPerThread = 4; - constexpr unsigned CPerThread = 2; + constexpr unsigned NPerThread = 2; + constexpr unsigned KPerThread = 4; + constexpr unsigned CPerThread = 2; + constexpr unsigned HoPerThread = 2; + constexpr unsigned WoPerThread = 2; constexpr unsigned BlockSize = 128; #endif - constexpr unsigned GridSize = (out_desc.GetLength(I0) / NPerBlock) * - (out_desc.GetLength(I1) / KPerBlock) * - (out_desc.GetLength(I2) / (OutTileSizeH * YPerBlock)) * - (out_desc.GetLength(I3) / (OutTileSizeW * XPerBlock)); + constexpr unsigned GridSize = + (out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) * + (out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); @@ -62,16 +61,16 @@ void device_direct_convolution_1(InDesc, InDesc, WeiDesc, OutDesc, - OutTileSizeH, - OutTileSizeW, NPerBlock, KPerBlock, CPerBlock, - YPerBlock, - XPerBlock, + HoPerBlock, + WoPerBlock, NPerThread, KPerThread, CPerThread, + HoPerThread, + WoPerThread, BlockSize, GridSize>, dim3(GridSize), diff --git a/driver/device_direct_convolution_2.hpp b/driver/device_direct_convolution_2.hpp index f627719026..1baedafc46 100644 --- a/driver/device_direct_convolution_2.hpp +++ b/driver/device_direct_convolution_2.hpp @@ -34,40 +34,24 @@ void device_direct_convolution_2(InDesc, #if 1 // 3x3, 34x34, 128 thread - constexpr unsigned OutTileSizeH = 2; - constexpr unsigned OutTileSizeW = 2; - constexpr unsigned NPerBlock = 2; - constexpr unsigned KPerBlock = 32; - constexpr unsigned CPerBlock = 4; - constexpr unsigned YPerBlock = 1; - constexpr unsigned XPerBlock = 16; + constexpr unsigned NPerBlock = 2; + constexpr unsigned KPerBlock = 32; + constexpr unsigned CPerBlock = 4; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 32; - constexpr unsigned NPerThread = 2; - constexpr unsigned KPerThread = 4; - constexpr unsigned CPerThread = 2; + constexpr unsigned NPerThread = 2; + constexpr unsigned KPerThread = 4; + constexpr unsigned CPerThread = 2; + constexpr unsigned HoPerThread = 2; + constexpr unsigned WoPerThread = 2; constexpr unsigned BlockSize = 128; -#elif 0 - // 3x3, 34x34, 256 thread - constexpr unsigned OutTileSizeH = 2; - constexpr unsigned OutTileSizeW = 2; - constexpr unsigned NPerBlock = 2; - constexpr unsigned KPerBlock = 32; - constexpr unsigned CPerBlock = 4; - constexpr unsigned YPerBlock = 1; - constexpr unsigned XPerBlock = 32; - - constexpr unsigned NPerThread = 2; - constexpr unsigned KPerThread = 4; - constexpr unsigned CPerThread = 2; - - constexpr unsigned BlockSize = 256; #endif - constexpr unsigned GridSize = (out_desc.GetLength(I0) / NPerBlock) * - (out_desc.GetLength(I1) / KPerBlock) * - (out_desc.GetLength(I2) / (OutTileSizeH * YPerBlock)) * - (out_desc.GetLength(I3) / (OutTileSizeW * XPerBlock)); + constexpr unsigned GridSize = + (out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) * + (out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); @@ -77,16 +61,16 @@ void device_direct_convolution_2(InDesc, InDesc, WeiDesc, OutDesc, - OutTileSizeH, - OutTileSizeW, NPerBlock, KPerBlock, CPerBlock, - YPerBlock, - XPerBlock, + HoPerBlock, + WoPerBlock, NPerThread, KPerThread, CPerThread, + HoPerThread, + WoPerThread, BlockSize, GridSize>, dim3(GridSize), diff --git a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp index 14e4d29f74..fc2b245148 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp @@ -30,11 +30,11 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned K = wei_kcsr_desc.GetLength(I0); constexpr unsigned C = wei_kcsr_desc.GetLength(I1); - constexpr unsigned S = wei_kcsr_desc.GetLength(I2); - constexpr unsigned R = wei_kcsr_desc.GetLength(I3); + constexpr unsigned Y = wei_kcsr_desc.GetLength(I2); + constexpr unsigned X = wei_kcsr_desc.GetLength(I3); // reorder weight - auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence{}); + auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence{}); ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: "); Tensor wei_csrk(make_TensorDescriptor(wei_csrk_desc)); @@ -43,7 +43,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r); }; - make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, S, R)( + make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, Y, X)( std::thread::hardware_concurrency()); // reorder input diff --git a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp index 0c38e5206f..db0cb3aa90 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp +++ b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp @@ -2,7 +2,6 @@ #include #include "device.hpp" #include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hip.hpp" -#include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_pipeline.hip.hpp" template void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc, @@ -33,11 +32,11 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc, constexpr unsigned K = wei_kcsr_desc.GetLength(I0); constexpr unsigned C = wei_kcsr_desc.GetLength(I1); - constexpr unsigned S = wei_kcsr_desc.GetLength(I2); - constexpr unsigned R = wei_kcsr_desc.GetLength(I3); + constexpr unsigned Y = wei_kcsr_desc.GetLength(I2); + constexpr unsigned X = wei_kcsr_desc.GetLength(I3); // reorder weight - auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence{}); + auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence{}); ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: "); Tensor wei_csrk(make_TensorDescriptor(wei_csrk_desc)); @@ -46,7 +45,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc, wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r); }; - make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, S, R)( + make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, Y, X)( std::thread::hardware_concurrency()); // reorder input @@ -251,31 +250,26 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc, for(unsigned i = 0; i < nrepeat; ++i) { float time = launch_kernel( -#if 0 - gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded -#elif 1 - gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_pipeline -#endif - , + gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded, dim3(GridSize), dim3(BlockSize), diff --git a/driver/device_implicit_gemm_convolution_1_nchw_kcsr_nkhw.hpp b/driver/device_implicit_gemm_convolution_1_nchw_kcsr_nkhw.hpp deleted file mode 100644 index ea7dbea266..0000000000 --- a/driver/device_implicit_gemm_convolution_1_nchw_kcsr_nkhw.hpp +++ /dev/null @@ -1,85 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw.hip.hpp" - -template -void device_implicit_gemm_convolution_1_nchw_kcsr_nkhw(InDesc, - const Tensor& in, - WeiDesc, - const Tensor& wei, - OutDesc, - Tensor& out, - unsigned nrepeat) -{ - std::size_t data_sz = sizeof(T); - DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(data_sz * wei.mDesc.GetElementSpace()); - DeviceMem out_device_buf(data_sz * out.mDesc.GetElementSpace()); - - int num_thread = std::thread::hardware_concurrency(); - - in_device_buf.ToDevice(in.mData.data()); - wei_device_buf.ToDevice(wei.mData.data()); - out_device_buf.ToDevice(out.mData.data()); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_desc = InDesc{}; - constexpr auto wei_desc = WeiDesc{}; - constexpr auto out_desc = OutDesc{}; - -#if 1 - // 3x3, 34x34 - constexpr unsigned NPerBlock = 1; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 32; - - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 2; - constexpr unsigned WoPerThread = 2; - - constexpr unsigned BlockSize = 128; -#endif - - constexpr unsigned GridSize = - (out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) * - (out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - for(unsigned i = 0; i < nrepeat; ++i) - { - float time = launch_kernel(gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw, - dim3(GridSize), - dim3(BlockSize), - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms\n", time); - usleep(std::min(time * 1000, float(10000))); - } - - out_device_buf.FromDevice(out.mData.data()); -} diff --git a/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.hpp b/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.hpp deleted file mode 100644 index 166d392e5f..0000000000 --- a/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.hpp +++ /dev/null @@ -1,139 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.hip.hpp" - -template -void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcsr, - OutDesc, - Tensor& out_nkhw, - unsigned nrepeat) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcsr_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr unsigned N = out_nkhw_desc.GetLength(I0); - constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); - constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); - - constexpr unsigned K = wei_kcsr_desc.GetLength(I0); - constexpr unsigned C = wei_kcsr_desc.GetLength(I1); - constexpr unsigned S = wei_kcsr_desc.GetLength(I2); - constexpr unsigned R = wei_kcsr_desc.GetLength(I3); - - auto wei_srck_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(wei_srck_desc, std::cout << "wei_srck_desc: "); - - Tensor wei_srck(make_TensorDescriptor(wei_srck_desc)); - - auto f_reorder_kcsr2srck = [&](auto k, auto c, auto s, auto r) { - wei_srck(s, r, c, k) = wei_kcsr(k, c, s, r); - }; - - make_ParallelTensorFunctor(f_reorder_kcsr2srck, K, C, S, R)( - std::thread::hardware_concurrency()); - - std::size_t data_sz = sizeof(T); - DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); - DeviceMem wei_srck_device_buf(data_sz * wei_srck.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); - - int num_thread = std::thread::hardware_concurrency(); - - in_nchw_device_buf.ToDevice(in_nchw.mData.data()); - wei_srck_device_buf.ToDevice(wei_srck.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - -#if 1 - // for 3x3, 34x34 - constexpr unsigned NPerBlock = 1; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 32; - - constexpr unsigned NPerThread = 1; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 2; - constexpr unsigned WoPerThread = 2; - - constexpr unsigned BlockSize = 128; -#elif 0 - // for 3x3, 58x58 - constexpr unsigned NPerBlock = 4; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 8; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned BlockSize = 128; -#elif 1 - // for 3x3, 56x56 - constexpr unsigned NPerBlock = 32; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 2; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned BlockSize = 128; -#endif - - constexpr unsigned GridSize = - ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * - ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - for(unsigned i = 0; i < nrepeat; ++i) - { - float time = launch_kernel( - gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw, - dim3(GridSize), - dim3(BlockSize), - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_srck_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms\n", time); - usleep(std::min(time * 1000, float(10000))); - } - - out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); -} diff --git a/driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp index 1736a5e874..88fd5a2dea 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp @@ -30,10 +30,10 @@ void device_implicit_gemm_convolution_2_chwn_csrk_khwn(InDesc, constexpr unsigned K = wei_kcsr_desc.GetLength(I0); constexpr unsigned C = wei_kcsr_desc.GetLength(I1); - constexpr unsigned S = wei_kcsr_desc.GetLength(I2); - constexpr unsigned R = wei_kcsr_desc.GetLength(I3); + constexpr unsigned Y = wei_kcsr_desc.GetLength(I2); + constexpr unsigned X = wei_kcsr_desc.GetLength(I3); - constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1); + constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1); // convert in_nchw to in_cnhw auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence{}); @@ -49,7 +49,7 @@ void device_implicit_gemm_convolution_2_chwn_csrk_khwn(InDesc, Wi)(std::thread::hardware_concurrency()); // convert wei_kcsr to wei_csrk - auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence{}); + auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence{}); ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: "); Tensor wei_csrk(make_TensorDescriptor(wei_csrk_desc)); @@ -58,8 +58,8 @@ void device_implicit_gemm_convolution_2_chwn_csrk_khwn(InDesc, [&](auto k, auto c, auto s, auto r) { wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r); }, K, C, - S, - R)(std::thread::hardware_concurrency()); + Y, + X)(std::thread::hardware_concurrency()); // conver out_nkhw to out_knhw auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence{}); @@ -209,43 +209,39 @@ void device_implicit_gemm_convolution_2_chwn_csrk_khwn(InDesc, for(unsigned i = 0; i < nrepeat; ++i) { - float time = launch_kernel( -#if 0 - gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn -#else - gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer -#endif - , - dim3(GridSize), - dim3(BlockSize), - static_cast(in_chwn_device_buf.GetDeviceBuffer()), - static_cast(wei_csrk_device_buf.GetDeviceBuffer()), - static_cast(out_khwn_device_buf.GetDeviceBuffer())); + float time = + launch_kernel(gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer< + GridSize, + BlockSize, + T, + decltype(in_chwn_desc), + decltype(wei_csrk_desc), + decltype(out_khwn_desc), + BPerBlock, + KPerBlock, + CPerBlock, + BPerThread, + KPerThread, + GemmThreadPerColumnPerCluster, + GemmThreadPerRowPerCluster, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + InBlockCopyThreadPerDim0, + InBlockCopyThreadPerDim1, + WeiBlockCopyThreadPerDim0, + WeiBlockCopyThreadPerDim1, + InBlockCopyDataPerRead, + WeiBlockCopyDataPerRead>, + dim3(GridSize), + dim3(BlockSize), + static_cast(in_chwn_device_buf.GetDeviceBuffer()), + static_cast(wei_csrk_device_buf.GetDeviceBuffer()), + static_cast(out_khwn_device_buf.GetDeviceBuffer())); printf("Elapsed time : %f ms\n", time); usleep(std::min(time * 1000, float(10000))); diff --git a/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.hpp b/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.hpp deleted file mode 100644 index 870d808bc9..0000000000 --- a/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.hpp +++ /dev/null @@ -1,264 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.hip.hpp" -#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.hip.hpp" - -template -void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcsr, - OutDesc, - Tensor& out_nkhw, - unsigned nrepeat) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcsr_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr unsigned N = in_nchw_desc.GetLength(I0); - constexpr unsigned Hi = in_nchw_desc.GetLength(I2); - constexpr unsigned Wi = in_nchw_desc.GetLength(I3); - - constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); - constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); - - constexpr unsigned K = wei_kcsr_desc.GetLength(I0); - constexpr unsigned C = wei_kcsr_desc.GetLength(I1); - constexpr unsigned S = wei_kcsr_desc.GetLength(I2); - constexpr unsigned R = wei_kcsr_desc.GetLength(I3); - - constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1); - - // convert in_nchw to in_cnhw - auto in_cnhw_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(in_cnhw_desc, std::cout << "in_cnhw_desc: "); - - Tensor in_cnhw(make_TensorDescriptor(in_cnhw_desc)); - - auto f_reorder_nchw2cnhw = [&](auto n, auto c, auto hi, auto wi) { - in_cnhw(c, n, hi, wi) = in_nchw(n, c, hi, wi); - }; - - make_ParallelTensorFunctor(f_reorder_nchw2cnhw, N, C, Hi, Wi)( - std::thread::hardware_concurrency()); - - // convert wei_kcsr to wei_csrk - auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: "); - - Tensor wei_csrk(make_TensorDescriptor(wei_csrk_desc)); - - auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto s, auto r) { - wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r); - }; - - make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, S, R)( - std::thread::hardware_concurrency()); - - // conver out_nkhw to out_knhw - auto out_knhw_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(out_knhw_desc, std::cout << "out_knhw_desc: "); - - Tensor out_knhw(make_TensorDescriptor(out_knhw_desc)); - -#if 0 - // 3x3, 34x34 - // need to use register double buffer for GEMM - constexpr unsigned BPerBlock = 128; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; - - constexpr unsigned BPerThread = 8; - constexpr unsigned KPerThread = 8; - - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 2; - constexpr unsigned GemmNLevel1Cluster = 8; - constexpr unsigned GemmKPerThreadLoop = 1; - - constexpr unsigned GemmThreadPerColumnPerCluster = 8; - constexpr unsigned GemmThreadPerRowPerCluster = 8; - - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; - - constexpr unsigned InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; - - constexpr unsigned BlockSize = 128; -#elif 0 - // 1x1, 28x28, 64 threads - constexpr unsigned BPerBlock = 64; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 8; - - constexpr unsigned BPerThread = 8; - constexpr unsigned KPerThread = 8; - - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 2; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; - - constexpr unsigned GemmThreadPerColumnPerCluster = 8; - constexpr unsigned GemmThreadPerRowPerCluster = 8; - - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; - - constexpr unsigned InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; - - constexpr unsigned BlockSize = 64; -#elif 1 - // 1x1, 28x28, 128 threads, no lds-double-buffer - // 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128 - constexpr unsigned BPerBlock = 64; - constexpr unsigned KPerBlock = 128; - constexpr unsigned CPerBlock = 8; - - constexpr unsigned BPerThread = 8; - constexpr unsigned KPerThread = 8; - - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 4; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; - - constexpr unsigned GemmThreadPerColumnPerCluster = 8; - constexpr unsigned GemmThreadPerRowPerCluster = 8; - - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; - - constexpr unsigned InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; - - constexpr unsigned BlockSize = 128; -#elif 1 - // 1x1, 28x28, 256 thread - constexpr unsigned BPerBlock = 128; - constexpr unsigned KPerBlock = 128; - constexpr unsigned CPerBlock = 8; - - constexpr unsigned BPerThread = 8; - constexpr unsigned KPerThread = 8; - - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 4; - constexpr unsigned GemmMLevel1Cluster = 4; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; - - constexpr unsigned GemmThreadPerColumnPerCluster = 8; - constexpr unsigned GemmThreadPerRowPerCluster = 8; - - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; - - constexpr unsigned InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; - - constexpr unsigned BlockSize = 256; -#endif - - constexpr unsigned GridSize = - ((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - // mem - std::size_t data_sz = sizeof(T); - DeviceMem in_cnhw_device_buf(data_sz * (in_cnhw.mDesc.GetElementSpace() + BGhostRead + - BPerBlock)); // reserve extra space for BGhostRead - DeviceMem wei_csrk_device_buf(data_sz * wei_csrk.mDesc.GetElementSpace()); - DeviceMem out_knhw_device_buf(data_sz * out_knhw.mDesc.GetElementSpace()); - - in_cnhw_device_buf.ToDevice(in_cnhw.mData.data()); - wei_csrk_device_buf.ToDevice(wei_csrk.mData.data()); - out_knhw_device_buf.ToDevice(out_knhw.mData.data()); - - for(unsigned i = 0; i < nrepeat; ++i) - { - float time = launch_kernel( -#if 0 - gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw -#else - gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer -#endif - , - dim3(GridSize), - dim3(BlockSize), - static_cast(in_cnhw_device_buf.GetDeviceBuffer()), - static_cast(wei_csrk_device_buf.GetDeviceBuffer()), - static_cast(out_knhw_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms\n", time); - usleep(std::min(time * 1000, float(10000))); - } - - out_knhw_device_buf.FromDevice(out_knhw.mData.data()); - - // convert out_knhw to out_nkhw - auto f_reorder_knhw2nkhw = [&](auto n, auto k, auto ho, auto wo) { - out_nkhw(n, k, ho, wo) = out_knhw(k, n, ho, wo); - }; - - make_ParallelTensorFunctor(f_reorder_knhw2nkhw, N, K, Ho, Wo)( - std::thread::hardware_concurrency()); -} diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index cc234f5091..f534e4eda9 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -9,13 +9,9 @@ #include "conv_common.hip.hpp" #include "device_direct_convolution_1.hpp" #include "device_direct_convolution_2.hpp" -#include "device_implicit_gemm_convolution_1_nchw_kcsr_nkhw.hpp" -#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.hpp" #include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp" #include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp" -#include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw.hpp" #include "device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp" -//#include "device_winograd_convolution.hip.hpp" struct GeneratorTensor_1 { @@ -154,8 +150,8 @@ template void host_winograd_3x3_convolution( const Tensor& in_nchw, const Tensor& wei_kcsr, Tensor& out, LowerPads, UpperPads) { - constexpr std::size_t OutTileSizeH = 2; - constexpr std::size_t OutTileSizeW = 2; + constexpr std::size_t HoPerTile = 2; + constexpr std::size_t WoPerTile = 2; std::size_t N = in_nchw.mDesc.GetLengths()[0]; std::size_t C = in_nchw.mDesc.GetLengths()[1]; @@ -163,8 +159,8 @@ void host_winograd_3x3_convolution( std::size_t WI = in_nchw.mDesc.GetLengths()[3]; std::size_t K = wei_kcsr.mDesc.GetLengths()[0]; - std::size_t S = wei_kcsr.mDesc.GetLengths()[2]; - std::size_t R = wei_kcsr.mDesc.GetLengths()[3]; + std::size_t Y = wei_kcsr.mDesc.GetLengths()[2]; + std::size_t X = wei_kcsr.mDesc.GetLengths()[3]; std::size_t HO = out.mDesc.GetLengths()[2]; std::size_t WO = out.mDesc.GetLengths()[3]; @@ -175,75 +171,91 @@ void host_winograd_3x3_convolution( unsigned h_pad_up = UpperPads{}.Get(Number<0>{}); unsigned w_pad_up = UpperPads{}.Get(Number<1>{}); - std::size_t InTileSizeH = OutTileSizeH + S - 1; - std::size_t InTileSizeW = OutTileSizeW + R - 1; + std::size_t HiPerTile = HoPerTile + Y - 1; + std::size_t WiPerTile = WoPerTile + X - 1; - std::size_t Y = (HO + OutTileSizeH - 1) / OutTileSizeH; - std::size_t X = (WO + OutTileSizeW - 1) / OutTileSizeW; + std::size_t HTile = (HO + HoPerTile - 1) / HoPerTile; + std::size_t WTile = (WO + WoPerTile - 1) / WoPerTile; - Tensor in_hold({N, C, Y, X, InTileSizeH, InTileSizeW}); - Tensor in_transform({N, C, Y, X, InTileSizeH, InTileSizeW}); - Tensor wei_transform({K, C, InTileSizeH, InTileSizeW}); - Tensor out_transform({N, K, Y, X, InTileSizeH, InTileSizeH}); - Tensor out_hold({N, K, Y, X, OutTileSizeH, OutTileSizeW}); + Tensor in_hold({N, C, HTile, WTile, HiPerTile, WiPerTile}); + Tensor in_transform({N, C, HTile, WTile, HiPerTile, WiPerTile}); + Tensor wei_transform({K, C, HiPerTile, WiPerTile}); + Tensor out_transform({N, K, HTile, WTile, HiPerTile, HiPerTile}); + Tensor out_hold({N, K, HTile, WTile, HoPerTile, WoPerTile}); - auto f_in_hold = [&](auto n, auto c, auto y, auto x) { - for(int j = 0; j < InTileSizeH; ++j) + auto f_in_hold = [&](auto n, auto c, auto htile, auto wtile) { + for(int j = 0; j < HiPerTile; ++j) { - int hi = OutTileSizeH * y + j - h_pad_low; - for(int i = 0; i < InTileSizeW; ++i) + int hi = HoPerTile * htile + j - h_pad_low; + for(int i = 0; i < WiPerTile; ++i) { - int wi = OutTileSizeW * x + i - w_pad_low; + int wi = WoPerTile * wtile + i - w_pad_low; if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && wi < in_nchw.mDesc.GetLengths()[3]) { - in_hold(n, c, y, x, j, i) = in_nchw(n, c, hi, wi); + in_hold(n, c, htile, wtile, j, i) = in_nchw(n, c, hi, wi); } else { - in_hold(n, c, y, x, j, i) = T(0); + in_hold(n, c, htile, wtile, j, i) = T(0); } } } }; - auto f_in_transform = [&](auto n, auto c, auto y, auto x) { - in_transform(n, c, y, x, 0, 0) = in_hold(n, c, y, x, 0, 0) - in_hold(n, c, y, x, 0, 2) - - in_hold(n, c, y, x, 2, 0) + in_hold(n, c, y, x, 2, 2); - in_transform(n, c, y, x, 0, 1) = in_hold(n, c, y, x, 0, 1) + in_hold(n, c, y, x, 0, 2) - - in_hold(n, c, y, x, 2, 1) - in_hold(n, c, y, x, 2, 2); - in_transform(n, c, y, x, 0, 2) = -in_hold(n, c, y, x, 0, 1) + in_hold(n, c, y, x, 0, 2) + - in_hold(n, c, y, x, 2, 1) - in_hold(n, c, y, x, 2, 2); - in_transform(n, c, y, x, 0, 3) = in_hold(n, c, y, x, 0, 1) - in_hold(n, c, y, x, 0, 3) - - in_hold(n, c, y, x, 2, 1) + in_hold(n, c, y, x, 2, 3); + auto f_in_transform = [&](auto n, auto c, auto htile, auto wtile) { + in_transform(n, c, htile, wtile, 0, 0) = + in_hold(n, c, htile, wtile, 0, 0) - in_hold(n, c, htile, wtile, 0, 2) - + in_hold(n, c, htile, wtile, 2, 0) + in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 0, 1) = + in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) - + in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 0, 2) = + -in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) + + in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 0, 3) = + in_hold(n, c, htile, wtile, 0, 1) - in_hold(n, c, htile, wtile, 0, 3) - + in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 3); - in_transform(n, c, y, x, 1, 0) = in_hold(n, c, y, x, 1, 0) - in_hold(n, c, y, x, 1, 2) + - in_hold(n, c, y, x, 2, 0) - in_hold(n, c, y, x, 2, 2); - in_transform(n, c, y, x, 1, 1) = in_hold(n, c, y, x, 1, 1) + in_hold(n, c, y, x, 1, 2) + - in_hold(n, c, y, x, 2, 1) + in_hold(n, c, y, x, 2, 2); - in_transform(n, c, y, x, 1, 2) = -in_hold(n, c, y, x, 1, 1) + in_hold(n, c, y, x, 1, 2) - - in_hold(n, c, y, x, 2, 1) + in_hold(n, c, y, x, 2, 2); - in_transform(n, c, y, x, 1, 3) = in_hold(n, c, y, x, 1, 1) - in_hold(n, c, y, x, 1, 3) + - in_hold(n, c, y, x, 2, 1) - in_hold(n, c, y, x, 2, 3); + in_transform(n, c, htile, wtile, 1, 0) = + in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) + + in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 1, 1) = + in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) + + in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 1, 2) = + -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) - + in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 1, 3) = + in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) + + in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3); - in_transform(n, c, y, x, 2, 0) = -in_hold(n, c, y, x, 1, 0) + in_hold(n, c, y, x, 1, 2) + - in_hold(n, c, y, x, 2, 0) - in_hold(n, c, y, x, 2, 2); - in_transform(n, c, y, x, 2, 1) = -in_hold(n, c, y, x, 1, 1) - in_hold(n, c, y, x, 1, 2) + - in_hold(n, c, y, x, 2, 1) + in_hold(n, c, y, x, 2, 2); - in_transform(n, c, y, x, 2, 2) = in_hold(n, c, y, x, 1, 1) - in_hold(n, c, y, x, 1, 2) - - in_hold(n, c, y, x, 2, 1) + in_hold(n, c, y, x, 2, 2); - in_transform(n, c, y, x, 2, 3) = -in_hold(n, c, y, x, 1, 1) + in_hold(n, c, y, x, 1, 3) + - in_hold(n, c, y, x, 2, 1) - in_hold(n, c, y, x, 2, 3); + in_transform(n, c, htile, wtile, 2, 0) = + -in_hold(n, c, htile, wtile, 1, 0) + in_hold(n, c, htile, wtile, 1, 2) + + in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 2, 1) = + -in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) + + in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 2, 2) = + in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) - + in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 2, 3) = + -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 3) + + in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3); - in_transform(n, c, y, x, 3, 0) = in_hold(n, c, y, x, 1, 0) - in_hold(n, c, y, x, 1, 2) - - in_hold(n, c, y, x, 3, 0) + in_hold(n, c, y, x, 3, 2); - in_transform(n, c, y, x, 3, 1) = in_hold(n, c, y, x, 1, 1) + in_hold(n, c, y, x, 1, 2) - - in_hold(n, c, y, x, 3, 1) - in_hold(n, c, y, x, 3, 2); - in_transform(n, c, y, x, 3, 2) = -in_hold(n, c, y, x, 1, 1) + in_hold(n, c, y, x, 1, 2) + - in_hold(n, c, y, x, 3, 1) - in_hold(n, c, y, x, 3, 2); - in_transform(n, c, y, x, 3, 3) = in_hold(n, c, y, x, 1, 1) - in_hold(n, c, y, x, 1, 3) - - in_hold(n, c, y, x, 3, 1) + in_hold(n, c, y, x, 3, 3); + in_transform(n, c, htile, wtile, 3, 0) = + in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) - + in_hold(n, c, htile, wtile, 3, 0) + in_hold(n, c, htile, wtile, 3, 2); + in_transform(n, c, htile, wtile, 3, 1) = + in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) - + in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2); + in_transform(n, c, htile, wtile, 3, 2) = + -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) + + in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2); + in_transform(n, c, htile, wtile, 3, 3) = + in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) - + in_hold(n, c, htile, wtile, 3, 1) + in_hold(n, c, htile, wtile, 3, 3); }; auto f_wei_transform = [&](auto k, auto c) { @@ -292,69 +304,69 @@ void host_winograd_3x3_convolution( wei_transform(k, c, 3, 3) = wei_kcsr(k, c, 2, 2); }; - auto f_out_transform = [&](auto n, auto k, auto y, auto x) { - for(int j = 0; j < InTileSizeH; ++j) + auto f_out_transform = [&](auto n, auto k, auto htile, auto wtile) { + for(int j = 0; j < HiPerTile; ++j) { - for(int i = 0; i < InTileSizeW; ++i) + for(int i = 0; i < WiPerTile; ++i) { double v = 0; for(int c = 0; c < C; ++c) { - v += in_transform(n, c, y, x, j, i) * wei_transform(k, c, j, i); + v += in_transform(n, c, htile, wtile, j, i) * wei_transform(k, c, j, i); } - out_transform(n, k, y, x, j, i) = v; + out_transform(n, k, htile, wtile, j, i) = v; } } }; - auto f_out_hold = [&](auto n, auto k, auto y, auto x) { - out_hold(n, k, y, x, 0, 0) = - out_transform(n, k, y, x, 0, 0) + out_transform(n, k, y, x, 0, 1) + - out_transform(n, k, y, x, 0, 2) + out_transform(n, k, y, x, 1, 0) + - out_transform(n, k, y, x, 1, 1) + out_transform(n, k, y, x, 1, 2) + - out_transform(n, k, y, x, 2, 0) + out_transform(n, k, y, x, 2, 1) + - out_transform(n, k, y, x, 2, 2); - out_hold(n, k, y, x, 0, 1) = - out_transform(n, k, y, x, 0, 1) - out_transform(n, k, y, x, 0, 2) - - out_transform(n, k, y, x, 0, 3) + out_transform(n, k, y, x, 1, 1) - - out_transform(n, k, y, x, 1, 2) - out_transform(n, k, y, x, 1, 3) + - out_transform(n, k, y, x, 2, 1) - out_transform(n, k, y, x, 2, 2) - - out_transform(n, k, y, x, 2, 3); - out_hold(n, k, y, x, 1, 0) = - out_transform(n, k, y, x, 1, 0) + out_transform(n, k, y, x, 1, 1) + - out_transform(n, k, y, x, 1, 2) - out_transform(n, k, y, x, 2, 0) - - out_transform(n, k, y, x, 2, 1) - out_transform(n, k, y, x, 2, 2) - - out_transform(n, k, y, x, 3, 0) - out_transform(n, k, y, x, 3, 1) - - out_transform(n, k, y, x, 3, 2); - out_hold(n, k, y, x, 1, 1) = - out_transform(n, k, y, x, 1, 1) - out_transform(n, k, y, x, 1, 2) - - out_transform(n, k, y, x, 1, 3) - out_transform(n, k, y, x, 2, 1) + - out_transform(n, k, y, x, 2, 2) + out_transform(n, k, y, x, 2, 3) - - out_transform(n, k, y, x, 3, 1) + out_transform(n, k, y, x, 3, 2) + - out_transform(n, k, y, x, 3, 3); + auto f_out_hold = [&](auto n, auto k, auto htile, auto wtile) { + out_hold(n, k, htile, wtile, 0, 0) = + out_transform(n, k, htile, wtile, 0, 0) + out_transform(n, k, htile, wtile, 0, 1) + + out_transform(n, k, htile, wtile, 0, 2) + out_transform(n, k, htile, wtile, 1, 0) + + out_transform(n, k, htile, wtile, 1, 1) + out_transform(n, k, htile, wtile, 1, 2) + + out_transform(n, k, htile, wtile, 2, 0) + out_transform(n, k, htile, wtile, 2, 1) + + out_transform(n, k, htile, wtile, 2, 2); + out_hold(n, k, htile, wtile, 0, 1) = + out_transform(n, k, htile, wtile, 0, 1) - out_transform(n, k, htile, wtile, 0, 2) - + out_transform(n, k, htile, wtile, 0, 3) + out_transform(n, k, htile, wtile, 1, 1) - + out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 1, 3) + + out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) - + out_transform(n, k, htile, wtile, 2, 3); + out_hold(n, k, htile, wtile, 1, 0) = + out_transform(n, k, htile, wtile, 1, 0) + out_transform(n, k, htile, wtile, 1, 1) + + out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 2, 0) - + out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) - + out_transform(n, k, htile, wtile, 3, 0) - out_transform(n, k, htile, wtile, 3, 1) - + out_transform(n, k, htile, wtile, 3, 2); + out_hold(n, k, htile, wtile, 1, 1) = + out_transform(n, k, htile, wtile, 1, 1) - out_transform(n, k, htile, wtile, 1, 2) - + out_transform(n, k, htile, wtile, 1, 3) - out_transform(n, k, htile, wtile, 2, 1) + + out_transform(n, k, htile, wtile, 2, 2) + out_transform(n, k, htile, wtile, 2, 3) - + out_transform(n, k, htile, wtile, 3, 1) + out_transform(n, k, htile, wtile, 3, 2) + + out_transform(n, k, htile, wtile, 3, 3); }; - auto f_out = [&](auto n, auto k, auto y, auto x) { - for(int j = 0; j < OutTileSizeH; ++j) + auto f_out = [&](auto n, auto k, auto htile, auto wtile) { + for(int j = 0; j < HoPerTile; ++j) { - std::size_t ho = OutTileSizeH * y + j; - for(int i = 0; i < OutTileSizeW; ++i) + std::size_t ho = HoPerTile * htile + j; + for(int i = 0; i < WoPerTile; ++i) { - std::size_t wo = OutTileSizeW * x + i; - out(n, k, ho, wo) = out_hold(n, k, y, x, j, i); + std::size_t wo = WoPerTile * wtile + i; + out(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i); } } }; std::size_t num_thread = std::thread::hardware_concurrency(); - make_ParallelTensorFunctor(f_in_hold, N, C, Y, X)(num_thread); - make_ParallelTensorFunctor(f_in_transform, N, C, Y, X)(num_thread); + make_ParallelTensorFunctor(f_in_hold, N, C, HTile, WTile)(num_thread); + make_ParallelTensorFunctor(f_in_transform, N, C, HTile, WTile)(num_thread); make_ParallelTensorFunctor(f_wei_transform, K, C)(num_thread); - make_ParallelTensorFunctor(f_out_transform, N, K, Y, X)(num_thread); - make_ParallelTensorFunctor(f_out_hold, N, K, Y, X)(num_thread); - make_ParallelTensorFunctor(f_out, N, K, Y, X)(num_thread); + make_ParallelTensorFunctor(f_out_transform, N, K, HTile, WTile)(num_thread); + make_ParallelTensorFunctor(f_out_hold, N, K, HTile, WTile)(num_thread); + make_ParallelTensorFunctor(f_out, N, K, HTile, WTile)(num_thread); } template @@ -387,8 +399,8 @@ int main(int argc, char* argv[]) constexpr unsigned HI = 28; constexpr unsigned WI = 28; constexpr unsigned K = 1; - constexpr unsigned S = 3; - constexpr unsigned R = 3; + constexpr unsigned Y = 3; + constexpr unsigned X = 3; constexpr unsigned HPad = 0; constexpr unsigned WPad = 0; @@ -399,8 +411,8 @@ int main(int argc, char* argv[]) constexpr unsigned HI = 34; constexpr unsigned WI = 34; constexpr unsigned K = 64; - constexpr unsigned S = 3; - constexpr unsigned R = 3; + constexpr unsigned Y = 3; + constexpr unsigned X = 3; constexpr unsigned HPad = 0; constexpr unsigned WPad = 0; @@ -411,8 +423,8 @@ int main(int argc, char* argv[]) constexpr unsigned HI = 56; constexpr unsigned WI = 56; constexpr unsigned K = 64; - constexpr unsigned S = 3; - constexpr unsigned R = 3; + constexpr unsigned Y = 3; + constexpr unsigned X = 3; #elif 0 // 3x3, 58x58 constexpr unsigned N = 64; @@ -420,8 +432,8 @@ int main(int argc, char* argv[]) constexpr unsigned HI = 58; constexpr unsigned WI = 58; constexpr unsigned K = 64; - constexpr unsigned S = 3; - constexpr unsigned R = 3; + constexpr unsigned Y = 3; + constexpr unsigned X = 3; #elif 0 // 5x5, 36x36 constexpr unsigned N = 64; @@ -429,8 +441,8 @@ int main(int argc, char* argv[]) constexpr unsigned HI = 36; constexpr unsigned WI = 36; constexpr unsigned K = 64; - constexpr unsigned S = 5; - constexpr unsigned R = 5; + constexpr unsigned Y = 5; + constexpr unsigned X = 5; constexpr unsigned HPad = 0; constexpr unsigned WPad = 0; @@ -441,8 +453,8 @@ int main(int argc, char* argv[]) constexpr unsigned HI = 38; constexpr unsigned WI = 38; constexpr unsigned K = 64; - constexpr unsigned S = 7; - constexpr unsigned R = 7; + constexpr unsigned Y = 7; + constexpr unsigned X = 7; constexpr unsigned HPad = 0; constexpr unsigned WPad = 0; @@ -453,8 +465,8 @@ int main(int argc, char* argv[]) constexpr unsigned HI = 58; constexpr unsigned WI = 58; constexpr unsigned K = 256; - constexpr unsigned S = 3; - constexpr unsigned R = 3; + constexpr unsigned Y = 3; + constexpr unsigned X = 3; #elif 0 // 3x3 filter, 58x58 image, 0x0 padding constexpr unsigned N = 16; @@ -462,8 +474,8 @@ int main(int argc, char* argv[]) constexpr unsigned HI = 58; constexpr unsigned WI = 58; constexpr unsigned K = 256; - constexpr unsigned S = 3; - constexpr unsigned R = 3; + constexpr unsigned Y = 3; + constexpr unsigned X = 3; constexpr unsigned HPad = 0; constexpr unsigned WPad = 0; @@ -474,8 +486,8 @@ int main(int argc, char* argv[]) constexpr unsigned HI = 56; constexpr unsigned WI = 56; constexpr unsigned K = 256; - constexpr unsigned S = 3; - constexpr unsigned R = 3; + constexpr unsigned Y = 3; + constexpr unsigned X = 3; constexpr unsigned HPad = 1; constexpr unsigned WPad = 1; @@ -486,8 +498,8 @@ int main(int argc, char* argv[]) constexpr unsigned HI = 28; constexpr unsigned WI = 28; constexpr unsigned K = 512; - constexpr unsigned S = 3; - constexpr unsigned R = 3; + constexpr unsigned Y = 3; + constexpr unsigned X = 3; constexpr unsigned HPad = 1; constexpr unsigned WPad = 1; @@ -498,8 +510,8 @@ int main(int argc, char* argv[]) constexpr unsigned HI = 28; constexpr unsigned WI = 28; constexpr unsigned K = 512; - constexpr unsigned S = 1; - constexpr unsigned R = 1; + constexpr unsigned Y = 1; + constexpr unsigned X = 1; constexpr unsigned HPad = 0; constexpr unsigned WPad = 0; @@ -510,8 +522,8 @@ int main(int argc, char* argv[]) constexpr unsigned HI = 20; constexpr unsigned WI = 84; constexpr unsigned K = 256; - constexpr unsigned S = 3; - constexpr unsigned R = 3; + constexpr unsigned Y = 3; + constexpr unsigned X = 3; constexpr unsigned HPad = 1; constexpr unsigned WPad = 1; @@ -522,8 +534,8 @@ int main(int argc, char* argv[]) constexpr unsigned HI = 112; constexpr unsigned WI = 112; constexpr unsigned K = 128; - constexpr unsigned S = 3; - constexpr unsigned R = 3; + constexpr unsigned Y = 3; + constexpr unsigned X = 3; constexpr unsigned HPad = 1; constexpr unsigned WPad = 1; @@ -534,8 +546,8 @@ int main(int argc, char* argv[]) constexpr unsigned HI = 20; constexpr unsigned WI = 86; constexpr unsigned K = 512; - constexpr unsigned S = 5; - constexpr unsigned R = 5; + constexpr unsigned Y = 5; + constexpr unsigned X = 5; constexpr unsigned HPad = 1; constexpr unsigned WPad = 1; @@ -546,8 +558,8 @@ int main(int argc, char* argv[]) constexpr unsigned HI = 28; constexpr unsigned WI = 28; constexpr unsigned K = 32; - constexpr unsigned S = 5; - constexpr unsigned R = 5; + constexpr unsigned Y = 5; + constexpr unsigned X = 5; constexpr unsigned HPad = 2; constexpr unsigned WPad = 2; @@ -557,7 +569,7 @@ int main(int argc, char* argv[]) auto upper_pads = Sequence{}; auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence{}); - auto wei_kcsr_desc = make_ConstantTensorDescriptor(Sequence{}); + auto wei_kcsr_desc = make_ConstantTensorDescriptor(Sequence{}); auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor( in_nchw_desc, wei_kcsr_desc, lower_pads, upper_pads); @@ -600,14 +612,8 @@ int main(int argc, char* argv[]) device_direct_convolution_1 #elif 0 device_direct_convolution_2 -#elif 0 - device_implicit_gemm_convolution_1_nchw_kcsr_nkhw -#elif 0 - device_implicit_gemm_convolution_1_nchw_srck_nkhw #elif 1 device_implicit_gemm_convolution_1_chwn_csrk_khwn -#elif 0 - device_implicit_gemm_convolution_2_cnhw_csrk_knhw #elif 0 device_implicit_gemm_convolution_2_chwn_csrk_khwn #endif @@ -627,7 +633,7 @@ int main(int argc, char* argv[]) if(do_verification) { - if(S == 3 && R == 3) + if(Y == 3 && X == 3) { host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); } diff --git a/src/include/blockwise_direct_convolution.hip.hpp b/src/include/blockwise_direct_convolution.hip.hpp index c6f7965de0..247ff219f6 100644 --- a/src/include/blockwise_direct_convolution.hip.hpp +++ b/src/include/blockwise_direct_convolution.hip.hpp @@ -8,11 +8,11 @@ template + unsigned CPerThread, + unsigned HoPerThread, + unsigned WoPerThread> __device__ void blockwise_direct_convolution(InBlockDesc, Float* const __restrict__ p_in_block, WeiBlockDesc, @@ -29,19 +29,17 @@ __device__ void blockwise_direct_convolution(InBlockDesc, constexpr auto wei_block_desc = WeiBlockDesc{}; constexpr auto out_block_desc = OutBlockDesc{}; - constexpr unsigned S = wei_block_desc.GetLength(I2); - constexpr unsigned R = wei_block_desc.GetLength(I3); + constexpr unsigned Y = wei_block_desc.GetLength(I2); + constexpr unsigned X = wei_block_desc.GetLength(I3); - constexpr unsigned InTileSizeH = OutTileSizeH + S - 1; - constexpr unsigned InTileSizeW = OutTileSizeW + R - 1; + constexpr unsigned InTileSizeH = HoPerThread + Y - 1; + constexpr unsigned InTileSizeW = WoPerThread + X - 1; // divide thread work constexpr unsigned NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread; constexpr unsigned KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread; - constexpr unsigned YThreadWork = - (out_block_desc.GetLength(I2) + OutTileSizeH - 1) / OutTileSizeH; - constexpr unsigned XThreadWork = - (out_block_desc.GetLength(I3) + OutTileSizeW - 1) / OutTileSizeW; + constexpr unsigned YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread; + constexpr unsigned XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread; #if 0 if(threadIdx.x == 0) @@ -56,7 +54,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc, make_ConstantTensorDescriptor(Sequence{}); constexpr auto wei_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); + make_ConstantTensorDescriptor(Sequence{}); constexpr auto out_thread_desc = get_convolution_output_default_4d_tensor_descriptor(in_thread_desc, wei_thread_desc); @@ -86,8 +84,8 @@ __device__ void blockwise_direct_convolution(InBlockDesc, unsigned n_thread_data_begin = n_thread_work_id * NPerThread; unsigned k_thread_data_begin = k_thread_work_id * KPerThread; - unsigned ho_thread_data_begin = y_thread_work_id * OutTileSizeH; - unsigned wo_thread_data_begin = x_thread_work_id * OutTileSizeW; + unsigned ho_thread_data_begin = y_thread_work_id * HoPerThread; + unsigned wo_thread_data_begin = x_thread_work_id * WoPerThread; unsigned hi_thread_data_begin = ho_thread_data_begin; // minus padding unsigned wi_thread_data_begin = wo_thread_data_begin; // minus padding diff --git a/src/include/conv_common.hip.hpp b/src/include/conv_common.hip.hpp index c4ad19fcb9..7d909a7963 100644 --- a/src/include/conv_common.hip.hpp +++ b/src/include/conv_common.hip.hpp @@ -24,11 +24,11 @@ __host__ __device__ constexpr auto get_convolution_output_default_4d_tensor_desc constexpr auto WI = in_desc.GetLength(I3); constexpr auto K = wei_desc.GetLength(I0); - constexpr auto S = wei_desc.GetLength(I2); - constexpr auto R = wei_desc.GetLength(I3); + constexpr auto Y = wei_desc.GetLength(I2); + constexpr auto X = wei_desc.GetLength(I3); - constexpr auto HO = HI + 1 - S; - constexpr auto WO = WI + 1 - R; + constexpr auto HO = HI + 1 - Y; + constexpr auto WO = WI + 1 - X; return make_ConstantTensorDescriptor(Sequence{}); } @@ -55,8 +55,8 @@ __host__ __device__ constexpr auto get_convolution_with_padding_output_default_4 constexpr auto WI = in_desc.GetLength(I3); constexpr auto K = wei_desc.GetLength(I0); - constexpr auto S = wei_desc.GetLength(I2); - constexpr auto R = wei_desc.GetLength(I3); + constexpr auto Y = wei_desc.GetLength(I2); + constexpr auto X = wei_desc.GetLength(I3); constexpr auto HPadLow = LowerPads{}.Get(I0); constexpr auto WPadLow = LowerPads{}.Get(I1); @@ -64,8 +64,8 @@ __host__ __device__ constexpr auto get_convolution_with_padding_output_default_4 constexpr auto HPadUp = UpperPads{}.Get(I0); constexpr auto WPadUp = UpperPads{}.Get(I1); - constexpr auto HO = HI + HPadLow + HPadUp + 1 - S; - constexpr auto WO = WI + WPadLow + WPadUp + 1 - R; + constexpr auto HO = HI + HPadLow + HPadUp + 1 - Y; + constexpr auto WO = WI + WPadLow + WPadUp + 1 - X; return make_ConstantTensorDescriptor(Sequence{}); } diff --git a/src/include/gridwise_direct_convolution_1.hip.hpp b/src/include/gridwise_direct_convolution_1.hip.hpp index 49129b24d3..f4fe1809fc 100644 --- a/src/include/gridwise_direct_convolution_1.hip.hpp +++ b/src/include/gridwise_direct_convolution_1.hip.hpp @@ -8,16 +8,16 @@ template __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_in_global, @@ -33,25 +33,22 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ constexpr auto wei_global_desc = WeiGlobalDesc{}; constexpr auto out_global_desc = OutGlobalDesc{}; - constexpr unsigned S = wei_global_desc.GetLength(I2); - constexpr unsigned R = wei_global_desc.GetLength(I3); + constexpr unsigned Y = wei_global_desc.GetLength(I2); + constexpr unsigned X = wei_global_desc.GetLength(I3); - constexpr unsigned HoPerBlock = OutTileSizeH * YPerBlock; - constexpr unsigned WoPerBlock = OutTileSizeW * XPerBlock; - - constexpr unsigned HiPerBlock = YPerBlock * OutTileSizeH + S - 1; - constexpr unsigned WiPerBlock = XPerBlock * OutTileSizeW + R - 1; + constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; + constexpr unsigned WiPerBlock = WoPerBlock + X - 1; constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; constexpr unsigned KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; - constexpr unsigned YBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; - constexpr unsigned XBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; + constexpr unsigned HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; + constexpr unsigned WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; constexpr auto in_block_global_desc = make_ConstantTensorDescriptor( Sequence{}, in_global_desc.GetStrides()); constexpr auto wei_block_global_desc = make_ConstantTensorDescriptor( - Sequence{}, wei_global_desc.GetStrides()); + Sequence{}, wei_global_desc.GetStrides()); constexpr auto out_block_global_desc = make_ConstantTensorDescriptor( Sequence{}, out_global_desc.GetStrides()); @@ -73,52 +70,21 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ const unsigned block_id = blockIdx.x; unsigned itmp = block_id; - unsigned n_block_work_id = itmp / (KBlockWork * YBlockWork * XBlockWork); - itmp -= n_block_work_id * (KBlockWork * YBlockWork * XBlockWork); - unsigned k_block_work_id = itmp / (YBlockWork * XBlockWork); - itmp -= k_block_work_id * (YBlockWork * XBlockWork); - unsigned y_block_work_id = itmp / XBlockWork; - unsigned x_block_work_id = itmp - y_block_work_id * XBlockWork; + unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); + itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); + unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork); + itmp -= k_block_work_id * (HBlockWork * WBlockWork); + unsigned h_block_work_id = itmp / WBlockWork; + unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork; - unsigned n_block_work_begin = n_block_work_id * NPerBlock; - unsigned k_block_work_begin = k_block_work_id * KPerBlock; - unsigned y_block_work_begin = y_block_work_id * YPerBlock; - unsigned x_block_work_begin = x_block_work_id * XPerBlock; - - unsigned ho_block_work_begin = y_block_work_begin * OutTileSizeH; - unsigned wo_block_work_begin = x_block_work_begin * OutTileSizeW; + unsigned n_block_work_begin = n_block_work_id * NPerBlock; + unsigned k_block_work_begin = k_block_work_id * KPerBlock; + unsigned ho_block_work_begin = h_block_work_id * HoPerBlock; + unsigned wo_block_work_begin = w_block_work_id * WoPerBlock; unsigned hi_block_work_begin = ho_block_work_begin; // minus padding unsigned wi_block_work_begin = wo_block_work_begin; // minus padding -#if 0 - if(threadIdx.x == 0) - { - print_ConstantTensorDescriptor( in_global_desc, "gridwise_convolution: in_global_desc: "); - print_ConstantTensorDescriptor(wei_global_desc, "gridwise_convolution: wei_global_desc: "); - print_ConstantTensorDescriptor(out_global_desc, "gridwise_convolution: out_global_desc: "); - print_ConstantTensorDescriptor( in_block_global_desc, "gridwise_convolution: in_block_global_desc: "); - print_ConstantTensorDescriptor(wei_block_global_desc, "gridwise_convolution: wei_block_global_desc: "); - print_ConstantTensorDescriptor(out_block_global_desc, "gridwise_convolution: out_block_global_desc: "); - print_ConstantTensorDescriptor( in_block_desc, "gridwise_convolution: in_block_desc: "); - print_ConstantTensorDescriptor(wei_block_desc, "gridwise_convolution: wei_block_desc: "); - print_ConstantTensorDescriptor(out_block_desc, "gridwise_convolution: out_block_desc: "); - - printf("NBlockWork %u, KBlockWork %u, YBlockWork %u, XBlockWork %u \t" - "block_id %u, n_block_work_id %u, k_block_work_id %u, y_block_work_id %u, " - "x_block_work_id %u\n", - NBlockWork, - KBlockWork, - YBlockWork, - XBlockWork, - block_id, - n_block_work_id, - k_block_work_id, - y_block_work_id, - x_block_work_id); - } -#endif - constexpr auto blockwise_in_copy = Blockwise4dTensorCopy1( + CPerThread, + HoPerThread, + WoPerThread>( in_block_desc, p_in_block, wei_block_desc, p_wei_block, out_block_desc, p_out_block); __syncthreads(); diff --git a/src/include/gridwise_direct_convolution_2.hip.hpp b/src/include/gridwise_direct_convolution_2.hip.hpp index 396f9a69a4..13f9e6cf1d 100644 --- a/src/include/gridwise_direct_convolution_2.hip.hpp +++ b/src/include/gridwise_direct_convolution_2.hip.hpp @@ -10,16 +10,16 @@ template __global__ void gridwise_direct_convolution_2(const Float* const __restrict__ p_in_global, @@ -35,20 +35,17 @@ __global__ void gridwise_direct_convolution_2(const Float* const __restrict__ p_ constexpr auto wei_global_desc = WeiGlobalDesc{}; constexpr auto out_global_desc = OutGlobalDesc{}; - constexpr unsigned S = wei_global_desc.GetLength(I2); - constexpr unsigned R = wei_global_desc.GetLength(I3); + constexpr unsigned Y = wei_global_desc.GetLength(I2); + constexpr unsigned X = wei_global_desc.GetLength(I3); - constexpr unsigned HoPerBlock = OutTileSizeH * YPerBlock; - constexpr unsigned WoPerBlock = OutTileSizeW * XPerBlock; - - constexpr unsigned HiPerBlock = YPerBlock * OutTileSizeH + S - 1; - constexpr unsigned WiPerBlock = XPerBlock * OutTileSizeW + R - 1; + constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; + constexpr unsigned WiPerBlock = WoPerBlock + X - 1; constexpr auto in_block_desc = make_ConstantTensorDescriptor(Sequence{}); constexpr auto wei_block_desc = - make_ConstantTensorDescriptor(Sequence{}); + make_ConstantTensorDescriptor(Sequence{}); // shared mem constexpr unsigned in_block_size = in_block_desc.GetElementSpace(); @@ -58,14 +55,14 @@ __global__ void gridwise_direct_convolution_2(const Float* const __restrict__ p_ __shared__ Float p_wei_block[wei_block_size]; // threadwise tensors - constexpr unsigned InTileSizeH = OutTileSizeH + S - 1; - constexpr unsigned InTileSizeW = OutTileSizeW + R - 1; + constexpr unsigned HiPerThread = HoPerThread + Y - 1; + constexpr unsigned WiPerThread = WoPerThread + X - 1; constexpr auto in_thread_block_desc = make_ConstantTensorDescriptor( - Sequence{}, in_block_desc.GetStrides()); + Sequence{}, in_block_desc.GetStrides()); constexpr auto wei_thread_block_desc = make_ConstantTensorDescriptor( - Sequence{}, wei_block_desc.GetStrides()); + Sequence{}, wei_block_desc.GetStrides()); constexpr auto out_thread_desc = get_convolution_output_default_4d_tensor_descriptor( in_thread_block_desc, wei_thread_block_desc); @@ -76,26 +73,23 @@ __global__ void gridwise_direct_convolution_2(const Float* const __restrict__ p_ // divide block work constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; constexpr unsigned KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; - constexpr unsigned YBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; - constexpr unsigned XBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; + constexpr unsigned HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; + constexpr unsigned WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; const unsigned block_id = blockIdx.x; unsigned itmp = block_id; - const unsigned n_block_work_id = itmp / (KBlockWork * YBlockWork * XBlockWork); - itmp -= n_block_work_id * (KBlockWork * YBlockWork * XBlockWork); - const unsigned k_block_work_id = itmp / (YBlockWork * XBlockWork); - itmp -= k_block_work_id * (YBlockWork * XBlockWork); - const unsigned y_block_work_id = itmp / XBlockWork; - const unsigned x_block_work_id = itmp - y_block_work_id * XBlockWork; + const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); + itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); + const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork); + itmp -= k_block_work_id * (HBlockWork * WBlockWork); + const unsigned h_block_work_id = itmp / WBlockWork; + const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork; - const unsigned n_block_data_begin = n_block_work_id * NPerBlock; - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned y_block_data_begin = y_block_work_id * YPerBlock; - const unsigned x_block_data_begin = x_block_work_id * XPerBlock; - - const unsigned ho_block_data_begin = y_block_data_begin * OutTileSizeH; - const unsigned wo_block_data_begin = x_block_data_begin * OutTileSizeW; + const unsigned n_block_data_begin = n_block_work_id * NPerBlock; + const unsigned k_block_data_begin = k_block_work_id * KPerBlock; + const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; + const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; const unsigned hi_block_data_begin = ho_block_data_begin; // minus padding const unsigned wi_block_data_begin = wo_block_data_begin; // minus padding @@ -103,45 +97,27 @@ __global__ void gridwise_direct_convolution_2(const Float* const __restrict__ p_ // divide thread work constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; constexpr unsigned KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread; - constexpr unsigned YThreadWork = YPerBlock; - constexpr unsigned XThreadWork = XPerBlock; + constexpr unsigned HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread; + constexpr unsigned WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread; const unsigned thread_id = threadIdx.x; itmp = thread_id; - const unsigned n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork); - itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork); - const unsigned k_thread_work_id = itmp / (YThreadWork * XThreadWork); - itmp -= k_thread_work_id * (YThreadWork * XThreadWork); - const unsigned y_thread_work_id = itmp / XThreadWork; - const unsigned x_thread_work_id = itmp - y_thread_work_id * XThreadWork; + const unsigned n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork); + itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork); + const unsigned k_thread_work_id = itmp / (HThreadWork * WThreadWork); + itmp -= k_thread_work_id * (HThreadWork * WThreadWork); + const unsigned h_thread_work_id = itmp / WThreadWork; + const unsigned w_thread_work_id = itmp - h_thread_work_id * WThreadWork; const unsigned n_thread_data_begin = n_thread_work_id * NPerThread; const unsigned k_thread_data_begin = k_thread_work_id * KPerThread; - const unsigned ho_thread_data_begin = y_thread_work_id * OutTileSizeH; - const unsigned wo_thread_data_begin = x_thread_work_id * OutTileSizeW; + const unsigned ho_thread_data_begin = h_thread_work_id * HoPerThread; + const unsigned wo_thread_data_begin = w_thread_work_id * WoPerThread; const unsigned hi_thread_data_begin = ho_thread_data_begin; const unsigned wi_thread_data_begin = wo_thread_data_begin; -#if 0 - if(threadIdx.x == 0) - { - print_ConstantTensorDescriptor(in_global_desc, "gridwise_convolution: in_global_desc: "); - print_ConstantTensorDescriptor(wei_global_desc, "gridwise_convolution: wei_global_desc: "); - print_ConstantTensorDescriptor(out_global_desc, "gridwise_convolution: out_global_desc: "); - } - - printf("threadIdx.x %u \t" - "n_thread_data_begin %u, k_thread_data_begin %u, ho_thread_data_begin %u, " - "wo_thread_data_begin %u\n", - threadIdx.x, - n_thread_data_begin, - k_thread_data_begin, - ho_thread_data_begin, - wo_thread_data_begin); -#endif - constexpr auto blockwise_in_copy = Blockwise4dTensorCopy1{}); + constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); // tensor view of blockwise input and weight in LDS // be careful of alignment @@ -98,10 +98,10 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric Sequence{}, Number{}); constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + Sequence{}, Number{}); constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + Sequence{}, Number{}); // tensor view of threadwise output in register constexpr auto out_khwn_thread_desc = @@ -118,7 +118,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric InBlockCopyDataPerRead>{}; // blockwise wei copy - // format is [CPerBlock*S*R,KPerBlock] + // format is [CPerBlock*Y*X,KPerBlock] const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}); + constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); // tensor view of blockwise input and weight in LDS constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor(Sequence{}); constexpr auto wei_csrk_block_desc = - make_ConstantTensorDescriptor(Sequence{}); + make_ConstantTensorDescriptor(Sequence{}); // flattened (2d) tensor view of wei in LDS constexpr auto wei_ek_block_desc = - make_ConstantTensorDescriptor(Sequence{}); + make_ConstantTensorDescriptor(Sequence{}); // tensor view of threadwise output in register constexpr auto out_hkwn_thread_desc = @@ -144,7 +144,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded( LowerPads>{}; #if 0 - // weight: format is [C,S,R,K] + // weight: format is [C,Y,X,K] constexpr auto blockwise_wei_copy = Blockwise4dTensorCopy1{}; #elif 0 - // weight: format is [C*S*R,K] + // weight: format is [C*Y*X,K] constexpr auto blockwise_wei_copy = Blockwise2dTensorCopy1{}; #elif 1 - // weight: format is [C*S*R,K] + // weight: format is [C*Y*X,K] const auto blockwise_wei_copy = Blockwise2dTensorCopy2 -__global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_pipeline( - const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) -{ - // NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N] - // for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N" - // if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock - static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0"); - static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock, - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_chwn_global_desc = InGlobalDesc{}; - constexpr auto wei_csrk_global_desc = WeiGlobalDesc{}; - constexpr auto out_khwn_global_desc = OutGlobalDesc{}; - - constexpr unsigned C = in_chwn_global_desc.GetLength(I0); - - constexpr unsigned K = out_khwn_global_desc.GetLength(I0); - constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); - constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); - constexpr unsigned N = out_khwn_global_desc.GetLength(I3); - - constexpr unsigned S = wei_csrk_global_desc.GetLength(I1); - constexpr unsigned R = wei_csrk_global_desc.GetLength(I2); - - constexpr unsigned HPadLow = LowerPads{}.Get(I0); - constexpr unsigned WPadLow = LowerPads{}.Get(I1); - - constexpr unsigned HPadUp = UpperPads{}.Get(I0); - constexpr unsigned WPadUp = UpperPads{}.Get(I1); - - constexpr unsigned HiPerBlock = HoPerBlock + S - 1; - constexpr unsigned WiPerBlock = WoPerBlock + R - 1; - - // divide block work: [K, Ho, Wo, N] - constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; - constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; - constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock; - - const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); - unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); - const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork); - itmp -= h_block_work_id * (WBlockWork * NBlockWork); - const unsigned w_block_work_id = itmp / NBlockWork; - const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork; - - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; - const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; - const unsigned n_block_data_begin = n_block_work_id * NPerBlock; - - // flattened (2d) tensor view of wei in global mem - constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of blockwise input and weight in LDS - constexpr auto in_chwn_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - constexpr auto wei_csrk_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - // flattened (2d) tensor view of wei in LDS - constexpr auto wei_ek_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of threadwise output in register - constexpr auto out_hkwn_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc"); - print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc"); - print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc"); - } -#endif - - // blockwise copy - // input: format is [C, Hi, Wi, N] - const unsigned h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0; - const unsigned w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0; - - const unsigned h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0; - const unsigned w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0; - -#if 0 - if(get_thread_local_1d_id() == 0) - ; - { - printf( - "%u %u, h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u\n", - get_block_1d_id(), - get_thread_local_1d_id(), - h_block_pad_low, - w_block_pad_low, - h_block_pad_up, - w_block_pad_up); - } -#endif - - constexpr auto blockwise_in_copy = - BlockwiseChwnTensorCopyPadded{}; - -#if 0 - // weight: format is [C,S,R,K] - constexpr auto blockwise_wei_copy = - Blockwise4dTensorCopy1{}; -#elif 0 - // weight: format is [C*S*R,K] - constexpr auto blockwise_wei_copy = - Blockwise2dTensorCopy1{}; -#elif 1 - // weight: format is [C*S*R,K] - const auto blockwise_wei_copy = Blockwise2dTensorCopy2{}; -#endif - - // a series of blockwise batched GEMM - // C_matrix += transpose(A_matrix) * B_matrix - // A_matrix and B_matrix saved in LDS, C_matrix saved in register - // A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K] - // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] - // C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N] - constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_cxwn_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - constexpr auto c_kxwn_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - const auto blockwise_batch_gemm = - Blockwise1dStridedBatchedGemmBlockABlockBThreadC{}; - - // LDS - constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); - constexpr unsigned wei_block_size = wei_csrk_block_desc.GetElementSpace(); - - // LDS double buffer - __shared__ Float p_in_block_0[in_block_size]; - __shared__ Float p_wei_block_0[wei_block_size]; - - __shared__ Float p_in_block_1[in_block_size]; - __shared__ Float p_wei_block_1[wei_block_size]; - - // register - Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()]; - - // set threadwise output tensor to 0 - threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread); - - const Float* p_wei_global_block_begin = - p_wei_global + wei_ek_global_desc.Get1dIndex(0, k_block_data_begin); - - // prelog: load data - // input: global mem to LDS, - blockwise_in_copy.Run(p_in_global, - 0, - ho_block_data_begin, - wo_block_data_begin, - n_block_data_begin, - p_in_block_0, - h_block_pad_low, - w_block_pad_low, - h_block_pad_up, - w_block_pad_up); - - // weight: global mem to LDS, - blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block_0); - - p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0); - - bool even_loop = true; - - for(unsigned c_block_data_begin = CPerBlock; c_block_data_begin < C; - c_block_data_begin += CPerBlock, - p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0), - even_loop = !even_loop) - { - __syncthreads(); - - Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; - Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; - - Float* p_in_block_next = even_loop ? p_in_block_1 : p_in_block_0; - Float* p_wei_block_next = even_loop ? p_wei_block_1 : p_wei_block_0; - -// preload next data -#if 1 - // input: global mem to LDS, - blockwise_in_copy.Run(p_in_global, - c_block_data_begin, - ho_block_data_begin, - wo_block_data_begin, - n_block_data_begin, - p_in_block_next, - h_block_pad_low, - w_block_pad_low, - h_block_pad_up, - w_block_pad_up); -#endif - -#if 1 - // weight: global mem to LDS, - blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block_next); -#endif - - // a series of batched GEMM - for(unsigned s = 0; s < S; ++s) - { - for(unsigned r = 0; r < R; ++r) - { - auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; - - blockwise_batch_gemm.Run(p_wei_block_now + - wei_csrk_block_desc.Get1dIndex(0, s, r, 0), - p_in_block_now + in_chwn_block_desc.Get1dIndex(0, s, r, 0), - p_out_thread, - f_accum); - } - } - } - - // last computation - { - __syncthreads(); - - Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; - Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; - - // a series of batched GEMM - for(unsigned s = 0; s < S; ++s) - { - for(unsigned r = 0; r < R; ++r) - { - auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; - - blockwise_batch_gemm.Run(p_wei_block_now + - wei_csrk_block_desc.Get1dIndex(0, s, r, 0), - p_in_block_now + in_chwn_block_desc.Get1dIndex(0, s, r, 0), - p_out_thread, - f_accum); - } - } - } - - const auto matrix_c_index = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const unsigned ho_thread_data_begin = matrix_c_index.batch; - const unsigned k_thread_data_begin = matrix_c_index.row; - const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock; - const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock; - -#if 0 - printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n", - get_block_1d_id(), get_thread_local_1d_id(), - ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin, - ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin, - p_out_thread[0]); -#endif - - // output: register to global mem, - // convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N] - constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{}; - - threadwise_4d_tensor_copy_reorder_by_get_dst_from_src( - out_hkwn_thread_desc, - p_out_thread, - out_khwn_global_desc, - p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin), - out_hkwn_thread_desc.GetLengths(), - reorder_khwn_from_hkwn); -} diff --git a/src/include/gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw.hip.hpp deleted file mode 100644 index bac52b27eb..0000000000 --- a/src/include/gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw.hip.hpp +++ /dev/null @@ -1,270 +0,0 @@ -#pragma once -#include "common.hip.hpp" -#include "ConstantTensorDescriptor.hip.hpp" -#include "ConstantMatrixDescriptor.hip.hpp" -#include "blockwise_4d_tensor_op.hip.hpp" -#include "threadwise_4d_tensor_op.hip.hpp" -#include "blockwise_gemm.hip.hpp" - -template -__global__ void -gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) -{ - // NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N] - // for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N" - // if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock - constexpr unsigned NPerThread = NPerBlock; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_global_desc = InGlobalDesc{}; - constexpr auto wei_kcsr_global_desc = WeiGlobalDesc{}; - constexpr auto out_nkhw_global_desc = OutGlobalDesc{}; - - constexpr unsigned S = wei_kcsr_global_desc.GetLength(I2); - constexpr unsigned R = wei_kcsr_global_desc.GetLength(I3); - - constexpr unsigned HiPerBlock = HoPerBlock + S - 1; - constexpr unsigned WiPerBlock = WoPerBlock + R - 1; - - // divide block work: NCHW - constexpr unsigned NBlockWork = - (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; - constexpr unsigned KBlockWork = - (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; - constexpr unsigned HBlockWork = - (out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; - constexpr unsigned WBlockWork = - (out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; - - // tensor view of un-reorderd blockwise input and weight (imaginary) - constexpr auto in_nchw_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - constexpr auto wei_kcsr_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of reordered blockwise input and weight in LDS - constexpr auto reorder_srck_from_kcsr = Sequence<2, 3, 1, 0>{}; - constexpr auto wei_srck_block_desc = make_ConstantTensorDescriptor( - wei_kcsr_block_desc.GetLengths().ReorderByGetNewFromOld(reorder_srck_from_kcsr)); - - constexpr auto reorder_chwn_from_nchw = Sequence<1, 2, 3, 0>{}; - constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor( - in_nchw_block_desc.GetLengths().ReorderByGetNewFromOld(reorder_chwn_from_nchw)); - - // tensor view of threadwise output in register - constexpr auto out_hkwn_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc"); - print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc"); - - print_ConstantTensorDescriptor(wei_kcsr_block_desc, "wei_kcsr_block_desc"); - print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc"); - - print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc"); - } -#endif - - // my block work - unsigned itmp = get_block_1d_id(); - const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); - itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); - const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork); - itmp -= k_block_work_id * (HBlockWork * WBlockWork); - const unsigned h_block_work_id = itmp / WBlockWork; - const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork; - - const unsigned n_block_data_begin = n_block_work_id * NPerBlock; - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; - const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; - - const unsigned hi_block_data_begin = ho_block_data_begin; - const unsigned wi_block_data_begin = wo_block_data_begin; - - // a series of blockwise batched GEMM - // C_matrix += transpose(A_matrix) * B_matrix - // A_matrix and B_matrix saved in LDS, C_matrix saved in register - // A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K] - // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] - // C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N] - constexpr auto a_cxk_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - constexpr auto b_cxwn_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - constexpr auto c_kxwn_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - const auto blockwise_batch_gemm = - Blockwise1dStridedBatchedGemmBlockABlockBThreadC{}; - - // LDS - constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); - constexpr unsigned wei_block_size = wei_srck_block_desc.GetElementSpace(); - - __shared__ Float p_in_block[in_block_size]; - __shared__ Float p_wei_block[wei_block_size]; - - // register - Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()]; - - // set threadwise output tensor to 0 - threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread); - - for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1); - c_block_data_begin += CPerBlock, __syncthreads()) - { -#if 1 - // input: global mem to LDS, - // convert [N,C,Hi,Wi] to [C,Hi,Wi,N] - blockwise_4d_tensor_copy_reorder_by_get_dst_from_src( - in_nchw_global_desc, - p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin, - c_block_data_begin, - hi_block_data_begin, - wi_block_data_begin), - in_chwn_block_desc, - p_in_block, - in_nchw_block_desc.GetLengths(), - reorder_chwn_from_nchw); -#else - // input: global mem to LDS, - // no format conversion, this is wrong, for performance study only! - Blockwise4dTensorCopy(in_nchw_global_desc, - p_in_global + - in_nchw_global_desc.Get1dIndex(n_block_data_begin, - c_block_data_begin, - hi_block_data_begin, - wi_block_data_begin), - in_nchw_block_desc, - p_in_block, - in_nchw_block_desc.GetLengths()); -#endif - -#if 1 - // weight: global mem to LDS, - // convert [K,C,S,R] to [S,R,C,K] - blockwise_4d_tensor_copy_reorder_by_get_dst_from_src( - wei_kcsr_global_desc, - p_wei_global + - wei_kcsr_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0), - wei_srck_block_desc, - p_wei_block, - wei_kcsr_block_desc.GetLengths(), - reorder_srck_from_kcsr); -#else - // weight: global mem to LDS, - // no format conversion, this is wrong, for performance study only! - Blockwise4dTensorCopy( - wei_kcsr_global_desc, - p_wei_global + - wei_kcsr_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0), - wei_kcsr_block_desc, - p_wei_block, - wei_kcsr_block_desc.GetLengths()); -#endif - - __syncthreads(); - -#if 1 - // a series of batched GEMM - for(unsigned s = 0; s < S; ++s) - { - for(unsigned r = 0; r < R; ++r) - { - auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; - - blockwise_batch_gemm.Run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), - p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0), - p_out_thread, - f_accum); - } - } -#endif - } - - const auto matrix_c_index = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - -#if 0 - printf("%u %u, %u %u %u\n",get_block_1d_id(), get_thread_local_1d_id(), matrix_c_index.batch, matrix_c_index.row, matrix_c_index.col); -#endif - - const unsigned ho_thread_data_begin = matrix_c_index.batch; - const unsigned k_thread_data_begin = matrix_c_index.row; - const unsigned wo_thread_data_begin = matrix_c_index.col / NPerThread; - -#if 1 - // output: register to global mem, - // convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo] - constexpr auto reorder_nkhw_from_hkwn = Sequence<3, 1, 0, 2>{}; - - threadwise_4d_tensor_copy_reorder_by_get_dst_from_src( - out_hkwn_thread_desc, - p_out_thread, - out_nkhw_global_desc, - p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), - out_hkwn_thread_desc.GetLengths(), - reorder_nkhw_from_hkwn); -#else - // output: register to global mem, - // no format conversion, assume register is in [N,K,Ho,Wo], this is wrong, for performance - // study only! - constexpr auto out_nkhw_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - - threadwise_4d_tensor_copy( - out_nkhw_thread_desc, - p_out_thread, - out_nkhw_global_desc, - p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), - out_nkhw_thread_desc.GetLengths()); -#endif -} diff --git a/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.hip.hpp deleted file mode 100644 index 0ea28e9ac2..0000000000 --- a/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.hip.hpp +++ /dev/null @@ -1,226 +0,0 @@ -#pragma once -#include "common.hip.hpp" -#include "ConstantTensorDescriptor.hip.hpp" -#include "ConstantMatrixDescriptor.hip.hpp" -#include "blockwise_4d_tensor_op.hip.hpp" -#include "threadwise_4d_tensor_op.hip.hpp" -#include "blockwise_gemm.hip.hpp" - -template -__global__ void -gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) -{ - // NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N] - // for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N" - // if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock - static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0"); - static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock, - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_global_desc = InGlobalDesc{}; - constexpr auto wei_srck_global_desc = WeiGlobalDesc{}; - constexpr auto out_nkhw_global_desc = OutGlobalDesc{}; - - constexpr unsigned S = wei_srck_global_desc.GetLength(I0); - constexpr unsigned R = wei_srck_global_desc.GetLength(I1); - - constexpr unsigned HiPerBlock = HoPerBlock + S - 1; - constexpr unsigned WiPerBlock = WoPerBlock + R - 1; - - // divide block work: NCHW - constexpr unsigned NBlockWork = - (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; - constexpr unsigned KBlockWork = - (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; - constexpr unsigned HBlockWork = - (out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; - constexpr unsigned WBlockWork = - (out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; - - unsigned itmp = get_block_1d_id(); - const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); - itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); - const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork); - itmp -= k_block_work_id * (HBlockWork * WBlockWork); - const unsigned h_block_work_id = itmp / WBlockWork; - const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork; - - const unsigned n_block_data_begin = n_block_work_id * NPerBlock; - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; - const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; - - const unsigned hi_block_data_begin = ho_block_data_begin; - const unsigned wi_block_data_begin = wo_block_data_begin; - - // tensor view of un-reorderd blockwise input and weight (imaginary) - constexpr auto in_nchw_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - constexpr auto wei_srck_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of reordered blockwise input and weight in LDS - constexpr auto reorder_chwn_from_nchw = Sequence<1, 2, 3, 0>{}; - constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor( - in_nchw_block_desc.GetLengths().ReorderByGetNewFromOld(reorder_chwn_from_nchw)); - - // tensor view of threadwise output in register - constexpr auto out_hkwn_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc"); - print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc"); - - print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc"); - - print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc"); - } -#endif - - // blockwise copy - // wei: format is [S,R,C,K], no conversion needed - constexpr auto blockwise_wei_copy = - Blockwise4dTensorCopy1{}; - - // a series of blockwise batched GEMM - // C_matrix += transpose(A_matrix) * B_matrix - // A_matrix and B_matrix saved in LDS, C_matrix saved in register - // A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K] - // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] - // C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N] - constexpr auto a_cxk_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - constexpr auto b_cxwn_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - constexpr auto c_kxwn_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - const auto blockwise_batch_gemm = - Blockwise1dStridedBatchedGemmBlockABlockBThreadC{}; - - // LDS - constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); - constexpr unsigned wei_block_size = wei_srck_block_desc.GetElementSpace(); - - __shared__ Float p_in_block[in_block_size]; - __shared__ Float p_wei_block[wei_block_size]; - - // register - Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()]; - - // set threadwise output tensor to 0 - threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread); - - for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1); - c_block_data_begin += CPerBlock, __syncthreads()) - { -#if 1 - // input: global mem to LDS, - // convert [N,C,Hi,Wi] to [C,Hi,Wi,N] - blockwise_4d_tensor_copy_reorder_by_get_dst_from_src( - in_nchw_global_desc, - p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin, - c_block_data_begin, - hi_block_data_begin, - wi_block_data_begin), - in_chwn_block_desc, - p_in_block, - in_nchw_block_desc.GetLengths(), - reorder_chwn_from_nchw); -#endif - -#if 1 - // weight: global mem to LDS, - // format is [S,R,C,K], no conversion needed - blockwise_wei_copy.Run(p_wei_global + wei_srck_global_desc.Get1dIndex( - 0, 0, c_block_data_begin, k_block_data_begin), - p_wei_block); -#endif - - __syncthreads(); - - // a series of batched GEMM - for(unsigned s = 0; s < S; ++s) - { - for(unsigned r = 0; r < R; ++r) - { - auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; - - blockwise_batch_gemm.Run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), - p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0), - p_out_thread, - f_accum); - } - } - } - - const auto matrix_c_index = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const unsigned ho_thread_data_begin = matrix_c_index.batch; - const unsigned k_thread_data_begin = matrix_c_index.row; - const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock; - const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock; - - // output: register to global mem, - // convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo] - constexpr auto reorder_nkhw_from_hkwn = Sequence<3, 1, 0, 2>{}; - - threadwise_4d_tensor_copy_reorder_by_get_dst_from_src( - out_hkwn_thread_desc, - p_out_thread, - out_nkhw_global_desc, - p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), - out_hkwn_thread_desc.GetLengths(), - reorder_nkhw_from_hkwn); -} diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp index e7070ae978..eba1d09675 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp @@ -57,11 +57,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); - constexpr unsigned S = wei_csrk_global_desc.GetLength(I1); - constexpr unsigned R = wei_csrk_global_desc.GetLength(I2); + constexpr unsigned Y = wei_csrk_global_desc.GetLength(I1); + constexpr unsigned X = wei_csrk_global_desc.GetLength(I2); constexpr unsigned B = N * Hi * Wi; - constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1); + constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1); // divide block work by 2d: [K, B] constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; @@ -75,7 +75,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b // flattend (2d) tensor view of gridwise input constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); - constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); + constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); // tensor view of blockwise input and weight // be careful of alignment @@ -83,10 +83,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b Sequence{}, Number{}); constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + Sequence{}, Number{}); constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + Sequence{}, Number{}); // tensor view of threadwise output in register constexpr auto out_kb_thread_desc = @@ -138,7 +138,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b #endif // blockwise wei copy -// format is [CPerBlock*S*R,KPerBlock] +// format is [CPerBlock*Y*X,KPerBlock] #if 0 const auto blockwise_wei_copy = Blockwise2dTensorCopy1 -__global__ void -gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_cnhw_global_desc = InGlobalDesc{}; - constexpr auto wei_csrk_global_desc = WeiGlobalDesc{}; - constexpr auto out_knhw_global_desc = OutGlobalDesc{}; - - constexpr unsigned C = in_cnhw_global_desc.GetLength(I0); - constexpr unsigned N = in_cnhw_global_desc.GetLength(I1); - constexpr unsigned Hi = in_cnhw_global_desc.GetLength(I2); - constexpr unsigned Wi = in_cnhw_global_desc.GetLength(I3); - - constexpr unsigned K = out_knhw_global_desc.GetLength(I0); - constexpr unsigned Ho = out_knhw_global_desc.GetLength(I2); - constexpr unsigned Wo = out_knhw_global_desc.GetLength(I3); - - constexpr unsigned S = wei_csrk_global_desc.GetLength(I1); - constexpr unsigned R = wei_csrk_global_desc.GetLength(I2); - - constexpr unsigned B = N * Hi * Wi; - constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1); - - // divide block work by 2d: [K, B] - constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock; - - const unsigned k_block_work_id = get_block_1d_id() / BBlockWork; - const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; - - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned b_block_data_begin = b_block_work_id * BPerBlock; - - // flattend (2d) tensor view of gridwise input - constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); - constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of blockwise input and weight - // be careful of alignment - constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // tensor view of threadwise output in register - constexpr auto out_kb_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_cnhw_global_desc, "in_cnhw_global_desc"); - print_ConstantTensorDescriptor(wei_csrk_global_desc, "wei_csrk_global_desc"); - print_ConstantTensorDescriptor(out_knhw_global_desc, "out_knhw_global_desc"); - - print_ConstantTensorDescriptor(in_cb_global_desc, "in_cb_global_desc"); - print_ConstantTensorDescriptor(wei_ek_global_desc, "wei_ek_global_desc"); - - print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc"); - print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc"); - print_ConstantTensorDescriptor(wei_ek_block_desc, "wei_ek_block_desc"); - print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc"); - - printf("KPerBlock %u\n", KPerBlock); - } -#endif - -// blockwise in copy -// formmat is [CPerBlock,BPerBlock + BGhostRead] -#if 0 - const auto blockwise_in_copy = - Blockwise2dTensorCopy1{}; -#elif 0 - const auto blockwise_in_copy = Blockwise2dTensorCopy2{}; -#elif 1 - const auto blockwise_in_copy = Blockwise2dTensorCopy3{}; -#endif - -// blockwise wei copy -// format is [CPerBlock*S*R,KPerBlock] -#if 0 - const auto blockwise_wei_copy = - Blockwise2dTensorCopy1{}; -#elif 0 - const auto blockwise_wei_copy = Blockwise2dTensorCopy2{}; -#elif 1 - const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}; -#endif - - // a series of blockwise GEMM - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx and b_mtx saved in LDS, c_mtx saved in register - // a_mtx[C,K] is a sub-matrix of wei_block[C,S,R,K] - // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] - // c_mtx[K,B] is out_block[K,B] - constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto c_kxb_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); - -#if 0 - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC{}; -#else - const auto blockwise_gemm = - BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2{}; -#endif - - // LDS: be careful of alignment - constexpr unsigned in_block_size = - in_cb_block_desc.GetElementSpace(Number{}); - - constexpr unsigned wei_block_size = - wei_csrk_block_desc.GetElementSpace(Number{}); - - constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead - ? InBlockCopyDataPerRead - : WeiBlockCopyDataPerRead; - - __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; - __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; - - // register - Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; - - // set threadwise output tensor to 0 - threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); - - const Float* p_in_global_block_offset = - p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); - - const Float* p_wei_global_block_offset = - p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); - - for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0), - __syncthreads()) - { - // input: global mem to LDS, - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); - - // weight: global mem to LDS, - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); - - __syncthreads(); - - // a series of GEMM - for(unsigned s = 0; s < S; ++s) - { - for(unsigned r = 0; r < R; ++r) - { - auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; - -#if 1 - blockwise_gemm.Run -#elif 0 - blockwise_gemm.Run_v2 -#elif 0 - blockwise_gemm.Run_RegisterDoubleBuffer -#endif - (p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), - p_in_block + s * Wi + r, - p_out_thread, - f_accum); - } - } - } - - // output: register to global mem, - const auto c_thread_mtx_begin = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; - const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; - -#if 0 - if(get_block_1d_id() == 0) - { - printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", - get_block_1d_id(), - get_thread_local_1d_id(), - matrix_c_index.row, - matrix_c_index.col, - k_data_begin, - b_data_begin, - p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); - } -#endif - - for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) - { - for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) - { - const auto c_thread_mtx_distance = - blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); - - unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row; - unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col; - - unsigned n_data = b_data / (Hi * Wi); - unsigned itmp = b_data - n_data * (Hi * Wi); - unsigned h_data = itmp / Wi; - unsigned w_data = itmp - h_data * Wi; - -#if 0 - if(get_block_1d_id() == 0) - { - printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n", - get_block_1d_id(), - get_thread_local_1d_id(), - k, - b, - k_data, - n_data, - h_data, - w_data, - p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]); - } -#endif - if(n_data < N && h_data < Ho && w_data < Wo) - { - p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] = - p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; - } - } - } -} diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.hip.hpp deleted file mode 100644 index 51d3a3212b..0000000000 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.hip.hpp +++ /dev/null @@ -1,393 +0,0 @@ -#pragma once -#include "common.hip.hpp" -#include "ConstantTensorDescriptor.hip.hpp" -#include "ConstantMatrixDescriptor.hip.hpp" -#include "blockwise_4d_tensor_op.hip.hpp" -#include "blockwise_2d_tensor_op.hip.hpp" -#include "threadwise_2d_tensor_op.hip.hpp" -#include "blockwise_gemm.hip.hpp" - -// define B = flatten(N, Hi, Wi) -template -__global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer( - const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_cnhw_global_desc = InGlobalDesc{}; - constexpr auto wei_csrk_global_desc = WeiGlobalDesc{}; - constexpr auto out_knhw_global_desc = OutGlobalDesc{}; - - constexpr unsigned C = in_cnhw_global_desc.GetLength(I0); - constexpr unsigned N = in_cnhw_global_desc.GetLength(I1); - constexpr unsigned Hi = in_cnhw_global_desc.GetLength(I2); - constexpr unsigned Wi = in_cnhw_global_desc.GetLength(I3); - - constexpr unsigned K = out_knhw_global_desc.GetLength(I0); - constexpr unsigned Ho = out_knhw_global_desc.GetLength(I2); - constexpr unsigned Wo = out_knhw_global_desc.GetLength(I3); - - constexpr unsigned S = wei_csrk_global_desc.GetLength(I1); - constexpr unsigned R = wei_csrk_global_desc.GetLength(I2); - - constexpr unsigned B = N * Hi * Wi; - constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1); - - // divide block work by 2d: [K, B] - constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock; - - const unsigned k_block_work_id = get_block_1d_id() / BBlockWork; - const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; - - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned b_block_data_begin = b_block_work_id * BPerBlock; - - // flattend (2d) tensor view of gridwise input - constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); - constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of blockwise input and weight - // be careful of alignment - constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // tensor view of threadwise output in register - constexpr auto out_kb_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_cnhw_global_desc, "in_cnhw_global_desc"); - print_ConstantTensorDescriptor(wei_csrk_global_desc, "wei_csrk_global_desc"); - print_ConstantTensorDescriptor(out_knhw_global_desc, "out_knhw_global_desc"); - - print_ConstantTensorDescriptor(in_cb_global_desc, "in_cb_global_desc"); - print_ConstantTensorDescriptor(wei_ek_global_desc, "wei_ek_global_desc"); - - print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc"); - print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc"); - print_ConstantTensorDescriptor(wei_ek_block_desc, "wei_ek_block_desc"); - print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc"); - - printf("KPerBlock %u\n", KPerBlock); - } -#endif - -// blockwise in copy -// formmat is [CPerBlock,BPerBlock + BGhostRead] -#if 0 - const auto blockwise_in_copy = - Blockwise2dTensorCopy1{}; -#elif 0 - const auto blockwise_in_copy = Blockwise2dTensorCopy2{}; -#elif 1 - const auto blockwise_in_copy = Blockwise2dTensorCopy3{}; -#endif - -// blockwise wei copy -// format is [CPerBlock*S*R,KPerBlock] -#if 0 - const auto blockwise_wei_copy = - Blockwise2dTensorCopy1{}; -#elif 0 - const auto blockwise_wei_copy = Blockwise2dTensorCopy2{}; -#elif 1 - const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}; -#endif - - // a series of blockwise GEMM - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx and b_mtx saved in LDS, c_mtx saved in register - // a_mtx[C,K] is a sub-matrix of wei_block[C,S,R,K] - // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] - // c_mtx[K,B] is out_block[K,B] - constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto c_kxb_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); - -#if 0 - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC{}; -#else - const auto blockwise_gemm = - BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2{}; -#endif - - // LDS: be careful of alignment - constexpr unsigned in_block_size = - in_cb_block_desc.GetElementSpace(Number{}); - - constexpr unsigned wei_block_size = - wei_csrk_block_desc.GetElementSpace(Number{}); - - constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead - ? InBlockCopyDataPerRead - : WeiBlockCopyDataPerRead; - - // LDS double buffer - __shared__ Float p_in_block_0[max_align * ((in_block_size + max_align - 1) / max_align)]; - __shared__ Float p_wei_block_0[max_align * ((wei_block_size + max_align - 1) / max_align)]; - - __shared__ Float p_in_block_1[max_align * ((in_block_size + max_align - 1) / max_align)]; - __shared__ Float p_wei_block_1[max_align * ((wei_block_size + max_align - 1) / max_align)]; - - const Float* p_in_global_block_offset = - p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); - - const Float* p_wei_global_block_offset = - p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); - - // preload data into LDS - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_0); - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_0); - - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0); - p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0); - - // register - Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; - - // set threadwise output tensor to 0 - threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); - - bool even_loop = true; - - for(unsigned c_block_data_begin = 0; c_block_data_begin + CPerBlock < C; - c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0), - even_loop = !even_loop) - { - Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; - Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; - - Float* p_in_block_next = even_loop ? p_in_block_1 : p_in_block_0; - Float* p_wei_block_next = even_loop ? p_wei_block_1 : p_wei_block_0; - - __syncthreads(); - -// load next data -#if 0 - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next); - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next); -#elif 0 - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next); - - Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; - - blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, - p_wei_register_clipboard); -#elif 1 - Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; - Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; - - blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, - p_in_register_clipboard); - - blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, - p_wei_register_clipboard); -#endif - - // compute on current data - // a series of GEMM - for(unsigned s = 0; s < S; ++s) - { - for(unsigned r = 0; r < R; ++r) - { - auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; -#if 1 - blockwise_gemm.Run -#else - blockwise_gemm.Run_RegisterDoubleBuffer -#endif - (p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), - p_in_block_now + s * Wi + r, - p_out_thread, - f_accum); - } - } - -#if 0 - blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next); -#elif 1 - blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next); - blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next); -#endif - } - - // last computation - { - Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; - Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; - - __syncthreads(); - - for(unsigned s = 0; s < S; ++s) - { - for(unsigned r = 0; r < R; ++r) - { - auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; -#if 0 - blockwise_gemm.Run -#else - blockwise_gemm.Run_RegisterDoubleBuffer -#endif - (p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), - p_in_block_now + s * Wi + r, - p_out_thread, - f_accum); - } - } - } - - // output: register to global mem, - const auto c_thread_mtx_begin = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; - const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; - -#if 0 - if(get_block_1d_id() == 0) - { - printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", - get_block_1d_id(), - get_thread_local_1d_id(), - matrix_c_index.row, - matrix_c_index.col, - k_data_begin, - b_data_begin, - p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); - } -#endif - - for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) - { - for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) - { - const auto c_thread_mtx_distance = - blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); - - unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row; - unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col; - - unsigned n_data = b_data / (Hi * Wi); - unsigned itmp = b_data - n_data * (Hi * Wi); - unsigned h_data = itmp / Wi; - unsigned w_data = itmp - h_data * Wi; - -#if 0 - if(get_block_1d_id() == 0) - { - printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n", - get_block_1d_id(), - get_thread_local_1d_id(), - k, - b, - k_data, - n_data, - h_data, - w_data, - p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]); - } -#endif - if(n_data < N && h_data < Ho && w_data < Wo) - { - p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] = - p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; - } - } - } -} diff --git a/src/include/gridwise_winograd_convolution.hip.hpp b/src/include/gridwise_winograd_convolution.hip.hpp deleted file mode 100644 index 9cf30f8743..0000000000 --- a/src/include/gridwise_winograd_convolution.hip.hpp +++ /dev/null @@ -1,228 +0,0 @@ -#pragma once -#include "ConstantTensorDescriptor.hip.hpp" -#include "blockwise_winograd_transform.hip.hpp" -#include "threadwise_winograd_transform.hip.hpp" - -template -__global__ void gridwise_winograd_convolution(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_global_desc = InGlobalDesc{}; - constexpr auto wei_global_desc = WeiGlobalDesc{}; - constexpr auto out_global_desc = OutGlobalDesc{}; - - constexpr unsigned S = wei_global_desc.GetLength(I2); - constexpr unsigned R = wei_global_desc.GetLength(I3); - - constexpr unsigned HoPerBlock = OutTileSizeH * YPerBlock; - constexpr unsigned WoPerBlock = OutTileSizeW * XPerBlock; - - constexpr unsigned HiPerBlock = HoPerBlock + S - 1; - constexpr unsigned WiPerBlock = WoPerBlock + R - 1; - - constexpr unsigned InTileSizeH = OutTileSizeH + S - 1; - constexpr unsigned InTileSizeW = OutTileSizeW + R - 1; - - // divide block work - constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; - constexpr unsigned KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; - constexpr unsigned YBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; - constexpr unsigned XBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; - - const unsigned block_id = blockIdx.x; - - unsigned itmp = block_id; - const unsigned n_block_work_id = itmp / (KBlockWork * YBlockWork * XBlockWork); - itmp -= n_block_work_id * (KBlockWork * YBlockWork * XBlockWork); - const unsigned k_block_work_id = itmp / (YBlockWork * XBlockWork); - itmp -= k_block_work_id * (YBlockWork * XBlockWork); - const unsigned y_block_work_id = itmp / XBlockWork; - const unsigned x_block_work_id = itmp - y_block_work_id * XBlockWork; - - const unsigned n_block_data_begin = n_block_work_id * NPerBlock; - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned y_block_data_begin = y_block_work_id * YPerBlock; - const unsigned x_block_data_begin = x_block_work_id * XPerBlock; - - const unsigned ho_block_data_begin = y_block_data_begin * OutTileSizeH; - const unsigned wo_block_data_begin = x_block_data_begin * OutTileSizeW; - - const unsigned hi_block_data_begin = ho_block_data_begin; // minus padding - const unsigned wi_block_data_begin = wo_block_data_begin; // minus padding - - // divide thread work - constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; - constexpr unsigned KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread; - constexpr unsigned YThreadWork = YPerBlock; - constexpr unsigned XThreadWork = XPerBlock; - - const unsigned thread_id = threadIdx.x; - - itmp = thread_id; - const unsigned n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork); - itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork); - const unsigned k_thread_work_id = itmp / (YThreadWork * XThreadWork); - itmp -= k_thread_work_id * (YThreadWork * XThreadWork); - const unsigned y_thread_work_id = itmp / XThreadWork; - const unsigned x_thread_work_id = itmp - y_thread_work_id * XThreadWork; - - const unsigned n_thread_data_begin = n_thread_work_id * NPerThread; - const unsigned k_thread_data_begin = k_thread_work_id * KPerThread; - const unsigned y_thread_data_begin = y_thread_work_id; - const unsigned x_thread_data_begin = x_thread_work_id; - - // block data - constexpr auto in_transform_block_desc = make_ConstantTensorDescriptor( - Sequence{}); - - constexpr auto wei_transform_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - __shared__ Float p_in_transform_block[in_transform_block_desc.GetElementSpace()]; - __shared__ Float p_wei_transform_block[wei_transform_block_desc.GetElementSpace()]; - - // thread data - constexpr auto in_transform_thread_block_desc = - make_ConstantTensorDescriptor(Sequence{}, - in_transform_block_desc.GetStrides()); - - constexpr auto wei_transform_thread_block_desc = - make_ConstantTensorDescriptor(Sequence{}, - wei_transform_block_desc.GetStrides()); - - constexpr auto out_transform_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - - constexpr auto out_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); - - constexpr auto out_thread_global_desc = - make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_global_desc.GetStrides()); - - Float p_out_transform_thread[out_transform_thread_desc.GetElementSpace()]; - Float p_out_thread[out_thread_desc.GetElementSpace()]; - -#if 0 - if(blockIdx.x == 0 && threadIdx.x == 0) - { - printf("in_transform_block_size %u, wei_transform_block_size %u, out_transform_thread_size " - "%u, out_thread_size %u \n", - in_transform_block_size, - wei_transform_block_size, - out_transform_thread_size, - out_thread_size); - } -#endif - - // set threadwise output transform tensor to 0 - threadwise_4d_tensor_set_zero(out_transform_thread_desc, p_out_transform_thread); - - for(unsigned c_block_data_begin = 0; c_block_data_begin < in_global_desc.GetLength(I1); - c_block_data_begin += CPerBlock, __syncthreads()) - { -#if 0 - // blockwise transform input - blockwise_winograd_transform_input( - p_in_global + in_global_desc.Get1dIndex(n_block_data_begin, - c_block_data_begin, - hi_block_data_begin, - wi_block_data_begin), - p_in_transform_block); - -#endif - // blockwise transform weights - blockwise_winograd_transform_weight( - p_wei_global + wei_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0), - p_wei_transform_block); - - for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread) - { - // threadwise point multiplication - threadwise_winograd_calculate_transformed_output< - Float, - decltype(in_transform_thread_block_desc), - decltype(wei_transform_thread_block_desc), - decltype(out_transform_thread_desc), - InTileSizeH, - InTileSizeW, - S, - R, - OutTileSizeH, - OutTileSizeW>(in_transform_thread_block_desc, - p_in_transform_block + in_transform_block_desc.Get1dIndex( - n_thread_data_begin, - c_thread_data, - y_thread_data_begin * InTileSizeH, - x_thread_data_begin * InTileSizeW), - wei_transform_thread_block_desc, - p_wei_transform_block + wei_transform_block_desc.Get1dIndex( - k_thread_data_begin, c_thread_data, 0, 0), - out_transform_thread_desc, - p_out_transform_thread); - } - }; - - // transform back - threadwise_winograd_reverse_transform_output( - out_transform_thread_desc, p_out_transform_thread, out_thread_desc, p_out_thread); - - // copy output tensor from register to global mem - threadwise_4d_tensor_copy( - out_thread_desc, - p_out_thread, - out_thread_global_desc, - p_out_global + - out_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + y_thread_data_begin * OutTileSizeH, - wo_block_data_begin + x_thread_data_begin * OutTileSizeW)); -}