From 19f17df47a2d814cab40b75027cbcac0c544932f Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 18 Apr 2019 11:49:09 -0500 Subject: [PATCH] implicit gemm v1r2: adding support for nchw --- ...lution_implicit_gemm_v1_chwn_cyxk_khwn.hpp | 4 +- ...lution_implicit_gemm_v1_nchw_cyxk_khwn.hpp | 433 ++++++++++++++++++ driver/driver.hip.cpp | 37 +- src/include/Array.hip.hpp | 34 ++ src/include/ConstantTensorDescriptor.hip.hpp | 22 +- src/include/Sequence.hip.hpp | 121 +++-- src/include/blockwise_2d_tensor_op.hip.hpp | 18 +- src/include/blockwise_3d_tensor_op.hip.hpp | 272 ++++++++++- src/include/blockwise_4d_tensor_op.hip.hpp | 370 ++++++++++++--- src/include/common.hip.hpp | 30 +- src/include/functional.hip.hpp | 15 - ..._implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp | 38 +- ..._implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp | 362 +++++++++++++++ src/include/threadwise_2d_tensor_op.hip.hpp | 19 +- src/include/threadwise_4d_tensor_op.hip.hpp | 67 +-- src/include/threadwise_nd_tensor_op.hip.hpp | 2 +- 16 files changed, 1624 insertions(+), 220 deletions(-) create mode 100644 driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp create mode 100644 src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp diff --git a/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp b/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp index 3532a4d4ce..6f5e29410c 100644 --- a/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp +++ b/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp @@ -243,7 +243,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t OutThreadCopyDataPerWrite = 4; constexpr index_t BlockSize = 128; -#elif 0 +#elif 1 // for 3x3, 28x28, v1r1, Pacal constexpr index_t NPerBlock = 32; constexpr index_t KPerBlock = 64; @@ -386,7 +386,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, for(index_t i = 0; i < nrepeat; ++i) { constexpr auto gridwise_conv = -#if 0 +#if 1 GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn #elif 0 GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer diff --git a/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp b/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp new file mode 100644 index 0000000000..6762bb1d2a --- /dev/null +++ b/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp @@ -0,0 +1,433 @@ +#pragma once +#include +#include "device.hpp" +#include "gridwise_convolution_wrapper.hip.hpp" +#include "gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp" + +template +void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcyx, + OutDesc, + Tensor& out_nkhw, + index_t nrepeat) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_nchw_desc = InDesc{}; + constexpr auto wei_kcyx_desc = WeiDesc{}; + constexpr auto out_nkhw_desc = OutDesc{}; + + constexpr index_t Hi = in_nchw_desc.GetLength(I2); + constexpr index_t Wi = in_nchw_desc.GetLength(I3); + + constexpr index_t N = out_nkhw_desc.GetLength(I0); + constexpr index_t Ho = out_nkhw_desc.GetLength(I2); + constexpr index_t Wo = out_nkhw_desc.GetLength(I3); + + constexpr index_t K = wei_kcyx_desc.GetLength(I0); + constexpr index_t C = wei_kcyx_desc.GetLength(I1); + constexpr index_t Y = wei_kcyx_desc.GetLength(I2); + constexpr index_t X = wei_kcyx_desc.GetLength(I3); + + // reorder weight + auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: "); + + Tensor wei_cyxk(make_TensorDescriptor(wei_cyxk_desc)); + + auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) { + wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x); + }; + + make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)( + std::thread::hardware_concurrency()); + + // output + auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: "); + + Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); + + std::size_t data_sz = sizeof(T); + DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); + DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace()); + DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace()); + + in_nchw_device_buf.ToDevice(in_nchw.mData.data()); + wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); + out_khwn_device_buf.ToDevice(out_khwn.mData.data()); + +#if 0 + // for 3x3, 34x34, v1r1, Pascal + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 4; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t InBlockCopy_ThreadPerDimC = 4; + constexpr index_t InBlockCopy_ThreadPerDimH = 4; + constexpr index_t InBlockCopy_ThreadPerDimW = 2; + constexpr index_t InBlockCopy_ThreadPerDimN = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t OutThreadCopyDataPerWrite = 2; + + constexpr index_t BlockSize = 128; +#elif 0 + // for 3x3, 34x34, v1r2, Pascal, in-block-copy1 + constexpr index_t NPerBlock = 4; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 4; + constexpr index_t WoPerBlock = 8; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t InBlockCopy_ThreadPerDimC = 4; + constexpr index_t InBlockCopy_ThreadPerDimH = 4; + constexpr index_t InBlockCopy_ThreadPerDimW = 2; + constexpr index_t InBlockCopy_ThreadPerDimN = 1; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t OutThreadCopyDataPerWrite = 2; + + constexpr index_t BlockSize = 128; +#elif 0 + // for 3x3, 34x34, v1r1, Vega 20 + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 4; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t InBlockCopy_ThreadPerDimC = 4; + constexpr index_t InBlockCopy_ThreadPerDimH = 4; + constexpr index_t InBlockCopy_ThreadPerDimW = 2; + constexpr index_t InBlockCopy_ThreadPerDimN = 8; + constexpr index_t InBlockCopyDataPerRead = 2; + + constexpr index_t WeiBlockCopyDataPerRead = 2; + constexpr index_t OutThreadCopyDataPerWrite = 4; + + constexpr index_t BlockSize = 256; +#elif 0 + // for 3x3, 56x56, v1, Pascal + constexpr index_t NPerBlock = 32; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t InBlockCopy_ThreadPerDimC = 1; + constexpr index_t InBlockCopy_ThreadPerDimH = 4; + constexpr index_t InBlockCopy_ThreadPerDimW = 4; + constexpr index_t InBlockCopy_ThreadPerDimN = 8; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t OutThreadCopyDataPerWrite = 2; + + constexpr index_t BlockSize = 128; +#elif 0 + // for 3x3, 56x56, v1r2, Pascal + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 1; + constexpr index_t GemmDataPerReadB = 1; + + constexpr index_t InBlockCopy_ThreadPerDimC = 1; + constexpr index_t InBlockCopy_ThreadPerDimH = 2; + constexpr index_t InBlockCopy_ThreadPerDimW = 4; + constexpr index_t InBlockCopy_ThreadPerDimN = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t OutThreadCopyDataPerWrite = 4; + + constexpr index_t BlockSize = 128; +#elif 0 + // for 3x3, 28x28, v1r1, Pacal + constexpr index_t NPerBlock = 32; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t InBlockCopy_ThreadPerDimC = 1; + constexpr index_t InBlockCopy_ThreadPerDimH = 4; + constexpr index_t InBlockCopy_ThreadPerDimW = 4; + constexpr index_t InBlockCopy_ThreadPerDimN = 8; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t OutThreadCopyDataPerWrite = 2; + + constexpr index_t BlockSize = 128; +#elif 1 + // for 3x3, 28x28, v1r2, Pascal + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t InBlockCopy_ThreadPerDimN = 4; + constexpr index_t InBlockCopy_ThreadPerDimC = 8; + constexpr index_t InBlockCopy_ThreadPerDimH = 2; + constexpr index_t InBlockCopy_ThreadPerDimW = 2; + constexpr index_t InBlockCopyDataPerRead = 2; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t OutThreadCopyDataPerWrite = 2; + + constexpr index_t BlockSize = 128; +#elif 0 + // for 1x1, 28x28 + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; + + constexpr index_t InBlockCopy_ThreadPerDimC = 8; + constexpr index_t InBlockCopy_ThreadPerDimH = 2; + constexpr index_t InBlockCopy_ThreadPerDimW = 2; + constexpr index_t InBlockCopy_ThreadPerDimN = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t OutThreadCopyDataPerWrite = 2; + + constexpr index_t BlockSize = 128; +#elif 1 + // for 1x1, 14x14, Pascal + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 8; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t InBlockCopy_ThreadPerDimC = 8; + constexpr index_t InBlockCopy_ThreadPerDimH = 2; + constexpr index_t InBlockCopy_ThreadPerDimW = 2; + constexpr index_t InBlockCopy_ThreadPerDimN = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t OutThreadCopyDataPerWrite = 2; + + constexpr index_t BlockSize = 128; +#endif + + constexpr index_t GridSize = + ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * + ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + for(index_t i = 0; i < nrepeat; ++i) + { + constexpr auto gridwise_conv = +#if 1 + GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn +#endif + , + InBlockCopyDataPerRead, + WeiBlockCopyDataPerRead, + OutThreadCopyDataPerWrite>{}; + + float time = launch_kernel(run_gridwise_convolution, + dim3(GridSize), + dim3(BlockSize), + 0, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), + static_cast(out_khwn_device_buf.GetDeviceBuffer())); + + printf("Elapsed time : %f ms, %f TFlop/s\n", + time, + (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / time); + usleep(std::min(time * 1000, float(10000))); + } + + out_khwn_device_buf.FromDevice(out_khwn.mData.data()); + + // reorder output + auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) { + out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); + }; + + make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)( + std::thread::hardware_concurrency()); +} diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index 5eaf42a8b7..89106fe85f 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -38,9 +38,6 @@ struct GeneratorTensor_2 struct GeneratorTensor_3 { - int min_value = 0; - int max_value = 9; - template double operator()(Is... is) { @@ -420,11 +417,10 @@ void check_error(const Tensor& ref, const Tensor& result) int main(int argc, char* argv[]) { #if 0 - // 3x3, 34x34 - constexpr index_t N = 64; - constexpr index_t C = 256; - constexpr index_t HI = 34; - constexpr index_t WI = 34; + constexpr index_t N = 128; + constexpr index_t C = 8; + constexpr index_t HI = 28; + constexpr index_t WI = 28; constexpr index_t K = 128; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -432,15 +428,27 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; #elif 0 - // 3x3, 56x56 + // 3x3, 34x34 constexpr index_t N = 64; - constexpr index_t C = 64; - constexpr index_t HI = 56; - constexpr index_t WI = 56; + constexpr index_t C = 256; + constexpr index_t HI = 34; + constexpr index_t WI = 34; constexpr index_t K = 128; constexpr index_t Y = 3; constexpr index_t X = 3; + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; +#elif 0 + // 3x3, 56x56 + constexpr index_t N = 64; + constexpr index_t C = 64; + constexpr index_t HI = 56; + constexpr index_t WI = 56; + constexpr index_t K = 128; + constexpr index_t Y = 3; + constexpr index_t X = 3; + constexpr index_t HPad = 0; constexpr index_t WPad = 0; #elif 0 @@ -642,6 +650,9 @@ int main(int argc, char* argv[]) #if 0 in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); +#elif 0 + in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread); + wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); #elif 1 in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); @@ -664,7 +675,7 @@ int main(int argc, char* argv[]) device_direct_convolution_2_vectorized_nchw_kcyx_nkhw #elif 1 device_convolution_implicit_gemm_v1_chwn_cyxk_khwn -#elif 0 +#elif 1 device_convolution_implicit_gemm_v1_nchw_cyxk_khwn #elif 0 device_convolution_implicit_gemm_v2_chwn_cyxk_khwn diff --git a/src/include/Array.hip.hpp b/src/include/Array.hip.hpp index 65762c82a1..c386542d22 100644 --- a/src/include/Array.hip.hpp +++ b/src/include/Array.hip.hpp @@ -1,4 +1,6 @@ #pragma once +#include "Sequence.hip.hpp" +#include "functional.hip.hpp" template struct Array @@ -18,3 +20,35 @@ struct Array __host__ __device__ TData& operator[](index_t i) { return mData[i]; } }; + +template +__host__ __device__ auto reorder_array_given_new2old(const Array& old_array, + Sequence new2old) +{ + Array new_array; + + static_assert(NSize == sizeof...(IRs), "NSize not consistent"); + + static_for<0, NSize, 1>{}([&](auto IDim) { + constexpr index_t idim = IDim.Get(); + new_array[idim] = old_array[new2old.Get(IDim)]; + }); + + return new_array; +} + +template +__host__ __device__ auto reorder_array_given_old2new(const Array& old_array, + Sequence old2new) +{ + Array new_array; + + static_assert(NSize == sizeof...(IRs), "NSize not consistent"); + + static_for<0, NSize, 1>{}([&](auto IDim) { + constexpr index_t idim = IDim.Get(); + new_array[old2new.Get(IDim)] = old_array[idim]; + }); + + return new_array; +} \ No newline at end of file diff --git a/src/include/ConstantTensorDescriptor.hip.hpp b/src/include/ConstantTensorDescriptor.hip.hpp index d204cba9ce..990d1724f4 100644 --- a/src/include/ConstantTensorDescriptor.hip.hpp +++ b/src/include/ConstantTensorDescriptor.hip.hpp @@ -108,11 +108,11 @@ template struct ConstantTensorDescriptor { using Type = ConstantTensorDescriptor; - static constexpr index_t nDim = Lengths::nDim; + static constexpr index_t nDim = Lengths::GetSize(); __host__ __device__ constexpr ConstantTensorDescriptor() { - static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent"); + static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent"); } __host__ __device__ static constexpr index_t GetDimension() { return nDim; } @@ -157,12 +157,10 @@ struct ConstantTensorDescriptor return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get()); } - template - __host__ __device__ static index_t Get1dIndex(Is... is) + template + __host__ __device__ static index_t Get1dIndex(Array multi_id) { - static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong"); - - const auto multi_id = Array(is...); + static_assert(NSize == nDim, "wrong! Dimension not consistent"); index_t id = 0; @@ -178,6 +176,16 @@ struct ConstantTensorDescriptor return id; } + template + __host__ __device__ static index_t Get1dIndex(Is... is) + { + static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong"); + + const auto multi_id = Array(is...); + + return Get1dIndex(multi_id); + } + __host__ __device__ static Array GetMultiIndex(index_t id) { Array multi_id; diff --git a/src/include/Sequence.hip.hpp b/src/include/Sequence.hip.hpp index 4ea641a47a..48a86505e7 100644 --- a/src/include/Sequence.hip.hpp +++ b/src/include/Sequence.hip.hpp @@ -7,9 +7,11 @@ struct Sequence { using Type = Sequence; - static constexpr index_t nDim = sizeof...(Is); + static constexpr index_t mSize = sizeof...(Is); - const index_t mData[nDim] = {Is...}; + const index_t mData[mSize] = {Is...}; + + __host__ __device__ static constexpr index_t GetSize() { return mSize; } template __host__ __device__ constexpr index_t Get(Number) const @@ -19,36 +21,38 @@ struct Sequence __host__ __device__ index_t operator[](index_t i) const { return mData[i]; } - // this is ugly, only for nDIm = 4 - template - __host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence) const + template + __host__ __device__ constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) const { - static_assert(nDim == 4, "nDim != 4"); + static_assert(mSize == sizeof...(IRs), "mSize not consistent"); - constexpr auto old_sequence = Type{}; + constexpr auto old = Type{}; - constexpr index_t NR0 = old_sequence.mData[I0]; - constexpr index_t NR1 = old_sequence.mData[I1]; - constexpr index_t NR2 = old_sequence.mData[I2]; - constexpr index_t NR3 = old_sequence.mData[I3]; - - return Sequence{}; + return Sequence{})...>{}; } - template - __host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence) const + template + __host__ __device__ constexpr auto ReorderGivenOld2New(Sequence /*old2new*/) const { // don't know how to implement this - printf("Sequence::ReorderByPutOldToNew not implemented"); + printf("Sequence::ReorderGivenOld2New not implemented"); assert(false); } + template + __host__ __device__ constexpr auto PushFront(Number) const + { + return Sequence{}; + } + template __host__ __device__ constexpr auto PushBack(Number) const { return Sequence{}; } + __host__ __device__ constexpr auto PopFront() const; + __host__ __device__ constexpr auto PopBack() const; template @@ -58,33 +62,84 @@ struct Sequence } }; -template -__host__ __device__ constexpr auto sequence_pop_back(Sequence) +template +__host__ __device__ constexpr auto sequence_pop_front(Sequence) { - static_assert(sizeof...(Is) >= 1, "empty Sequence!"); + static_assert(sizeof...(Is) > 0, "empty Sequence!"); return Sequence{}; } -template -__host__ __device__ constexpr auto sequence_sequence_op(Sequence, Sequence, F f) +template +__host__ __device__ constexpr auto sequence_pop_back(Sequence) { - static_assert(Sequence::nDim == Sequence::nDim, "Dim not the same"); + static_assert(sizeof...(Is) > 0, "empty Sequence!"); + return Sequence{}; +} + +#if 1 +// this is ugly, only for 2 sequences +template +__host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) +{ + static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } -template -__host__ __device__ constexpr auto sequence_sequence_add(Sequence, Sequence) +// this is ugly, only for 3 sequences +template +__host__ __device__ constexpr auto +transform_sequences(F f, Sequence, Sequence, Sequence) { - struct add - { - __host__ __device__ constexpr index_t operator()(index_t x, index_t y) const - { - return x + y; - } - }; + static_assert(Sequence::mSize == Sequence::mSize && + Sequence::mSize == Sequence::mSize, + "Dim not the same"); - return sequence_sequence_op(Sequence{}, Sequence{}, add{}); + return Sequence{}; +} +#else +template +struct transform_sequences_impl +{ + template + __host__ __device__ constexpr auto operator()(F f, Y y, Xs... xs) const + { + static_assert(NRemain > 1, "wrong! should have NRemain > 1"); + + constexpr index_t N = f(Xs{}.Get(Number<0>{})...); + constexpr auto y_new = y.PushBack(Number{}); + + return transform_sequences_impl{}(f, y_new, xs.PopFront()...); + } +}; + +template <> +struct transform_sequences_impl<1> +{ + template + __host__ __device__ constexpr auto operator()(F f, Y, Xs...) const + { + constexpr index_t N = f(Xs{}.Get(Number<0>{})...); + return Y{}.PushBack(Number{}); + } +}; + +template +__host__ __device__ constexpr auto transform_sequences(F f, X x, Xs... xs) +{ + constexpr index_t nSize = X::GetSize(); + constexpr auto I0 = Number<0>{}; + + constexpr auto y0 = Sequence{}; + + return transform_sequences_impl{}(f, y0, x.PopFront(), xs.PopFront()...); +} +#endif + +template +__host__ __device__ constexpr auto Sequence::PopFront() const +{ + return sequence_pop_front(Type{}); } template @@ -107,6 +162,6 @@ template __host__ __device__ constexpr index_t accumulate_on_sequence(Seq, Reduce, Number) { constexpr index_t a = - static_const_reduce_n{}(accumulate_on_sequence_f{}, Reduce{}); + static_const_reduce_n{}(accumulate_on_sequence_f{}, Reduce{}); return Reduce{}(a, I); } diff --git a/src/include/blockwise_2d_tensor_op.hip.hpp b/src/include/blockwise_2d_tensor_op.hip.hpp index cfbcce2a86..da7f9b9037 100644 --- a/src/include/blockwise_2d_tensor_op.hip.hpp +++ b/src/include/blockwise_2d_tensor_op.hip.hpp @@ -67,7 +67,7 @@ template __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( SrcDesc, @@ -75,14 +75,14 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds DstDesc, Float* __restrict__ p_dst, SrcOpLengths, - DstFromSrcReorder, + MapDst2Src, F f) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; - constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0); - constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1); + constexpr index_t IR0 = MapDst2Src{}.Get(I0); + constexpr index_t IR1 = MapDst2Src{}.Get(I1); constexpr auto src_desc = SrcDesc{}; constexpr auto dst_desc = DstDesc{}; @@ -147,19 +147,19 @@ template + class MapDst2Src> __device__ void blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, const Float* __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths, - DstFromSrcReorder) + MapDst2Src) { auto f_copy = [](const Float& src, Float& dst) { dst = src; }; blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( - SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy); + SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy); } template {}); diff --git a/src/include/blockwise_3d_tensor_op.hip.hpp b/src/include/blockwise_3d_tensor_op.hip.hpp index a6fe257e55..6a88757075 100644 --- a/src/include/blockwise_3d_tensor_op.hip.hpp +++ b/src/include/blockwise_3d_tensor_op.hip.hpp @@ -33,7 +33,7 @@ struct Blockwise3dTensorCopy1 // but we need to make sure dst stride2 is big enough, // so that the out-of-bound write won't contaminate next line in dst constexpr index_t L2 = CopyLengths{}.Get(I2); - constexpr index_t read_per_d2 = integer_divide_ceil(L2, DataPerRead); + constexpr index_t read_per_d2 = mod_conv::integer_divide_ceil(L2, DataPerRead); static_assert(read_per_d2 * DataPerRead <= DstDesc{}.GetStride(I1), "wrong! out-of-bound write will contaminate next line!\n"); @@ -52,7 +52,7 @@ struct Blockwise3dTensorCopy1 constexpr index_t L1 = CopyLengths{}.Get(I1); constexpr index_t L2 = CopyLengths{}.Get(I2); - constexpr index_t read_per_d2 = integer_divide_ceil(L2, DataPerRead); + constexpr index_t read_per_d2 = mod_conv::integer_divide_ceil(L2, DataPerRead); constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence{}); @@ -98,3 +98,271 @@ struct Blockwise3dTensorCopy1 } } }; + +// starting point need to be aligned to float4 or float2 or float +// stride3 need to be 1 for both source and destination +template +struct Blockwise3dTensorCopy3 +{ + using vector_t = typename vector_type::MemoryType; + + index_t mSrcMyThreadOffset; + index_t mDstMyThreadOffset; + + __device__ Blockwise3dTensorCopy3() + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + static_assert(DataPerRead == 1 || + (SrcDesc{}.GetStride(I2) == 1 && DstDesc{}.GetStride(I2) == 1), + "wrong! only support stride3 == 1 if DataPerRead > 1!\n"); + + static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4, + "wrong! only support DataPerRead == 1, 2 or 4!\n"); + + static_assert( + SrcDesc{}.GetStride(I1) % DataPerRead == 0 && + DstDesc{}.GetStride(I1) % DataPerRead == 0, + "wrong! src and dst stride1 should be multiple of DataPerRead to keep alignment"); + + constexpr index_t L0 = CopyLengths{}.Get(I0); + constexpr index_t L1 = CopyLengths{}.Get(I1); + constexpr index_t L2 = CopyLengths{}.Get(I2); + + constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0); + constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1); + constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2); + + // we allow out-of-bound read from src in D2 dimension, + // but we need to make sure dst stride is big enough, + // so that the out-of-bound write won't contaminate next line in dst + constexpr index_t nloop_d2 = mod_conv::integer_divide_ceil(L2, thread_per_d2 * DataPerRead); + + static_assert(nloop_d2 * thread_per_d2 * DataPerRead <= DstDesc{}.GetStride(I1), + "wrong! out-of-bound write will contaminate next line!\n"); + + static_assert(L0 % thread_per_d0 == 0 && L1 % thread_per_d1 == 0, + "wrong! L0, L1, L2 should be divided evenly!\n"); + + static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2, + "wrrong! BlockSize is not big enough for ThreadPerDims!"); + + constexpr index_t num_active_thread = + accumulate_on_sequence(ThreadPerDims{}, mod_conv::multiplies{}, Number<1>{}); + + if(BlockSize > num_active_thread) + { + if(get_thread_local_1d_id() >= num_active_thread) + { + return; + } + } + + constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor(ThreadPerDims{}); + const auto thread_multi_id = thread_cluster_desc.GetMultiIndex(get_thread_local_1d_id()); + + mSrcMyThreadOffset = SrcDesc{}.Get1dIndex( + thread_multi_id[0], thread_multi_id[1], thread_multi_id[2] * DataPerRead); + + mDstMyThreadOffset = DstDesc{}.Get1dIndex( + thread_multi_id[0], thread_multi_id[1], thread_multi_id[2] * DataPerRead); + } + + __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + constexpr index_t L0 = CopyLengths{}.Get(I0); + constexpr index_t L1 = CopyLengths{}.Get(I1); + constexpr index_t L2 = CopyLengths{}.Get(I2); + + constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0); + constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1); + constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2); + + constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2; + + if(BlockSize > num_active_thread) + { + if(get_thread_local_1d_id() >= num_active_thread) + { + return; + } + } + + constexpr index_t nloop_d0 = L0 / thread_per_d0; + constexpr index_t nloop_d1 = L1 / thread_per_d1; + constexpr index_t nloop_d2 = mod_conv::integer_divide_ceil(L2, thread_per_d2 * DataPerRead); + +#pragma unroll + for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0) + { +#pragma unroll + for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1) + { +#pragma unroll + for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2) + { +#pragma unroll + const index_t src_offset = + SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, + iloop_d1 * thread_per_d1, + iloop_d2 * thread_per_d2 * DataPerRead); + + const index_t dst_offset = + DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, + iloop_d1 * thread_per_d1, + iloop_d2 * thread_per_d2 * DataPerRead); + + *(reinterpret_cast(&p_dst[dst_offset + mDstMyThreadOffset])) = *( + reinterpret_cast(&p_src[src_offset + mSrcMyThreadOffset])); + } + } + } + } + + __device__ constexpr index_t GetRegisterClipboardSize() const + { + static_assert(is_same::value, "wrong! only support float!\n"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + constexpr index_t L0 = CopyLengths{}.Get(I0); + constexpr index_t L1 = CopyLengths{}.Get(I1); + constexpr index_t L2 = CopyLengths{}.Get(I2); + + constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0); + constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1); + constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2); + + constexpr index_t nloop_d0 = L0 / thread_per_d0; + constexpr index_t nloop_d1 = L1 / thread_per_d1; + constexpr index_t nloop_d2 = mod_conv::integer_divide_ceil(L2, thread_per_d2 * DataPerRead); + + return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2; + } + + __device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src, + Float* __restrict__ p_clipboard) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + constexpr index_t L0 = CopyLengths{}.Get(I0); + constexpr index_t L1 = CopyLengths{}.Get(I1); + constexpr index_t L2 = CopyLengths{}.Get(I2); + + constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0); + constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1); + constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2); + + constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2; + + if(BlockSize > num_active_thread) + { + if(get_thread_local_1d_id() >= num_active_thread) + { + return; + } + } + + constexpr index_t nloop_d0 = L0 / thread_per_d0; + constexpr index_t nloop_d1 = L1 / thread_per_d1; + constexpr index_t nloop_d2 = mod_conv::integer_divide_ceil(L2, thread_per_d2 * DataPerRead); + + constexpr auto clipboard_desc = + make_ConstantTensorDescriptor(Sequence{}); + +#pragma unroll + for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0) + { +#pragma unroll + for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1) + { +#pragma unroll + for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2) + { + const index_t src_offset = + SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, + iloop_d1 * thread_per_d1, + iloop_d2 * thread_per_d2 * DataPerRead); + + const index_t clipboard_offset = + clipboard_desc.Get1dIndex(iloop_d0, iloop_d1, iloop_d2 * DataPerRead); + + *(reinterpret_cast(&p_clipboard[clipboard_offset])) = *( + reinterpret_cast(&p_src[src_offset + mSrcMyThreadOffset])); + } + } + } + } + + __device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard, + Float* __restrict__ p_dst) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + constexpr index_t L0 = CopyLengths{}.Get(I0); + constexpr index_t L1 = CopyLengths{}.Get(I1); + constexpr index_t L2 = CopyLengths{}.Get(I2); + + constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0); + constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1); + constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2); + + constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2; + + if(BlockSize > num_active_thread) + { + if(get_thread_local_1d_id() >= num_active_thread) + { + return; + } + } + + constexpr index_t nloop_d0 = L0 / thread_per_d0; + constexpr index_t nloop_d1 = L1 / thread_per_d1; + constexpr index_t nloop_d2 = mod_conv::integer_divide_ceil(L2, thread_per_d2 * DataPerRead); + + constexpr auto clipboard_desc = + make_ConstantTensorDescriptor(Sequence{}); + +#pragma unroll + for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0) + { +#pragma unroll + for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1) + { +#pragma unroll + for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2) + { + const index_t clipboard_offset = + clipboard_desc.Get1dIndex(iloop_d0, iloop_d1, iloop_d2 * DataPerRead); + + const index_t dst_offset = + DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, + iloop_d1 * thread_per_d1, + iloop_d2 * thread_per_d2 * DataPerRead); + + *(reinterpret_cast(&p_dst[dst_offset + mDstMyThreadOffset])) = + *(reinterpret_cast(&p_clipboard[clipboard_offset])); + } + } + } + } +}; diff --git a/src/include/blockwise_4d_tensor_op.hip.hpp b/src/include/blockwise_4d_tensor_op.hip.hpp index bd4124de57..eea37a2b2e 100644 --- a/src/include/blockwise_4d_tensor_op.hip.hpp +++ b/src/include/blockwise_4d_tensor_op.hip.hpp @@ -84,7 +84,7 @@ template __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( SrcDesc, @@ -92,7 +92,7 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds DstDesc, Float* __restrict__ p_dst, SrcOpLengths, - DstFromSrcReorder, + MapDst2Src, F f) { constexpr auto I0 = Number<0>{}; @@ -100,10 +100,10 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0); - constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1); - constexpr index_t IR2 = DstFromSrcReorder{}.Get(I2); - constexpr index_t IR3 = DstFromSrcReorder{}.Get(I3); + constexpr index_t IR0 = MapDst2Src{}.Get(I0); + constexpr index_t IR1 = MapDst2Src{}.Get(I1); + constexpr index_t IR2 = MapDst2Src{}.Get(I2); + constexpr index_t IR3 = MapDst2Src{}.Get(I3); constexpr auto src_desc = SrcDesc{}; constexpr auto dst_desc = DstDesc{}; @@ -184,19 +184,19 @@ template + class MapDst2Src> __device__ void blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, const Float* __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths, - DstFromSrcReorder) + MapDst2Src) { auto f_copy = [](const Float& src, Float& dst) { dst = src; }; blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( - SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy); + SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy); } template {}); @@ -481,7 +481,7 @@ struct Blockwise4dTensorCopy3 // we allow out-of-bound read from src in D3 dimension, // but we need to make sure dst stride is big enough, // so that the out-of-bound write won't contaminate next line in dst - constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); + constexpr index_t nloop_d3 = mod_conv::integer_divide_ceil(L3, thread_per_d3 * DataPerRead); static_assert(nloop_d3 * thread_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2), "wrong! out-of-bound write will contaminate next line!\n"); @@ -548,7 +548,7 @@ struct Blockwise4dTensorCopy3 constexpr index_t nloop_d0 = L0 / thread_per_d0; constexpr index_t nloop_d1 = L1 / thread_per_d1; constexpr index_t nloop_d2 = L2 / thread_per_d2; - constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); + constexpr index_t nloop_d3 = mod_conv::integer_divide_ceil(L3, thread_per_d3 * DataPerRead); #pragma unroll for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0) @@ -605,7 +605,7 @@ struct Blockwise4dTensorCopy3 constexpr index_t nloop_d0 = L0 / thread_per_d0; constexpr index_t nloop_d1 = L1 / thread_per_d1; constexpr index_t nloop_d2 = L2 / thread_per_d2; - constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); + constexpr index_t nloop_d3 = mod_conv::integer_divide_ceil(L3, thread_per_d3 * DataPerRead); return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2 * nloop_d3; } @@ -642,7 +642,7 @@ struct Blockwise4dTensorCopy3 constexpr index_t nloop_d0 = L0 / thread_per_d0; constexpr index_t nloop_d1 = L1 / thread_per_d1; constexpr index_t nloop_d2 = L2 / thread_per_d2; - constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); + constexpr index_t nloop_d3 = mod_conv::integer_divide_ceil(L3, thread_per_d3 * DataPerRead); constexpr auto clipboard_desc = make_ConstantTensorDescriptor( Sequence{}); @@ -709,7 +709,7 @@ struct Blockwise4dTensorCopy3 constexpr index_t nloop_d0 = L0 / thread_per_d0; constexpr index_t nloop_d1 = L1 / thread_per_d1; constexpr index_t nloop_d2 = L2 / thread_per_d2; - constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); + constexpr index_t nloop_d3 = mod_conv::integer_divide_ceil(L3, thread_per_d3 * DataPerRead); constexpr auto clipboard_desc = make_ConstantTensorDescriptor( Sequence{}); @@ -749,7 +749,7 @@ template + class MapDst2Src> struct Blockwise4dTensorCopyReorder1 { __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const @@ -757,60 +757,104 @@ struct Blockwise4dTensorCopyReorder1 auto f_copy = [](const Float& src, Float& dst) { dst = src; }; blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( - SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy); + SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy); } }; -#if 0 +#if 1 template + class SrcClusterLengths, + class MapDst2Src, + class MapThreadCluster2SrcCluster, + index_t SrcDataPerRead, + index_t DstDataPerWrite> struct Blockwise4dTensorCopyReorder3 { + static constexpr index_t nDim = SrcLengths::GetSize(); + index_t mSrcMyThreadOffset; index_t mDstMyThreadOffset; __device__ Blockwise4dTensorCopyReorder3() { - constexpr index_t nDim = SrcDesc{}.GetDimension(); + constexpr auto src_desc = SrcDesc{}; + constexpr auto dst_desc = DstDesc{}; - static_assert(DstDesc{}.GetDimension() == nDim && SrcOpLengths::nDim == nDim && - SrcOpThreadPerDims::nDim == nDim && DstFromSrcReorder::nDim == nDim, - "wrong! nDim is not consistent\n"); + constexpr auto src_lengths = SrcLengths{}; - // Src - static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4, - "wrong! only support DataPerRead == 1, 2 or 4!\n"); + constexpr auto map_dst2src = MapDst2Src{}; - static_assert(DataPerRead == 1 || SrcDesc{}.GetStride(Number{}) == 1, - "wrong! only support src.stride(nDim-1) == 1 if DataPerRead > 1!\n"); + constexpr auto src_sub_lengths = SrcSubLengths{}; + constexpr auto dst_sub_lengths = src_sub_lengths.ReorderGivenNew2Old(map_dst2src); - static_assert( - SrcDesc{}.GetStride(Number{}) % DataPerRead == 0, - "wrong! src.stride(nDim-2) should be multiple of DataPerRead to keep alignment"); + constexpr auto map_thread_cluster_2_src_cluster = MapThreadCluster2SrcCluster{}; - static_assert(SrcSubLengths{}.Get(Number{}) % DataPerRead == 0, "wrong! SrcSubLengths[nDim-1] % DataPerRead != 0\n"); + constexpr auto src_cluster_lengths = SrcClusterLengths{}; + constexpr auto thread_cluster_lengths = + src_cluster_lengths.ReorderGivenNew2Old(map_thread_cluster_2_src_cluster); - static_loop([](auto I){ - constexpr index_t src_len = SrcLengths{}.Get(I); - constexpr index_t src_sub_len = SrcSubLengths{}.Get(I); - constexpr index_t thread_per_dim = SrcThreadPerDims{}.Get(I); - static_assert(src_len % (src_sub_len * thread_per_dim) == 0, - "wrong! cannot evenly divide tensor lengths"); - }); + constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor(thread_cluster_lengths); - constexpr index_t num_active_thread = accumulate_on_sequence(SrcOpThreadPerDims{}, mod_conv::multiplies{}, Number<1>{}); + // sanity check: data type + static_assert(is_same::value, "wrong! only support float for now!\n"); + + // sanity check: nDim + static_assert(SrcDesc::GetDimension() == nDim && DstDesc::GetDimension() == nDim && + SrcLengths::GetSize() == nDim && SrcSubLengths::GetSize() == nDim && + SrcClusterLengths::GetSize() == nDim && MapDst2Src::GetSize() == nDim && + MapThreadCluster2SrcCluster::GetSize() == nDim, + "wrong! nDim is not consistent\n"); + + // sanity check: BlockSize + constexpr index_t num_active_thread = thread_cluster_desc.GetElementSize(); static_assert(BlockSize >= num_active_thread, "wrong! BlockSize is not big enough for ThreadPerDims!"); + // sanity check: work division + static_for<0, nDim, 1>{}([](auto IDim) { + constexpr auto I = decltype(IDim){}; + constexpr index_t src_len = src_lengths.Get(I); + constexpr index_t src_sub_len = src_sub_lengths.Get(I); + constexpr index_t src_cluster_len = src_cluster_lengths.Get(I); + static_assert(src_len % (src_sub_len * src_cluster_len) == 0, + "wrong! cannot evenly divide Src tensor lengths"); + }); + + // sanity check: src read + static_assert(SrcDataPerRead == 1 || SrcDataPerRead == 2 || SrcDataPerRead == 4, + "wrong! only support SrcDataPerRead == 1, 2 or 4!\n"); + + static_assert(SrcDataPerRead == 1 || src_desc.GetStride(Number{}) == 1, + "wrong! only support src.stride(nDim-1) == 1 if SrcDataPerRead > 1!\n"); + + static_assert(src_sub_lengths.Get(Number{}) % SrcDataPerRead == 0, + "wrong! src_sub_lengths[nDim-1] % SrcDataPerRead != 0\n"); + + static_assert(src_desc.GetStride(Number{}) % SrcDataPerRead == 0, + "wrong! should satisfy src_desc.stride(nDim-2) % SrcDataPerRead == 0, to " + "keep alignment"); + + // sanity check: dst write + static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4, + "wrong! only support DstDataPerWrite == 1, 2 or 4!\n"); + + static_assert(DstDataPerWrite == 1 || dst_desc.GetStride(Number{}) == 1, + "wrong! only support dst.stride(nDim-1) == 1 if DstDataPerWrite > 1!\n"); + + static_assert(dst_sub_lengths.Get(Number{}) % DstDataPerWrite == 0, + "wrong! dst_sub_lengths[nDim-1] % DstDataPerWrite != 0\n"); + + static_assert(dst_desc.GetStride(Number{}) % DstDataPerWrite == 0, + "wrong! should satisfy dst_desc.stride(nDim-2) % DstDataPerWrite == 0, to " + "keep alignment"); + + // start dividing work if(BlockSize > num_active_thread) { if(get_thread_local_1d_id() >= num_active_thread) @@ -819,37 +863,251 @@ struct Blockwise4dTensorCopyReorder3 } } - const auto thread_multi_id = SrcOpThreadPerDims::GetMultiIndex(get_thread_local_1d_id()); + const auto thread_multi_id = thread_cluster_desc.GetMultiIndex(get_thread_local_1d_id()); + // compiler: thread_multi_id, src_data_multi_id, dst_data_multi_id, will use separate + // regsiters, or only one copy??? + auto src_data_multi_id = + reorder_array_given_old2new(thread_multi_id, map_thread_cluster_2_src_cluster); - const index_t thread_id_d0 = - get_thread_local_1d_id() / (thread_per_d1 * thread_per_d2 * thread_per_d3); - index_t itmp = get_thread_local_1d_id() - - thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3); - const index_t thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3); - itmp -= thread_id_d1 * (thread_per_d2 * thread_per_d3); - const index_t thread_id_d2 = itmp / thread_per_d3; - const index_t thread_id_d3 = itmp - thread_id_d2 * thread_per_d3; + static_for<0, nDim, 1>{}([&](auto IDim) { + constexpr auto I = decltype(IDim){}; + constexpr index_t i = I.Get(); + // compiler: will it really compute index here, or be associated with Get1dIndex and + // optimized away??? + src_data_multi_id[i] *= src_sub_lengths.Get(I); + }); + // compiler: will it really compute index here, or be associated with Get1dIndex and + // optimized away??? + const auto dst_data_multi_id = reorder_array_given_new2old(src_data_multi_id, map_dst2src); - mSrcMyThreadOffset = SrcDesc{}.Get1dIndex( - thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead); + mSrcMyThreadOffset = src_desc.Get1dIndex(src_data_multi_id); + mDstMyThreadOffset = dst_desc.Get1dIndex(dst_data_multi_id); +#if 0 + if(get_block_1d_id() == 0) + { + printf("tid %5u, " + "thread_multi_id %5u %5u %5u %5u, " + "src_data_multi_id %5u %5u %5u %5u, " + "dst_data_multi_id %5u %5u %5u %5u, " + "mSrcMyThreadOffset %u, mDstMyThreadOffset %u\n", + get_thread_local_1d_id(), + thread_multi_id[0], + thread_multi_id[1], + thread_multi_id[2], + thread_multi_id[3], + src_data_multi_id[0], + src_data_multi_id[1], + src_data_multi_id[2], + src_data_multi_id[3], + dst_data_multi_id[0], + dst_data_multi_id[1], + dst_data_multi_id[2], + dst_data_multi_id[3], + mSrcMyThreadOffset, + mDstMyThreadOffset); + } +#endif } __device__ static constexpr index_t GetRegisterClipboardSize() { - static_assert(is_same::value, "wrong! only support float!\n"); + constexpr auto thread_sub_tensor_lengths = SrcSubLengths{}; + + constexpr auto src_data_per_cluster_per_dims = transform_sequences( + mod_conv::multiplies{}, thread_sub_tensor_lengths, SrcClusterLengths{}); + + constexpr auto cluster_per_dims = + transform_sequences(mod_conv::integer_divide_ceiler{}, + SrcLengths{}, + src_data_per_cluster_per_dims); + + constexpr auto thread_tensor_lengths = transform_sequences( + mod_conv::multiplies{}, thread_sub_tensor_lengths, cluster_per_dims); + + constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths); + + return thread_tensor_desc.GetElementSpace(); } __device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src, Float* __restrict__ p_clipboard) const { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto thread_sub_tensor_lengths = SrcSubLengths{}; + + constexpr auto src_data_per_cluster_per_dims = transform_sequences( + mod_conv::multiplies{}, thread_sub_tensor_lengths, SrcClusterLengths{}); + + constexpr auto cluster_per_dims = + transform_sequences(mod_conv::integer_divide_ceiler{}, + SrcLengths{}, + src_data_per_cluster_per_dims); + + constexpr auto thread_tensor_lengths = transform_sequences( + mod_conv::multiplies{}, thread_sub_tensor_lengths, cluster_per_dims); + + constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths); + + constexpr auto thread_sub_tensor_desc = + make_ConstantTensorDescriptor(SrcClusterLengths{}, thread_tensor_desc.GetStrides()); + + for(index_t icluster_d0 = 0; icluster_d0 < cluster_per_dims.Get(I0); ++icluster_d0) + { + for(index_t icluster_d1 = 0; icluster_d1 < cluster_per_dims.Get(I1); ++icluster_d1) + { + for(index_t icluster_d2 = 0; icluster_d2 < cluster_per_dims.Get(I2); ++icluster_d2) + { + for(index_t icluster_d3 = 0; icluster_d3 < cluster_per_dims.Get(I3); + ++icluster_d3) + { + const index_t src_offset = SrcDesc{}.Get1dIndex( + icluster_d0 * src_data_per_cluster_per_dims.Get(I0), + icluster_d1 * src_data_per_cluster_per_dims.Get(I1), + icluster_d2 * src_data_per_cluster_per_dims.Get(I2), + icluster_d3 * src_data_per_cluster_per_dims.Get(I3)); + + const index_t clipboard_offset = thread_tensor_desc.Get1dIndex( + icluster_d0 * thread_sub_tensor_lengths.Get(I0), + icluster_d1 * thread_sub_tensor_lengths.Get(I1), + icluster_d2 * thread_sub_tensor_lengths.Get(I2), + icluster_d3 * thread_sub_tensor_lengths.Get(I3)); + + threadwise_4d_tensor_copy_v2(SrcDesc{}, + p_src + src_offset + mSrcMyThreadOffset, + thread_tensor_desc, + p_clipboard + clipboard_offset, + thread_sub_tensor_lengths, + Number{}); + } + } + } + } + +#if 0 + if(get_block_1d_id() == 0) + { + printf("tid %5u, " + "data: %f %f %f %f %f %f %f %f\n", + get_thread_local_1d_id(), + p_clipboard[0], + p_clipboard[1], + p_clipboard[2], + p_clipboard[3], + p_clipboard[4], + p_clipboard[5], + p_clipboard[6], + p_clipboard[7]); + } +#endif } __device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard, Float* __restrict__ p_dst) const { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto thread_sub_tensor_lengths = SrcSubLengths{}; + + constexpr auto src_data_per_cluster_per_dims = transform_sequences( + mod_conv::multiplies{}, thread_sub_tensor_lengths, SrcClusterLengths{}); + + constexpr auto cluster_per_dims = + transform_sequences(mod_conv::integer_divide_ceiler{}, + SrcLengths{}, + src_data_per_cluster_per_dims); + + constexpr auto thread_tensor_lengths = transform_sequences( + mod_conv::multiplies{}, thread_sub_tensor_lengths, cluster_per_dims); + + constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths); + + constexpr auto thread_sub_tensor_desc = + make_ConstantTensorDescriptor(SrcClusterLengths{}, thread_tensor_desc.GetStrides()); + + for(index_t icluster_d0 = 0; icluster_d0 < cluster_per_dims.Get(I0); ++icluster_d0) + { + for(index_t icluster_d1 = 0; icluster_d1 < cluster_per_dims.Get(I1); ++icluster_d1) + { + for(index_t icluster_d2 = 0; icluster_d2 < cluster_per_dims.Get(I2); ++icluster_d2) + { + for(index_t icluster_d3 = 0; icluster_d3 < cluster_per_dims.Get(I3); + ++icluster_d3) + { + const index_t clipboard_offset = thread_tensor_desc.Get1dIndex( + icluster_d0 * thread_sub_tensor_lengths.Get(I0), + icluster_d1 * thread_sub_tensor_lengths.Get(I1), + icluster_d2 * thread_sub_tensor_lengths.Get(I2), + icluster_d3 * thread_sub_tensor_lengths.Get(I3)); + + const auto dst_multi_id = reorder_array_given_new2old( + Array{ + icluster_d0 * src_data_per_cluster_per_dims.Get(I0), + icluster_d1 * src_data_per_cluster_per_dims.Get(I1), + icluster_d2 * src_data_per_cluster_per_dims.Get(I2), + icluster_d3 * src_data_per_cluster_per_dims.Get(I3)}, + MapDst2Src{}); + + const index_t dst_offset = DstDesc{}.Get1dIndex(dst_multi_id); + +#if 0 + if(get_block_1d_id() == 0) + { + printf("tid %5u, " + "clipboard_offsetm %5u, dst_offset %5u\n", + get_thread_local_1d_id(), + clipboard_offset, + dst_offset); + } +#endif + +#if 1 + threadwise_4d_tensor_copy_reorder_by_get_dst_from_src( + thread_tensor_desc, + p_clipboard + clipboard_offset, + DstDesc{}, + p_dst + dst_offset + mDstMyThreadOffset, + thread_sub_tensor_lengths, + MapDst2Src{}); +#endif + } + } + } + } + +#if 0 + if(get_block_1d_id() == 0) + { + printf("tid %5u, " + "data: %f %f %f %f %f %f %f %f\n", + get_thread_local_1d_id(), + p_clipboard[0], + p_clipboard[1], + p_clipboard[2], + p_clipboard[3], + p_clipboard[4], + p_clipboard[5], + p_clipboard[6], + p_clipboard[7]); + } +#endif + } + + __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const + { + Float p_clipboard[GetRegisterClipboardSize()]; + + RunLoadRegisterClipboard(p_src, p_clipboard); + RunStoreRegisterClipboard(p_clipboard, p_dst); } }; #endif diff --git a/src/include/common.hip.hpp b/src/include/common.hip.hpp index feb9060be7..a3d5596c7f 100644 --- a/src/include/common.hip.hpp +++ b/src/include/common.hip.hpp @@ -25,12 +25,38 @@ struct is_same static const bool value = true; }; -__host__ __device__ constexpr index_t integer_divide_ceil(index_t a, index_t b) +namespace mod_conv { // namespace mod_conv +template +struct multiplies { + __host__ __device__ constexpr T operator()(T a, T b) const { return a * b; } +}; + +template +struct plus +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a + b; } +}; + +template +struct integer_divide_ceiler +{ + __host__ __device__ constexpr T operator()(T a, T b) const + { + static_assert(is_same::value || is_same::value, "wrong type"); + + return (a + b - 1) / b; + } +}; + +template +__host__ __device__ constexpr T integer_divide_ceil(T a, T b) +{ + static_assert(is_same::value || is_same::value, "wrong type"); + return (a + b - 1) / b; } -namespace mod_conv { // namespace mod_conv template __host__ __device__ constexpr T max(T x, T y) { diff --git a/src/include/functional.hip.hpp b/src/include/functional.hip.hpp index 2cb91c1922..3db890f46b 100644 --- a/src/include/functional.hip.hpp +++ b/src/include/functional.hip.hpp @@ -70,18 +70,3 @@ __host__ __device__ constexpr auto unpacker(F f) return [=](auto xs_array){ f(xs...); }; } #endif - -namespace mod_conv { -template -struct multiplies -{ - __host__ __device__ constexpr T operator()(T a, T b) const { return a * b; } -}; - -template -struct plus -{ - __host__ __device__ constexpr T operator()(T a, T b) const { return a + b; } -}; - -} // namespace mod_conv diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp index 454ed30392..0c7a455fc9 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp @@ -248,42 +248,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn } } -// output: register to global mem, -#if 0 - const auto c_thread_mtx_begin = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k) - { - for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho) - { - for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo) - { - for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n) - { - const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n); - - const auto c_thread_mtx_distance = - blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b); - - const index_t ho_thread = - c_thread_mtx_begin.batch + c_thread_mtx_distance.batch; - const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row; - const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col; - - const index_t wo_thread = b_thread / NPerBlock; - const index_t n_thread = b_thread % NPerBlock; - - p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread, - ho_block_data_begin + ho_thread, - wo_block_data_begin + wo_thread, - n_block_data_begin + n_thread)] = - p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)]; - } - } - } - } -#elif 1 + // output: register to global mem, const auto c_thread_mtx_begin = blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); @@ -331,6 +296,5 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn n_block_data_begin + n_thread_data_begin), out_10d_thread_desc.GetLengths(), Number{}); -#endif } }; diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp new file mode 100644 index 0000000000..587435ad99 --- /dev/null +++ b/src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp @@ -0,0 +1,362 @@ +#pragma once +#include "common.hip.hpp" +#include "ConstantTensorDescriptor.hip.hpp" +#include "ConstantMatrixDescriptor.hip.hpp" +#include "blockwise_2d_tensor_op.hip.hpp" +#include "blockwise_3d_tensor_op.hip.hpp" +#include "blockwise_4d_tensor_op.hip.hpp" +#include "threadwise_nd_tensor_op.hip.hpp" +#include "threadwise_4d_tensor_op.hip.hpp" +#include "blockwise_batched_gemm.hip.hpp" + +template +struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn +{ + __device__ void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) const + { + // be careful of this assertion + static_assert( + NPerThread <= NPerBlock && NPerBlock % NPerThread == 0, + "wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; + constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; + constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{}; + + constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); + + constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0); + constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1); + constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2); + constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3); + + constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); + constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); + + constexpr index_t HiPerBlock = HoPerBlock + Y - 1; + constexpr index_t WiPerBlock = WoPerBlock + X - 1; + + // divide block work: [K, Ho, Wo, N] + static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && + Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, + "wrong! cannot evenly divide work for workgroup "); + + constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; + constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; + constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; + + const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); + index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); + const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork); + itmp -= h_block_work_id * (WBlockWork * NBlockWork); + const index_t w_block_work_id = itmp / NBlockWork; + const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; + + const index_t k_block_data_begin = k_block_work_id * KPerBlock; + const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; + const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; + const index_t n_block_data_begin = n_block_work_id * NPerBlock; + + const index_t hi_block_data_begin = ho_block_data_begin; + const index_t wi_block_data_begin = wo_block_data_begin; + + // global tensor view + constexpr auto wei_c_x_k_global_desc = + make_ConstantTensorDescriptor(Sequence{}, Sequence{}); + + // LDS tensor view + // be careful of alignment + constexpr index_t max_align = mod_conv::max( + InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB); + + constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_c_x_k_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + // tensor view of threadwise output in register + constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + + // blockwise copy + // input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N] + auto map_chwn2nchw = Sequence<1, 2, 3, 0>{}; +#if 0 + const auto blockwise_in_copy_reorder = + Blockwise4dTensorCopyReorder1, + decltype(map_chwn2nchw)>{}; +#else + auto map_thread_cluster_2_src_cluster = Sequence<1, 2, 0, 3>{}; + + const auto blockwise_in_copy_reorder = + Blockwise4dTensorCopyReorder3, + Sequence<4, 1, 1, 2>, + Sequence<4, 8, 2, 2>, + decltype(map_chwn2nchw), + decltype(map_thread_cluster_2_src_cluster), + 2, + 4>{}; + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + printf("size %u\n", blockwise_in_copy_reorder.GetRegisterClipboardSize()); + } +#endif +#endif + + // blockwise wei copy + // format is [CPerBlock, X * KPerBlock] + const auto blockwise_wei_copy = +#if 0 + Blockwise3dTensorCopy1{}; +#else + Blockwise3dTensorCopy3, + WeiBlockCopyDataPerRead>{}; +#endif + + // a series of blockwise batched GEMM + // C_matrix += transpose(A_matrix) * B_matrix + // A_matrix and B_matrix saved in LDS, C_matrix saved in register + // A_matrix[C,K] is a sub-matrix of wei_block[C,K] + // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] + // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] + constexpr auto a_c_k_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + constexpr auto b_c_wn_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + constexpr auto c_k_wn_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + const auto blockwise_batch_gemm = + BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< + BlockSize, + decltype(a_c_k_block_mtx_desc), + decltype(b_c_wn_block_mtx_desc), + decltype(c_k_wn_thread_mtx_desc), + 0, + in_c_h_w_n_block_desc.GetStride(I1), + out_k_h_w_n_thread_desc.GetStride(I1), + HoPerBlock, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + HoPerThread, + GemmDataPerReadA, + GemmDataPerReadB>{}; + + // LDS: be careful of alignment + constexpr index_t in_block_space = + in_c_h_w_n_block_desc.GetElementSpace(Number{}); + constexpr index_t wei_block_space = + wei_c_x_k_block_desc.GetElementSpace(Number{}); + + __shared__ Float p_in_block[in_block_space]; + __shared__ Float p_wei_block[wei_block_space]; + + // register + Float p_out_thread[out_k_h_w_n_thread_desc.GetElementSpace()]; + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc"); + print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc"); + + print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc"); + print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc"); + + printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space); + } +#endif + + // set threadwise output tensor to 0 + threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); + + const Float* p_in_global_block_offset = + p_in_global + in_n_c_h_w_global_desc.Get1dIndex( + n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin); + + const Float* p_wei_global_block_offset = + p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); + + for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, + p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1), + p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) + { + for(index_t y = 0; y < Y; ++y) + { + blockwise_in_copy_reorder.Run(p_in_global_block_offset + + in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, 0), + p_in_block); + + blockwise_wei_copy.Run(p_wei_global_block_offset + + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, 0), + p_wei_block); + + __syncthreads(); + + for(index_t x = 0; x < X; ++x) + { + blockwise_batch_gemm.Run(p_wei_block + wei_c_x_k_block_desc.Get1dIndex(0, x, 0), + p_in_block + + in_c_h_w_n_block_desc.Get1dIndex(0, 0, x, 0), + p_out_thread); + } + + __syncthreads(); + } + } + +// output: register to global mem, +#if 0 + const auto c_thread_mtx_begin = + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k) + { + for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho) + { + for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo) + { + for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n) + { + const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n); + + const auto c_thread_mtx_distance = + blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b); + + const index_t ho_thread = + c_thread_mtx_begin.batch + c_thread_mtx_distance.batch; + const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row; + const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col; + + const index_t wo_thread = b_thread / NPerBlock; + const index_t n_thread = b_thread % NPerBlock; + + p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread, + ho_block_data_begin + ho_thread, + wo_block_data_begin + wo_thread, + n_block_data_begin + n_thread)] = + p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)]; + } + } + } + } +#elif 1 + const auto c_thread_mtx_begin = + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const index_t k_thread_data_begin = c_thread_mtx_begin.row; + const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; + const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; + const index_t n_thread_data_begin = + c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; + + // output is a 10d tensor + constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N1 = NPerBlock / N2; + + constexpr index_t W2 = + (GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC); + constexpr index_t W1 = WoPerBlock / W2; + + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( + Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + + print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + } +#endif + + threadwise_10d_tensor_copy(out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + out_k_h_w_n_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); +#endif + } +}; diff --git a/src/include/threadwise_2d_tensor_op.hip.hpp b/src/include/threadwise_2d_tensor_op.hip.hpp index f8b8f722e3..34f34db086 100644 --- a/src/include/threadwise_2d_tensor_op.hip.hpp +++ b/src/include/threadwise_2d_tensor_op.hip.hpp @@ -29,26 +29,21 @@ __device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __re // TODO: in order to optimize mem access for different mem type, // need to write specialized version -template +template __device__ void threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths, - DstFromSrcReorder, + MapDst2Src, F f) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; - constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0); - constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1); + constexpr index_t IR0 = MapDst2Src{}.Get(I0); + constexpr index_t IR1 = MapDst2Src{}.Get(I1); constexpr auto src_desc = SrcDesc{}; constexpr auto dst_desc = DstDesc{}; @@ -78,19 +73,19 @@ __device__ void threadwise_2d_tensor_set_zero(Desc, Float* __restrict__ p) Desc{}, p, f_set_zero); } -template +template __device__ void threadwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths, - DstFromSrcReorder) + MapDst2Src) { auto f_copy = [](const Float& src, Float& dst) { dst = src; }; threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( - SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy); + SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy); } template diff --git a/src/include/threadwise_4d_tensor_op.hip.hpp b/src/include/threadwise_4d_tensor_op.hip.hpp index 21fed4d286..05894d434f 100644 --- a/src/include/threadwise_4d_tensor_op.hip.hpp +++ b/src/include/threadwise_4d_tensor_op.hip.hpp @@ -42,7 +42,7 @@ template __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( SrcDesc, @@ -50,7 +50,7 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d DstDesc, DstData* __restrict__ p_dst, SrcOpLengths, - DstFromSrcReorder, + MapDst2Src, F f) { constexpr auto I0 = Number<0>{}; @@ -58,10 +58,10 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0); - constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1); - constexpr index_t IR2 = DstFromSrcReorder{}.Get(I2); - constexpr index_t IR3 = DstFromSrcReorder{}.Get(I3); + constexpr index_t IR0 = MapDst2Src{}.Get(I0); + constexpr index_t IR1 = MapDst2Src{}.Get(I1); + constexpr index_t IR2 = MapDst2Src{}.Get(I2); + constexpr index_t IR3 = MapDst2Src{}.Get(I3); constexpr auto src_desc = SrcDesc{}; constexpr auto dst_desc = DstDesc{}; @@ -82,7 +82,29 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); +#if 1 f(p_src[aindex], p_dst[bindex]); +#else + if(get_block_1d_id() == 0) + { + printf("tid %5u, " + "src did %u %u %u %u, " + "dst did %u %u %u %u, " + "aindex %5u, " + "bindex %5u\n", + get_thread_local_1d_id(), + did0, + did1, + did2, + did3, + did[IR0], + did[IR1], + did[IR2], + did[IR3], + aindex, + bindex); + } +#endif } } } @@ -103,19 +125,19 @@ template + class MapDst2Src> __device__ void threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, const SrcData* __restrict__ p_src, DstDesc, DstData* __restrict__ p_dst, SrcOpLengths, - DstFromSrcReorder) + MapDst2Src) { auto f_copy = [](const SrcData& src, DstData& dst) { dst = static_cast(src); }; threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( - SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy); + SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy); } template @@ -137,13 +159,12 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc, SrcOpLengths, Number) { - using Float2 = float2; - using Float4 = float4; - static_assert(SrcDesc{}.GetDimension() == 4 && DstDesc{}.GetDimension() == 4 && - SrcOpLengths::nDim == 4, + SrcOpLengths::GetSize() == 4, "wrong! should be 4 dimension"); + using vector_t = typename vector_type::MemoryType; + constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; @@ -183,24 +204,8 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc, const index_t dst_index = dst_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead); - if(DataPerRead == 1) - { - p_dst[dst_index] = p_src[src_index]; - } - else if(DataPerRead == 2) - { - *(reinterpret_cast(p_dst + dst_index)) = - *(reinterpret_cast(p_src + src_index)); - } - else if(DataPerRead == 4) - { - *(reinterpret_cast(p_dst + dst_index)) = - *(reinterpret_cast(p_src + src_index)); - } - else - { - assert(false); - } + *(reinterpret_cast(&p_dst[dst_index])) = + *(reinterpret_cast(&p_src[src_index])); } } } diff --git a/src/include/threadwise_nd_tensor_op.hip.hpp b/src/include/threadwise_nd_tensor_op.hip.hpp index 42e5d1660c..fa8ada1fb1 100644 --- a/src/include/threadwise_nd_tensor_op.hip.hpp +++ b/src/include/threadwise_nd_tensor_op.hip.hpp @@ -175,7 +175,7 @@ __device__ void threadwise_10d_tensor_copy(SrcDesc, using vector_t = typename vector_type::MemoryType; static_assert(SrcDesc{}.GetDimension() == 10 && DstDesc{}.GetDimension() == 10 && - SrcOpLengths::nDim == 10, + SrcOpLengths::GetSize() == 10, "wrong! should be 10 dimension"); constexpr auto I0 = Number<0>{};