From 569ad66e2a03789c4a1fa6659dc8296b4dfb868b Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 23 Apr 2019 17:51:14 -0500 Subject: [PATCH] added implicit gemm v1r3 lds_double_buffer NCHW * CYXK = KNHW, reworked static functionals --- ...lution_implicit_gemm_v1_chwn_cyxk_khwn.hpp | 135 ++--- ...lution_implicit_gemm_v1_nchw_cyxk_khwn.hpp | 88 +++- driver/driver.hip.cpp | 48 +- src/include/Array.hip.hpp | 16 +- src/include/ConstantTensorDescriptor.hip.hpp | 92 +--- src/include/Sequence.hip.hpp | 95 +++- src/include/blockwise_4d_tensor_op.hip.hpp | 19 +- src/include/blockwise_batched_gemm.hip.hpp | 20 +- src/include/common.hip.hpp | 1 + src/include/functional.hip.hpp | 66 ++- src/include/functional2.hip.hpp | 117 +++++ ..._implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp | 318 +++++++----- ...1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp | 407 --------------- ..._implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp | 189 +++++-- ..._implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp | 68 ++- ..._implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp | 169 +++++-- ...3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp | 165 ++++-- ...3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp | 472 ++++++++++++++++++ ..._implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp | 452 +++++++++++++++++ src/include/threadwise_2d_tensor_op.hip.hpp | 2 + src/include/threadwise_4d_tensor_op.hip.hpp | 2 + src/include/threadwise_nd_tensor_op.hip.hpp | 283 ++--------- 22 files changed, 2117 insertions(+), 1107 deletions(-) create mode 100644 src/include/functional2.hip.hpp delete mode 100644 src/include/gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp create mode 100644 src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp create mode 100644 src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp diff --git a/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp b/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp index 613b55a81e..2fc66819d3 100644 --- a/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp +++ b/driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp @@ -3,7 +3,6 @@ #include "device.hpp" #include "gridwise_convolution_wrapper.hip.hpp" #include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp" -#include "gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp" #include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp" #include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp" #include "gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp" @@ -81,6 +80,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, #if 0 // for 3x3, 34x34, v1r1, Pascal + constexpr index_t BlockSize = 128; + constexpr index_t NPerBlock = 16; constexpr index_t KPerBlock = 64; constexpr index_t CPerBlock = 4; @@ -92,14 +93,6 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t HoPerThread = 1; constexpr index_t WoPerThread = 2; - constexpr index_t InBlockCopy_ThreadPerDimC = 4; - constexpr index_t InBlockCopy_ThreadPerDimH = 4; - constexpr index_t InBlockCopy_ThreadPerDimW = 2; - constexpr index_t InBlockCopy_ThreadPerDimN = 4; - constexpr index_t InBlockCopyDataPerRead_N = 4; - - constexpr index_t WeiBlockCopyDataPerRead_K = 4; - constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmMLevel0Cluster = 4; @@ -110,11 +103,16 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadB = 4; - constexpr index_t OutThreadCopyDataPerWrite_N = 2; + using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 4>; + constexpr index_t InBlockCopyDataPerRead_N = 4; - constexpr index_t BlockSize = 128; + constexpr index_t WeiBlockCopyDataPerRead_K = 4; + + constexpr index_t OutThreadCopyDataPerWrite_N = 2; #elif 0 // for 3x3, 34x34, v1r2, Pascal, in-block-copy1 + constexpr index_t BlockSize = 128; + constexpr index_t NPerBlock = 4; constexpr index_t KPerBlock = 64; constexpr index_t CPerBlock = 8; @@ -126,14 +124,6 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t HoPerThread = 1; constexpr index_t WoPerThread = 2; - constexpr index_t InBlockCopy_ThreadPerDimC = 4; - constexpr index_t InBlockCopy_ThreadPerDimH = 4; - constexpr index_t InBlockCopy_ThreadPerDimW = 2; - constexpr index_t InBlockCopy_ThreadPerDimN = 1; - constexpr index_t InBlockCopyDataPerRead_N = 4; - - constexpr index_t WeiBlockCopyDataPerRead_K = 4; - constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmMLevel0Cluster = 4; @@ -144,9 +134,76 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadB = 4; - constexpr index_t OutThreadCopyDataPerWrite_N = 2; + using InBlockCopyClusterLengths_CHWN = Sequence<0, 0, 0, 0>; // not used + constexpr index_t InBlockCopyDataPerRead_N = 4; + constexpr index_t WeiBlockCopyDataPerRead_K = 4; + + constexpr index_t OutThreadCopyDataPerWrite_N = 2; +#elif 1 + // for 3x3, 34x34, v1r3, Pascal + // for 3x3, 28x28, v1r3, Pascal + // for 3x3, 14x14, v1r3, Pascal constexpr index_t BlockSize = 128; + + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>; + constexpr index_t InBlockCopyDataPerRead_N = 4; + + constexpr index_t WeiBlockCopyDataPerRead_K = 4; + + constexpr index_t OutThreadCopyDataPerWrite_N = 2; +#elif 0 + // for 3x3, 34x34, v1r3, Pascal, bad + constexpr index_t BlockSize = 128; + + constexpr index_t NPerBlock = 1; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 32; + + constexpr index_t NPerThread = 1; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopyClusterLengths_CHWN = Sequence<2, 2, 32, 1>; + constexpr index_t InBlockCopyDataPerRead_N = 1; + + constexpr index_t WeiBlockCopyDataPerRead_K = 2; + + constexpr index_t OutThreadCopyDataPerWrite_N = 1; #elif 0 // for 3x3, 34x34, v1r1, Vega 20 constexpr index_t NPerBlock = 16; @@ -309,38 +366,6 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr index_t WeiBlockCopyDataPerRead_K = 4; - constexpr index_t OutThreadCopyDataPerWrite_N = 2; -#elif 1 - // for 3x3, 28x28, v1r3, Pascal - // for 3x3, 14x14, v1r3, Pascal - constexpr index_t BlockSize = 128; - - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>; - constexpr index_t InBlockCopyDataPerRead_N = 4; - - constexpr index_t WeiBlockCopyDataPerRead_K = 4; - constexpr index_t OutThreadCopyDataPerWrite_N = 2; #elif 0 // for 1x1, 28x28, v1r1, Pascal @@ -419,13 +444,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, constexpr auto gridwise_conv = #if 0 GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn -#elif 0 - GridwiseConvolutionImplicitGemm_v1r1_lds_double_buffer_chwn_cyxk_khwn #elif 0 GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn -#elif 0 - GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn #elif 1 + GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn +#elif 0 GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn #endif void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, @@ -62,7 +64,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); out_khwn_device_buf.ToDevice(out_khwn.mData.data()); -#if 1 +#if 0 // for 3x3, 28x28, v1r2, Pascal constexpr index_t BlockSize = 128; @@ -93,8 +95,78 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, constexpr index_t InBlockReorderDataPerRead_W = 2; constexpr index_t InBlockReorderDataPerWrite_N = 4; - using WeiBlockCopyClusterLengths_CXK = Sequence<4, 1, 32>; - constexpr index_t WeiBlockCopyDataPerRead_C = 4; + using WeiBlockCopyClusterLengths = Sequence<4, 1, 32>; + constexpr index_t WeiBlockCopyDataPerRead_K = 4; + + constexpr index_t OutThreadCopyDataPerWrite_N = 2; +#elif 0 + // for 3x3, 28x28, v1r3, Pascal, bad + constexpr index_t BlockSize = 128; + + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>; + using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>; + using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; + constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW + constexpr index_t InBlockReorderDataPerWrite_N = 1; // not used yet + + using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used + constexpr index_t WeiBlockCopyDataPerRead_K = 4; + + constexpr index_t OutThreadCopyDataPerWrite_N = 2; +#elif 1 + // for 3x3, 34x34, v1r3, Pascal + constexpr index_t BlockSize = 128; + + constexpr index_t NPerBlock = 2; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 16; + + constexpr index_t NPerThread = 2; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>; + using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 1, 16>; + using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; + constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW + constexpr index_t InBlockReorderDataPerWrite_N = 1; // not used yet + + using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used + constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t OutThreadCopyDataPerWrite_N = 2; #endif @@ -108,8 +180,12 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, for(index_t i = 0; i < nrepeat; ++i) { constexpr auto gridwise_conv = -#if 1 +#if 0 GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn +#elif 1 + GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn +#elif 1 + GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn #endif {}; float time = launch_kernel(run_gridwise_convolution, diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index acac1d09fa..a272a0e4a2 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -46,7 +46,7 @@ struct GeneratorTensor_3 #if 0 auto f_acc = std::plus{}; #else - auto f_acc = [](auto a, auto b) { return 10 * a + b; }; + auto f_acc = [](auto a, auto b) { return 100 * a + b; }; #endif return std::accumulate(dims.begin(), dims.end(), index_t(0), f_acc); @@ -390,8 +390,6 @@ void host_winograd_3x3_convolution(const Tensor& in_nchw, template void check_error(const Tensor& ref, const Tensor& result) { - // printf("\n"); - float error = 0; float max_diff = -1; float ref_value = 0, result_value = 0; @@ -405,10 +403,7 @@ void check_error(const Tensor& ref, const Tensor& result) ref_value = ref.mData[i]; result_value = result.mData[i]; } - - // printf("{%f, %f}", double(ref.mData[i]), double(result.mData[i])); } - // printf("\n"); std::cout << "error: " << error << std::endl; std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl; @@ -416,38 +411,27 @@ void check_error(const Tensor& ref, const Tensor& result) int main(int argc, char* argv[]) { -#if 0 - constexpr index_t N = 128; - constexpr index_t C = 8; - constexpr index_t HI = 28; - constexpr index_t WI = 28; +#if 1 + // 3x3, 34x34 + constexpr index_t N = 64; + constexpr index_t C = 256; + constexpr index_t HI = 34; + constexpr index_t WI = 34; constexpr index_t K = 128; constexpr index_t Y = 3; constexpr index_t X = 3; - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; -#elif 0 - // 3x3, 34x34 - constexpr index_t N = 64; - constexpr index_t C = 256; - constexpr index_t HI = 34; - constexpr index_t WI = 34; - constexpr index_t K = 128; - constexpr index_t Y = 3; - constexpr index_t X = 3; - constexpr index_t HPad = 0; constexpr index_t WPad = 0; #elif 0 // 3x3, 56x56 - constexpr index_t N = 64; - constexpr index_t C = 64; + constexpr index_t N = 64; + constexpr index_t C = 64; constexpr index_t HI = 56; constexpr index_t WI = 56; - constexpr index_t K = 128; - constexpr index_t Y = 3; - constexpr index_t X = 3; + constexpr index_t K = 128; + constexpr index_t Y = 3; + constexpr index_t X = 3; constexpr index_t HPad = 0; constexpr index_t WPad = 0; @@ -499,7 +483,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 1; constexpr index_t WPad = 1; -#elif 1 +#elif 0 // 5x5 filter, 20x86 image constexpr index_t N = 16; constexpr index_t C = 256; @@ -547,7 +531,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 10 +#elif 0 // 1x1 filter, 14x14 image constexpr index_t N = 128; constexpr index_t C = 512; @@ -619,9 +603,9 @@ int main(int argc, char* argv[]) device_direct_convolution_2_nchw_kcyx_nkhw #elif 0 device_direct_convolution_2_vectorized_nchw_kcyx_nkhw -#elif 1 - device_convolution_implicit_gemm_v1_chwn_cyxk_khwn #elif 0 + device_convolution_implicit_gemm_v1_chwn_cyxk_khwn +#elif 1 device_convolution_implicit_gemm_v1_nchw_cyxk_khwn #elif 0 device_convolution_implicit_gemm_v2_chwn_cyxk_khwn diff --git a/src/include/Array.hip.hpp b/src/include/Array.hip.hpp index c386542d22..2801a426f5 100644 --- a/src/include/Array.hip.hpp +++ b/src/include/Array.hip.hpp @@ -19,6 +19,20 @@ struct Array __host__ __device__ const TData& operator[](index_t i) const { return mData[i]; } __host__ __device__ TData& operator[](index_t i) { return mData[i]; } + + __host__ __device__ auto PushBack(TData x) const + { + Array new_array; + + static_for<0, NSize, 1>{}([=](auto I) { + constexpr index_t i = I.Get(); + new_array[i] = mData[i]; + }); + + new_array[NSize] = x; + + return new_array; + } }; template @@ -51,4 +65,4 @@ __host__ __device__ auto reorder_array_given_old2new(const Array& }); return new_array; -} \ No newline at end of file +} diff --git a/src/include/ConstantTensorDescriptor.hip.hpp b/src/include/ConstantTensorDescriptor.hip.hpp index 990d1724f4..5c3f0d8132 100644 --- a/src/include/ConstantTensorDescriptor.hip.hpp +++ b/src/include/ConstantTensorDescriptor.hip.hpp @@ -1,80 +1,30 @@ #pragma once #include "common.hip.hpp" -// this is ugly, only for 2d -template -__host__ __device__ constexpr auto calculate_default_strides(Sequence) +template +__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, RemainLengths) { - return Sequence{}; + constexpr index_t previous_stride = PreviousStrides{}.Front(); + constexpr index_t current_length = RemainLengths{}.Back(); + constexpr index_t current_stride = current_length * previous_stride; + + return calculate_default_strides_impl(PreviousStrides{}.PushFront(Number{}), + RemainLengths{}.PopBack()); } -// this is ugly, only for 3d -template -__host__ __device__ constexpr auto calculate_default_strides(Sequence) +template +__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, Sequence) { - return Sequence{}; + constexpr index_t previous_stride = PreviousStrides{}.Front(); + constexpr index_t current_stride = L1 * previous_stride; + + return PreviousStrides{}.PushFront(Number{}); } -// this is ugly, only for 4d -template -__host__ __device__ constexpr auto calculate_default_strides(Sequence) +template +__host__ __device__ constexpr auto calculate_default_strides(Lengths) { - return Sequence{}; -} - -// this is ugly, only for 6d -template -__host__ __device__ constexpr auto calculate_default_strides(Sequence) -{ - return Sequence{}; -} - -// this is ugly, only for 8d -template -__host__ __device__ constexpr auto - calculate_default_strides(Sequence) -{ - return Sequence{}; -} - -// this is ugly, only for 8d -template -__host__ __device__ constexpr auto - calculate_default_strides(Sequence) -{ - return Sequence{}; + return calculate_default_strides_impl(Sequence<1>{}, Lengths{}); } // this is ugly, only for 2d @@ -186,6 +136,14 @@ struct ConstantTensorDescriptor return Get1dIndex(multi_id); } + template + __host__ __device__ static constexpr index_t Get1dIndex(Sequence multi_id) + { + static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent"); + + return Get1dIndex(Is...); + } + __host__ __device__ static Array GetMultiIndex(index_t id) { Array multi_id; diff --git a/src/include/Sequence.hip.hpp b/src/include/Sequence.hip.hpp index 48a86505e7..632626392e 100644 --- a/src/include/Sequence.hip.hpp +++ b/src/include/Sequence.hip.hpp @@ -34,11 +34,15 @@ struct Sequence template __host__ __device__ constexpr auto ReorderGivenOld2New(Sequence /*old2new*/) const { - // don't know how to implement this + // TODO: don't know how to implement this printf("Sequence::ReorderGivenOld2New not implemented"); assert(false); } + __host__ __device__ constexpr index_t Front() const { return mData[0]; } + + __host__ __device__ constexpr index_t Back() const { return mData[mSize - 1]; } + template __host__ __device__ constexpr auto PushFront(Number) const { @@ -69,15 +73,98 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence) return Sequence{}; } -template +#if 0 +// TODO: for some reason, compiler cannot instantiate this template +template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { static_assert(sizeof...(Is) > 0, "empty Sequence!"); return Sequence{}; } +#else +// TODO: delete these very ugly mess +template +__host__ __device__ constexpr auto sequence_pop_back(Sequence) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto sequence_pop_back(Sequence) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto sequence_pop_back(Sequence) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto sequence_pop_back(Sequence) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto sequence_pop_back(Sequence) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto sequence_pop_back(Sequence) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto sequence_pop_back(Sequence) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto sequence_pop_back(Sequence) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto + sequence_pop_back(Sequence) +{ + return Sequence{}; +} +#endif #if 1 -// this is ugly, only for 2 sequences +// TODO: fix these mess template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) { @@ -86,7 +173,6 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Seq return Sequence{}; } -// this is ugly, only for 3 sequences template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence, Sequence) @@ -98,6 +184,7 @@ transform_sequences(F f, Sequence, Sequence, Sequence) return Sequence{}; } #else +// TODO:: these doesn't compile template struct transform_sequences_impl { diff --git a/src/include/blockwise_4d_tensor_op.hip.hpp b/src/include/blockwise_4d_tensor_op.hip.hpp index e301631e46..e513eb1e81 100644 --- a/src/include/blockwise_4d_tensor_op.hip.hpp +++ b/src/include/blockwise_4d_tensor_op.hip.hpp @@ -1,5 +1,6 @@ #pragma once #include "ConstantTensorDescriptor.hip.hpp" +#include "threadwise_nd_tensor_op.hip.hpp" template __device__ void @@ -957,6 +958,7 @@ struct Blockwise4dTensorCopyReorder3 constexpr auto thread_sub_tensor_desc = make_ConstantTensorDescriptor(SrcClusterLengths{}, thread_tensor_desc.GetStrides()); +#if 1 for(index_t icluster_d0 = 0; icluster_d0 < cluster_per_dims.Get(I0); ++icluster_d0) { for(index_t icluster_d1 = 0; icluster_d1 < cluster_per_dims.Get(I1); ++icluster_d1) @@ -978,16 +980,21 @@ struct Blockwise4dTensorCopyReorder3 icluster_d2 * thread_sub_tensor_lengths.Get(I2), icluster_d3 * thread_sub_tensor_lengths.Get(I3)); - threadwise_4d_tensor_copy_v2(SrcDesc{}, - p_src + src_offset + mSrcMyThreadOffset, - thread_tensor_desc, - p_clipboard + clipboard_offset, - thread_sub_tensor_lengths, - Number{}); + threadwise_nd_tensor_copy(SrcDesc{}, + p_src + src_offset + mSrcMyThreadOffset, + thread_tensor_desc, + p_clipboard + clipboard_offset, + thread_sub_tensor_lengths, + Number{}); } } } } +#else + static_ford{}([=](auto cluster_ids) { + + }); +#endif #if 0 if(get_block_1d_id() == 0) diff --git a/src/include/blockwise_batched_gemm.hip.hpp b/src/include/blockwise_batched_gemm.hip.hpp index 364d3646d0..cdde25efa0 100644 --- a/src/include/blockwise_batched_gemm.hip.hpp +++ b/src/include/blockwise_batched_gemm.hip.hpp @@ -253,9 +253,23 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 #if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) { - printf("a: %f %f %f %f %f %f %f %f, b: %f %f %f %f %f %f %f %f\n", - p_a_thread[0], p_a_thread[1], p_a_thread[2], p_a_thread[3], p_a_thread[4], p_a_thread[5], p_a_thread[6], p_a_thread[7], - p_b_thread[0], p_b_thread[1], p_b_thread[2], p_b_thread[3], p_b_thread[4], p_b_thread[5], p_b_thread[6], p_b_thread[7]); + printf("a: %f %f %f %f %f %f %f %f, b: %f %f %f %f %f %f %f %f\n", + p_a_thread[0], + p_a_thread[1], + p_a_thread[2], + p_a_thread[3], + p_a_thread[4], + p_a_thread[5], + p_a_thread[6], + p_a_thread[7], + p_b_thread[0], + p_b_thread[1], + p_b_thread[2], + p_b_thread[3], + p_b_thread[4], + p_b_thread[5], + p_b_thread[6], + p_b_thread[7]); } #endif diff --git a/src/include/common.hip.hpp b/src/include/common.hip.hpp index a3d5596c7f..5f99872d8e 100644 --- a/src/include/common.hip.hpp +++ b/src/include/common.hip.hpp @@ -4,6 +4,7 @@ #include "Sequence.hip.hpp" #include "Array.hip.hpp" #include "functional.hip.hpp" +#include "functional2.hip.hpp" #if DEVICE_BACKEND_HIP #include "amd_inline_asm.hip.hpp" diff --git a/src/include/functional.hip.hpp b/src/include/functional.hip.hpp index 3db890f46b..90c976b2a1 100644 --- a/src/include/functional.hip.hpp +++ b/src/include/functional.hip.hpp @@ -21,7 +21,7 @@ struct static_for_impl template __host__ __device__ void operator()(F) const { - // do nothing + // no work left, just return return; } }; @@ -48,7 +48,7 @@ struct static_const_reduce_n static_assert(NLoop > 1, "out-of-range"); constexpr auto a = f(Number{}); - auto b = static_const_reduce_n{}(f, r); // cannot use constexpr here, weird + auto b = static_const_reduce_n{}(f, r); // TODO: cannot use constexpr here, weird return r(a, b); } }; @@ -70,3 +70,65 @@ __host__ __device__ constexpr auto unpacker(F f) return [=](auto xs_array){ f(xs...); }; } #endif + +struct forwarder +{ + template + __host__ __device__ constexpr T operator()(T&& x) const + { + return std::forward(x); + } +}; + +// Emulate compile time if statement for C++14 +// Get the idea from +// "https://baptiste-wicht.com/posts/2015/07/simulate-static_if-with-c11c14.html" +// TODO: use if constexpr, when C++17 is supported +template +struct static_if +{ +}; + +template <> +struct static_if +{ + using Type = static_if; + + template + __host__ __device__ constexpr auto operator()(F f) const + { + // This is a trick for compiler: + // Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it, + // this will make "f" a generic lambda, so that "f" won't be compiled until here + f(forwarder{}); + return Type{}; + } + + template + __host__ __device__ static constexpr auto else_(F) + { + return Type{}; + } +}; + +template <> +struct static_if +{ + using Type = static_if; + + template + __host__ __device__ constexpr auto operator()(F) const + { + return Type{}; + } + + template + __host__ __device__ static constexpr auto else_(F f) + { + // This is a trick for compiler: + // Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it, + // this will make "f" a generic lambda, so that "f" won't be compiled until here + f(forwarder{}); + return Type{}; + } +}; diff --git a/src/include/functional2.hip.hpp b/src/include/functional2.hip.hpp new file mode 100644 index 0000000000..ac83d8dcf2 --- /dev/null +++ b/src/include/functional2.hip.hpp @@ -0,0 +1,117 @@ +#pragma once +#include "Sequence.hip.hpp" + +template +struct static_ford_impl +{ + // F signature: F(Sequence<...> multi_id) + // CurrentMultiIndex: Sequence<...> + // RemainLengths: Sequence<...> + template + __host__ __device__ void operator()(F f, CurrentMultiIndex, RemainLengths) const + { + static_assert(RemainLengths::GetSize() == RemainDim, "wrong!"); + static_assert(RemainDim > 1, "wrong!"); + + constexpr auto next_length = RemainLengths{}.Front(); + + static_for<0, next_length, 1>{}([=](auto I) { + static_ford_impl{}( + f, CurrentMultiIndex{}.PushBack(I), RemainLengths{}.PopFront()); + }); + } +}; + +template <> +struct static_ford_impl<1> +{ + // F signature: F(Sequence multi_id) + // CurrentMultiIndex: Sequence<...> + // RemainLengths: Sequence<...> + template + __host__ __device__ void operator()(F f, CurrentMultiIndex, RemainLengths) const + { + static_assert(RemainLengths::GetSize() == 1, "wrong!"); + + constexpr index_t last_length = RemainLengths{}.Front(); + + static_for<0, last_length, 1>{}([=](auto I) { f(CurrentMultiIndex{}.PushBack(I)); }); + } +}; + +// Lengths is Sequence<...> +template +struct static_ford +{ + // F signature: F(Sequence multi_id) + template + __host__ __device__ void operator()(F f) const + { + constexpr index_t first_length = Lengths{}.Front(); + + static_for<0, first_length, 1>{}([=](auto I) { + static_ford_impl{}( + f, Sequence{}, Lengths{}.PopFront()); + }); + } +}; + +template +struct ford_impl +{ + // F signature: F(Array<...> multi_id) + // CurrentMultiIndex: Array<...> + // RemainLengths: Sequence<...> + template + __host__ __device__ void + operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const + { + static_assert(RemainLengths::GetSize() == RemainDim, "wrong!"); + static_assert(RemainDim > 1, "wrong!"); + + constexpr auto next_length = RemainLengths{}.Front(); + + for(index_t i = 0; i < next_length; ++i) + { + ford_impl{}(f, current_multi_id.PushBack(i), RemainLengths{}.PopFront()); + } + } +}; + +template <> +struct ford_impl<1> +{ + // F signature: F(Array<...> multi_id) + // CurrentMultiIndex: Array<...> + // RemainLengths: Sequence<...> + template + __host__ __device__ void + operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const + { + static_assert(RemainLengths::GetSize() == 1, "wrong!"); + + constexpr index_t last_length = RemainLengths{}.Front(); + + for(index_t i = 0; i < last_length; ++i) + { + f(current_multi_id.PushBack(i)); + } + } +}; + +// Lengths is Sequence<...> +template +struct ford +{ + // F signature: F(Array<...> multi_id) + template + __host__ __device__ void operator()(F f) const + { + constexpr index_t first_length = Lengths{}.Front(); + + for(index_t i = 0; i < first_length; ++i) + { + ford_impl{}(f, Array{i}, Lengths{}.PopFront()); + } + } +}; diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp index 2a26255c32..97b2984d1d 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp @@ -32,10 +32,10 @@ template + class InBlockCopyClusterLengths_CHWN, + index_t InBlockCopyDataPerRead_N, + index_t WeiBlockCopyDataPerRead_K, + index_t OutThreadCopyDataPerWrite_N> struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn { __device__ void Run(const Float* const __restrict__ p_in_global, @@ -43,28 +43,30 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn Float* const __restrict__ p_out_global) const { // be careful of this assertion - static_assert( - NPerThread <= NPerBlock && NPerBlock % NPerThread == 0, - "wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"); + static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock && + NPerBlock % GemmNPerThreadSubC == 0) || + (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0), + "wrong!"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr auto in_chwn_global_desc = InGlobalDesc{}; - constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; - constexpr auto out_khwn_global_desc = OutGlobalDesc{}; + constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{}; + constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; + constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{}; - constexpr index_t C = in_chwn_global_desc.GetLength(I0); + constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0); - constexpr index_t K = out_khwn_global_desc.GetLength(I0); - constexpr index_t Ho = out_khwn_global_desc.GetLength(I1); - constexpr index_t Wo = out_khwn_global_desc.GetLength(I2); - constexpr index_t N = out_khwn_global_desc.GetLength(I3); + constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0); + constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1); + constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2); + constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3); - constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1); - constexpr index_t X = wei_cyxk_global_desc.GetLength(I2); + constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); + constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); constexpr index_t HiPerBlock = HoPerBlock + Y - 1; constexpr index_t WiPerBlock = WoPerBlock + X - 1; @@ -95,24 +97,35 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn const index_t wi_block_data_begin = wo_block_data_begin; // flattend (2d) tensor view of gridwise weight - constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); + constexpr auto wei_cyx_k_global_desc = + make_ConstantTensorDescriptor(Sequence{}); // tensor view of blockwise input and weight in LDS // be careful of alignment - constexpr index_t max_align = mod_conv::max( - InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB); + constexpr index_t max_align = mod_conv::max(InBlockCopyDataPerRead_N, + WeiBlockCopyDataPerRead_K, + GemmDataPerReadA, + GemmDataPerReadB); - constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, + Number{}); - constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + // this check is ad-hoc + // TODO: need to properly implement tensor descriptor with alignment + static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, + "GemmDataPerReadB alignment requirement is not meet"); - constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + constexpr auto wei_cyx_k_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, + Number{}); + + constexpr auto wei_c_y_x_k_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, + Number{}); // tensor view of threadwise output in register - constexpr auto out_khwn_thread_desc = make_ConstantTensorDescriptor( + constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( Sequence{}); // blockwise copy @@ -121,18 +134,18 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn #if 0 Blockwise4dTensorCopy1{}; + decltype(in_c_h_w_n_global_desc), + decltype(in_c_h_w_n_block_desc), + decltype(in_c_h_w_n_block_desc.GetLengths()), + InBlockCopyDataPerRead_N>{}; #else Blockwise4dTensorCopy3{}; + decltype(in_c_h_w_n_global_desc), + decltype(in_c_h_w_n_block_desc), + decltype(in_c_h_w_n_block_desc.GetLengths()), + InBlockCopyClusterLengths_CHWN, + InBlockCopyDataPerRead_N>{}; #endif // blockwise wei copy @@ -140,10 +153,10 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}; + decltype(wei_cyx_k_global_desc), + decltype(wei_cyx_k_block_desc), + decltype(wei_cyx_k_block_desc.GetLengths()), + WeiBlockCopyDataPerRead_K>{}; // a series of blockwise batched GEMM // C_matrix += transpose(A_matrix) * B_matrix @@ -151,28 +164,30 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn // A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K] // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] - constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); + constexpr auto a_c_k_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); - constexpr auto b_cxwn_block_mtx_desc = + constexpr auto b_c_wn_block_mtx_desc = make_ConstantMatrixDescriptor(Number{}, Number{}, - Number{}); + Number{}); - constexpr auto c_kxwn_thread_mtx_desc = + constexpr auto c_k_wn_thread_mtx_desc = make_ConstantMatrixDescriptor(Number{}, Number{}, - Number{}); + Number{}); const auto blockwise_batch_gemm = BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< BlockSize, - decltype(a_cxk_block_mtx_desc), - decltype(b_cxwn_block_mtx_desc), - decltype(c_kxwn_thread_mtx_desc), + decltype(a_c_k_block_mtx_desc), + decltype(b_c_wn_block_mtx_desc), + decltype(c_k_wn_thread_mtx_desc), 0, - in_chwn_block_desc.GetStride(I1), - out_khwn_thread_desc.GetStride(I1), + in_c_h_w_n_block_desc.GetStride(I1), + out_k_h_w_n_thread_desc.GetStride(I1), HoPerBlock, GemmMPerThreadSubC, GemmNPerThreadSubC, @@ -186,30 +201,33 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn GemmDataPerReadB>{}; // LDS: be careful of alignment - constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number{}); + constexpr index_t in_block_space = + in_c_h_w_n_block_desc.GetElementSpace(Number{}); constexpr index_t wei_block_space = - wei_cyxk_block_desc.GetElementSpace(Number{}); + wei_c_y_x_k_block_desc.GetElementSpace(Number{}); __shared__ Float p_in_block[in_block_space]; __shared__ Float p_wei_block[wei_block_space]; // register - Float p_out_thread[out_khwn_thread_desc.GetElementSpace()]; + // C++ lambda doesn't capture array, use pointer instead + Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()]; + Float* const p_out_thread = p_out_thread_data; // set threadwise output tensor to 0 - threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread); + threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); const Float* p_in_global_block_offset = - p_in_global + in_chwn_global_desc.Get1dIndex( + p_in_global + in_c_h_w_n_global_desc.Get1dIndex( 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); const Float* p_wei_global_block_offset = - p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); + p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_chwn_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), + p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0), + p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0), __syncthreads()) { #if 1 @@ -241,96 +259,140 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn #else blockwise_batch_gemm.Run_asm #endif - (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0), + (p_wei_block + wei_c_y_x_k_block_desc.Get1dIndex(0, y, x, 0), + p_in_block + in_c_h_w_n_block_desc.Get1dIndex(0, y, x, 0), p_out_thread); } } } -// output: register to global mem, -#if 0 - const auto c_thread_mtx_begin = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k) - { - for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho) - { - for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo) - { - for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n) - { - const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n); - - const auto c_thread_mtx_distance = - blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b); - - const index_t ho_thread = - c_thread_mtx_begin.batch + c_thread_mtx_distance.batch; - const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row; - const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col; - - const index_t wo_thread = b_thread / NPerBlock; - const index_t n_thread = b_thread % NPerBlock; - - p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread, - ho_block_data_begin + ho_thread, - wo_block_data_begin + wo_thread, - n_block_data_begin + n_thread)] = - p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)]; - } - } - } - } -#elif 1 + // output: register to global mem, const auto c_thread_mtx_begin = blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); const index_t k_thread_data_begin = c_thread_mtx_begin.row; const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; - const index_t n_thread_data_begin = - c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; + const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; - // output is a 10d tensor - constexpr index_t N2 = GemmNPerThreadSubC; - constexpr index_t N1 = NPerBlock / N2; + static_if{}( + [&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to + // make this lambda a generic lambda, so it won't be compiled until + // instantiated + static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock && + NPerBlock % GemmNPerThreadSubC == 0), + "wrong!"); - constexpr index_t W2 = - (GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC); - constexpr index_t W1 = WoPerBlock / W2; + // output is a 10d tensor + constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N1 = NPerBlock / N2; - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; + constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / + f_dummy(NPerBlock / GemmNPerThreadSubC); + constexpr index_t W1 = WoPerBlock / W2; - constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); #if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); - print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); - } + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + } #endif - threadwise_10d_tensor_copy(out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + out_khwn_global_desc.Get1dIndex( - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin), - out_10d_thread_desc.GetLengths(), - Number{}); + threadwise_nd_tensor_copy(out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_k_h_w_n_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); + }) + .else_([&](auto f_dummy) { + static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0, + "wrong!"); + + // output is a 10d tensor + constexpr index_t N1 = NPerBlock; + + constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; + constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; + constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); + + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + + for(index_t i = 0; i < 64; ++i) + { + printf("out %f, ", p_out_thread[i]); + } + } #endif + + threadwise_nd_tensor_copy(out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_k_h_w_n_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); + }); } }; diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp deleted file mode 100644 index 34cb38822e..0000000000 --- a/src/include/gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp +++ /dev/null @@ -1,407 +0,0 @@ -#pragma once -#include "common.hip.hpp" -#include "ConstantTensorDescriptor.hip.hpp" -#include "ConstantMatrixDescriptor.hip.hpp" -#include "blockwise_4d_tensor_op.hip.hpp" -#include "blockwise_2d_tensor_op.hip.hpp" -#include "threadwise_nd_tensor_op.hip.hpp" -#include "threadwise_4d_tensor_op.hip.hpp" -#include "blockwise_batched_gemm.hip.hpp" - -template -struct GridwiseConvolutionImplicitGemm_v1r1_lds_double_buffer_chwn_cyxk_khwn -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - // be careful of this assertion - static_assert( - NPerThread <= NPerBlock && NPerBlock % NPerThread == 0, - "wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_chwn_global_desc = InGlobalDesc{}; - constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; - constexpr auto out_khwn_global_desc = OutGlobalDesc{}; - - constexpr index_t C = in_chwn_global_desc.GetLength(I0); - - constexpr index_t K = out_khwn_global_desc.GetLength(I0); - constexpr index_t Ho = out_khwn_global_desc.GetLength(I1); - constexpr index_t Wo = out_khwn_global_desc.GetLength(I2); - constexpr index_t N = out_khwn_global_desc.GetLength(I3); - - constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1); - constexpr index_t X = wei_cyxk_global_desc.GetLength(I2); - - constexpr index_t HiPerBlock = HoPerBlock + Y - 1; - constexpr index_t WiPerBlock = WoPerBlock + X - 1; - - // assert for LDS double buffer - static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided"); - - // divide block work: [K, Ho, Wo, N] - static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && - Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, - "wrong! cannot evenly divide work for workgroup "); - - constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; - constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; - constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; - - const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); - index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); - const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork); - itmp -= h_block_work_id * (WBlockWork * NBlockWork); - const index_t w_block_work_id = itmp / NBlockWork; - const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; - - const index_t k_block_data_begin = k_block_work_id * KPerBlock; - const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; - const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; - const index_t n_block_data_begin = n_block_work_id * NPerBlock; - - const index_t hi_block_data_begin = ho_block_data_begin; - const index_t wi_block_data_begin = wo_block_data_begin; - - // flattend (2d) tensor view of gridwise weight - constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of blockwise input and weight in LDS - // be careful of alignment - constexpr index_t max_align = - mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead); - - constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // tensor view of threadwise output in register - constexpr auto out_khwn_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); - - // blockwise copy - // input: format is [C, Hi, Wi, N] - const auto blockwise_in_copy = - Blockwise4dTensorCopy3{}; - - // blockwise wei copy - // format is [CPerBlock*Y*X,KPerBlock] - const auto blockwise_wei_copy = - Blockwise2dTensorCopy3{}; - - // a series of blockwise batched GEMM - // C_matrix += transpose(A_matrix) * B_matrix - // A_matrix and B_matrix saved in LDS, C_matrix saved in register - // A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K] - // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] - // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] - constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_cxwn_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - constexpr auto c_kxwn_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - const auto blockwise_batch_gemm = - BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< - BlockSize, - decltype(a_cxk_block_mtx_desc), - decltype(b_cxwn_block_mtx_desc), - decltype(c_kxwn_thread_mtx_desc), - 0, - in_chwn_block_desc.GetStride(I1), - out_khwn_thread_desc.GetStride(I1), - HoPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - HoPerThread, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // LDS: be careful of alignment - constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number{}); - - constexpr index_t wei_block_space = - wei_cyxk_block_desc.GetElementSpace(Number{}); - - // LDS double buffer - __shared__ Float p_in_block_double[2 * in_block_space]; - __shared__ Float p_wei_block_double[2 * wei_block_space]; - - const Float* p_in_global_block_offset = - p_in_global + in_chwn_global_desc.Get1dIndex( - 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); - - const Float* p_wei_global_block_offset = - p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); - - // preload data into LDS - { - Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; - Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; - - blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, - p_in_register_clipboard); - blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, - p_wei_register_clipboard); - - blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double); - blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, - p_wei_block_double); - } - - // register - Float p_out_thread[out_khwn_thread_desc.GetElementSpace()]; - - // set threadwise output tensor to 0 - threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread); - - for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C; - c_block_data_begin += 2 * CPerBlock) - { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); - - Float* p_in_block_now = - even_loop ? p_in_block_double : p_in_block_double + in_block_space; - Float* p_wei_block_now = - even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; - - Float* p_in_block_next = - even_loop ? p_in_block_double + in_block_space : p_in_block_double; - Float* p_wei_block_next = - even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; - - // load next data - Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; - Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; - - p_in_global_block_offset += CPerBlock * in_chwn_global_desc.GetStride(I0); - p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0); - - __syncthreads(); - - blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, - p_in_register_clipboard); - - blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, - p_wei_register_clipboard); - - // a series of batched GEMM - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { - blockwise_batch_gemm.Run( - p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_now + in_chwn_block_desc.Get1dIndex(0, y, x, 0), - p_out_thread); - } - } - - blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, - p_in_block_next); - blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, - p_wei_block_next); - } - } - - // tail - { - // even - p_in_global_block_offset += CPerBlock * in_chwn_global_desc.GetStride(I0); - p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0); - - __syncthreads(); - - Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; - Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; - - blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, - p_in_register_clipboard); - - blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, - p_wei_register_clipboard); - - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { - blockwise_batch_gemm.Run( - p_wei_block_double + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_double + in_chwn_block_desc.Get1dIndex(0, y, x, 0), - p_out_thread); - } - } - - blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, - p_in_block_double + in_block_space); - - blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, - p_wei_block_double + wei_block_space); - - // odd - __syncthreads(); - - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { - blockwise_batch_gemm.Run(p_wei_block_double + wei_block_space + - wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_double + in_block_space + - in_chwn_block_desc.Get1dIndex(0, y, x, 0), - p_out_thread); - } - } - } - -// output: register to global mem, -#if 0 - const auto c_thread_mtx_begin = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k) - { - for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho) - { - for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo) - { - for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n) - { - const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n); - - const auto c_thread_mtx_distance = - blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b); - - const index_t ho_thread = - c_thread_mtx_begin.batch + c_thread_mtx_distance.batch; - const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row; - const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col; - - const index_t wo_thread = b_thread / NPerBlock; - const index_t n_thread = b_thread % NPerBlock; - - p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread, - ho_block_data_begin + ho_thread, - wo_block_data_begin + wo_thread, - n_block_data_begin + n_thread)] = - p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)]; - } - } - } - } -#elif 1 - const auto c_thread_mtx_begin = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_begin = c_thread_mtx_begin.row; - const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; - const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; - const index_t n_thread_data_begin = - c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; - - // output is a 10d tensor - constexpr index_t N2 = GemmNPerThreadSubC; - constexpr index_t N1 = NPerBlock / N2; - - constexpr index_t W2 = - (GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC); - constexpr index_t W1 = WoPerBlock / W2; - - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; - - constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( - Sequence{}); - - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); - - print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); - } -#endif - - threadwise_10d_tensor_copy(out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + out_khwn_global_desc.Get1dIndex( - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin), - out_10d_thread_desc.GetLengths(), - Number{}); -#endif - } -}; diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp index 74c0e5b4b6..2a6e985cdb 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp @@ -33,10 +33,10 @@ template + class InBlockCopyClusterLengths_CHWN, + index_t InBlockCopyDataPerRead_N, + index_t WeiBlockCopyDataPerRead_K, + index_t OutThreadCopyDataPerWrite_N> struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn { __device__ void Run(const Float* const __restrict__ p_in_global, @@ -44,9 +44,11 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn Float* const __restrict__ p_out_global) const { // be careful of this assertion - static_assert( - NPerThread <= NPerBlock && NPerBlock % NPerThread == 0, - "wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"); + static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock && + NPerBlock % GemmNPerThreadSubC == 0) || + (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0), + "wrong!"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -101,14 +103,23 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn // LDS tensor view // be careful of alignment - constexpr index_t max_align = mod_conv::max( - InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB); + constexpr index_t max_align = mod_conv::max(InBlockCopyDataPerRead_N, + WeiBlockCopyDataPerRead_K, + GemmDataPerReadA, + GemmDataPerReadB); constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + Sequence{}, + Number{}); + + // this check is ad-hoc + // TODO: need to properly implement tensor descriptor with alignment + static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, + "GemmDataPerReadB alignment requirement is not meet"); constexpr auto wei_c_x_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + Sequence{}, + Number{}); // tensor view of threadwise output in register constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( @@ -116,14 +127,14 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn // blockwise copy // input: format is [C, Hi, Wi, N] -#if 0 +#if 1 const auto blockwise_in_copy = Blockwise4dTensorCopy1{}; + InBlockCopyDataPerRead_N>{}; #else const auto blockwise_in_copy = Blockwise4dTensorCopy3{}; + InBlockCopyClusterLengths_CHWN, + InBlockCopyDataPerRead_N>{}; #endif // blockwise wei copy @@ -143,7 +154,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn decltype(wei_c_x_k_global_desc), decltype(wei_c_x_k_block_desc), decltype(wei_c_x_k_block_desc.GetLengths()), - WeiBlockCopyDataPerRead>{}; + WeiBlockCopyDataPerRead_K>{}; // a series of blockwise batched GEMM // C_matrix += transpose(A_matrix) * B_matrix @@ -195,7 +206,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn __shared__ Float p_wei_block[wei_block_space]; // register - Float p_out_thread[out_k_h_w_n_thread_desc.GetElementSpace()]; + // C++ lambda doesn't capture array, use pointer instead + Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()]; + Float* const p_out_thread = p_out_thread_data; #if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) @@ -293,46 +306,126 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn const index_t k_thread_data_begin = c_thread_mtx_begin.row; const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; - const index_t n_thread_data_begin = - c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; + const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; - // output is a 10d tensor - constexpr index_t N2 = GemmNPerThreadSubC; - constexpr index_t N1 = NPerBlock / N2; + static_if{}( + [&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to + // make this lambda a generic lambda, so it won't be compiled until + // instantiated + static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock && + NPerBlock % GemmNPerThreadSubC == 0), + "wrong!"); - constexpr index_t W2 = - (GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC); - constexpr index_t W1 = WoPerBlock / W2; + // output is a 10d tensor + constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N1 = NPerBlock / N2; - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; + constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / + f_dummy(NPerBlock / GemmNPerThreadSubC); + constexpr index_t W1 = WoPerBlock / W2; - constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); #if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); - print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); - } + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + } #endif - threadwise_10d_tensor_copy(out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + out_k_h_w_n_global_desc.Get1dIndex( - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin), - out_10d_thread_desc.GetLengths(), - Number{}); + threadwise_nd_tensor_copy(out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_k_h_w_n_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); + }) + .else_([&](auto f_dummy) { + static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0, + "wrong!"); + + // output is a 10d tensor + constexpr index_t N1 = NPerBlock; + + constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; + constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; + constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); + + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + + for(index_t i = 0; i < 64; ++i) + { + printf("out %f, ", p_out_thread[i]); + } + } +#endif + + threadwise_nd_tensor_copy(out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_k_h_w_n_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); + }); } }; diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp index e7d8dee565..a1aea70cec 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp @@ -39,7 +39,7 @@ template struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn { @@ -106,7 +106,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn // LDS tensor view // be careful of alignment constexpr index_t max_align = mod_conv::max(InBlockReorderDataPerWrite_N, - WeiBlockCopyDataPerRead_C, + WeiBlockCopyDataPerRead_K, GemmDataPerReadA, GemmDataPerReadB); @@ -146,7 +146,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn decltype(wei_c_x_k_block_desc), decltype(wei_c_x_k_block_desc.GetLengths()), WeiBlockCopyClusterLengths_CXK, - WeiBlockCopyDataPerRead_C>{}; + WeiBlockCopyDataPerRead_K>{}; // a series of blockwise batched GEMM // C_matrix += transpose(A_matrix) * B_matrix @@ -216,6 +216,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn // set threadwise output tensor to 0 threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); +#if 0 const Float* p_in_global_block_offset = p_in_global + in_n_c_h_w_global_desc.Get1dIndex( n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin); @@ -229,7 +230,6 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn { for(index_t y = 0; y < Y; ++y) { -#if 1 blockwise_in_copy_reorder.Run(p_in_global_block_offset + in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, 0), p_in_block); @@ -237,23 +237,6 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn blockwise_wei_copy.Run(p_wei_global_block_offset + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, 0), p_wei_block); -#else - Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()]; - Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; - - blockwise_in_copy_reorder.RunLoadRegisterClipboard( - p_in_global_block_offset + in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, 0), - p_in_clipboard); - - blockwise_wei_copy.RunLoadRegisterClipboard( - p_wei_global_block_offset + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, 0), - p_wei_clipboard); - - blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block); - - blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block); - -#endif __syncthreads(); @@ -268,6 +251,49 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn __syncthreads(); } } +#else + for(index_t y = 0; y < Y; ++y) + { + const Float* p_in_global_block_offset = + p_in_global + + in_n_c_h_w_global_desc.Get1dIndex( + n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin); + + const Float* p_wei_global_block_offset = + p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, k_block_data_begin); + + for(index_t c_block_data_begin = 0; c_block_data_begin < C; + c_block_data_begin += CPerBlock, + p_in_global_block_offset += + CPerBlock * in_n_c_h_w_global_desc.GetStride(I1), + p_wei_global_block_offset += + CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) + { + Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()]; + Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_clipboard); + + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block); + blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block); + + __syncthreads(); + + for(index_t x = 0; x < X; ++x) + { + blockwise_batch_gemm.Run(p_wei_block + wei_c_x_k_block_desc.Get1dIndex(0, x, 0), + p_in_block + + in_c_h_w_n_block_desc.Get1dIndex(0, 0, x, 0), + p_out_thread); + } + + __syncthreads(); + } + } +#endif // output: register to global mem, const auto c_thread_mtx_begin = diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp index ff1d024346..d71ee639a2 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp @@ -43,9 +43,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn Float* const __restrict__ p_out_global) const { // be careful of this assertion - static_assert( - NPerThread <= NPerBlock && NPerBlock % NPerThread == 0, - "wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"); + static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock && + NPerBlock % GemmNPerThreadSubC == 0) || + (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0), + "wrong!"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -66,9 +68,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); - constexpr index_t HiPerBlock = HoPerBlock + Y - 1; - constexpr index_t WiPerBlock = WoPerBlock + X - 1; - // divide block work: [K, Ho, Wo, N] static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, @@ -106,10 +105,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn GemmDataPerReadB); constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + Sequence{}, + Number{}); + + // this check is ad-hoc + // TODO: need to properly implement tensor descriptor with alignment + static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, + "GemmDataPerReadB alignment requirement is not meet"); constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + Sequence{}, + Number{}); // tensor view of threadwise output in register constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( @@ -177,6 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn GemmDataPerReadB>{}; // LDS: be careful of alignment + // TODO:: need to properly implement tensor descriptor with alignment constexpr index_t in_block_space = in_c_h_w_n_block_desc.GetElementSpace(Number{}); constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number{}); @@ -185,7 +192,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn __shared__ Float p_wei_block[wei_block_space]; // register - Float p_out_thread[out_k_h_w_n_thread_desc.GetElementSpace()]; + // C++ lambda doesn't capture array, use pointer instead + Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()]; + Float* const p_out_thread = p_out_thread_data; #if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) @@ -276,46 +285,126 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn const index_t k_thread_data_begin = c_thread_mtx_begin.row; const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; - const index_t n_thread_data_begin = - c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; + const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; - // output is a 10d tensor - constexpr index_t N2 = GemmNPerThreadSubC; - constexpr index_t N1 = NPerBlock / N2; + static_if{}( + [&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to + // make this lambda a generic lambda, so it won't be compiled until + // instantiated + static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock && + NPerBlock % GemmNPerThreadSubC == 0), + "wrong!"); - constexpr index_t W2 = - (GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC); - constexpr index_t W1 = WoPerBlock / W2; + // output is a 10d tensor + constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N1 = NPerBlock / N2; - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; + constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / + f_dummy(NPerBlock / GemmNPerThreadSubC); + constexpr index_t W1 = WoPerBlock / W2; - constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); #if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); - print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); - } + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + } #endif - threadwise_10d_tensor_copy(out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + out_k_h_w_n_global_desc.Get1dIndex( - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin), - out_10d_thread_desc.GetLengths(), - Number{}); + threadwise_nd_tensor_copy(out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_k_h_w_n_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); + }) + .else_([&](auto f_dummy) { + static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0, + "wrong!"); + + // output is a 10d tensor + constexpr index_t N1 = NPerBlock; + + constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; + constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; + constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); + + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + + for(index_t i = 0; i < 64; ++i) + { + printf("out %f, ", p_out_thread[i]); + } + } +#endif + + threadwise_nd_tensor_copy(out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_k_h_w_n_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); + }); } }; diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp index 27704b60dc..7a757a6673 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp @@ -43,9 +43,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn Float* const __restrict__ p_out_global) const { // be careful of this assertion - static_assert( - NPerThread <= NPerBlock && NPerBlock % NPerThread == 0, - "wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"); + static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock && + NPerBlock % GemmNPerThreadSubC == 0) || + (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0), + "wrong!"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -109,10 +111,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn GemmDataPerReadB); constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + Sequence{}, + Number{}); + + // this check is ad-hoc + // TODO: need to properly implement tensor descriptor with alignment + static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, + "GemmDataPerReadB alignment requirement is not meet"); constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + Sequence{}, + Number{}); // tensor view of threadwise output in register constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( @@ -199,7 +208,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn __shared__ Float p_wei_block_double[2 * wei_block_space]; // register - Float p_out_thread[out_k_h_w_n_thread_desc.GetElementSpace()]; + // C++ lambda doesn't capture array, use pointer instead + Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()]; + Float* const p_out_thread = p_out_thread_data; #if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) @@ -336,46 +347,126 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn const index_t k_thread_data_begin = c_thread_mtx_begin.row; const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; - const index_t n_thread_data_begin = - c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; + const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; - // output is a 10d tensor - constexpr index_t N2 = GemmNPerThreadSubC; - constexpr index_t N1 = NPerBlock / N2; + static_if{}( + [&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to + // make this lambda a generic lambda, so it won't be compiled until + // instantiated + static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock && + NPerBlock % GemmNPerThreadSubC == 0), + "wrong!"); - constexpr index_t W2 = - (GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC); - constexpr index_t W1 = WoPerBlock / W2; + // output is a 10d tensor + constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N1 = NPerBlock / N2; - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; + constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / + f_dummy(NPerBlock / GemmNPerThreadSubC); + constexpr index_t W1 = WoPerBlock / W2; - constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); #if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); - print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); - } + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + } #endif - threadwise_10d_tensor_copy(out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + out_k_h_w_n_global_desc.Get1dIndex( - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin), - out_10d_thread_desc.GetLengths(), - Number{}); + threadwise_nd_tensor_copy(out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_k_h_w_n_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); + }) + .else_([&](auto f_dummy) { + static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0, + "wrong!"); + + // output is a 10d tensor + constexpr index_t N1 = NPerBlock; + + constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; + constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; + constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); + + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + + for(index_t i = 0; i < 64; ++i) + { + printf("out %f, ", p_out_thread[i]); + } + } +#endif + + threadwise_nd_tensor_copy(out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_k_h_w_n_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); + }); } }; diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp new file mode 100644 index 0000000000..d713b8cf2e --- /dev/null +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp @@ -0,0 +1,472 @@ +#pragma once +#include "common.hip.hpp" +#include "ConstantTensorDescriptor.hip.hpp" +#include "ConstantMatrixDescriptor.hip.hpp" +#include "blockwise_2d_tensor_op.hip.hpp" +#include "blockwise_4d_tensor_op.hip.hpp" +#include "threadwise_nd_tensor_op.hip.hpp" +#include "threadwise_4d_tensor_op.hip.hpp" +#include "blockwise_batched_gemm.hip.hpp" + +template +struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn +{ + __device__ void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) const + { + // be careful of this assertion + static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock && + NPerBlock % GemmNPerThreadSubC == 0) || + (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0), + "wrong!"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; + constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; + constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{}; + + constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); + + constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0); + constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1); + constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2); + constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3); + + constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); + constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); + + // assert for LDS double buffer + static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided"); + + // divide block work: [K, Ho, Wo, N] + static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && + Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, + "wrong! cannot evenly divide work for workgroup "); + + constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; + constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; + constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; + + const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); + index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); + const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork); + itmp -= h_block_work_id * (WBlockWork * NBlockWork); + const index_t w_block_work_id = itmp / NBlockWork; + const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; + + const index_t k_block_data_begin = k_block_work_id * KPerBlock; + const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; + const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; + const index_t n_block_data_begin = n_block_work_id * NPerBlock; + + const index_t hi_block_data_begin = ho_block_data_begin; + const index_t wi_block_data_begin = wo_block_data_begin; + + // global tensor view + constexpr auto wei_c_k_global_desc = + make_ConstantTensorDescriptor(Sequence{}, Sequence{}); + + // LDS tensor view + // be careful of alignment + constexpr index_t max_align = mod_conv::max(InBlockReorderDataPerWrite_N, + WeiBlockCopyDataPerRead_K, + GemmDataPerReadA, + GemmDataPerReadB); + + constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, + Number{}); + + // this check is ad-hoc + // TODO: need to properly implement tensor descriptor with alignment + static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, + "GemmDataPerReadB alignment requirement is not meet"); + + constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, + Number{}); + + // tensor view of threadwise output in register + constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + + // blockwise copy + // input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N] + constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{}; + + const auto blockwise_in_copy_reorder = + Blockwise4dTensorCopyReorder3, + InBlockReorderSrcSubLengths_NCHW, + InBlockReorderSrcClusterLengths_NCHW, + decltype(map_chwn2nchw), + InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW, + InBlockReorderDataPerRead_W, + InBlockReorderDataPerWrite_N>{}; + + // blockwise wei copy + // format is [CPerBlock, KPerBlock] + const auto blockwise_wei_copy = + Blockwise2dTensorCopy3{}; + + // a series of blockwise batched GEMM + // C_matrix += transpose(A_matrix) * B_matrix + // A_matrix and B_matrix saved in LDS, C_matrix saved in register + // A_matrix[C,K] is a sub-matrix of wei_block[C,K] + // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] + // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] + constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_c_wn_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + constexpr auto c_k_wn_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + const auto blockwise_batch_gemm = + BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< + BlockSize, + decltype(a_c_k_block_mtx_desc), + decltype(b_c_wn_block_mtx_desc), + decltype(c_k_wn_thread_mtx_desc), + 0, + in_c_h_w_n_block_desc.GetStride(I1), + out_k_h_w_n_thread_desc.GetStride(I1), + HoPerBlock, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + HoPerThread, + GemmDataPerReadA, + GemmDataPerReadB>{}; + + // LDS: be careful of alignment + constexpr index_t in_block_space = + in_c_h_w_n_block_desc.GetElementSpace(Number{}); + constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number{}); + + // LDS double buffer + __shared__ Float p_in_block_double[2 * in_block_space]; + __shared__ Float p_wei_block_double[2 * wei_block_space]; + + // register + // C++ lambda doesn't capture array, use pointer instead + Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()]; + Float* const p_out_thread = p_out_thread_data; + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc"); + print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc"); + + print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc"); + print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc"); + + printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space); + } +#endif + + // set threadwise output tensor to 0 + threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); + + for(index_t y = 0; y < Y; ++y) + { + for(index_t x = 0; x < X; ++x) + { + const Float* p_in_global_block_offset = + p_in_global + + in_n_c_h_w_global_desc.Get1dIndex( + n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x); + + const Float* p_wei_global_block_offset = + p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin); + + // LDS double buffer: preload data into LDS + { + Float p_in_register_clipboard[blockwise_in_copy_reorder + .GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); + + blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard, + p_in_block_double); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, + p_wei_block_double); + } + + // LDS double buffer: main body + for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C; + c_block_data_begin += 2 * CPerBlock) + { +#pragma unroll + for(index_t iloop = 0; iloop < 2; ++iloop) + { + const bool even_loop = (iloop % 2 == 0); + + Float* p_in_block_now = + even_loop ? p_in_block_double : p_in_block_double + in_block_space; + Float* p_wei_block_now = + even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; + + Float* p_in_block_next = + even_loop ? p_in_block_double + in_block_space : p_in_block_double; + Float* p_wei_block_next = + even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; + + Float p_in_register_clipboard[blockwise_in_copy_reorder + .GetRegisterClipboardSize()]; + Float + p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + p_in_global_block_offset += + CPerBlock * in_n_c_h_w_global_desc.GetStride(I1); + p_wei_global_block_offset += + CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); + + // LDS double buffer: GEMM on current data + blockwise_batch_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); + + // LDS double buffer: store next data to LDS + blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard, + p_in_block_next); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, + p_wei_block_next); + } + } + + // LDS double buffer: tail + { + Float p_in_register_clipboard[blockwise_in_copy_reorder + .GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + // even iteration + p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1); + p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); + + // LDS double buffer: GEMM on current data + blockwise_batch_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); + + // LDS double buffer: store next data to LDS + blockwise_in_copy_reorder.RunStoreRegisterClipboard( + p_in_register_clipboard, p_in_block_double + in_block_space); + blockwise_wei_copy.RunStoreRegisterClipboard( + p_wei_register_clipboard, p_wei_block_double + wei_block_space); + + // odd iteration + __syncthreads(); + + // LDS double buffer: GEMM on current data + blockwise_batch_gemm.Run(p_wei_block_double + wei_block_space, + p_in_block_double + in_block_space, + p_out_thread); + } + } + } + + // output: register to global mem, + const auto c_thread_mtx_begin = + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const index_t k_thread_data_begin = c_thread_mtx_begin.row; + const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; + const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; + const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; + + static_if{}( + [&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to + // make this lambda a generic lambda, so it won't be compiled until + // instantiated + static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock && + NPerBlock % GemmNPerThreadSubC == 0), + "wrong!"); + + // output is a 10d tensor + constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N1 = NPerBlock / N2; + + constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / + f_dummy(NPerBlock / GemmNPerThreadSubC); + constexpr index_t W1 = WoPerBlock / W2; + + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + } +#endif + + threadwise_nd_tensor_copy(out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_k_h_w_n_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); + }) + .else_([&](auto f_dummy) { + static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0, + "wrong!"); + + // output is a 10d tensor + constexpr index_t N1 = NPerBlock; + + constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; + constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; + constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); + + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + + for(index_t i = 0; i < 64; ++i) + { + printf("out %f, ", p_out_thread[i]); + } + } +#endif + + threadwise_nd_tensor_copy(out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_k_h_w_n_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); + }); + } +}; diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp new file mode 100644 index 0000000000..b6008df44e --- /dev/null +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp @@ -0,0 +1,452 @@ +#pragma once +#include "common.hip.hpp" +#include "ConstantTensorDescriptor.hip.hpp" +#include "ConstantMatrixDescriptor.hip.hpp" +#include "blockwise_2d_tensor_op.hip.hpp" +#include "blockwise_4d_tensor_op.hip.hpp" +#include "threadwise_nd_tensor_op.hip.hpp" +#include "threadwise_4d_tensor_op.hip.hpp" +#include "blockwise_batched_gemm.hip.hpp" + +template +struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn +{ + __device__ void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) const + { + // be careful of this assertion + static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock && + NPerBlock % GemmNPerThreadSubC == 0) || + (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0), + "wrong!"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; + constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; + constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{}; + + constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); + + constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0); + constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1); + constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2); + constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3); + + constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); + constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); + + // divide block work: [K, Ho, Wo, N] + static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && + Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, + "wrong! cannot evenly divide work for workgroup "); + + constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; + constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; + constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; + + const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); + index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); + const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork); + itmp -= h_block_work_id * (WBlockWork * NBlockWork); + const index_t w_block_work_id = itmp / NBlockWork; + const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; + + const index_t k_block_data_begin = k_block_work_id * KPerBlock; + const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; + const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; + const index_t n_block_data_begin = n_block_work_id * NPerBlock; + + const index_t hi_block_data_begin = ho_block_data_begin; + const index_t wi_block_data_begin = wo_block_data_begin; + + // global tensor view + constexpr auto wei_c_k_global_desc = + make_ConstantTensorDescriptor(Sequence{}, Sequence{}); + + // LDS tensor view + // be careful of alignment + constexpr index_t max_align = mod_conv::max(InBlockReorderDataPerWrite_N, + WeiBlockCopyDataPerRead_K, + GemmDataPerReadA, + GemmDataPerReadB); + + constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, + Number{}); + + // this check is ad-hoc + // TODO: need to properly implement tensor descriptor with alignment + static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, + "GemmDataPerReadB alignment requirement is not meet"); + + constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, + Number{}); + + // tensor view of threadwise output in register + constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + + // blockwise copy + // input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N] + constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{}; + + const auto blockwise_in_copy_reorder = + Blockwise4dTensorCopyReorder3, + InBlockReorderSrcSubLengths_NCHW, + InBlockReorderSrcClusterLengths_NCHW, + decltype(map_chwn2nchw), + InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW, + InBlockReorderDataPerRead_W, + InBlockReorderDataPerWrite_N>{}; + + // blockwise wei copy + // format is [CPerBlock, KPerBlock] + const auto blockwise_wei_copy = + Blockwise2dTensorCopy3{}; + + // a series of blockwise batched GEMM + // C_matrix += transpose(A_matrix) * B_matrix + // A_matrix and B_matrix saved in LDS, C_matrix saved in register + // A_matrix[C,K] is a sub-matrix of wei_block[C,K] + // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] + // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] + constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_c_wn_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + constexpr auto c_k_wn_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); + + const auto blockwise_batch_gemm = + BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< + BlockSize, + decltype(a_c_k_block_mtx_desc), + decltype(b_c_wn_block_mtx_desc), + decltype(c_k_wn_thread_mtx_desc), + 0, + in_c_h_w_n_block_desc.GetStride(I1), + out_k_h_w_n_thread_desc.GetStride(I1), + HoPerBlock, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + HoPerThread, + GemmDataPerReadA, + GemmDataPerReadB>{}; + + // LDS: be careful of alignment + constexpr index_t in_block_space = + in_c_h_w_n_block_desc.GetElementSpace(Number{}); + constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number{}); + + __shared__ Float p_in_block[in_block_space]; + __shared__ Float p_wei_block[wei_block_space]; + + // register + // C++ lambda doesn't capture array, use pointer instead + Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()]; + Float* const p_out_thread = p_out_thread_data; + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc"); + print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc"); + + print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc"); + print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc"); + + printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space); + } +#endif + + // set threadwise output tensor to 0 + threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); + +#if 1 + const Float* p_in_global_block_offset = + p_in_global + in_n_c_h_w_global_desc.Get1dIndex( + n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin); + + const Float* p_wei_global_block_offset = + p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); + + for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, + p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1), + p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) + { + for(index_t y = 0; y < Y; ++y) + { + for(index_t x = 0; x < X; ++x) + { +#if 1 + blockwise_in_copy_reorder.Run(p_in_global_block_offset + + in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, x), + p_in_block); + + blockwise_wei_copy.Run(p_wei_global_block_offset + + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, 0), + p_wei_block); +#else + Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()]; + Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + blockwise_in_copy_reorder.RunLoadRegisterClipboard( + p_in_global_block_offset + in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, x), + p_in_clipboard); + + blockwise_wei_copy.RunLoadRegisterClipboard( + p_wei_global_block_offset + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, 0), + p_wei_clipboard); + + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block); + + blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block); + +#endif + + __syncthreads(); + + blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); + + __syncthreads(); + } + } + } +#else + for(index_t y = 0; y < Y; ++y) + { + for(index_t x = 0; x < X; ++x) + { + const Float* p_in_global_block_offset = + p_in_global + + in_n_c_h_w_global_desc.Get1dIndex( + n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x); + + const Float* p_wei_global_block_offset = + p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin); + + for(index_t c_block_data_begin = 0; c_block_data_begin < C; + c_block_data_begin += CPerBlock, + p_in_global_block_offset += + CPerBlock * in_n_c_h_w_global_desc.GetStride(I1), + p_wei_global_block_offset += + CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) + { +#if 0 + blockwise_in_copy_reorder.Run(p_in_global_block_offset, + p_in_block); + + blockwise_wei_copy.Run(p_wei_global_block_offset, + p_wei_block); +#else + Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()]; + Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_clipboard); + + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block); + blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block); +#endif + + __syncthreads(); + + blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); + + __syncthreads(); + } + } + } +#endif + + // output: register to global mem, + const auto c_thread_mtx_begin = + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const index_t k_thread_data_begin = c_thread_mtx_begin.row; + const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; + const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; + const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; + + static_if{}( + [&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to + // make this lambda a generic lambda, so it won't be compiled until + // instantiated + static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock && + NPerBlock % GemmNPerThreadSubC == 0), + "wrong!"); + + // output is a 10d tensor + constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N1 = NPerBlock / N2; + + constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / + f_dummy(NPerBlock / GemmNPerThreadSubC); + constexpr index_t W1 = WoPerBlock / W2; + + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + } +#endif + + threadwise_nd_tensor_copy(out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_k_h_w_n_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); + }) + .else_([&](auto f_dummy) { + static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && + GemmNPerThreadSubC % NPerThread == 0, + "wrong!"); + + // output is a 10d tensor + constexpr index_t N1 = NPerBlock; + + constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; + constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; + constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); + + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); + + for(index_t i = 0; i < 64; ++i) + { + printf("out %f, ", p_out_thread[i]); + } + } +#endif + + threadwise_nd_tensor_copy(out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_k_h_w_n_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); + }); + } +}; diff --git a/src/include/threadwise_2d_tensor_op.hip.hpp b/src/include/threadwise_2d_tensor_op.hip.hpp index 34f34db086..9121bb9e76 100644 --- a/src/include/threadwise_2d_tensor_op.hip.hpp +++ b/src/include/threadwise_2d_tensor_op.hip.hpp @@ -88,6 +88,7 @@ threadwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy); } +#if 0 // replaced threadwise_nd_tensor_copy template __device__ void threadwise_2d_tensor_copy( SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths) @@ -97,6 +98,7 @@ __device__ void threadwise_2d_tensor_copy( threadwise_2d_tensor_copy_reorder_by_get_dst_from_src( SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder); } +#endif template __device__ void threadwise_2d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift) diff --git a/src/include/threadwise_4d_tensor_op.hip.hpp b/src/include/threadwise_4d_tensor_op.hip.hpp index 37427c0b8b..cdd27199c2 100644 --- a/src/include/threadwise_4d_tensor_op.hip.hpp +++ b/src/include/threadwise_4d_tensor_op.hip.hpp @@ -139,6 +139,7 @@ __device__ void threadwise_4d_tensor_copy_reorder_given_dst2src(SrcDesc, SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy); } +#if 0 // replaced threadwise_nd_tensor_copy template __device__ void threadwise_4d_tensor_copy( SrcDesc, const SrcData* __restrict__ p_src, DstDesc, DstData* __restrict__ p_dst, SrcOpLengths) @@ -210,6 +211,7 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc, } } } +#endif template -__device__ void threadwise_6d_tensor_copy(SrcDesc, +__device__ void threadwise_nd_tensor_copy(SrcDesc, const Float* __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, @@ -12,268 +12,53 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc, { using vector_t = typename vector_type::MemoryType; - static_assert(SrcDesc{}.GetDimension() == 6 && DstDesc{}.GetDimension() == 6 && - SrcOpLengths::nDim == 6, - "wrong! should be 6 dimension"); + constexpr index_t nDim = SrcOpLengths::GetSize(); - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; + static_assert(SrcDesc{}.GetDimension() == nDim && DstDesc{}.GetDimension() == nDim, + "wrong! dimension not consistent"); constexpr auto src_desc = SrcDesc{}; constexpr auto dst_desc = DstDesc{}; constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); - static_assert(SrcDesc{}.GetStride(I5) == 1 && DstDesc{}.GetStride(I5) == 1, - "wrong! only support stride5 == 1!\n"); +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(src_desc, "src_desc"); + print_ConstantTensorDescriptor(dst_desc, "dst_desc"); + print_ConstantTensorDescriptor(ref_desc, "ref_desc"); + } +#endif + + static_assert(DataPerRead == 1 || (SrcDesc{}.GetStride(Number{}) == 1 && + DstDesc{}.GetStride(Number{}) == 1), + "wrong! only support stride[nDim-1] == 1!\n"); static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4, "wrong! only support DataPerRead == 1, 2 or 4!\n"); - static_assert(SrcDesc{}.GetStride(I4) % DataPerRead == 0 && - DstDesc{}.GetStride(I4) % DataPerRead == 0, - "wrong! src and dst stride should be multiple of DataPerRead to keep alignment"); + static_assert( + SrcDesc{}.GetStride(Number{}) % DataPerRead == 0 && + DstDesc{}.GetStride(Number{}) % DataPerRead == 0, + "wrong! src and dst stride[nDim-2] should be multiple of DataPerRead to keep alignment"); - constexpr index_t L5 = SrcOpLengths{}.Get(I5); + constexpr index_t L_Back = SrcOpLengths{}.Back(); - static_assert(L5 % DataPerRead == 0, "wrong! L5 should be evenly divided by DataPerRead"); + static_assert(L_Back % DataPerRead == 0, + "wrong! lengths[nDim-1] should be evenly divided by DataPerRead"); - constexpr index_t nloop_d5 = L5 / DataPerRead; + constexpr index_t nRead = L_Back / DataPerRead; - for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) - { - for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) - { - for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) - { - for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) - { - for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4) - { - for(index_t iloop_d5 = 0; iloop_d5 < nloop_d5; ++iloop_d5) - { - const index_t src_index = src_desc.Get1dIndex( - did0, did1, did2, did3, did4, iloop_d5 * DataPerRead); + static_ford{}([=](auto Ids) { + static_for<0, nRead, 1>{}([=](auto IRead) { + constexpr auto multi_id = decltype(Ids){}.PushBack(Number{}); - const index_t dst_index = dst_desc.Get1dIndex( - did0, did1, did2, did3, did4, iloop_d5 * DataPerRead); + const index_t src_index = src_desc.Get1dIndex(multi_id); - *(reinterpret_cast(p_dst + dst_index)) = - *(reinterpret_cast(p_src + src_index)); - } - } - } - } - } - } -} - -// need to assume src and dst is aligned -template -__device__ void threadwise_8d_tensor_copy(SrcDesc, - const Float* __restrict__ p_src, - DstDesc, - Float* __restrict__ p_dst, - SrcOpLengths, - Number) -{ - using vector_t = typename vector_type::MemoryType; - - static_assert(SrcDesc{}.GetDimension() == 8 && DstDesc{}.GetDimension() == 8 && - SrcOpLengths::nDim == 8, - "wrong! should be 8 dimension"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - - constexpr auto src_desc = SrcDesc{}; - constexpr auto dst_desc = DstDesc{}; - constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); - - static_assert(SrcDesc{}.GetStride(I7) == 1 && DstDesc{}.GetStride(I7) == 1, - "wrong! only support stride7 == 1!\n"); - - static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4, - "wrong! only support DataPerRead == 1, 2 or 4!\n"); - - static_assert(SrcDesc{}.GetStride(I6) % DataPerRead == 0 && - DstDesc{}.GetStride(I6) % DataPerRead == 0, - "wrong! src and dst stride should be multiple of DataPerRead to keep alignment"); - - constexpr index_t L7 = SrcOpLengths{}.Get(I7); - - static_assert(L7 % DataPerRead == 0, "wrong! L7 should be evenly divided by DataPerRead"); - - constexpr index_t nloop_d7 = L7 / DataPerRead; - - for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) - { - for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) - { - for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) - { - for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) - { - for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4) - { - for(index_t did5 = 0; did5 < ref_desc.GetLength(I5); ++did5) - { - for(index_t did6 = 0; did6 < ref_desc.GetLength(I6); ++did6) - { - for(index_t iloop_d7 = 0; iloop_d7 < nloop_d7; ++iloop_d7) - { - const index_t src_index = - src_desc.Get1dIndex(did0, - did1, - did2, - did3, - did4, - did5, - did6, - iloop_d7 * DataPerRead); - - const index_t dst_index = - dst_desc.Get1dIndex(did0, - did1, - did2, - did3, - did4, - did5, - did6, - iloop_d7 * DataPerRead); - - *(reinterpret_cast(p_dst + dst_index)) = - *(reinterpret_cast(p_src + src_index)); - } - } - } - } - } - } - } - } -} - -// need to assume src and dst is aligned -template -__device__ void threadwise_10d_tensor_copy(SrcDesc, - const Float* __restrict__ p_src, - DstDesc, - Float* __restrict__ p_dst, - SrcOpLengths, - Number) -{ - using vector_t = typename vector_type::MemoryType; - - static_assert(SrcDesc{}.GetDimension() == 10 && DstDesc{}.GetDimension() == 10 && - SrcOpLengths::GetSize() == 10, - "wrong! should be 10 dimension"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; - constexpr auto I9 = Number<9>{}; - - constexpr auto src_desc = SrcDesc{}; - constexpr auto dst_desc = DstDesc{}; - constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); - - static_assert(SrcDesc{}.GetStride(I9) == 1 && DstDesc{}.GetStride(I9) == 1, - "wrong! only support stride7 == 1!\n"); - - static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4, - "wrong! only support DataPerRead == 1, 2 or 4!\n"); - - static_assert(SrcDesc{}.GetStride(I8) % DataPerRead == 0 && - DstDesc{}.GetStride(I8) % DataPerRead == 0, - "wrong! src and dst stride should be multiple of DataPerRead to keep alignment"); - - constexpr index_t L9 = SrcOpLengths{}.Get(I9); - - static_assert(L9 % DataPerRead == 0, "wrong! L9 should be evenly divided by DataPerRead"); - - constexpr index_t nloop_d9 = L9 / DataPerRead; - -#pragma unroll - for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) - { -#pragma unroll - for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) - { -#pragma unroll - for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) - { -#pragma unroll - for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) - { -#pragma unroll - for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4) - { -#pragma unroll - for(index_t did5 = 0; did5 < ref_desc.GetLength(I5); ++did5) - { -#pragma unroll - for(index_t did6 = 0; did6 < ref_desc.GetLength(I6); ++did6) - { -#pragma unroll - for(index_t did7 = 0; did7 < ref_desc.GetLength(I7); ++did7) - { -#pragma unroll - for(index_t did8 = 0; did8 < ref_desc.GetLength(I8); ++did8) - { -#pragma unroll - for(index_t iloop_d9 = 0; iloop_d9 < nloop_d9; ++iloop_d9) - { - const index_t src_index = - src_desc.Get1dIndex(did0, - did1, - did2, - did3, - did4, - did5, - did6, - did7, - did8, - iloop_d9 * DataPerRead); - - const index_t dst_index = - dst_desc.Get1dIndex(did0, - did1, - did2, - did3, - did4, - did5, - did6, - did7, - did8, - iloop_d9 * DataPerRead); - - *(reinterpret_cast(p_dst + dst_index)) = - *(reinterpret_cast(p_src + - src_index)); - } - } - } - } - } - } - } - } - } - } + const index_t dst_index = dst_desc.Get1dIndex(multi_id); + + *(reinterpret_cast(&p_dst[dst_index])) = + *(reinterpret_cast(&p_src[src_index])); + }); + }); }