diff --git a/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp b/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp new file mode 100644 index 0000000000..86ffc58e77 --- /dev/null +++ b/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp @@ -0,0 +1,452 @@ +#pragma once +#include +#include "device.hpp" +#include "gridwise_convolution_wrapper.hip.hpp" +#include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp" +#include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp" +#include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp" + +template +void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcyx, + OutDesc, + Tensor& out_nkhw, + index_t nrepeat) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_nchw_desc = InDesc{}; + constexpr auto wei_kcyx_desc = WeiDesc{}; + constexpr auto out_nkhw_desc = OutDesc{}; + + constexpr index_t Hi = in_nchw_desc.GetLength(I2); + constexpr index_t Wi = in_nchw_desc.GetLength(I3); + + constexpr index_t N = out_nkhw_desc.GetLength(I0); + constexpr index_t Ho = out_nkhw_desc.GetLength(I2); + constexpr index_t Wo = out_nkhw_desc.GetLength(I3); + + constexpr index_t K = wei_kcyx_desc.GetLength(I0); + constexpr index_t C = wei_kcyx_desc.GetLength(I1); + constexpr index_t Y = wei_kcyx_desc.GetLength(I2); + constexpr index_t X = wei_kcyx_desc.GetLength(I3); + + // reorder weight + auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: "); + + Tensor wei_cyxk(make_TensorDescriptor(wei_cyxk_desc)); + + auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) { + wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x); + }; + + make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)( + std::thread::hardware_concurrency()); + + // reorder input + auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: "); + + Tensor in_chwn(make_TensorDescriptor(in_chwn_desc)); + + auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) { + in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); + }; + + make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)( + std::thread::hardware_concurrency()); + + // output + auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: "); + + Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); + + std::size_t data_sz = sizeof(T); + DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace()); + DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace()); + DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace()); + + in_chwn_device_buf.ToDevice(in_chwn.mData.data()); + wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); + out_khwn_device_buf.ToDevice(out_khwn.mData.data()); + +#if 0 + // for 3x3, 34x34, v1r1, Pascal + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 4; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t InBlockCopy_ThreadPerDimC = 4; + constexpr index_t InBlockCopy_ThreadPerDimH = 4; + constexpr index_t InBlockCopy_ThreadPerDimW = 2; + constexpr index_t InBlockCopy_ThreadPerDimN = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t OutThreadCopyDataPerWrite = 2; + + constexpr index_t BlockSize = 128; +#elif 0 + // for 3x3, 34x34, v1r2, Pascal, in-block-copy1 + constexpr index_t NPerBlock = 4; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 4; + constexpr index_t WoPerBlock = 8; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t InBlockCopy_ThreadPerDimC = 4; + constexpr index_t InBlockCopy_ThreadPerDimH = 4; + constexpr index_t InBlockCopy_ThreadPerDimW = 2; + constexpr index_t InBlockCopy_ThreadPerDimN = 1; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t OutThreadCopyDataPerWrite = 2; + + constexpr index_t BlockSize = 128; +#elif 0 + // for 3x3, 34x34, v1r1, Vega 20 + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 4; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t InBlockCopy_ThreadPerDimC = 4; + constexpr index_t InBlockCopy_ThreadPerDimH = 4; + constexpr index_t InBlockCopy_ThreadPerDimW = 2; + constexpr index_t InBlockCopy_ThreadPerDimN = 8; + constexpr index_t InBlockCopyDataPerRead = 2; + + constexpr index_t WeiBlockCopyDataPerRead = 2; + constexpr index_t OutThreadCopyDataPerWrite = 4; + + constexpr index_t BlockSize = 256; +#elif 0 + // for 3x3, 56x56, v1, Pascal + constexpr index_t NPerBlock = 32; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t InBlockCopy_ThreadPerDimC = 1; + constexpr index_t InBlockCopy_ThreadPerDimH = 4; + constexpr index_t InBlockCopy_ThreadPerDimW = 4; + constexpr index_t InBlockCopy_ThreadPerDimN = 8; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t OutThreadCopyDataPerWrite = 2; + + constexpr index_t BlockSize = 128; +#elif 0 + // for 3x3, 56x56, v1r2, Pascal + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 1; + constexpr index_t GemmDataPerReadB = 1; + + constexpr index_t InBlockCopy_ThreadPerDimC = 1; + constexpr index_t InBlockCopy_ThreadPerDimH = 2; + constexpr index_t InBlockCopy_ThreadPerDimW = 4; + constexpr index_t InBlockCopy_ThreadPerDimN = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t OutThreadCopyDataPerWrite = 4; + + constexpr index_t BlockSize = 128; +#elif 0 + // for 3x3, 28x28, v1r1, Pacal + constexpr index_t NPerBlock = 32; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t InBlockCopy_ThreadPerDimC = 1; + constexpr index_t InBlockCopy_ThreadPerDimH = 4; + constexpr index_t InBlockCopy_ThreadPerDimW = 4; + constexpr index_t InBlockCopy_ThreadPerDimN = 8; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t OutThreadCopyDataPerWrite = 2; + + constexpr index_t BlockSize = 128; +#elif 1 + // for 3x3, 28x28, v1r2, Pascal + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t InBlockCopy_ThreadPerDimC = 4; + constexpr index_t InBlockCopy_ThreadPerDimH = 2; + constexpr index_t InBlockCopy_ThreadPerDimW = 4; + constexpr index_t InBlockCopy_ThreadPerDimN = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t OutThreadCopyDataPerWrite = 2; + + constexpr index_t BlockSize = 128; +#elif 0 + // for 1x1, 28x28 + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; + + constexpr index_t InBlockCopy_ThreadPerDimC = 8; + constexpr index_t InBlockCopy_ThreadPerDimH = 2; + constexpr index_t InBlockCopy_ThreadPerDimW = 2; + constexpr index_t InBlockCopy_ThreadPerDimN = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t OutThreadCopyDataPerWrite = 2; + + constexpr index_t BlockSize = 128; +#elif 1 + // for 1x1, 14x14, Pascal + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 8; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t InBlockCopy_ThreadPerDimC = 8; + constexpr index_t InBlockCopy_ThreadPerDimH = 2; + constexpr index_t InBlockCopy_ThreadPerDimW = 2; + constexpr index_t InBlockCopy_ThreadPerDimN = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t OutThreadCopyDataPerWrite = 2; + + constexpr index_t BlockSize = 128; +#endif + + constexpr index_t GridSize = + ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * + ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + for(index_t i = 0; i < nrepeat; ++i) + { + constexpr auto gridwise_conv = +#if 0 + GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn +#elif 0 + GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer +#elif 1 + GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn +#endif + , + InBlockCopyDataPerRead, + WeiBlockCopyDataPerRead, + OutThreadCopyDataPerWrite>{}; + + float time = launch_kernel(run_gridwise_convolution, + dim3(GridSize), + dim3(BlockSize), + 0, + static_cast(in_chwn_device_buf.GetDeviceBuffer()), + static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), + static_cast(out_khwn_device_buf.GetDeviceBuffer())); + + printf("Elapsed time : %f ms, %f TFlop/s\n", + time, + (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / time); + usleep(std::min(time * 1000, float(10000))); + } + + out_khwn_device_buf.FromDevice(out_khwn.mData.data()); + + // reorder output + auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) { + out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); + }; + + make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)( + std::thread::hardware_concurrency()); +} diff --git a/driver/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp b/driver/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp new file mode 100644 index 0000000000..e8a893957b --- /dev/null +++ b/driver/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp @@ -0,0 +1,330 @@ +#pragma once +#include +#include "device.hpp" +#include "gridwise_convolution_wrapper.hip.hpp" +#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp" +#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp" + +template +void device_convolution_implicit_gemm_v2_chwn_cyxk_khwn(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcyx, + OutDesc, + Tensor& out_nkhw, + index_t nrepeat) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_nchw_desc = InDesc{}; + constexpr auto wei_kcyx_desc = WeiDesc{}; + constexpr auto out_nkhw_desc = OutDesc{}; + + constexpr index_t N = in_nchw_desc.GetLength(I0); + constexpr index_t Hi = in_nchw_desc.GetLength(I2); + constexpr index_t Wi = in_nchw_desc.GetLength(I3); + + constexpr index_t Ho = out_nkhw_desc.GetLength(I2); + constexpr index_t Wo = out_nkhw_desc.GetLength(I3); + + constexpr index_t K = wei_kcyx_desc.GetLength(I0); + constexpr index_t C = wei_kcyx_desc.GetLength(I1); + constexpr index_t Y = wei_kcyx_desc.GetLength(I2); + constexpr index_t X = wei_kcyx_desc.GetLength(I3); + + constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1); + + // convert in_nchw to in_cnhw + auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: "); + + Tensor in_chwn(make_TensorDescriptor(in_chwn_desc)); + + make_ParallelTensorFunctor( + [&](auto n, auto c, auto hi, auto wi) { in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); }, + N, + C, + Hi, + Wi)(std::thread::hardware_concurrency()); + + // convert wei_kcyx to wei_cyxk + auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: "); + + Tensor wei_cyxk(make_TensorDescriptor(wei_cyxk_desc)); + + make_ParallelTensorFunctor( + [&](auto k, auto c, auto y, auto x) { wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x); }, + K, + C, + Y, + X)(std::thread::hardware_concurrency()); + + // conver out_nkhw to out_knhw + auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: "); + + Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); + +#if 0 + // 3x3, 34x34 + // need to use register double buffer for GEMM + constexpr index_t BPerBlock = 128; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 4; + + constexpr index_t BPerThread = 8; + constexpr index_t KPerThread = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t InBlockCopyThreadPerDim0 = 4; + constexpr index_t InBlockCopyThreadPerDim1 = 16; + + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 16; + + constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t OutThreadCopyDataPerWrite = 4; + + constexpr index_t BlockSize = 128; +#elif 0 + // 1x1, 28x28, 64 threads + constexpr index_t BPerBlock = 64; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 8; + + constexpr index_t BPerThread = 8; + constexpr index_t KPerThread = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t GemmThreadPerColumnPerCluster = 8; + constexpr index_t GemmThreadPerRowPerCluster = 8; + + constexpr index_t InBlockCopyThreadPerDim0 = 4; + constexpr index_t InBlockCopyThreadPerDim1 = 16; + + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 16; + + constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t BlockSize = 64; +#elif 0 + // 1x1, 28x28, 128 threads, no lds-double-buffer + // 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128 + constexpr index_t BPerBlock = 64; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + + constexpr index_t BPerThread = 8; + constexpr index_t KPerThread = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t GemmThreadPerColumnPerCluster = 8; + constexpr index_t GemmThreadPerRowPerCluster = 8; + + constexpr index_t InBlockCopyThreadPerDim0 = 4; + constexpr index_t InBlockCopyThreadPerDim1 = 16; + + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 16; + + constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t BlockSize = 128; +#elif 0 + // 1x1, 28x28, 256 thread + constexpr index_t BPerBlock = 128; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + + constexpr index_t BPerThread = 8; + constexpr index_t KPerThread = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t GemmThreadPerColumnPerCluster = 8; + constexpr index_t GemmThreadPerRowPerCluster = 8; + + constexpr index_t InBlockCopyThreadPerDim0 = 4; + constexpr index_t InBlockCopyThreadPerDim1 = 16; + + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 16; + + constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t BlockSize = 256; +#elif 0 + // 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer + constexpr index_t BPerBlock = 64; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + + constexpr index_t BPerThread = 8; + constexpr index_t KPerThread = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t InBlockCopyThreadPerDim0 = 4; + constexpr index_t InBlockCopyThreadPerDim1 = 16; + + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 16; + + constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t OutThreadCopyDataPerWrite = 4; + + constexpr index_t BlockSize = 128; +#elif 1 + // 1x1, 14x14, Vega 20, enable lds_double_buffer, disable register_double_buffer + constexpr index_t BPerBlock = 128; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + + constexpr index_t BPerThread = 8; + constexpr index_t KPerThread = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t InBlockCopyThreadPerDim0 = 4; + constexpr index_t InBlockCopyThreadPerDim1 = 16; + + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 16; + + constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t OutThreadCopyDataPerWrite = 4; + + constexpr index_t BlockSize = 256; +#endif + + constexpr index_t GridSize = + ((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + // mem + std::size_t data_sz = sizeof(T); + DeviceMem in_chwn_device_buf(data_sz * (in_chwn.mDesc.GetElementSpace() + BGhostRead + + BPerBlock)); // reserve extra space for BGhostRead + DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace()); + DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace()); + + in_chwn_device_buf.ToDevice(in_chwn.mData.data()); + wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); + out_khwn_device_buf.ToDevice(out_khwn.mData.data()); + + for(index_t i = 0; i < nrepeat; ++i) + { + constexpr auto gridwise_conv = +#if 0 + GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn +#else + GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer +#endif + {}; + + float time = launch_kernel(run_gridwise_convolution, + dim3(GridSize), + dim3(BlockSize), + 0, + static_cast(in_chwn_device_buf.GetDeviceBuffer()), + static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), + static_cast(out_khwn_device_buf.GetDeviceBuffer())); + + printf("Elapsed time : %f ms, %f TFlop/s\n", + time, + (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / time); + usleep(std::min(time * 1000, float(10000))); + } + + out_khwn_device_buf.FromDevice(out_khwn.mData.data()); + + // convert out_khwn to out_nkhw + make_ParallelTensorFunctor( + [&](auto n, auto k, auto ho, auto wo) { out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); }, + N, + K, + Ho, + Wo)(std::thread::hardware_concurrency()); +}