From 43cd8529c240161ffbad165603d47c5d008559fc Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sat, 9 Mar 2019 12:52:16 -0600 Subject: [PATCH 1/3] refactor --- ...icit_gemm_convolution_1_chwn_csrk_khwn.hpp | 4 +-- ...mm_convolution_1_chwn_csrk_khwn_padded.hpp | 4 +-- ...icit_gemm_convolution_2_chwn_csrk_khwn.hpp | 2 +- ..._gemm_convolution_1_chwn_csrk_khwn.hip.hpp | 8 ++--- ...onvolution_1_chwn_csrk_khwn_padded.hip.hpp | 8 ++--- ...2_chwn_csrk_khwn_lds_double_buffer.hip.hpp | 16 +++++----- src/include/tensor.hpp | 6 ++-- .../threadwise_direct_convolution.hip.hpp | 30 +++++++++---------- 8 files changed, 39 insertions(+), 39 deletions(-) diff --git a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp index fc2b245148..72905ce47a 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp @@ -39,8 +39,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, Tensor wei_csrk(make_TensorDescriptor(wei_csrk_desc)); - auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto s, auto r) { - wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r); + auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto y, auto x) { + wei_csrk(c, y, x, k) = wei_kcsr(k, c, y, x); }; make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, Y, X)( diff --git a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp index db0cb3aa90..de18deefc6 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp +++ b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp @@ -41,8 +41,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc, Tensor wei_csrk(make_TensorDescriptor(wei_csrk_desc)); - auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto s, auto r) { - wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r); + auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto y, auto x) { + wei_csrk(c, y, x, k) = wei_kcsr(k, c, y, x); }; make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, Y, X)( diff --git a/driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp index 88fd5a2dea..db7802902d 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp @@ -55,7 +55,7 @@ void device_implicit_gemm_convolution_2_chwn_csrk_khwn(InDesc, Tensor wei_csrk(make_TensorDescriptor(wei_csrk_desc)); make_ParallelTensorFunctor( - [&](auto k, auto c, auto s, auto r) { wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r); }, + [&](auto k, auto c, auto y, auto x) { wei_csrk(c, y, x, k) = wei_kcsr(k, c, y, x); }, K, C, Y, diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp index b5e810cfd5..a6cb6f60e1 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp @@ -204,12 +204,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric __syncthreads(); // a series of batched GEMM - for(unsigned s = 0; s < Y; ++s) + for(unsigned y = 0; y < Y; ++y) { - for(unsigned r = 0; r < X; ++r) + for(unsigned x = 0; x < X; ++x) { - blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), - p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0), + blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0), p_out_thread, [](auto& acc, const auto&& v) { acc += v; }); } diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hip.hpp index 5f14fc3373..7f6d54143b 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hip.hpp @@ -245,14 +245,14 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded( __syncthreads(); // a series of batched GEMM - for(unsigned s = 0; s < Y; ++s) + for(unsigned y = 0; y < Y; ++y) { - for(unsigned r = 0; r < X; ++r) + for(unsigned x = 0; x < X; ++x) { auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; - blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), - p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0), + blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0), p_out_thread, f_accum); } diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp index eba1d09675..b6498f8175 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp @@ -275,9 +275,9 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b // compute on current data // a series of GEMM - for(unsigned s = 0; s < Y; ++s) + for(unsigned y = 0; y < Y; ++y) { - for(unsigned r = 0; r < X; ++r) + for(unsigned x = 0; x < X; ++x) { auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; #if 1 @@ -285,8 +285,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b #else blockwise_gemm.Run_RegisterDoubleBuffer #endif - (p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), - p_in_block_now + s * Wi + r, + (p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_now + y * Wi + x, p_out_thread, f_accum); } @@ -305,9 +305,9 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b __syncthreads(); - for(unsigned s = 0; s < Y; ++s) + for(unsigned y = 0; y < Y; ++y) { - for(unsigned r = 0; r < X; ++r) + for(unsigned x = 0; x < X; ++x) { auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; #if 0 @@ -315,8 +315,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b #else blockwise_gemm.Run_RegisterDoubleBuffer #endif - (p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), - p_in_block_now + s * Wi + r, + (p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_now + y * Wi + x, p_out_thread, f_accum); } diff --git a/src/include/tensor.hpp b/src/include/tensor.hpp index 09ac224007..d0c785c16e 100644 --- a/src/include/tensor.hpp +++ b/src/include/tensor.hpp @@ -8,16 +8,16 @@ #include template -std::ostream& LogRange(std::ostream& os, Range&& r, std::string delim) +std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) { bool first = true; - for(auto&& x : r) + for(auto&& v : range) { if(first) first = false; else os << delim; - os << x; + os << v; } return os; } diff --git a/src/include/threadwise_direct_convolution.hip.hpp b/src/include/threadwise_direct_convolution.hip.hpp index 04ea7b1506..32d446491b 100644 --- a/src/include/threadwise_direct_convolution.hip.hpp +++ b/src/include/threadwise_direct_convolution.hip.hpp @@ -38,16 +38,16 @@ __device__ void threadwise_direct_convolution_1(InDesc, { for(unsigned c = 0; c < wei_desc.GetLength(I1); ++c) { - for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s) + for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y) { - for(unsigned r = 0; r < wei_desc.GetLength(I3); ++r) + for(unsigned x = 0; x < wei_desc.GetLength(I3); ++x) { - const unsigned hi = ho + s; - const unsigned wi = wo + r; + const unsigned hi = ho + y; + const unsigned wi = wo + x; const unsigned in_index = in_desc.Get1dIndex(n, c, hi, wi); - const unsigned wei_index = wei_desc.Get1dIndex(k, c, s, r); + const unsigned wei_index = wei_desc.Get1dIndex(k, c, y, x); const unsigned out_index = out_desc.Get1dIndex(n, k, ho, wo); @@ -153,18 +153,18 @@ __device__ void threadwise_direct_convolution_3(InDesc, #if 0 // this verison reused old input data in register, and read new data from LDS // loop over vertical direction - for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s) + for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y) { // read first input threadwise_4d_tensor_copy(in_desc, - p_in + in_desc.Get1dIndex(0, 0, s, 0), + p_in + in_desc.Get1dIndex(0, 0, y, 0), in_reg_desc, p_in_reg, in_reg_desc.GetLengths()); // read first 1x1 weight threadwise_4d_tensor_copy(wei_desc, - p_wei + wei_desc.Get1dIndex(0, 0, s, 0), + p_wei + wei_desc.Get1dIndex(0, 0, y, 0), wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths()); @@ -174,11 +174,11 @@ __device__ void threadwise_direct_convolution_3(InDesc, in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out); // loop over horizontal direction - for(unsigned r = 1; r < wei_desc.GetLength(I3); ++r) + for(unsigned x = 1; x < wei_desc.GetLength(I3); ++x) { // read new weight threadwise_4d_tensor_copy(wei_desc, - p_wei + wei_desc.Get1dIndex(0, 0, s, r), + p_wei + wei_desc.Get1dIndex(0, 0, y, x), wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths()); @@ -189,7 +189,7 @@ __device__ void threadwise_direct_convolution_3(InDesc, // read new input threadwise_4d_tensor_copy( in_desc, - p_in + in_desc.Get1dIndex(0, 0, s, r + in_reg_desc.GetLength(I3) - 1), + p_in + in_desc.Get1dIndex(0, 0, y, x + in_reg_desc.GetLength(I3) - 1), in_reg_desc, p_in_reg + in_reg_desc.Get1dIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read), @@ -203,21 +203,21 @@ __device__ void threadwise_direct_convolution_3(InDesc, #elif 1 // this version read all input from LDS when filter moves // loop over vertical direction - for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s) + for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y) { // loop over horizontal direction - for(unsigned r = 0; r < wei_desc.GetLength(I3); ++r) + for(unsigned x = 0; x < wei_desc.GetLength(I3); ++x) { // read new weight threadwise_4d_tensor_copy(wei_desc, - p_wei + wei_desc.Get1dIndex(0, 0, s, r), + p_wei + wei_desc.Get1dIndex(0, 0, y, x), wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths()); // read new input threadwise_4d_tensor_copy(in_desc, - p_in + in_desc.Get1dIndex(0, 0, s, r), + p_in + in_desc.Get1dIndex(0, 0, y, x), in_reg_desc, p_in_reg, in_reg_desc.GetLengths()); From 7a970877138c03eba26bed1d5e7bd5da0b570899 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sat, 9 Mar 2019 12:59:47 -0600 Subject: [PATCH 2/3] refactor --- ...icit_gemm_convolution_1_chwn_csrk_khwn.hpp | 308 --------------- ...mm_convolution_1_chwn_csrk_khwn_padded.hpp | 293 -------------- ...icit_gemm_convolution_2_chwn_csrk_khwn.hpp | 259 ------------ driver/driver.hip.cpp | 120 +++--- ..._gemm_convolution_1_chwn_csrk_khwn.hip.hpp | 310 --------------- ...onvolution_1_chwn_csrk_khwn_padded.hip.hpp | 292 -------------- ...2_chwn_csrk_khwn_lds_double_buffer.hip.hpp | 369 ------------------ 7 files changed, 60 insertions(+), 1891 deletions(-) delete mode 100644 driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp delete mode 100644 driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp delete mode 100644 driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp delete mode 100644 src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp delete mode 100644 src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hip.hpp delete mode 100644 src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp diff --git a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp deleted file mode 100644 index 72905ce47a..0000000000 --- a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp +++ /dev/null @@ -1,308 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp" - -template -void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcsr, - OutDesc, - Tensor& out_nkhw, - unsigned nrepeat) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcsr_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr unsigned Hi = in_nchw_desc.GetLength(I2); - constexpr unsigned Wi = in_nchw_desc.GetLength(I3); - - constexpr unsigned N = out_nkhw_desc.GetLength(I0); - constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); - constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); - - constexpr unsigned K = wei_kcsr_desc.GetLength(I0); - constexpr unsigned C = wei_kcsr_desc.GetLength(I1); - constexpr unsigned Y = wei_kcsr_desc.GetLength(I2); - constexpr unsigned X = wei_kcsr_desc.GetLength(I3); - - // reorder weight - auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: "); - - Tensor wei_csrk(make_TensorDescriptor(wei_csrk_desc)); - - auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto y, auto x) { - wei_csrk(c, y, x, k) = wei_kcsr(k, c, y, x); - }; - - make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, Y, X)( - std::thread::hardware_concurrency()); - - // reorder input - auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: "); - - Tensor in_chwn(make_TensorDescriptor(in_chwn_desc)); - - auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) { - in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); - }; - - make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)( - std::thread::hardware_concurrency()); - - // output - auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: "); - - Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); - - std::size_t data_sz = sizeof(T); - DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace()); - DeviceMem wei_csrk_device_buf(data_sz * wei_csrk.mDesc.GetElementSpace()); - DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace()); - - in_chwn_device_buf.ToDevice(in_chwn.mData.data()); - wei_csrk_device_buf.ToDevice(wei_csrk.mData.data()); - out_khwn_device_buf.ToDevice(out_khwn.mData.data()); - -#if 1 - // for 3x3, 34x34 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; - - constexpr unsigned NPerThread = 8; - constexpr unsigned KPerThread = 8; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned InBlockCopy_ThreadPerDimC = 4; - constexpr unsigned InBlockCopy_ThreadPerDimH = 4; - constexpr unsigned InBlockCopy_ThreadPerDimW = 2; - constexpr unsigned InBlockCopy_ThreadPerDimN = 4; - constexpr unsigned InBlockCopyDataPerRead = 4; - - constexpr unsigned WeiBlockCopyDataPerRead = 4; - - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 2; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; - - constexpr unsigned OutThreadCopyDataPerWrite = 2; - - constexpr unsigned BlockSize = 128; -#elif 0 - // for 5x5, 36x36 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; - - constexpr unsigned NPerThread = 8; - constexpr unsigned KPerThread = 8; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; - - constexpr unsigned InBlockCopy_ThreadPerDimC = 2; - constexpr unsigned InBlockCopy_ThreadPerDimH = 2; - constexpr unsigned InBlockCopy_ThreadPerDimW = 4; - constexpr unsigned InBlockCopy_ThreadPerDimN = 4; - constexpr unsigned InBlockCopyDataPerRead = 4; - - constexpr unsigned WeiBlockCopyDataPerRead = 2; - - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 2; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; - - constexpr unsigned OutThreadCopyDataPerWrite = 2; - - constexpr unsigned BlockSize = 128; -#elif 0 - // 3x3 58x58, NKC = 64, 64, 256 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; - - constexpr unsigned InBlockCopyDataPerRead = 2; // not used, yet - constexpr unsigned WeiBlockCopyDataPerRead = 4; - - constexpr unsigned BlockSize = 128; -#elif 0 - // 3x3 58x58, NKC = 16,256,128 - constexpr unsigned NPerBlock = 8; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 4; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned BlockSize = 128; -#elif 0 - // for 7x7, 38x38 - constexpr unsigned NPerBlock = 8; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 1; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 4; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; - - constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet - constexpr unsigned WeiBlockCopyDataPerRead = 4; - - constexpr unsigned BlockSize = 128; -#elif 0 - // for 3x3, 56x56 - constexpr unsigned NPerBlock = 32; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 2; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned BlockSize = 128; -#elif 1 - // for 1x1, 28x28 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 128; - constexpr unsigned CPerBlock = 8; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 2; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned InBlockCopy_ThreadPerDimC = 8; - constexpr unsigned InBlockCopy_ThreadPerDimH = 2; - constexpr unsigned InBlockCopy_ThreadPerDimW = 2; - constexpr unsigned InBlockCopy_ThreadPerDimN = 4; - constexpr unsigned InBlockCopyDataPerRead = 4; - - constexpr unsigned WeiBlockCopyDataPerRead = 4; - - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 2; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; - - constexpr unsigned OutThreadCopyDataPerWrite = 2; - - constexpr unsigned BlockSize = 128; -#endif - - constexpr unsigned GridSize = - ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * - ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - for(unsigned i = 0; i < nrepeat; ++i) - { - float time = launch_kernel( - gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn, - InBlockCopyDataPerRead, - WeiBlockCopyDataPerRead, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - OutThreadCopyDataPerWrite>, - dim3(GridSize), - dim3(BlockSize), - static_cast(in_chwn_device_buf.GetDeviceBuffer()), - static_cast(wei_csrk_device_buf.GetDeviceBuffer()), - static_cast(out_khwn_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms\n", time); - usleep(std::min(time * 1000, float(10000))); - } - - out_khwn_device_buf.FromDevice(out_khwn.mData.data()); - - // reorder output - auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) { - out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); - }; - - make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)( - std::thread::hardware_concurrency()); -} diff --git a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp deleted file mode 100644 index de18deefc6..0000000000 --- a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp +++ /dev/null @@ -1,293 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hip.hpp" - -template -void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcsr, - OutDesc, - Tensor& out_nkhw, - LowerPads, - UpperPads, - unsigned nrepeat) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcsr_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr unsigned Hi = in_nchw_desc.GetLength(I2); - constexpr unsigned Wi = in_nchw_desc.GetLength(I3); - - constexpr unsigned N = out_nkhw_desc.GetLength(I0); - constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); - constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); - - constexpr unsigned K = wei_kcsr_desc.GetLength(I0); - constexpr unsigned C = wei_kcsr_desc.GetLength(I1); - constexpr unsigned Y = wei_kcsr_desc.GetLength(I2); - constexpr unsigned X = wei_kcsr_desc.GetLength(I3); - - // reorder weight - auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: "); - - Tensor wei_csrk(make_TensorDescriptor(wei_csrk_desc)); - - auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto y, auto x) { - wei_csrk(c, y, x, k) = wei_kcsr(k, c, y, x); - }; - - make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, Y, X)( - std::thread::hardware_concurrency()); - - // reorder input - auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: "); - - Tensor in_chwn(make_TensorDescriptor(in_chwn_desc)); - - auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) { - in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); - }; - - make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)( - std::thread::hardware_concurrency()); - - // output - auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: "); - - Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); - - std::size_t data_sz = sizeof(T); - DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace()); - DeviceMem wei_csrk_device_buf(data_sz * wei_csrk.mDesc.GetElementSpace()); - DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace()); - - in_chwn_device_buf.ToDevice(in_chwn.mData.data()); - wei_csrk_device_buf.ToDevice(wei_csrk.mData.data()); - out_khwn_device_buf.ToDevice(out_khwn.mData.data()); - -#if 0 - constexpr unsigned NPerBlock = 1; - constexpr unsigned KPerBlock = 1; - constexpr unsigned CPerBlock = 1; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; - - constexpr unsigned NPerThread = 1; - constexpr unsigned KPerThread = 1; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 1; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 1; - - constexpr unsigned BlockSize = 8; -#elif 1 - // for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; - - constexpr unsigned BlockSize = 128; -#elif 0 - // 3x3 58x58, NKC = 16,256,128 - constexpr unsigned NPerBlock = 8; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 4; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned BlockSize = 128; -#elif 0 - // for 5x5, 36x36 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned BlockSize = 128; -#elif 0 - // for 7x7, 38x38 - constexpr unsigned NPerBlock = 8; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 4; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned BlockSize = 128; -#elif 0 - // for 3x3, 56x56 - constexpr unsigned NPerBlock = 32; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 2; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned BlockSize = 128; -#elif 1 - // 3x3 56x56, NKC = 16,256,128, with padding - // 3x3 28x28, NKC = 16,512,256, with padding - // 3x3 20x84, NKC = 16,256,256, with padding - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 2; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 64; - - constexpr unsigned BlockSize = 128; -#elif 0 - // for 5x5 filter, 20x84 image, 1x1 padding - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 1; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned BlockSize = 128; -#elif 0 - // 5x5 filter, 28x28 image, 2x2 padding - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 32; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 4; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned BlockSize = 128; -#elif 0 - // for 1x1, 28x28 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 128; - constexpr unsigned CPerBlock = 8; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 2; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 2; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; - - constexpr unsigned BlockSize = 128; -#endif - - constexpr unsigned GridSize = - ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * - ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - for(unsigned i = 0; i < nrepeat; ++i) - { - float time = launch_kernel( - gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded, - dim3(GridSize), - dim3(BlockSize), - - static_cast(in_chwn_device_buf.GetDeviceBuffer()), - static_cast(wei_csrk_device_buf.GetDeviceBuffer()), - static_cast(out_khwn_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms\n", time); - usleep(std::min(time * 1000, float(10000))); - } - - out_khwn_device_buf.FromDevice(out_khwn.mData.data()); - - // reorder output - auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) { - out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); - }; - - make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)( - std::thread::hardware_concurrency()); -} diff --git a/driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp deleted file mode 100644 index db7802902d..0000000000 --- a/driver/device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp +++ /dev/null @@ -1,259 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp" - -template -void device_implicit_gemm_convolution_2_chwn_csrk_khwn(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcsr, - OutDesc, - Tensor& out_nkhw, - unsigned nrepeat) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcsr_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr unsigned N = in_nchw_desc.GetLength(I0); - constexpr unsigned Hi = in_nchw_desc.GetLength(I2); - constexpr unsigned Wi = in_nchw_desc.GetLength(I3); - - constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); - constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); - - constexpr unsigned K = wei_kcsr_desc.GetLength(I0); - constexpr unsigned C = wei_kcsr_desc.GetLength(I1); - constexpr unsigned Y = wei_kcsr_desc.GetLength(I2); - constexpr unsigned X = wei_kcsr_desc.GetLength(I3); - - constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1); - - // convert in_nchw to in_cnhw - auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: "); - - Tensor in_chwn(make_TensorDescriptor(in_chwn_desc)); - - make_ParallelTensorFunctor( - [&](auto n, auto c, auto hi, auto wi) { in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); }, - N, - C, - Hi, - Wi)(std::thread::hardware_concurrency()); - - // convert wei_kcsr to wei_csrk - auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: "); - - Tensor wei_csrk(make_TensorDescriptor(wei_csrk_desc)); - - make_ParallelTensorFunctor( - [&](auto k, auto c, auto y, auto x) { wei_csrk(c, y, x, k) = wei_kcsr(k, c, y, x); }, - K, - C, - Y, - X)(std::thread::hardware_concurrency()); - - // conver out_nkhw to out_knhw - auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: "); - - Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); - -#if 0 - // 3x3, 34x34 - // need to use register double buffer for GEMM - constexpr unsigned BPerBlock = 128; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; - - constexpr unsigned BPerThread = 8; - constexpr unsigned KPerThread = 8; - - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 2; - constexpr unsigned GemmNLevel1Cluster = 8; - constexpr unsigned GemmKPerThreadLoop = 1; - - constexpr unsigned GemmThreadPerColumnPerCluster = 8; - constexpr unsigned GemmThreadPerRowPerCluster = 8; - - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; - - constexpr unsigned InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; - - constexpr unsigned BlockSize = 128; -#elif 0 - // 1x1, 28x28, 64 threads - constexpr unsigned BPerBlock = 64; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 8; - - constexpr unsigned BPerThread = 8; - constexpr unsigned KPerThread = 8; - - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 2; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; - - constexpr unsigned GemmThreadPerColumnPerCluster = 8; - constexpr unsigned GemmThreadPerRowPerCluster = 8; - - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; - - constexpr unsigned InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; - - constexpr unsigned BlockSize = 64; -#elif 1 - // 1x1, 28x28, 128 threads, no lds-double-buffer - // 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128 - constexpr unsigned BPerBlock = 64; - constexpr unsigned KPerBlock = 128; - constexpr unsigned CPerBlock = 8; - - constexpr unsigned BPerThread = 8; - constexpr unsigned KPerThread = 8; - - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 4; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; - - constexpr unsigned GemmThreadPerColumnPerCluster = 8; - constexpr unsigned GemmThreadPerRowPerCluster = 8; - - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; - - constexpr unsigned InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; - - constexpr unsigned BlockSize = 128; -#elif 1 - // 1x1, 28x28, 256 thread - constexpr unsigned BPerBlock = 128; - constexpr unsigned KPerBlock = 128; - constexpr unsigned CPerBlock = 8; - - constexpr unsigned BPerThread = 8; - constexpr unsigned KPerThread = 8; - - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 4; - constexpr unsigned GemmMLevel1Cluster = 4; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; - - constexpr unsigned GemmThreadPerColumnPerCluster = 8; - constexpr unsigned GemmThreadPerRowPerCluster = 8; - - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; - - constexpr unsigned InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; - - constexpr unsigned BlockSize = 256; -#endif - - constexpr unsigned GridSize = - ((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - // mem - std::size_t data_sz = sizeof(T); - DeviceMem in_chwn_device_buf(data_sz * (in_chwn.mDesc.GetElementSpace() + BGhostRead + - BPerBlock)); // reserve extra space for BGhostRead - DeviceMem wei_csrk_device_buf(data_sz * wei_csrk.mDesc.GetElementSpace()); - DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace()); - - in_chwn_device_buf.ToDevice(in_chwn.mData.data()); - wei_csrk_device_buf.ToDevice(wei_csrk.mData.data()); - out_khwn_device_buf.ToDevice(out_khwn.mData.data()); - - for(unsigned i = 0; i < nrepeat; ++i) - { - float time = - launch_kernel(gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer< - GridSize, - BlockSize, - T, - decltype(in_chwn_desc), - decltype(wei_csrk_desc), - decltype(out_khwn_desc), - BPerBlock, - KPerBlock, - CPerBlock, - BPerThread, - KPerThread, - GemmThreadPerColumnPerCluster, - GemmThreadPerRowPerCluster, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - InBlockCopyThreadPerDim0, - InBlockCopyThreadPerDim1, - WeiBlockCopyThreadPerDim0, - WeiBlockCopyThreadPerDim1, - InBlockCopyDataPerRead, - WeiBlockCopyDataPerRead>, - dim3(GridSize), - dim3(BlockSize), - static_cast(in_chwn_device_buf.GetDeviceBuffer()), - static_cast(wei_csrk_device_buf.GetDeviceBuffer()), - static_cast(out_khwn_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms\n", time); - usleep(std::min(time * 1000, float(10000))); - } - - out_khwn_device_buf.FromDevice(out_khwn.mData.data()); - - // convert out_khwn to out_nkhw - make_ParallelTensorFunctor( - [&](auto n, auto k, auto ho, auto wo) { out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); }, - N, - K, - Ho, - Wo)(std::thread::hardware_concurrency()); -} diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index 3b18645c4b..2cd1ac4b24 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -9,9 +9,9 @@ #include "conv_common.hip.hpp" #include "device_direct_convolution_1.hpp" #include "device_direct_convolution_2.hpp" -#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp" -#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp" -#include "device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp" +#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp" +#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp" +#include "device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp" struct GeneratorTensor_1 { @@ -108,7 +108,7 @@ auto make_TensorDescriptor(TConstTensorDesc) template void host_direct_convolution( - const Tensor& in_nchw, const Tensor& wei_kcsr, Tensor& out, LowerPads, UpperPads) + const Tensor& in_nchw, const Tensor& wei_kcyx, Tensor& out, LowerPads, UpperPads) { unsigned h_pad_low = LowerPads{}.Get(Number<0>{}); unsigned w_pad_low = LowerPads{}.Get(Number<1>{}); @@ -118,18 +118,18 @@ void host_direct_convolution( auto f = [&](auto n, auto k, auto ho, auto wo) { double v = 0; - for(int c = 0; c < wei_kcsr.mDesc.GetLengths()[1]; ++c) + for(int c = 0; c < wei_kcyx.mDesc.GetLengths()[1]; ++c) { - for(int y = 0; y < wei_kcsr.mDesc.GetLengths()[2]; ++y) + for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y) { int hi = ho + y - h_pad_low; - for(int x = 0; x < wei_kcsr.mDesc.GetLengths()[3]; ++x) + for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x) { int wi = wo + x - w_pad_low; if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && wi < in_nchw.mDesc.GetLengths()[3]) { - v += in_nchw(n, c, hi, wi) * wei_kcsr(k, c, y, x); + v += in_nchw(n, c, hi, wi) * wei_kcyx(k, c, y, x); } } } @@ -148,7 +148,7 @@ void host_direct_convolution( template void host_winograd_3x3_convolution( - const Tensor& in_nchw, const Tensor& wei_kcsr, Tensor& out, LowerPads, UpperPads) + const Tensor& in_nchw, const Tensor& wei_kcyx, Tensor& out, LowerPads, UpperPads) { constexpr std::size_t HoPerTile = 2; constexpr std::size_t WoPerTile = 2; @@ -158,9 +158,9 @@ void host_winograd_3x3_convolution( std::size_t HI = in_nchw.mDesc.GetLengths()[2]; std::size_t WI = in_nchw.mDesc.GetLengths()[3]; - std::size_t K = wei_kcsr.mDesc.GetLengths()[0]; - std::size_t Y = wei_kcsr.mDesc.GetLengths()[2]; - std::size_t X = wei_kcsr.mDesc.GetLengths()[3]; + std::size_t K = wei_kcyx.mDesc.GetLengths()[0]; + std::size_t Y = wei_kcyx.mDesc.GetLengths()[2]; + std::size_t X = wei_kcyx.mDesc.GetLengths()[3]; std::size_t HO = out.mDesc.GetLengths()[2]; std::size_t WO = out.mDesc.GetLengths()[3]; @@ -259,49 +259,49 @@ void host_winograd_3x3_convolution( }; auto f_wei_transform = [&](auto k, auto c) { - wei_transform(k, c, 0, 0) = wei_kcsr(k, c, 0, 0); + wei_transform(k, c, 0, 0) = wei_kcyx(k, c, 0, 0); wei_transform(k, c, 0, 1) = - 0.5 * wei_kcsr(k, c, 0, 0) + 0.5 * wei_kcsr(k, c, 0, 1) + 0.5 * wei_kcsr(k, c, 0, 2); + 0.5 * wei_kcyx(k, c, 0, 0) + 0.5 * wei_kcyx(k, c, 0, 1) + 0.5 * wei_kcyx(k, c, 0, 2); wei_transform(k, c, 0, 2) = - 0.5 * wei_kcsr(k, c, 0, 0) - 0.5 * wei_kcsr(k, c, 0, 1) + 0.5 * wei_kcsr(k, c, 0, 2); - wei_transform(k, c, 0, 3) = wei_kcsr(k, c, 0, 2); + 0.5 * wei_kcyx(k, c, 0, 0) - 0.5 * wei_kcyx(k, c, 0, 1) + 0.5 * wei_kcyx(k, c, 0, 2); + wei_transform(k, c, 0, 3) = wei_kcyx(k, c, 0, 2); wei_transform(k, c, 1, 0) = - 0.5 * wei_kcsr(k, c, 0, 0) + 0.5 * wei_kcsr(k, c, 1, 0) + 0.5 * wei_kcsr(k, c, 2, 0); - wei_transform(k, c, 1, 1) = 0.25 * wei_kcsr(k, c, 0, 0) + 0.25 * wei_kcsr(k, c, 0, 1) + - 0.25 * wei_kcsr(k, c, 0, 2) + 0.25 * wei_kcsr(k, c, 1, 0) + - 0.25 * wei_kcsr(k, c, 1, 1) + 0.25 * wei_kcsr(k, c, 1, 2) + - 0.25 * wei_kcsr(k, c, 2, 0) + 0.25 * wei_kcsr(k, c, 2, 1) + - 0.25 * wei_kcsr(k, c, 2, 2); - wei_transform(k, c, 1, 2) = 0.25 * wei_kcsr(k, c, 0, 0) - 0.25 * wei_kcsr(k, c, 0, 1) + - 0.25 * wei_kcsr(k, c, 0, 2) + 0.25 * wei_kcsr(k, c, 1, 0) - - 0.25 * wei_kcsr(k, c, 1, 1) + 0.25 * wei_kcsr(k, c, 1, 2) + - 0.25 * wei_kcsr(k, c, 2, 0) - 0.25 * wei_kcsr(k, c, 2, 1) + - 0.25 * wei_kcsr(k, c, 2, 2); + 0.5 * wei_kcyx(k, c, 0, 0) + 0.5 * wei_kcyx(k, c, 1, 0) + 0.5 * wei_kcyx(k, c, 2, 0); + wei_transform(k, c, 1, 1) = 0.25 * wei_kcyx(k, c, 0, 0) + 0.25 * wei_kcyx(k, c, 0, 1) + + 0.25 * wei_kcyx(k, c, 0, 2) + 0.25 * wei_kcyx(k, c, 1, 0) + + 0.25 * wei_kcyx(k, c, 1, 1) + 0.25 * wei_kcyx(k, c, 1, 2) + + 0.25 * wei_kcyx(k, c, 2, 0) + 0.25 * wei_kcyx(k, c, 2, 1) + + 0.25 * wei_kcyx(k, c, 2, 2); + wei_transform(k, c, 1, 2) = 0.25 * wei_kcyx(k, c, 0, 0) - 0.25 * wei_kcyx(k, c, 0, 1) + + 0.25 * wei_kcyx(k, c, 0, 2) + 0.25 * wei_kcyx(k, c, 1, 0) - + 0.25 * wei_kcyx(k, c, 1, 1) + 0.25 * wei_kcyx(k, c, 1, 2) + + 0.25 * wei_kcyx(k, c, 2, 0) - 0.25 * wei_kcyx(k, c, 2, 1) + + 0.25 * wei_kcyx(k, c, 2, 2); wei_transform(k, c, 1, 3) = - 0.5 * wei_kcsr(k, c, 0, 2) + 0.5 * wei_kcsr(k, c, 1, 2) + 0.5 * wei_kcsr(k, c, 2, 2); + 0.5 * wei_kcyx(k, c, 0, 2) + 0.5 * wei_kcyx(k, c, 1, 2) + 0.5 * wei_kcyx(k, c, 2, 2); wei_transform(k, c, 2, 0) = - 0.5 * wei_kcsr(k, c, 0, 0) - 0.5 * wei_kcsr(k, c, 1, 0) + 0.5 * wei_kcsr(k, c, 2, 0); - wei_transform(k, c, 2, 1) = 0.25 * wei_kcsr(k, c, 0, 0) + 0.25 * wei_kcsr(k, c, 0, 1) + - 0.25 * wei_kcsr(k, c, 0, 2) - 0.25 * wei_kcsr(k, c, 1, 0) - - 0.25 * wei_kcsr(k, c, 1, 1) - 0.25 * wei_kcsr(k, c, 1, 2) + - 0.25 * wei_kcsr(k, c, 2, 0) + 0.25 * wei_kcsr(k, c, 2, 1) + - 0.25 * wei_kcsr(k, c, 2, 2); - wei_transform(k, c, 2, 2) = 0.25 * wei_kcsr(k, c, 0, 0) - 0.25 * wei_kcsr(k, c, 0, 1) + - 0.25 * wei_kcsr(k, c, 0, 2) - 0.25 * wei_kcsr(k, c, 1, 0) + - 0.25 * wei_kcsr(k, c, 1, 1) - 0.25 * wei_kcsr(k, c, 1, 2) + - 0.25 * wei_kcsr(k, c, 2, 0) - 0.25 * wei_kcsr(k, c, 2, 1) + - 0.25 * wei_kcsr(k, c, 2, 2); + 0.5 * wei_kcyx(k, c, 0, 0) - 0.5 * wei_kcyx(k, c, 1, 0) + 0.5 * wei_kcyx(k, c, 2, 0); + wei_transform(k, c, 2, 1) = 0.25 * wei_kcyx(k, c, 0, 0) + 0.25 * wei_kcyx(k, c, 0, 1) + + 0.25 * wei_kcyx(k, c, 0, 2) - 0.25 * wei_kcyx(k, c, 1, 0) - + 0.25 * wei_kcyx(k, c, 1, 1) - 0.25 * wei_kcyx(k, c, 1, 2) + + 0.25 * wei_kcyx(k, c, 2, 0) + 0.25 * wei_kcyx(k, c, 2, 1) + + 0.25 * wei_kcyx(k, c, 2, 2); + wei_transform(k, c, 2, 2) = 0.25 * wei_kcyx(k, c, 0, 0) - 0.25 * wei_kcyx(k, c, 0, 1) + + 0.25 * wei_kcyx(k, c, 0, 2) - 0.25 * wei_kcyx(k, c, 1, 0) + + 0.25 * wei_kcyx(k, c, 1, 1) - 0.25 * wei_kcyx(k, c, 1, 2) + + 0.25 * wei_kcyx(k, c, 2, 0) - 0.25 * wei_kcyx(k, c, 2, 1) + + 0.25 * wei_kcyx(k, c, 2, 2); wei_transform(k, c, 2, 3) = - 0.5 * wei_kcsr(k, c, 0, 2) - 0.5 * wei_kcsr(k, c, 1, 2) + 0.5 * wei_kcsr(k, c, 2, 2); + 0.5 * wei_kcyx(k, c, 0, 2) - 0.5 * wei_kcyx(k, c, 1, 2) + 0.5 * wei_kcyx(k, c, 2, 2); - wei_transform(k, c, 3, 0) = wei_kcsr(k, c, 2, 0); + wei_transform(k, c, 3, 0) = wei_kcyx(k, c, 2, 0); wei_transform(k, c, 3, 1) = - 0.5 * wei_kcsr(k, c, 2, 0) + 0.5 * wei_kcsr(k, c, 2, 1) + 0.5 * wei_kcsr(k, c, 2, 2); + 0.5 * wei_kcyx(k, c, 2, 0) + 0.5 * wei_kcyx(k, c, 2, 1) + 0.5 * wei_kcyx(k, c, 2, 2); wei_transform(k, c, 3, 2) = - 0.5 * wei_kcsr(k, c, 2, 0) - 0.5 * wei_kcsr(k, c, 2, 1) + 0.5 * wei_kcsr(k, c, 2, 2); - wei_transform(k, c, 3, 3) = wei_kcsr(k, c, 2, 2); + 0.5 * wei_kcyx(k, c, 2, 0) - 0.5 * wei_kcyx(k, c, 2, 1) + 0.5 * wei_kcyx(k, c, 2, 2); + wei_transform(k, c, 3, 3) = wei_kcyx(k, c, 2, 2); }; auto f_out_transform = [&](auto n, auto k, auto htile, auto wtile) { @@ -569,16 +569,16 @@ int main(int argc, char* argv[]) auto upper_pads = Sequence{}; auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence{}); - auto wei_kcsr_desc = make_ConstantTensorDescriptor(Sequence{}); + auto wei_kcyx_desc = make_ConstantTensorDescriptor(Sequence{}); auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor( - in_nchw_desc, wei_kcsr_desc, lower_pads, upper_pads); + in_nchw_desc, wei_kcyx_desc, lower_pads, upper_pads); ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); - ostream_ConstantTensorDescriptor(wei_kcsr_desc, std::cout << "wei_kcsr_desc: "); + ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); Tensor in_nchw(make_TensorDescriptor(in_nchw_desc)); - Tensor wei_kcsr(make_TensorDescriptor(wei_kcsr_desc)); + Tensor wei_kcyx(make_TensorDescriptor(wei_kcyx_desc)); Tensor out_nkhw_host(make_TensorDescriptor(out_nkhw_desc)); Tensor out_nkhw_device(make_TensorDescriptor(out_nkhw_desc)); @@ -597,13 +597,13 @@ int main(int argc, char* argv[]) { #if 0 in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); #elif 1 in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); #elif 1 in_nchw.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread); - wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); #endif } @@ -613,17 +613,17 @@ int main(int argc, char* argv[]) #elif 0 device_direct_convolution_2 #elif 1 - device_implicit_gemm_convolution_1_chwn_csrk_khwn + device_implicit_gemm_convolution_1_chwn_cyxk_khwn #elif 0 - device_implicit_gemm_convolution_2_chwn_csrk_khwn + device_implicit_gemm_convolution_2_chwn_cyxk_khwn #endif - (in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat); + (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); #elif 1 - device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(in_nchw_desc, + device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc, in_nchw, - wei_kcsr_desc, - wei_kcsr, + wei_kcyx_desc, + wei_kcyx, out_nkhw_desc, out_nkhw_device, lower_pads, @@ -636,18 +636,18 @@ int main(int argc, char* argv[]) #if 1 if(Y == 3 && X == 3) { - host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); + host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads); } else { - host_direct_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); + host_direct_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads); } check_error(out_nkhw_host, out_nkhw_device); #endif #if 0 LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; - LogRange(std::cout << "wei_kcsr: ", wei_kcsr.mData, ",") << std::endl; + LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; #endif diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp deleted file mode 100644 index a6cb6f60e1..0000000000 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp +++ /dev/null @@ -1,310 +0,0 @@ -#pragma once -#include "common.hip.hpp" -#include "ConstantTensorDescriptor.hip.hpp" -#include "ConstantMatrixDescriptor.hip.hpp" -#include "blockwise_4d_tensor_op.hip.hpp" -#include "blockwise_2d_tensor_op.hip.hpp" -#include "threadwise_nd_tensor_op.hip.hpp" -#include "threadwise_4d_tensor_op.hip.hpp" -#include "blockwise_gemm.hip.hpp" - -template -__global__ void -gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) -{ - // NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N] - // for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N" - // if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock - static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0"); - static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock, - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_chwn_global_desc = InGlobalDesc{}; - constexpr auto wei_csrk_global_desc = WeiGlobalDesc{}; - constexpr auto out_khwn_global_desc = OutGlobalDesc{}; - - constexpr unsigned C = in_chwn_global_desc.GetLength(I0); - - constexpr unsigned K = out_khwn_global_desc.GetLength(I0); - constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); - constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); - constexpr unsigned N = out_khwn_global_desc.GetLength(I3); - - constexpr unsigned Y = wei_csrk_global_desc.GetLength(I1); - constexpr unsigned X = wei_csrk_global_desc.GetLength(I2); - - constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; - constexpr unsigned WiPerBlock = WoPerBlock + X - 1; - - // divide block work: [K, Ho, Wo, N] - constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; - constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; - constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock; - - const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); - unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); - const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork); - itmp -= h_block_work_id * (WBlockWork * NBlockWork); - const unsigned w_block_work_id = itmp / NBlockWork; - const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork; - - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; - const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; - const unsigned n_block_data_begin = n_block_work_id * NPerBlock; - - const unsigned hi_block_data_begin = ho_block_data_begin; - const unsigned wi_block_data_begin = wo_block_data_begin; - - // flattend (2d) tensor view of gridwise weight - constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of blockwise input and weight in LDS - // be careful of alignment - constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // tensor view of threadwise output in register - constexpr auto out_khwn_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - - // blockwise copy - // input: format is [C, Hi, Wi, N] - const auto blockwise_in_copy = Blockwise4dTensorCopy3{}; - - // blockwise wei copy - // format is [CPerBlock*Y*X,KPerBlock] - const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}; - - // a series of blockwise batched GEMM - // C_matrix += transpose(A_matrix) * B_matrix - // A_matrix and B_matrix saved in LDS, C_matrix saved in register - // A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K] - // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] - // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] - constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_cxwn_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - constexpr auto c_kxwn_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - const auto blockwise_batch_gemm = BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< - BlockSize, - decltype(a_cxk_block_mtx_desc), - decltype(b_cxwn_block_mtx_desc), - decltype(c_kxwn_thread_mtx_desc), - 0, - in_chwn_block_desc.GetStride(I1), - out_khwn_thread_desc.GetStride(I1), - HoPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - HoPerThread>{}; - - // LDS: be careful of alignment - constexpr unsigned in_block_size = - in_chwn_block_desc.GetElementSpace(Number{}); - - constexpr unsigned wei_block_size = - wei_csrk_block_desc.GetElementSpace(Number{}); - - constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead - ? InBlockCopyDataPerRead - : WeiBlockCopyDataPerRead; - - __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; - __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; - - // register - Float p_out_thread[out_khwn_thread_desc.GetElementSpace()]; - - // set threadwise output tensor to 0 - threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread); - - const Float* p_in_global_block_begin = - p_in_global + in_chwn_global_desc.Get1dIndex( - 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); - - const Float* p_wei_global_block_begin = - p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); - - for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, - p_in_global_block_begin += CPerBlock * in_chwn_global_desc.GetStride(I0), - p_wei_global_block_begin += CPerBlock * wei_csrk_global_desc.GetStride(I0), - __syncthreads()) - { - // input: global mem to LDS - blockwise_in_copy.Run(p_in_global_block_begin, p_in_block); - - // weight: global mem to LDS - blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block); - - __syncthreads(); - - // a series of batched GEMM - for(unsigned y = 0; y < Y; ++y) - { - for(unsigned x = 0; x < X; ++x) - { - blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0), - p_out_thread, - [](auto& acc, const auto&& v) { acc += v; }); - } - } - } - - // output: register to global mem, -#if 0 - const auto c_thread_mtx_begin = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - for(unsigned k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k) - { - for(unsigned ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho) - { - for(unsigned wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo) - { - for(unsigned n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n) - { - const unsigned b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n); - - const auto c_thread_mtx_distance = - blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b); - - const unsigned ho_thread = - c_thread_mtx_begin.batch + c_thread_mtx_distance.batch; - const unsigned k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row; - const unsigned b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col; - - const unsigned wo_thread = b_thread / NPerBlock; - const unsigned n_thread = b_thread % NPerBlock; - - p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread, - ho_block_data_begin + ho_thread, - wo_block_data_begin + wo_thread, - n_block_data_begin + n_thread)] = - p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)]; - } - } - } - } -#elif 1 - const auto c_thread_mtx_begin = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const unsigned k_thread_data_begin = c_thread_mtx_begin.row; - const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch; - const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; - const unsigned n_thread_data_begin = c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; - - // this is for v2 GEMM - // output is a 8d tensor - if(NPerThread < NPerBlock && WoPerThread == 1) - { - constexpr unsigned N1_ = GemmNPerThreadSubC; - constexpr unsigned W1_ = WoPerBlock / ((WoPerThread * NPerThread) / GemmNPerThreadSubC); - constexpr unsigned K2_ = GemmMPerThreadSubC; - constexpr unsigned K1_ = KPerBlock / KPerThread; - - constexpr auto out_8d_global_desc = make_ConstantTensorDescriptor( - Sequence{}); - - constexpr auto out_8d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc"); - print_ConstantTensorDescriptor(out_8d_thread_desc, "out_8d_thread_desc"); - - print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); - print_ConstantTensorDescriptor(out_8d_global_desc, "out_8d_global_desc"); - } -#endif - - threadwise_8d_tensor_copy(out_8d_thread_desc, - p_out_thread, - out_8d_global_desc, - p_out_global + out_khwn_global_desc.Get1dIndex( - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin), - out_8d_thread_desc.GetLengths(), - Number{}); - } - else if(NPerThread == NPerBlock) - { - // not implemented yet - assert(false); - } - else - { - assert(false); - } -#endif -} diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hip.hpp deleted file mode 100644 index 7f6d54143b..0000000000 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hip.hpp +++ /dev/null @@ -1,292 +0,0 @@ -#pragma once -#include "common.hip.hpp" -#include "ConstantTensorDescriptor.hip.hpp" -#include "ConstantMatrixDescriptor.hip.hpp" -#include "blockwise_4d_tensor_op.hip.hpp" -#include "blockwise_2d_tensor_op.hip.hpp" -#include "threadwise_4d_tensor_op.hip.hpp" -#include "blockwise_gemm.hip.hpp" - -template -__global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded( - const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) -{ - // NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N] - // for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N" - // if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock - static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0"); - static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock, - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_chwn_global_desc = InGlobalDesc{}; - constexpr auto wei_csrk_global_desc = WeiGlobalDesc{}; - constexpr auto out_khwn_global_desc = OutGlobalDesc{}; - - constexpr unsigned C = in_chwn_global_desc.GetLength(I0); - - constexpr unsigned K = out_khwn_global_desc.GetLength(I0); - constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); - constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); - constexpr unsigned N = out_khwn_global_desc.GetLength(I3); - - constexpr unsigned Y = wei_csrk_global_desc.GetLength(I1); - constexpr unsigned X = wei_csrk_global_desc.GetLength(I2); - - constexpr unsigned HPadLow = LowerPads{}.Get(I0); - constexpr unsigned WPadLow = LowerPads{}.Get(I1); - - constexpr unsigned HPadUp = UpperPads{}.Get(I0); - constexpr unsigned WPadUp = UpperPads{}.Get(I1); - - constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; - constexpr unsigned WiPerBlock = WoPerBlock + X - 1; - - // divide block work: [K, Ho, Wo, N] - constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; - constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; - constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock; - - const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); - unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); - const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork); - itmp -= h_block_work_id * (WBlockWork * NBlockWork); - const unsigned w_block_work_id = itmp / NBlockWork; - const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork; - - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; - const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; - const unsigned n_block_data_begin = n_block_work_id * NPerBlock; - - // flattened (2d) tensor view of wei in global mem - constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of blockwise input and weight in LDS - constexpr auto in_chwn_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - constexpr auto wei_csrk_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - // flattened (2d) tensor view of wei in LDS - constexpr auto wei_ek_block_desc = - make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of threadwise output in register - constexpr auto out_hkwn_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc"); - print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc"); - print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc"); - } -#endif - - // blockwise copy - // input: format is [C, Hi, Wi, N] - const unsigned h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0; - const unsigned w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0; - - const unsigned h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0; - const unsigned w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0; - -#if 0 - if(get_thread_local_1d_id() == 0) - ; - { - printf( - "%u %u, h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u\n", - get_block_1d_id(), - get_thread_local_1d_id(), - h_block_pad_low, - w_block_pad_low, - h_block_pad_up, - w_block_pad_up); - } -#endif - - constexpr auto blockwise_in_copy = - BlockwiseChwnTensorCopyPadded{}; - -#if 0 - // weight: format is [C,Y,X,K] - constexpr auto blockwise_wei_copy = - Blockwise4dTensorCopy1{}; -#elif 0 - // weight: format is [C*Y*X,K] - constexpr auto blockwise_wei_copy = - Blockwise2dTensorCopy1{}; -#elif 1 - // weight: format is [C*Y*X,K] - const auto blockwise_wei_copy = Blockwise2dTensorCopy2{}; -#endif - - // a series of blockwise batched GEMM - // C_matrix += transpose(A_matrix) * B_matrix - // A_matrix and B_matrix saved in LDS, C_matrix saved in register - // A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K] - // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] - // C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N] - constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_cxwn_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - constexpr auto c_kxwn_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - const auto blockwise_batch_gemm = - Blockwise1dStridedBatchedGemmBlockABlockBThreadC{}; - - // LDS - constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); - constexpr unsigned wei_block_size = wei_csrk_block_desc.GetElementSpace(); - - __shared__ Float p_in_block[in_block_size]; - __shared__ Float p_wei_block[wei_block_size]; - - // register - Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()]; - - // set threadwise output tensor to 0 - threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread); - - const Float* p_wei_global_block_begin = - p_wei_global + wei_ek_global_desc.Get1dIndex(0, k_block_data_begin); - - for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, - p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0), - __syncthreads()) - { -#if 1 - // input: global mem to LDS, - blockwise_in_copy.Run(p_in_global, - c_block_data_begin, - ho_block_data_begin, - wo_block_data_begin, - n_block_data_begin, - p_in_block, - h_block_pad_low, - w_block_pad_low, - h_block_pad_up, - w_block_pad_up); -#endif - -#if 1 - // weight: global mem to LDS, - blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block); -#endif - - __syncthreads(); - - // a series of batched GEMM - for(unsigned y = 0; y < Y; ++y) - { - for(unsigned x = 0; x < X; ++x) - { - auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; - - blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0), - p_out_thread, - f_accum); - } - } - } - - const auto matrix_c_index = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const unsigned ho_thread_data_begin = matrix_c_index.batch; - const unsigned k_thread_data_begin = matrix_c_index.row; - const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock; - const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock; - -#if 0 - printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n", - get_block_1d_id(), get_thread_local_1d_id(), - ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin, - ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin, - p_out_thread[0]); -#endif - - // output: register to global mem, - // convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N] - constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{}; - - threadwise_4d_tensor_copy_reorder_by_get_dst_from_src( - out_hkwn_thread_desc, - p_out_thread, - out_khwn_global_desc, - p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin), - out_hkwn_thread_desc.GetLengths(), - reorder_khwn_from_hkwn); -} diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp deleted file mode 100644 index b6498f8175..0000000000 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp +++ /dev/null @@ -1,369 +0,0 @@ -#pragma once -#include "common.hip.hpp" -#include "ConstantTensorDescriptor.hip.hpp" -#include "ConstantMatrixDescriptor.hip.hpp" -#include "blockwise_4d_tensor_op.hip.hpp" -#include "blockwise_2d_tensor_op.hip.hpp" -#include "threadwise_2d_tensor_op.hip.hpp" -#include "blockwise_gemm.hip.hpp" - -// define B = flatten(N, Hi, Wi) -template -__global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer( - const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_chwn_global_desc = InGlobalDesc{}; - constexpr auto wei_csrk_global_desc = WeiGlobalDesc{}; - constexpr auto out_khwn_global_desc = OutGlobalDesc{}; - - constexpr unsigned C = in_chwn_global_desc.GetLength(I0); - constexpr unsigned Hi = in_chwn_global_desc.GetLength(I1); - constexpr unsigned Wi = in_chwn_global_desc.GetLength(I2); - constexpr unsigned N = in_chwn_global_desc.GetLength(I3); - - constexpr unsigned K = out_khwn_global_desc.GetLength(I0); - constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); - constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); - - constexpr unsigned Y = wei_csrk_global_desc.GetLength(I1); - constexpr unsigned X = wei_csrk_global_desc.GetLength(I2); - - constexpr unsigned B = N * Hi * Wi; - constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1); - - // divide block work by 2d: [K, B] - constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock; - - const unsigned k_block_work_id = get_block_1d_id() / BBlockWork; - const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; - - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned b_block_data_begin = b_block_work_id * BPerBlock; - - // flattend (2d) tensor view of gridwise input - constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); - constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of blockwise input and weight - // be careful of alignment - constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // tensor view of threadwise output in register - constexpr auto out_kb_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc"); - print_ConstantTensorDescriptor(wei_csrk_global_desc, "wei_csrk_global_desc"); - print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); - - print_ConstantTensorDescriptor(in_cb_global_desc, "in_cb_global_desc"); - print_ConstantTensorDescriptor(wei_ek_global_desc, "wei_ek_global_desc"); - - print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc"); - print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc"); - print_ConstantTensorDescriptor(wei_ek_block_desc, "wei_ek_block_desc"); - print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc"); - - printf("KPerBlock %u\n", KPerBlock); - } -#endif - -// blockwise in copy -// formmat is [CPerBlock,BPerBlock + BGhostRead] -#if 0 - const auto blockwise_in_copy = - Blockwise2dTensorCopy1{}; -#elif 0 - const auto blockwise_in_copy = Blockwise2dTensorCopy2{}; -#elif 1 - const auto blockwise_in_copy = Blockwise2dTensorCopy3{}; -#endif - -// blockwise wei copy -// format is [CPerBlock*Y*X,KPerBlock] -#if 0 - const auto blockwise_wei_copy = - Blockwise2dTensorCopy1{}; -#elif 0 - const auto blockwise_wei_copy = Blockwise2dTensorCopy2{}; -#elif 1 - const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}; -#endif - - // a series of blockwise GEMM - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx and b_mtx saved in LDS, c_mtx saved in register - // a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K] - // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] - // c_mtx[K,B] is out_block[K,B] - constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto c_kxb_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); - -#if 0 - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC{}; -#else - const auto blockwise_gemm = - BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2{}; -#endif - - // LDS: be careful of alignment - constexpr unsigned in_block_size = - in_cb_block_desc.GetElementSpace(Number{}); - - constexpr unsigned wei_block_size = - wei_csrk_block_desc.GetElementSpace(Number{}); - - constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead - ? InBlockCopyDataPerRead - : WeiBlockCopyDataPerRead; - - // LDS double buffer - __shared__ Float p_in_block_0[max_align * ((in_block_size + max_align - 1) / max_align)]; - __shared__ Float p_wei_block_0[max_align * ((wei_block_size + max_align - 1) / max_align)]; - - __shared__ Float p_in_block_1[max_align * ((in_block_size + max_align - 1) / max_align)]; - __shared__ Float p_wei_block_1[max_align * ((wei_block_size + max_align - 1) / max_align)]; - - const Float* p_in_global_block_offset = - p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); - - const Float* p_wei_global_block_offset = - p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); - - // preload data into LDS - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_0); - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_0); - - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0); - p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0); - - // register - Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; - - // set threadwise output tensor to 0 - threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); - - bool even_loop = true; - - for(unsigned c_block_data_begin = 0; c_block_data_begin + CPerBlock < C; - c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0), - even_loop = !even_loop) - { - Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; - Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; - - Float* p_in_block_next = even_loop ? p_in_block_1 : p_in_block_0; - Float* p_wei_block_next = even_loop ? p_wei_block_1 : p_wei_block_0; - - __syncthreads(); - -// load next data -#if 1 - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next); - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next); -#elif 1 - Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; - Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; - - blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, - p_in_register_clipboard); - - blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, - p_wei_register_clipboard); -#endif - - // compute on current data - // a series of GEMM - for(unsigned y = 0; y < Y; ++y) - { - for(unsigned x = 0; x < X; ++x) - { - auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; -#if 1 - blockwise_gemm.Run -#else - blockwise_gemm.Run_RegisterDoubleBuffer -#endif - (p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_now + y * Wi + x, - p_out_thread, - f_accum); - } - } - -#if 0 - blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next); - blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next); -#endif - } - - // last computation - { - Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; - Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; - - __syncthreads(); - - for(unsigned y = 0; y < Y; ++y) - { - for(unsigned x = 0; x < X; ++x) - { - auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; -#if 0 - blockwise_gemm.Run -#else - blockwise_gemm.Run_RegisterDoubleBuffer -#endif - (p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_now + y * Wi + x, - p_out_thread, - f_accum); - } - } - } - - // output: register to global mem, - const auto c_thread_mtx_begin = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; - const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; - -#if 0 - if(get_block_1d_id() == 0) - { - printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", - get_block_1d_id(), - get_thread_local_1d_id(), - matrix_c_index.row, - matrix_c_index.col, - k_data_begin, - b_data_begin, - p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); - } -#endif - - for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) - { - for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) - { - const auto c_thread_mtx_distance = - blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); - - unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row; - unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col; - - unsigned h_data = b_data / (Wi * N); - unsigned itmp = b_data - h_data * (Wi * N); - unsigned w_data = itmp / N; - unsigned n_data = itmp - w_data * N; - - if(n_data < N && h_data < Ho && w_data < Wo) - { - p_out_global[out_khwn_global_desc.Get1dIndex(k_data, h_data, w_data, n_data)] = - p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; - } - } - } -} From f54cad7d4f34787356c3e82f05d82b3f2e5db9ca Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sat, 9 Mar 2019 13:39:24 -0600 Subject: [PATCH 3/3] refactor --- ...icit_gemm_convolution_1_chwn_cyxk_khwn.hpp | 308 +++++++++++++++ ...mm_convolution_1_chwn_cyxk_khwn_padded.hpp | 293 ++++++++++++++ ...icit_gemm_convolution_2_chwn_cyxk_khwn.hpp | 259 ++++++++++++ ..._gemm_convolution_1_chwn_cyxk_khwn.hip.hpp | 310 +++++++++++++++ ...onvolution_1_chwn_cyxk_khwn_padded.hip.hpp | 292 ++++++++++++++ ...2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp | 369 ++++++++++++++++++ 6 files changed, 1831 insertions(+) create mode 100644 driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp create mode 100644 driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp create mode 100644 driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp create mode 100644 src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp create mode 100644 src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp create mode 100644 src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp diff --git a/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp new file mode 100644 index 0000000000..2c27080670 --- /dev/null +++ b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp @@ -0,0 +1,308 @@ +#pragma once +#include +#include "device.hpp" +#include "gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp" + +template +void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcyx, + OutDesc, + Tensor& out_nkhw, + unsigned nrepeat) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_nchw_desc = InDesc{}; + constexpr auto wei_kcyx_desc = WeiDesc{}; + constexpr auto out_nkhw_desc = OutDesc{}; + + constexpr unsigned Hi = in_nchw_desc.GetLength(I2); + constexpr unsigned Wi = in_nchw_desc.GetLength(I3); + + constexpr unsigned N = out_nkhw_desc.GetLength(I0); + constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); + constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); + + constexpr unsigned K = wei_kcyx_desc.GetLength(I0); + constexpr unsigned C = wei_kcyx_desc.GetLength(I1); + constexpr unsigned Y = wei_kcyx_desc.GetLength(I2); + constexpr unsigned X = wei_kcyx_desc.GetLength(I3); + + // reorder weight + auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: "); + + Tensor wei_cyxk(make_TensorDescriptor(wei_cyxk_desc)); + + auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) { + wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x); + }; + + make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)( + std::thread::hardware_concurrency()); + + // reorder input + auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: "); + + Tensor in_chwn(make_TensorDescriptor(in_chwn_desc)); + + auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) { + in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); + }; + + make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)( + std::thread::hardware_concurrency()); + + // output + auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: "); + + Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); + + std::size_t data_sz = sizeof(T); + DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace()); + DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace()); + DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace()); + + in_chwn_device_buf.ToDevice(in_chwn.mData.data()); + wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); + out_khwn_device_buf.ToDevice(out_khwn.mData.data()); + +#if 1 + // for 3x3, 34x34 + constexpr unsigned NPerBlock = 16; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 4; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 8; + constexpr unsigned KPerThread = 8; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned InBlockCopy_ThreadPerDimC = 4; + constexpr unsigned InBlockCopy_ThreadPerDimH = 4; + constexpr unsigned InBlockCopy_ThreadPerDimW = 2; + constexpr unsigned InBlockCopy_ThreadPerDimN = 4; + constexpr unsigned InBlockCopyDataPerRead = 4; + + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 4; + constexpr unsigned GemmNLevel0Cluster = 2; + constexpr unsigned GemmMLevel1Cluster = 2; + constexpr unsigned GemmNLevel1Cluster = 4; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned OutThreadCopyDataPerWrite = 2; + + constexpr unsigned BlockSize = 128; +#elif 0 + // for 5x5, 36x36 + constexpr unsigned NPerBlock = 16; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 2; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 8; + constexpr unsigned KPerThread = 8; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + + constexpr unsigned InBlockCopy_ThreadPerDimC = 2; + constexpr unsigned InBlockCopy_ThreadPerDimH = 2; + constexpr unsigned InBlockCopy_ThreadPerDimW = 4; + constexpr unsigned InBlockCopy_ThreadPerDimN = 4; + constexpr unsigned InBlockCopyDataPerRead = 4; + + constexpr unsigned WeiBlockCopyDataPerRead = 2; + + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 4; + constexpr unsigned GemmNLevel0Cluster = 2; + constexpr unsigned GemmMLevel1Cluster = 2; + constexpr unsigned GemmNLevel1Cluster = 4; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned OutThreadCopyDataPerWrite = 2; + + constexpr unsigned BlockSize = 128; +#elif 0 + // 3x3 58x58, NKC = 64, 64, 256 + constexpr unsigned NPerBlock = 16; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 4; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + + constexpr unsigned InBlockCopyDataPerRead = 2; // not used, yet + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned BlockSize = 128; +#elif 0 + // 3x3 58x58, NKC = 16,256,128 + constexpr unsigned NPerBlock = 8; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 2; + constexpr unsigned HoPerBlock = 4; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 128; +#elif 0 + // for 7x7, 38x38 + constexpr unsigned NPerBlock = 8; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 1; + constexpr unsigned HoPerBlock = 4; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + + constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned BlockSize = 128; +#elif 0 + // for 3x3, 56x56 + constexpr unsigned NPerBlock = 32; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 4; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 2; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 128; +#elif 1 + // for 1x1, 28x28 + constexpr unsigned NPerBlock = 16; + constexpr unsigned KPerBlock = 128; + constexpr unsigned CPerBlock = 8; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 2; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned InBlockCopy_ThreadPerDimC = 8; + constexpr unsigned InBlockCopy_ThreadPerDimH = 2; + constexpr unsigned InBlockCopy_ThreadPerDimW = 2; + constexpr unsigned InBlockCopy_ThreadPerDimN = 4; + constexpr unsigned InBlockCopyDataPerRead = 4; + + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 4; + constexpr unsigned GemmNLevel0Cluster = 2; + constexpr unsigned GemmMLevel1Cluster = 2; + constexpr unsigned GemmNLevel1Cluster = 4; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned OutThreadCopyDataPerWrite = 2; + + constexpr unsigned BlockSize = 128; +#endif + + constexpr unsigned GridSize = + ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * + ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + for(unsigned i = 0; i < nrepeat; ++i) + { + float time = launch_kernel( + gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn, + InBlockCopyDataPerRead, + WeiBlockCopyDataPerRead, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + OutThreadCopyDataPerWrite>, + dim3(GridSize), + dim3(BlockSize), + static_cast(in_chwn_device_buf.GetDeviceBuffer()), + static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), + static_cast(out_khwn_device_buf.GetDeviceBuffer())); + + printf("Elapsed time : %f ms\n", time); + usleep(std::min(time * 1000, float(10000))); + } + + out_khwn_device_buf.FromDevice(out_khwn.mData.data()); + + // reorder output + auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) { + out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); + }; + + make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)( + std::thread::hardware_concurrency()); +} diff --git a/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp new file mode 100644 index 0000000000..1843061a7a --- /dev/null +++ b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp @@ -0,0 +1,293 @@ +#pragma once +#include +#include "device.hpp" +#include "gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp" + +template +void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcyx, + OutDesc, + Tensor& out_nkhw, + LowerPads, + UpperPads, + unsigned nrepeat) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_nchw_desc = InDesc{}; + constexpr auto wei_kcyx_desc = WeiDesc{}; + constexpr auto out_nkhw_desc = OutDesc{}; + + constexpr unsigned Hi = in_nchw_desc.GetLength(I2); + constexpr unsigned Wi = in_nchw_desc.GetLength(I3); + + constexpr unsigned N = out_nkhw_desc.GetLength(I0); + constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); + constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); + + constexpr unsigned K = wei_kcyx_desc.GetLength(I0); + constexpr unsigned C = wei_kcyx_desc.GetLength(I1); + constexpr unsigned Y = wei_kcyx_desc.GetLength(I2); + constexpr unsigned X = wei_kcyx_desc.GetLength(I3); + + // reorder weight + auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: "); + + Tensor wei_cyxk(make_TensorDescriptor(wei_cyxk_desc)); + + auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) { + wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x); + }; + + make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)( + std::thread::hardware_concurrency()); + + // reorder input + auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: "); + + Tensor in_chwn(make_TensorDescriptor(in_chwn_desc)); + + auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) { + in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); + }; + + make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)( + std::thread::hardware_concurrency()); + + // output + auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: "); + + Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); + + std::size_t data_sz = sizeof(T); + DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace()); + DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace()); + DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace()); + + in_chwn_device_buf.ToDevice(in_chwn.mData.data()); + wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); + out_khwn_device_buf.ToDevice(out_khwn.mData.data()); + +#if 0 + constexpr unsigned NPerBlock = 1; + constexpr unsigned KPerBlock = 1; + constexpr unsigned CPerBlock = 1; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 1; + constexpr unsigned KPerThread = 1; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 1; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 1; + + constexpr unsigned BlockSize = 8; +#elif 1 + // for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256 + constexpr unsigned NPerBlock = 16; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 4; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + + constexpr unsigned BlockSize = 128; +#elif 0 + // 3x3 58x58, NKC = 16,256,128 + constexpr unsigned NPerBlock = 8; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 2; + constexpr unsigned HoPerBlock = 4; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 128; +#elif 0 + // for 5x5, 36x36 + constexpr unsigned NPerBlock = 16; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 2; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 128; +#elif 0 + // for 7x7, 38x38 + constexpr unsigned NPerBlock = 8; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 2; + constexpr unsigned HoPerBlock = 4; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 128; +#elif 0 + // for 3x3, 56x56 + constexpr unsigned NPerBlock = 32; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 4; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 2; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 128; +#elif 1 + // 3x3 56x56, NKC = 16,256,128, with padding + // 3x3 28x28, NKC = 16,512,256, with padding + // 3x3 20x84, NKC = 16,256,256, with padding + constexpr unsigned NPerBlock = 16; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 2; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 2; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 64; + + constexpr unsigned BlockSize = 128; +#elif 0 + // for 5x5 filter, 20x84 image, 1x1 padding + constexpr unsigned NPerBlock = 16; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 1; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 128; +#elif 0 + // 5x5 filter, 28x28 image, 2x2 padding + constexpr unsigned NPerBlock = 16; + constexpr unsigned KPerBlock = 32; + constexpr unsigned CPerBlock = 2; + constexpr unsigned HoPerBlock = 4; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned BlockSize = 128; +#elif 0 + // for 1x1, 28x28 + constexpr unsigned NPerBlock = 16; + constexpr unsigned KPerBlock = 128; + constexpr unsigned CPerBlock = 8; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 2; + + constexpr unsigned NPerThread = 4; + constexpr unsigned KPerThread = 16; + constexpr unsigned CPerThread = 2; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + + constexpr unsigned BlockSize = 128; +#endif + + constexpr unsigned GridSize = + ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * + ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + for(unsigned i = 0; i < nrepeat; ++i) + { + float time = launch_kernel( + gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded, + dim3(GridSize), + dim3(BlockSize), + + static_cast(in_chwn_device_buf.GetDeviceBuffer()), + static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), + static_cast(out_khwn_device_buf.GetDeviceBuffer())); + + printf("Elapsed time : %f ms\n", time); + usleep(std::min(time * 1000, float(10000))); + } + + out_khwn_device_buf.FromDevice(out_khwn.mData.data()); + + // reorder output + auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) { + out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); + }; + + make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)( + std::thread::hardware_concurrency()); +} diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp new file mode 100644 index 0000000000..a657949f35 --- /dev/null +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -0,0 +1,259 @@ +#pragma once +#include +#include "device.hpp" +#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp" + +template +void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcyx, + OutDesc, + Tensor& out_nkhw, + unsigned nrepeat) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_nchw_desc = InDesc{}; + constexpr auto wei_kcyx_desc = WeiDesc{}; + constexpr auto out_nkhw_desc = OutDesc{}; + + constexpr unsigned N = in_nchw_desc.GetLength(I0); + constexpr unsigned Hi = in_nchw_desc.GetLength(I2); + constexpr unsigned Wi = in_nchw_desc.GetLength(I3); + + constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); + constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); + + constexpr unsigned K = wei_kcyx_desc.GetLength(I0); + constexpr unsigned C = wei_kcyx_desc.GetLength(I1); + constexpr unsigned Y = wei_kcyx_desc.GetLength(I2); + constexpr unsigned X = wei_kcyx_desc.GetLength(I3); + + constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1); + + // convert in_nchw to in_cnhw + auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: "); + + Tensor in_chwn(make_TensorDescriptor(in_chwn_desc)); + + make_ParallelTensorFunctor( + [&](auto n, auto c, auto hi, auto wi) { in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); }, + N, + C, + Hi, + Wi)(std::thread::hardware_concurrency()); + + // convert wei_kcyx to wei_cyxk + auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: "); + + Tensor wei_cyxk(make_TensorDescriptor(wei_cyxk_desc)); + + make_ParallelTensorFunctor( + [&](auto k, auto c, auto y, auto x) { wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x); }, + K, + C, + Y, + X)(std::thread::hardware_concurrency()); + + // conver out_nkhw to out_knhw + auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: "); + + Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); + +#if 0 + // 3x3, 34x34 + // need to use register double buffer for GEMM + constexpr unsigned BPerBlock = 128; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 4; + + constexpr unsigned BPerThread = 8; + constexpr unsigned KPerThread = 8; + + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 4; + constexpr unsigned GemmNLevel0Cluster = 2; + constexpr unsigned GemmMLevel1Cluster = 2; + constexpr unsigned GemmNLevel1Cluster = 8; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned GemmThreadPerColumnPerCluster = 8; + constexpr unsigned GemmThreadPerRowPerCluster = 8; + + constexpr unsigned InBlockCopyThreadPerDim0 = 4; + constexpr unsigned InBlockCopyThreadPerDim1 = 16; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; + + constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned BlockSize = 128; +#elif 0 + // 1x1, 28x28, 64 threads + constexpr unsigned BPerBlock = 64; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 8; + + constexpr unsigned BPerThread = 8; + constexpr unsigned KPerThread = 8; + + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 4; + constexpr unsigned GemmNLevel0Cluster = 2; + constexpr unsigned GemmMLevel1Cluster = 2; + constexpr unsigned GemmNLevel1Cluster = 4; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned GemmThreadPerColumnPerCluster = 8; + constexpr unsigned GemmThreadPerRowPerCluster = 8; + + constexpr unsigned InBlockCopyThreadPerDim0 = 4; + constexpr unsigned InBlockCopyThreadPerDim1 = 16; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; + + constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned BlockSize = 64; +#elif 1 + // 1x1, 28x28, 128 threads, no lds-double-buffer + // 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128 + constexpr unsigned BPerBlock = 64; + constexpr unsigned KPerBlock = 128; + constexpr unsigned CPerBlock = 8; + + constexpr unsigned BPerThread = 8; + constexpr unsigned KPerThread = 8; + + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 4; + constexpr unsigned GemmNLevel0Cluster = 2; + constexpr unsigned GemmMLevel1Cluster = 4; + constexpr unsigned GemmNLevel1Cluster = 4; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned GemmThreadPerColumnPerCluster = 8; + constexpr unsigned GemmThreadPerRowPerCluster = 8; + + constexpr unsigned InBlockCopyThreadPerDim0 = 4; + constexpr unsigned InBlockCopyThreadPerDim1 = 16; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; + + constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned BlockSize = 128; +#elif 1 + // 1x1, 28x28, 256 thread + constexpr unsigned BPerBlock = 128; + constexpr unsigned KPerBlock = 128; + constexpr unsigned CPerBlock = 8; + + constexpr unsigned BPerThread = 8; + constexpr unsigned KPerThread = 8; + + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 4; + constexpr unsigned GemmNLevel0Cluster = 4; + constexpr unsigned GemmMLevel1Cluster = 4; + constexpr unsigned GemmNLevel1Cluster = 4; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned GemmThreadPerColumnPerCluster = 8; + constexpr unsigned GemmThreadPerRowPerCluster = 8; + + constexpr unsigned InBlockCopyThreadPerDim0 = 4; + constexpr unsigned InBlockCopyThreadPerDim1 = 16; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; + + constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr unsigned WeiBlockCopyDataPerRead = 4; + + constexpr unsigned BlockSize = 256; +#endif + + constexpr unsigned GridSize = + ((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + // mem + std::size_t data_sz = sizeof(T); + DeviceMem in_chwn_device_buf(data_sz * (in_chwn.mDesc.GetElementSpace() + BGhostRead + + BPerBlock)); // reserve extra space for BGhostRead + DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace()); + DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace()); + + in_chwn_device_buf.ToDevice(in_chwn.mData.data()); + wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); + out_khwn_device_buf.ToDevice(out_khwn.mData.data()); + + for(unsigned i = 0; i < nrepeat; ++i) + { + float time = + launch_kernel(gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer< + GridSize, + BlockSize, + T, + decltype(in_chwn_desc), + decltype(wei_cyxk_desc), + decltype(out_khwn_desc), + BPerBlock, + KPerBlock, + CPerBlock, + BPerThread, + KPerThread, + GemmThreadPerColumnPerCluster, + GemmThreadPerRowPerCluster, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + InBlockCopyThreadPerDim0, + InBlockCopyThreadPerDim1, + WeiBlockCopyThreadPerDim0, + WeiBlockCopyThreadPerDim1, + InBlockCopyDataPerRead, + WeiBlockCopyDataPerRead>, + dim3(GridSize), + dim3(BlockSize), + static_cast(in_chwn_device_buf.GetDeviceBuffer()), + static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), + static_cast(out_khwn_device_buf.GetDeviceBuffer())); + + printf("Elapsed time : %f ms\n", time); + usleep(std::min(time * 1000, float(10000))); + } + + out_khwn_device_buf.FromDevice(out_khwn.mData.data()); + + // convert out_khwn to out_nkhw + make_ParallelTensorFunctor( + [&](auto n, auto k, auto ho, auto wo) { out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); }, + N, + K, + Ho, + Wo)(std::thread::hardware_concurrency()); +} diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp new file mode 100644 index 0000000000..1caef669e9 --- /dev/null +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp @@ -0,0 +1,310 @@ +#pragma once +#include "common.hip.hpp" +#include "ConstantTensorDescriptor.hip.hpp" +#include "ConstantMatrixDescriptor.hip.hpp" +#include "blockwise_4d_tensor_op.hip.hpp" +#include "blockwise_2d_tensor_op.hip.hpp" +#include "threadwise_nd_tensor_op.hip.hpp" +#include "threadwise_4d_tensor_op.hip.hpp" +#include "blockwise_gemm.hip.hpp" + +template +__global__ void +gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) +{ + // NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N] + // for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N" + // if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock + static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0"); + static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock, + "wrong!"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_chwn_global_desc = InGlobalDesc{}; + constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; + constexpr auto out_khwn_global_desc = OutGlobalDesc{}; + + constexpr unsigned C = in_chwn_global_desc.GetLength(I0); + + constexpr unsigned K = out_khwn_global_desc.GetLength(I0); + constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); + constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); + constexpr unsigned N = out_khwn_global_desc.GetLength(I3); + + constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1); + constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2); + + constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; + constexpr unsigned WiPerBlock = WoPerBlock + X - 1; + + // divide block work: [K, Ho, Wo, N] + constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; + constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; + constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock; + + const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); + unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); + const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork); + itmp -= h_block_work_id * (WBlockWork * NBlockWork); + const unsigned w_block_work_id = itmp / NBlockWork; + const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork; + + const unsigned k_block_data_begin = k_block_work_id * KPerBlock; + const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; + const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; + const unsigned n_block_data_begin = n_block_work_id * NPerBlock; + + const unsigned hi_block_data_begin = ho_block_data_begin; + const unsigned wi_block_data_begin = wo_block_data_begin; + + // flattend (2d) tensor view of gridwise weight + constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); + + // tensor view of blockwise input and weight in LDS + // be careful of alignment + constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + // tensor view of threadwise output in register + constexpr auto out_khwn_thread_desc = + make_ConstantTensorDescriptor(Sequence{}); + + // blockwise copy + // input: format is [C, Hi, Wi, N] + const auto blockwise_in_copy = Blockwise4dTensorCopy3{}; + + // blockwise wei copy + // format is [CPerBlock*Y*X,KPerBlock] + const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}; + + // a series of blockwise batched GEMM + // C_matrix += transpose(A_matrix) * B_matrix + // A_matrix and B_matrix saved in LDS, C_matrix saved in register + // A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K] + // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] + // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] + constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_cxwn_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + constexpr auto c_kxwn_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + const auto blockwise_batch_gemm = BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< + BlockSize, + decltype(a_cxk_block_mtx_desc), + decltype(b_cxwn_block_mtx_desc), + decltype(c_kxwn_thread_mtx_desc), + 0, + in_chwn_block_desc.GetStride(I1), + out_khwn_thread_desc.GetStride(I1), + HoPerBlock, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + HoPerThread>{}; + + // LDS: be careful of alignment + constexpr unsigned in_block_size = + in_chwn_block_desc.GetElementSpace(Number{}); + + constexpr unsigned wei_block_size = + wei_cyxk_block_desc.GetElementSpace(Number{}); + + constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead + ? InBlockCopyDataPerRead + : WeiBlockCopyDataPerRead; + + __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; + __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; + + // register + Float p_out_thread[out_khwn_thread_desc.GetElementSpace()]; + + // set threadwise output tensor to 0 + threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread); + + const Float* p_in_global_block_begin = + p_in_global + in_chwn_global_desc.Get1dIndex( + 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); + + const Float* p_wei_global_block_begin = + p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); + + for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, + p_in_global_block_begin += CPerBlock * in_chwn_global_desc.GetStride(I0), + p_wei_global_block_begin += CPerBlock * wei_cyxk_global_desc.GetStride(I0), + __syncthreads()) + { + // input: global mem to LDS + blockwise_in_copy.Run(p_in_global_block_begin, p_in_block); + + // weight: global mem to LDS + blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block); + + __syncthreads(); + + // a series of batched GEMM + for(unsigned y = 0; y < Y; ++y) + { + for(unsigned x = 0; x < X; ++x) + { + blockwise_batch_gemm.Run(p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0), + p_out_thread, + [](auto& acc, const auto&& v) { acc += v; }); + } + } + } + + // output: register to global mem, +#if 0 + const auto c_thread_mtx_begin = + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + for(unsigned k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k) + { + for(unsigned ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho) + { + for(unsigned wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo) + { + for(unsigned n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n) + { + const unsigned b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n); + + const auto c_thread_mtx_distance = + blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b); + + const unsigned ho_thread = + c_thread_mtx_begin.batch + c_thread_mtx_distance.batch; + const unsigned k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row; + const unsigned b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col; + + const unsigned wo_thread = b_thread / NPerBlock; + const unsigned n_thread = b_thread % NPerBlock; + + p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread, + ho_block_data_begin + ho_thread, + wo_block_data_begin + wo_thread, + n_block_data_begin + n_thread)] = + p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)]; + } + } + } + } +#elif 1 + const auto c_thread_mtx_begin = + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const unsigned k_thread_data_begin = c_thread_mtx_begin.row; + const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch; + const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; + const unsigned n_thread_data_begin = c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; + + // this is for v2 GEMM + // output is a 8d tensor + if(NPerThread < NPerBlock && WoPerThread == 1) + { + constexpr unsigned N1_ = GemmNPerThreadSubC; + constexpr unsigned W1_ = WoPerBlock / ((WoPerThread * NPerThread) / GemmNPerThreadSubC); + constexpr unsigned K2_ = GemmMPerThreadSubC; + constexpr unsigned K1_ = KPerBlock / KPerThread; + + constexpr auto out_8d_global_desc = make_ConstantTensorDescriptor( + Sequence{}); + + constexpr auto out_8d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc"); + print_ConstantTensorDescriptor(out_8d_thread_desc, "out_8d_thread_desc"); + + print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); + print_ConstantTensorDescriptor(out_8d_global_desc, "out_8d_global_desc"); + } +#endif + + threadwise_8d_tensor_copy(out_8d_thread_desc, + p_out_thread, + out_8d_global_desc, + p_out_global + out_khwn_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_8d_thread_desc.GetLengths(), + Number{}); + } + else if(NPerThread == NPerBlock) + { + // not implemented yet + assert(false); + } + else + { + assert(false); + } +#endif +} diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp new file mode 100644 index 0000000000..a4904cdf58 --- /dev/null +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp @@ -0,0 +1,292 @@ +#pragma once +#include "common.hip.hpp" +#include "ConstantTensorDescriptor.hip.hpp" +#include "ConstantMatrixDescriptor.hip.hpp" +#include "blockwise_4d_tensor_op.hip.hpp" +#include "blockwise_2d_tensor_op.hip.hpp" +#include "threadwise_4d_tensor_op.hip.hpp" +#include "blockwise_gemm.hip.hpp" + +template +__global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( + const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) +{ + // NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N] + // for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N" + // if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock + static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0"); + static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock, + "wrong!"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_chwn_global_desc = InGlobalDesc{}; + constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; + constexpr auto out_khwn_global_desc = OutGlobalDesc{}; + + constexpr unsigned C = in_chwn_global_desc.GetLength(I0); + + constexpr unsigned K = out_khwn_global_desc.GetLength(I0); + constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); + constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); + constexpr unsigned N = out_khwn_global_desc.GetLength(I3); + + constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1); + constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2); + + constexpr unsigned HPadLow = LowerPads{}.Get(I0); + constexpr unsigned WPadLow = LowerPads{}.Get(I1); + + constexpr unsigned HPadUp = UpperPads{}.Get(I0); + constexpr unsigned WPadUp = UpperPads{}.Get(I1); + + constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; + constexpr unsigned WiPerBlock = WoPerBlock + X - 1; + + // divide block work: [K, Ho, Wo, N] + constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; + constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; + constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock; + + const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); + unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); + const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork); + itmp -= h_block_work_id * (WBlockWork * NBlockWork); + const unsigned w_block_work_id = itmp / NBlockWork; + const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork; + + const unsigned k_block_data_begin = k_block_work_id * KPerBlock; + const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; + const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; + const unsigned n_block_data_begin = n_block_work_id * NPerBlock; + + // flattened (2d) tensor view of wei in global mem + constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); + + // tensor view of blockwise input and weight in LDS + constexpr auto in_chwn_block_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto wei_cyxk_block_desc = + make_ConstantTensorDescriptor(Sequence{}); + + // flattened (2d) tensor view of wei in LDS + constexpr auto wei_ek_block_desc = + make_ConstantTensorDescriptor(Sequence{}); + + // tensor view of threadwise output in register + constexpr auto out_hkwn_thread_desc = + make_ConstantTensorDescriptor(Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc"); + print_ConstantTensorDescriptor(wei_cyxk_block_desc, "wei_cyxk_block_desc"); + print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc"); + } +#endif + + // blockwise copy + // input: format is [C, Hi, Wi, N] + const unsigned h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0; + const unsigned w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0; + + const unsigned h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0; + const unsigned w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0; + +#if 0 + if(get_thread_local_1d_id() == 0) + ; + { + printf( + "%u %u, h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u\n", + get_block_1d_id(), + get_thread_local_1d_id(), + h_block_pad_low, + w_block_pad_low, + h_block_pad_up, + w_block_pad_up); + } +#endif + + constexpr auto blockwise_in_copy = + BlockwiseChwnTensorCopyPadded{}; + +#if 0 + // weight: format is [C,Y,X,K] + constexpr auto blockwise_wei_copy = + Blockwise4dTensorCopy1{}; +#elif 0 + // weight: format is [C*Y*X,K] + constexpr auto blockwise_wei_copy = + Blockwise2dTensorCopy1{}; +#elif 1 + // weight: format is [C*Y*X,K] + const auto blockwise_wei_copy = Blockwise2dTensorCopy2{}; +#endif + + // a series of blockwise batched GEMM + // C_matrix += transpose(A_matrix) * B_matrix + // A_matrix and B_matrix saved in LDS, C_matrix saved in register + // A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K] + // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] + // C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N] + constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_cxwn_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + constexpr auto c_kxwn_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + const auto blockwise_batch_gemm = + Blockwise1dStridedBatchedGemmBlockABlockBThreadC{}; + + // LDS + constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); + constexpr unsigned wei_block_size = wei_cyxk_block_desc.GetElementSpace(); + + __shared__ Float p_in_block[in_block_size]; + __shared__ Float p_wei_block[wei_block_size]; + + // register + Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()]; + + // set threadwise output tensor to 0 + threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread); + + const Float* p_wei_global_block_begin = + p_wei_global + wei_ek_global_desc.Get1dIndex(0, k_block_data_begin); + + for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, + p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0), + __syncthreads()) + { +#if 1 + // input: global mem to LDS, + blockwise_in_copy.Run(p_in_global, + c_block_data_begin, + ho_block_data_begin, + wo_block_data_begin, + n_block_data_begin, + p_in_block, + h_block_pad_low, + w_block_pad_low, + h_block_pad_up, + w_block_pad_up); +#endif + +#if 1 + // weight: global mem to LDS, + blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block); +#endif + + __syncthreads(); + + // a series of batched GEMM + for(unsigned y = 0; y < Y; ++y) + { + for(unsigned x = 0; x < X; ++x) + { + auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; + + blockwise_batch_gemm.Run(p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0), + p_out_thread, + f_accum); + } + } + } + + const auto matrix_c_index = + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const unsigned ho_thread_data_begin = matrix_c_index.batch; + const unsigned k_thread_data_begin = matrix_c_index.row; + const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock; + const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock; + +#if 0 + printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n", + get_block_1d_id(), get_thread_local_1d_id(), + ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin, + ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin, + p_out_thread[0]); +#endif + + // output: register to global mem, + // convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N] + constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{}; + + threadwise_4d_tensor_copy_reorder_by_get_dst_from_src( + out_hkwn_thread_desc, + p_out_thread, + out_khwn_global_desc, + p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_hkwn_thread_desc.GetLengths(), + reorder_khwn_from_hkwn); +} diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp new file mode 100644 index 0000000000..7c802266d8 --- /dev/null +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp @@ -0,0 +1,369 @@ +#pragma once +#include "common.hip.hpp" +#include "ConstantTensorDescriptor.hip.hpp" +#include "ConstantMatrixDescriptor.hip.hpp" +#include "blockwise_4d_tensor_op.hip.hpp" +#include "blockwise_2d_tensor_op.hip.hpp" +#include "threadwise_2d_tensor_op.hip.hpp" +#include "blockwise_gemm.hip.hpp" + +// define B = flatten(N, Hi, Wi) +template +__global__ void gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( + const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_chwn_global_desc = InGlobalDesc{}; + constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; + constexpr auto out_khwn_global_desc = OutGlobalDesc{}; + + constexpr unsigned C = in_chwn_global_desc.GetLength(I0); + constexpr unsigned Hi = in_chwn_global_desc.GetLength(I1); + constexpr unsigned Wi = in_chwn_global_desc.GetLength(I2); + constexpr unsigned N = in_chwn_global_desc.GetLength(I3); + + constexpr unsigned K = out_khwn_global_desc.GetLength(I0); + constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); + constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); + + constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1); + constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2); + + constexpr unsigned B = N * Hi * Wi; + constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1); + + // divide block work by 2d: [K, B] + constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock; + + const unsigned k_block_work_id = get_block_1d_id() / BBlockWork; + const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; + + const unsigned k_block_data_begin = k_block_work_id * KPerBlock; + const unsigned b_block_data_begin = b_block_work_id * BPerBlock; + + // flattend (2d) tensor view of gridwise input + constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); + constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); + + // tensor view of blockwise input and weight + // be careful of alignment + constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + // tensor view of threadwise output in register + constexpr auto out_kb_thread_desc = + make_ConstantTensorDescriptor(Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc"); + print_ConstantTensorDescriptor(wei_cyxk_global_desc, "wei_cyxk_global_desc"); + print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); + + print_ConstantTensorDescriptor(in_cb_global_desc, "in_cb_global_desc"); + print_ConstantTensorDescriptor(wei_ek_global_desc, "wei_ek_global_desc"); + + print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc"); + print_ConstantTensorDescriptor(wei_cyxk_block_desc, "wei_cyxk_block_desc"); + print_ConstantTensorDescriptor(wei_ek_block_desc, "wei_ek_block_desc"); + print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc"); + + printf("KPerBlock %u\n", KPerBlock); + } +#endif + +// blockwise in copy +// formmat is [CPerBlock,BPerBlock + BGhostRead] +#if 0 + const auto blockwise_in_copy = + Blockwise2dTensorCopy1{}; +#elif 0 + const auto blockwise_in_copy = Blockwise2dTensorCopy2{}; +#elif 1 + const auto blockwise_in_copy = Blockwise2dTensorCopy3{}; +#endif + +// blockwise wei copy +// format is [CPerBlock*Y*X,KPerBlock] +#if 0 + const auto blockwise_wei_copy = + Blockwise2dTensorCopy1{}; +#elif 0 + const auto blockwise_wei_copy = Blockwise2dTensorCopy2{}; +#elif 1 + const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}; +#endif + + // a series of blockwise GEMM + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx and b_mtx saved in LDS, c_mtx saved in register + // a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K] + // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] + // c_mtx[K,B] is out_block[K,B] + constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto c_kxb_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); + +#if 0 + const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC{}; +#else + const auto blockwise_gemm = + BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2{}; +#endif + + // LDS: be careful of alignment + constexpr unsigned in_block_size = + in_cb_block_desc.GetElementSpace(Number{}); + + constexpr unsigned wei_block_size = + wei_cyxk_block_desc.GetElementSpace(Number{}); + + constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead + ? InBlockCopyDataPerRead + : WeiBlockCopyDataPerRead; + + // LDS double buffer + __shared__ Float p_in_block_0[max_align * ((in_block_size + max_align - 1) / max_align)]; + __shared__ Float p_wei_block_0[max_align * ((wei_block_size + max_align - 1) / max_align)]; + + __shared__ Float p_in_block_1[max_align * ((in_block_size + max_align - 1) / max_align)]; + __shared__ Float p_wei_block_1[max_align * ((wei_block_size + max_align - 1) / max_align)]; + + const Float* p_in_global_block_offset = + p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); + + const Float* p_wei_global_block_offset = + p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); + + // preload data into LDS + blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_0); + blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_0); + + p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0); + p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0); + + // register + Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; + + // set threadwise output tensor to 0 + threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); + + bool even_loop = true; + + for(unsigned c_block_data_begin = 0; c_block_data_begin + CPerBlock < C; + c_block_data_begin += CPerBlock, + p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), + p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), + even_loop = !even_loop) + { + Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; + Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; + + Float* p_in_block_next = even_loop ? p_in_block_1 : p_in_block_0; + Float* p_wei_block_next = even_loop ? p_wei_block_1 : p_wei_block_0; + + __syncthreads(); + +// load next data +#if 1 + blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next); + blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next); +#elif 1 + Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); + + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); +#endif + + // compute on current data + // a series of GEMM + for(unsigned y = 0; y < Y; ++y) + { + for(unsigned x = 0; x < X; ++x) + { + auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; +#if 1 + blockwise_gemm.Run +#else + blockwise_gemm.Run_RegisterDoubleBuffer +#endif + (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_now + y * Wi + x, + p_out_thread, + f_accum); + } + } + +#if 0 + blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next); +#endif + } + + // last computation + { + Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; + Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; + + __syncthreads(); + + for(unsigned y = 0; y < Y; ++y) + { + for(unsigned x = 0; x < X; ++x) + { + auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; +#if 0 + blockwise_gemm.Run +#else + blockwise_gemm.Run_RegisterDoubleBuffer +#endif + (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_now + y * Wi + x, + p_out_thread, + f_accum); + } + } + } + + // output: register to global mem, + const auto c_thread_mtx_begin = + blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; + const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; + +#if 0 + if(get_block_1d_id() == 0) + { + printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", + get_block_1d_id(), + get_thread_local_1d_id(), + matrix_c_index.row, + matrix_c_index.col, + k_data_begin, + b_data_begin, + p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); + } +#endif + + for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) + { + for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) + { + const auto c_thread_mtx_distance = + blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); + + unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row; + unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col; + + unsigned h_data = b_data / (Wi * N); + unsigned itmp = b_data - h_data * (Wi * N); + unsigned w_data = itmp / N; + unsigned n_data = itmp - w_data * N; + + if(n_data < N && h_data < Ho && w_data < Wo) + { + p_out_global[out_khwn_global_desc.Get1dIndex(k_data, h_data, w_data, n_data)] = + p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; + } + } + } +}