From a65ef9030880d51dd159e4d23f1dc6093b17651c Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 19 Feb 2019 11:47:46 -0600 Subject: [PATCH] device_implicit_gemm_convolution_1_chwn_csrk_khwn: use tensor copy (instead of pointwise) for writing output, 3x3 increased from 78% to 84%, 5x5 from 80% to 84% --- ...icit_gemm_convolution_1_chwn_csrk_khwn.hpp | 79 +++-- driver/driver.hip.cpp | 46 +-- src/include/ConstantTensorDescriptor.hip.hpp | 329 +++++++++++++++++- src/include/blockwise_gemm.hip.hpp | 44 +++ ..._gemm_convolution_1_chwn_csrk_khwn.hip.hpp | 74 +++- src/include/threadwise_4d_tensor_op.hip.hpp | 85 ++++- src/include/threadwise_nd_tensor_op.hip.hpp | 198 +++++++++++ 7 files changed, 795 insertions(+), 60 deletions(-) create mode 100644 src/include/threadwise_nd_tensor_op.hip.hpp diff --git a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp index 246d331fb5..83e2aa2642 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp @@ -75,7 +75,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, out_khwn_device_buf.ToDevice(out_khwn.mData.data()); #if 1 - // for 3x3, 34x34, try + // for 3x3, 34x34 constexpr unsigned NPerBlock = 16; constexpr unsigned KPerBlock = 64; constexpr unsigned CPerBlock = 4; @@ -106,9 +106,46 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned GemmNLevel1Cluster = 4; constexpr unsigned GemmKPerThreadLoop = 1; + constexpr unsigned OutThreadCopyDataPerWrite = 2; + constexpr unsigned BlockSize = 128; #elif 0 - // for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256 + // for 5x5, 36x36 + constexpr unsigned NPerBlock = 16; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 2; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 4; + + constexpr unsigned NPerThread = 8; + constexpr unsigned KPerThread = 8; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 1; + + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + + constexpr unsigned InBlockCopy_ThreadPerDimC = 2; + constexpr unsigned InBlockCopy_ThreadPerDimH = 2; + constexpr unsigned InBlockCopy_ThreadPerDimW = 4; + constexpr unsigned InBlockCopy_ThreadPerDimN = 4; + constexpr unsigned InBlockCopyDataPerRead = 4; + + constexpr unsigned WeiBlockCopyDataPerRead = 2; + + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 4; + constexpr unsigned GemmNLevel0Cluster = 2; + constexpr unsigned GemmMLevel1Cluster = 2; + constexpr unsigned GemmNLevel1Cluster = 4; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned OutThreadCopyDataPerWrite = 2; + + constexpr unsigned BlockSize = 128; +#elif 0 + // 3x3 58x58, NKC = 64, 64, 256 constexpr unsigned NPerBlock = 16; constexpr unsigned KPerBlock = 64; constexpr unsigned CPerBlock = 4; @@ -142,27 +179,6 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned HoPerThread = 1; constexpr unsigned WoPerThread = 1; - constexpr unsigned BlockSize = 128; -#elif 0 - // for 5x5, 36x36 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; - - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; - - constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet - constexpr unsigned WeiBlockCopyDataPerRead = 4; - constexpr unsigned BlockSize = 128; #elif 0 // for 7x7, 38x38 @@ -200,7 +216,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned WoPerThread = 1; constexpr unsigned BlockSize = 128; -#elif 0 +#elif 1 // for 1x1, 28x28 constexpr unsigned NPerBlock = 16; constexpr unsigned KPerBlock = 128; @@ -210,7 +226,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned NPerThread = 4; constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 2; + constexpr unsigned CPerThread = 1; constexpr unsigned HoPerThread = 1; constexpr unsigned WoPerThread = 1; @@ -225,6 +241,16 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr unsigned GemmMPerThreadSubC = 4; + constexpr unsigned GemmNPerThreadSubC = 4; + constexpr unsigned GemmMLevel0Cluster = 4; + constexpr unsigned GemmNLevel0Cluster = 2; + constexpr unsigned GemmMLevel1Cluster = 2; + constexpr unsigned GemmNLevel1Cluster = 4; + constexpr unsigned GemmKPerThreadLoop = 1; + + constexpr unsigned OutThreadCopyDataPerWrite = 2; + constexpr unsigned BlockSize = 128; #endif @@ -266,7 +292,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, GemmNLevel0Cluster, GemmMLevel1Cluster, GemmNLevel1Cluster, - GemmKPerThreadLoop>, + GemmKPerThreadLoop, + OutThreadCopyDataPerWrite>, dim3(GridSize), dim3(BlockSize), static_cast(in_chwn_device_buf.GetDeviceBuffer()), diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index 4fc78491bb..5b4e08d0bf 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -571,16 +571,21 @@ int main() std::size_t num_thread = std::thread::hardware_concurrency(); + bool do_verification = true; + + if(do_verification) + { #if 0 - in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread); #elif 1 - in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); #elif 1 - in_nchw.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread); - wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + in_nchw.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread); + wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread); #endif + } unsigned nrepeat = 200; @@ -614,22 +619,23 @@ int main() nrepeat); #endif -#if 1 - if(S == 3 && R == 3) + if(do_verification) { - host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); - } - else - { - host_direct_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); - } - check_error(out_nkhw_host, out_nkhw_device); -#endif + if(S == 3 && R == 3) + { + host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); + } + else + { + host_direct_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); + } + check_error(out_nkhw_host, out_nkhw_device); #if 0 - LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; - LogRange(std::cout << "wei_kcsr: ", wei_kcsr.mData, ",") << std::endl; - LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; - LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; + LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; + LogRange(std::cout << "wei_kcsr: ", wei_kcsr.mData, ",") << std::endl; + LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; + LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; #endif + } } diff --git a/src/include/ConstantTensorDescriptor.hip.hpp b/src/include/ConstantTensorDescriptor.hip.hpp index c3157653d2..2352b0f50c 100644 --- a/src/include/ConstantTensorDescriptor.hip.hpp +++ b/src/include/ConstantTensorDescriptor.hip.hpp @@ -15,6 +15,35 @@ __host__ __device__ constexpr auto calculate_default_strides(Sequence{}; } +// this is ugly, only for 6d +template +__host__ __device__ constexpr auto calculate_default_strides(Sequence) +{ + return Sequence{}; +} + +// this is ugly, only for 8d +template +__host__ __device__ constexpr auto + calculate_default_strides(Sequence) +{ + return Sequence{}; +} + // this is ugly, only for 2d template __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence, @@ -64,7 +93,7 @@ struct ConstantTensorDescriptor __host__ __device__ constexpr unsigned GetElementSize() const { - static_assert(nDim >= 2 && nDim <= 4, "nDim"); + static_assert(nDim >= 2 && nDim <= 8, "nDim"); if(nDim == 2) { @@ -90,12 +119,65 @@ struct ConstantTensorDescriptor return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3); } + else if(nDim == 5) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4); + } + else if(nDim == 6) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) * + GetLength(I5); + } + else if(nDim == 7) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + + return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) * + GetLength(I5) * GetLength(I6); + } + else if(nDim == 8) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + + return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) * + GetLength(I5) * GetLength(I6) * GetLength(I7); + } + else + { + assert(false); + } } template > __host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const { - static_assert(nDim >= 2 && nDim <= 4, "nDim"); + static_assert(nDim >= 2 && nDim <= 8, "nDim"); constexpr unsigned align_size = align.Get(); @@ -127,6 +209,64 @@ struct ConstantTensorDescriptor (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + align_size; } + else if(nDim == 5) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + + (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + + (GetLength(I4) - 1) * GetStride(I4) + align_size; + } + else if(nDim == 6) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + + (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + + (GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) + + align_size; + } + else if(nDim == 7) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + + return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + + (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + + (GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) + + (GetLength(I6) - 1) * GetStride(I6) + align_size; + } + else if(nDim == 8) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + + return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + + (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + + (GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) + + (GetLength(I6) - 1) * GetStride(I6) + (GetLength(I7) - 1) * GetStride(I7) + + align_size; + } } // this is ugly, only for 2d @@ -163,6 +303,83 @@ struct ConstantTensorDescriptor return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3); } + // this is ugly, only for 5d + __host__ __device__ unsigned + Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + static_assert(nDim == 5, "nDim is not 5"); + return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) + + i4 * GetStride(I4); + } + + // this is ugly, only for 6d + __host__ __device__ unsigned + Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4, unsigned i5) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + static_assert(nDim == 6, "nDim is not 6"); + return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) + + i4 * GetStride(I4) + i5 * GetStride(I5); + } + + // this is ugly, only for 7d + __host__ __device__ unsigned Get1dIndex(unsigned i0, + unsigned i1, + unsigned i2, + unsigned i3, + unsigned i4, + unsigned i5, + unsigned i6) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + + static_assert(nDim == 7, "nDim is not 7"); + return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) + + i4 * GetStride(I4) + i5 * GetStride(I5) + i6 * GetStride(I6); + } + + // this is ugly, only for 8d + __host__ __device__ unsigned Get1dIndex(unsigned i0, + unsigned i1, + unsigned i2, + unsigned i3, + unsigned i4, + unsigned i5, + unsigned i6, + unsigned i7) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + + static_assert(nDim == 8, "nDim is not 8"); + return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) + + i4 * GetStride(I4) + i5 * GetStride(I5) + i6 * GetStride(I6) + i7 * GetStride(I7); + } + __host__ __device__ constexpr auto Condense() const { constexpr auto default_strides = calculate_default_strides(Lengths{}); @@ -196,7 +413,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) constexpr auto desc = TDesc{}; constexpr unsigned ndim = desc.GetDimension(); - static_assert(ndim >= 2 && ndim <= 4, "wrong!"); + static_assert(ndim >= 2 && ndim <= 8, "wrong!"); if(ndim == 2) { @@ -230,4 +447,110 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) desc.GetStride(I2), desc.GetStride(I3)); } + else if(ndim == 5) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n", + s, + desc.GetDimension(), + desc.GetLength(I0), + desc.GetLength(I1), + desc.GetLength(I2), + desc.GetLength(I3), + desc.GetLength(I4), + desc.GetStride(I0), + desc.GetStride(I1), + desc.GetStride(I2), + desc.GetStride(I3), + desc.GetStride(I4)); + } + else if(ndim == 6) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n", + s, + desc.GetDimension(), + desc.GetLength(I0), + desc.GetLength(I1), + desc.GetLength(I2), + desc.GetLength(I3), + desc.GetLength(I4), + desc.GetLength(I5), + desc.GetStride(I0), + desc.GetStride(I1), + desc.GetStride(I2), + desc.GetStride(I3), + desc.GetStride(I4), + desc.GetStride(I5)); + } + else if(ndim == 7) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + + printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n", + s, + desc.GetDimension(), + desc.GetLength(I0), + desc.GetLength(I1), + desc.GetLength(I2), + desc.GetLength(I3), + desc.GetLength(I4), + desc.GetLength(I5), + desc.GetLength(I6), + desc.GetStride(I0), + desc.GetStride(I1), + desc.GetStride(I2), + desc.GetStride(I3), + desc.GetStride(I4), + desc.GetStride(I5), + desc.GetStride(I6)); + } + else if(ndim == 8) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + + printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n", + s, + desc.GetDimension(), + desc.GetLength(I0), + desc.GetLength(I1), + desc.GetLength(I2), + desc.GetLength(I3), + desc.GetLength(I4), + desc.GetLength(I5), + desc.GetLength(I6), + desc.GetLength(I7), + desc.GetStride(I0), + desc.GetStride(I1), + desc.GetStride(I2), + desc.GetStride(I3), + desc.GetStride(I4), + desc.GetStride(I5), + desc.GetStride(I6), + desc.GetStride(I7)); + } } diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index 46773388ed..33556dde25 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -525,7 +525,51 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 f_accum); } } + + template + __device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread, + FloatC* __restrict__ p_c_block) const + { + constexpr auto c_block_mtx = BlockMatrixC{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + + constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; + constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + + const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const unsigned c_thread_offset = + c_thread_mtx_begin.batch * BlockMatrixStrideC + + c_block_mtx.Get1dIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col); + + for(unsigned m_repeat = 0; m_repeat, MRepeat; ++m_repeat) + { + for(unsigned n_repeat = 0; n_repeat, NRepeat; ++n_repeat) + { + threadwise_matrix_copy( + c_thread_sub_mtx, + p_c_thread + c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster, + n_repeat * NPerLevel1Cluster), + c_block_mtx, + p_c_block + + c_block_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster, + n_repeat * NPerLevel1Cluster) + + c_thread_offset, + c_thread_sub_mtx.GetLengths()); + } + } + } }; + template + unsigned GemmKPerThreadLoop, + unsigned OutThreadCopyDataPerWrite> __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global, @@ -270,19 +272,18 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric } } + // output: register to global mem, +#if 0 const auto c_thread_mtx_begin = blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - // output: register to global mem, - // convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N] -#if 0 // for v1 batch-gemm const unsigned k_thread_data_begin = c_thread_mtx_begin.row; const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch; const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; - const unsigned n_thread_data_begin = c_thread_mtx_begin.col - wo_thread_data_begin * NPerBlock; + const unsigned n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; - threadwise_4d_tensor_copy( + threadwise_4d_tensor_copy_v2( out_khwn_thread_desc, p_out_thread, out_khwn_global_desc, @@ -290,8 +291,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ho_block_data_begin + ho_thread_data_begin, wo_block_data_begin + wo_thread_data_begin, n_block_data_begin + n_thread_data_begin), - out_khwn_thread_desc.GetLengths()); -#else + out_khwn_thread_desc.GetLengths(), + Number{}); +#elif 0 + const auto c_thread_mtx_begin = + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + for(unsigned k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k) { for(unsigned ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho) @@ -322,5 +327,58 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric } } } +#elif 1 + const auto c_thread_mtx_begin = + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const unsigned k_thread_data_begin = c_thread_mtx_begin.row; + const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch; + const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; + const unsigned n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; + + // this is for v2 GEMM + // output is a 8d tensor + if(NPerThread < NPerBlock && WoPerThread == 1) + { + constexpr unsigned N1_ = GemmNPerThreadSubC; + constexpr unsigned W1_ = WoPerBlock / ((WoPerThread * NPerThread) / GemmNPerThreadSubC); + constexpr unsigned K2_ = GemmMPerThreadSubC; + constexpr unsigned K1_ = KPerBlock / KPerThread; + + constexpr auto out_8d_global_desc = make_ConstantTensorDescriptor( + Sequence{}); + + constexpr auto out_8d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc"); + print_ConstantTensorDescriptor(out_8d_thread_desc, "out_8d_thread_desc"); + + print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); + print_ConstantTensorDescriptor(out_8d_global_desc, "out_8d_global_desc"); + } +#endif + + threadwise_8d_tensor_copy(out_8d_thread_desc, + p_out_thread, + out_8d_global_desc, + p_out_global + out_khwn_global_desc.Get1dIndex( + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_8d_thread_desc.GetLengths(), + Number{}); + } + else if(NPerThread == NPerBlock) + { + } + else + { + assert(false); + } #endif } diff --git a/src/include/threadwise_4d_tensor_op.hip.hpp b/src/include/threadwise_4d_tensor_op.hip.hpp index 6cf413187f..3d13ae2aa6 100644 --- a/src/include/threadwise_4d_tensor_op.hip.hpp +++ b/src/include/threadwise_4d_tensor_op.hip.hpp @@ -45,7 +45,7 @@ template __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( SrcDesc, - Float* const __restrict__ p_src, + const Float* __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths, @@ -100,7 +100,7 @@ __device__ void threadwise_4d_tensor_set_zero(Desc, Float* __restrict__ p) template __device__ void threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, - Float* const __restrict__ p_src, + const Float* __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths, @@ -114,7 +114,7 @@ threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, template __device__ void threadwise_4d_tensor_copy( - SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths) + SrcDesc, const Float* __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths) { auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{}; @@ -122,6 +122,85 @@ __device__ void threadwise_4d_tensor_copy( SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder); } +// need to assume src and dst is aligned +template +__device__ void threadwise_4d_tensor_copy_v2(SrcDesc, + const Float* __restrict__ p_src, + DstDesc, + Float* __restrict__ p_dst, + SrcOpLengths, + Number) +{ + using Float2 = float2; + using Float4 = float4; + + static_assert(SrcDesc{}.GetDimension() == 4 && DstDesc{}.GetDimension() == 4 && + SrcOpLengths::nDim == 4, + "wrong! should be 4 dimension"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto src_desc = SrcDesc{}; + constexpr auto dst_desc = DstDesc{}; + constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); + + static_assert(SrcDesc{}.GetStride(I3) == 1 && DstDesc{}.GetStride(I3) == 1, + "wrong! only support stride3 == 1!\n"); + + static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4, + "wrong! only support DataPerRead == 1, 2 or 4!\n"); + + static_assert(SrcDesc{}.GetStride(I2) % DataPerRead == 0 && + DstDesc{}.GetStride(I2) % DataPerRead == 0, + "wrong! src and dst stride should be multiple of DataPerRead to keep alignment"); + + constexpr unsigned L3 = SrcOpLengths{}.Get(I3); + + static_assert(L3 % DataPerRead == 0, "wrong! L3 should be evenly divided by DataPerRead"); + + constexpr unsigned nloop_d3 = L3 / DataPerRead; + + for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) + { + for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) + { + for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) + { + for(unsigned iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3) + { + const unsigned src_index = + src_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead); + + const unsigned dst_index = + dst_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead); + + if(DataPerRead == 1) + { + p_dst[dst_index] = p_src[src_index]; + } + else if(DataPerRead == 2) + { + *(reinterpret_cast(p_dst + dst_index)) = + *(reinterpret_cast(p_src + src_index)); + } + else if(DataPerRead == 4) + { + *(reinterpret_cast(p_dst + dst_index)) = + *(reinterpret_cast(p_src + src_index)); + } + else + { + assert(false); + } + } + } + } + } +} + template __device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift) { diff --git a/src/include/threadwise_nd_tensor_op.hip.hpp b/src/include/threadwise_nd_tensor_op.hip.hpp new file mode 100644 index 0000000000..97206e88f5 --- /dev/null +++ b/src/include/threadwise_nd_tensor_op.hip.hpp @@ -0,0 +1,198 @@ +#pragma once +#include "ConstantTensorDescriptor.hip.hpp" + +// need to assume src and dst is aligned +template +__device__ void threadwise_6d_tensor_copy(SrcDesc, + const Float* __restrict__ p_src, + DstDesc, + Float* __restrict__ p_dst, + SrcOpLengths, + Number) +{ + using Float2 = float2; + using Float4 = float4; + + static_assert(SrcDesc{}.GetDimension() == 6 && DstDesc{}.GetDimension() == 6 && + SrcOpLengths::nDim == 6, + "wrong! should be 6 dimension"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + constexpr auto src_desc = SrcDesc{}; + constexpr auto dst_desc = DstDesc{}; + constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); + + static_assert(SrcDesc{}.GetStride(I5) == 1 && DstDesc{}.GetStride(I5) == 1, + "wrong! only support stride5 == 1!\n"); + + static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4, + "wrong! only support DataPerRead == 1, 2 or 4!\n"); + + static_assert(SrcDesc{}.GetStride(I4) % DataPerRead == 0 && + DstDesc{}.GetStride(I4) % DataPerRead == 0, + "wrong! src and dst stride should be multiple of DataPerRead to keep alignment"); + + constexpr unsigned L5 = SrcOpLengths{}.Get(I5); + + static_assert(L5 % DataPerRead == 0, "wrong! L5 should be evenly divided by DataPerRead"); + + constexpr unsigned nloop_d5 = L5 / DataPerRead; + + for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) + { + for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) + { + for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) + { + for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) + { + for(unsigned did4 = 0; did4 < ref_desc.GetLength(I4); ++did4) + { + for(unsigned iloop_d5 = 0; iloop_d5 < nloop_d5; ++iloop_d5) + { + const unsigned src_index = src_desc.Get1dIndex( + did0, did1, did2, did3, did4, iloop_d5 * DataPerRead); + + const unsigned dst_index = dst_desc.Get1dIndex( + did0, did1, did2, did3, did4, iloop_d5 * DataPerRead); + + if(DataPerRead == 1) + { + p_dst[dst_index] = p_src[src_index]; + } + else if(DataPerRead == 2) + { + *(reinterpret_cast(p_dst + dst_index)) = + *(reinterpret_cast(p_src + src_index)); + } + else if(DataPerRead == 4) + { + *(reinterpret_cast(p_dst + dst_index)) = + *(reinterpret_cast(p_src + src_index)); + } + else + { + assert(false); + } + } + } + } + } + } + } +} + +// need to assume src and dst is aligned +template +__device__ void threadwise_8d_tensor_copy(SrcDesc, + const Float* __restrict__ p_src, + DstDesc, + Float* __restrict__ p_dst, + SrcOpLengths, + Number) +{ + using Float2 = float2; + using Float4 = float4; + + static_assert(SrcDesc{}.GetDimension() == 8 && DstDesc{}.GetDimension() == 8 && + SrcOpLengths::nDim == 8, + "wrong! should be 8 dimension"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + + constexpr auto src_desc = SrcDesc{}; + constexpr auto dst_desc = DstDesc{}; + constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); + + static_assert(SrcDesc{}.GetStride(I7) == 1 && DstDesc{}.GetStride(I7) == 1, + "wrong! only support stride7 == 1!\n"); + + static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4, + "wrong! only support DataPerRead == 1, 2 or 4!\n"); + + static_assert(SrcDesc{}.GetStride(I6) % DataPerRead == 0 && + DstDesc{}.GetStride(I6) % DataPerRead == 0, + "wrong! src and dst stride should be multiple of DataPerRead to keep alignment"); + + constexpr unsigned L7 = SrcOpLengths{}.Get(I7); + + static_assert(L7 % DataPerRead == 0, "wrong! L7 should be evenly divided by DataPerRead"); + + constexpr unsigned nloop_d7 = L7 / DataPerRead; + + for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) + { + for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) + { + for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) + { + for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) + { + for(unsigned did4 = 0; did4 < ref_desc.GetLength(I4); ++did4) + { + for(unsigned did5 = 0; did5 < ref_desc.GetLength(I5); ++did5) + { + for(unsigned did6 = 0; did6 < ref_desc.GetLength(I6); ++did6) + { + for(unsigned iloop_d7 = 0; iloop_d7 < nloop_d7; ++iloop_d7) + { + const unsigned src_index = + src_desc.Get1dIndex(did0, + did1, + did2, + did3, + did4, + did5, + did6, + iloop_d7 * DataPerRead); + + const unsigned dst_index = + dst_desc.Get1dIndex(did0, + did1, + did2, + did3, + did4, + did5, + did6, + iloop_d7 * DataPerRead); + + if(DataPerRead == 1) + { + p_dst[dst_index] = p_src[src_index]; + } + else if(DataPerRead == 2) + { + *(reinterpret_cast(p_dst + dst_index)) = + *(reinterpret_cast(p_src + src_index)); + } + else if(DataPerRead == 4) + { + *(reinterpret_cast(p_dst + dst_index)) = + *(reinterpret_cast(p_src + src_index)); + } + else + { + assert(false); + } + } + } + } + } + } + } + } + } +}