From c9fa46af0bf70701e73a6d2cd9741759d179e5ee Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 8 Apr 2019 10:27:32 -0500 Subject: [PATCH] debugging implicit gemm v1: use 10d tensor output --- ...onvolution_2_vectorized_nchw_kcyx_nkhw.hpp | 30 ++--- ...icit_gemm_convolution_1_chwn_cyxk_khwn.hpp | 42 ++++++- ...icit_gemm_convolution_2_chwn_cyxk_khwn.hpp | 14 +-- driver/driver.hip.cpp | 10 +- src/include/ConstantTensorDescriptor.hip.hpp | 26 ++++ src/include/blockwise_4d_tensor_op.hip.hpp | 9 +- src/include/blockwise_batched_gemm.hip.hpp | 5 +- .../blockwise_direct_convolution.hip.hpp | 27 ++-- ...on_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp | 74 ++++++----- ...on_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp | 7 +- ...2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp | 35 +++--- .../gridwise_direct_convolution_1.hip.hpp | 19 ++- ...irect_convolution_2_nchw_kcyx_nkhw.hip.hpp | 43 +++---- ...lution_2_vectorized_nchw_kcyx_nkhw.hip.hpp | 34 +++--- ...onvolution_1_chwn_cyxk_khwn_padded.hip.hpp | 9 +- src/include/tensor.hpp | 3 +- src/include/threadwise_nd_tensor_op.hip.hpp | 115 ++++++++++++++++++ 17 files changed, 324 insertions(+), 178 deletions(-) diff --git a/driver/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp b/driver/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp index 7790900f83..938bc4cd30 100644 --- a/driver/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp +++ b/driver/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp @@ -52,7 +52,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, in_nchw_vec(n, c, h, w) = vector_t::Pack(in_nchw(n, 2 * c, h, w), in_nchw(n, 2 * c + 1, h, w)); #elif 1 - in_nchw_vec(n, c, h, w) = vector_t::Pack(in_nchw(n, 4 * c, h, w), + in_nchw_vec(n, c, h, w) = vector_t::Pack(in_nchw(n, 4 * c, h, w), in_nchw(n, 4 * c + 1, h, w), in_nchw(n, 4 * c + 2, h, w), in_nchw(n, 4 * c + 3, h, w)); @@ -114,37 +114,37 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, constexpr index_t BlockSize = 128; #elif 0 // 3x3, 34x34, 128 thread, fp32, vector = 2 - constexpr index_t NPerBlock = 2; - constexpr index_t KPerBlock = 32; - constexpr index_t CPerBlock = 2; + constexpr index_t NPerBlock = 2; + constexpr index_t KPerBlock = 32; + constexpr index_t CPerBlock = 2; constexpr index_t HoPerBlock = 2; constexpr index_t WoPerBlock = 32; - constexpr index_t NPerThread = 2; - constexpr index_t KPerThread = 4; - constexpr index_t CPerThread = 1; + constexpr index_t NPerThread = 2; + constexpr index_t KPerThread = 4; + constexpr index_t CPerThread = 1; constexpr index_t HoPerThread = 2; constexpr index_t WoPerThread = 2; - constexpr index_t InBlockCopyDataPerRead = 2; + constexpr index_t InBlockCopyDataPerRead = 2; constexpr index_t WeiBlockCopyDataPerRead = 2; constexpr index_t BlockSize = 128; #elif 0 // 3x3, 34x34, 128 thread, int8, vector = 4 - constexpr index_t NPerBlock = 2; - constexpr index_t KPerBlock = 32; - constexpr index_t CPerBlock = 8; + constexpr index_t NPerBlock = 2; + constexpr index_t KPerBlock = 32; + constexpr index_t CPerBlock = 8; constexpr index_t HoPerBlock = 4; constexpr index_t WoPerBlock = 32; - constexpr index_t NPerThread = 1; - constexpr index_t KPerThread = 8; - constexpr index_t CPerThread = 2; + constexpr index_t NPerThread = 1; + constexpr index_t KPerThread = 8; + constexpr index_t CPerThread = 2; constexpr index_t HoPerThread = 4; constexpr index_t WoPerThread = 2; - constexpr index_t InBlockCopyDataPerRead = 2; + constexpr index_t InBlockCopyDataPerRead = 2; constexpr index_t WeiBlockCopyDataPerRead = 2; constexpr index_t BlockSize = 128; diff --git a/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp index 08fd5ee44d..c6323581fd 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp @@ -248,16 +248,15 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t BlockSize = 128; #elif 1 - // for 1x1, 14x14 + // for 1x1, 14x14, Pascal constexpr index_t NPerBlock = 16; constexpr index_t KPerBlock = 128; constexpr index_t CPerBlock = 8; constexpr index_t HoPerBlock = 2; constexpr index_t WoPerBlock = 2; - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 16; - constexpr index_t CPerThread = 1; + constexpr index_t NPerThread = 8; + constexpr index_t KPerThread = 8; constexpr index_t HoPerThread = 1; constexpr index_t WoPerThread = 1; @@ -265,8 +264,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t InBlockCopy_ThreadPerDimC = 8; @@ -278,6 +277,37 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t OutThreadCopyDataPerWrite = 2; + constexpr index_t BlockSize = 128; +#elif 1 + // for 1x1, 14x14, Pascal, try + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 1; + constexpr index_t WoPerBlock = 4; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t InBlockCopy_ThreadPerDimC = 8; + constexpr index_t InBlockCopy_ThreadPerDimH = 1; + constexpr index_t InBlockCopy_ThreadPerDimW = 4; + constexpr index_t InBlockCopy_ThreadPerDimN = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t OutThreadCopyDataPerWrite = 4; + constexpr index_t BlockSize = 128; #endif diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index 1bf7921abe..4e9c147186 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -69,7 +69,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); -#if 0 +#if 1 // 3x3, 34x34 // need to use register double buffer for GEMM constexpr index_t BPerBlock = 128; @@ -87,9 +87,6 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t GemmNLevel1Cluster = 8; constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmThreadPerColumnPerCluster = 8; - constexpr index_t GemmThreadPerRowPerCluster = 8; - constexpr index_t InBlockCopyThreadPerDim0 = 4; constexpr index_t InBlockCopyThreadPerDim1 = 16; @@ -98,6 +95,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t InBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t OutThreadCopyDataPerWrite = 4; constexpr index_t BlockSize = 128; #elif 0 @@ -214,8 +212,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim1 = 16; - constexpr index_t InBlockCopyDataPerRead = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t OutThreadCopyDataPerWrite = 4; constexpr index_t BlockSize = 128; @@ -242,8 +240,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim1 = 16; - constexpr index_t InBlockCopyDataPerRead = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t OutThreadCopyDataPerWrite = 4; constexpr index_t BlockSize = 256; diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index 53256a14ff..7bc8b0897e 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -353,7 +353,7 @@ void host_winograd_3x3_convolution(const Tensor& in_nchw, std::size_t ho = HoPerTile * htile + j; for(int i = 0; i < WoPerTile; ++i) { - std::size_t wo = WoPerTile * wtile + i; + std::size_t wo = WoPerTile * wtile + i; out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i); } } @@ -409,7 +409,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 0 +#elif 1 // 3x3, 34x34 constexpr index_t N = 64; constexpr index_t C = 256; @@ -580,7 +580,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 1 +#elif 0 // 1x1 filter, 14x14 image, C = 2048 constexpr index_t N = 128; constexpr index_t C = 2048; @@ -592,7 +592,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 0 +#elif 1 // 1x1 filter, 14x14 image, C = 512 constexpr index_t N = 128; constexpr index_t C = 512; @@ -663,7 +663,7 @@ int main(int argc, char* argv[]) device_direct_convolution_2_vectorized_nchw_kcyx_nkhw #elif 1 device_implicit_gemm_convolution_1_chwn_cyxk_khwn -#elif 0 +#elif 1 device_implicit_gemm_convolution_2_chwn_cyxk_khwn #endif (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); diff --git a/src/include/ConstantTensorDescriptor.hip.hpp b/src/include/ConstantTensorDescriptor.hip.hpp index d36a752eb4..2a0b430566 100644 --- a/src/include/ConstantTensorDescriptor.hip.hpp +++ b/src/include/ConstantTensorDescriptor.hip.hpp @@ -44,6 +44,32 @@ __host__ __device__ constexpr auto 1>{}; } +// this is ugly, only for 8d +template +__host__ __device__ constexpr auto + calculate_default_strides(Sequence) +{ + return Sequence{}; +} + // this is ugly, only for 2d template __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence, diff --git a/src/include/blockwise_4d_tensor_op.hip.hpp b/src/include/blockwise_4d_tensor_op.hip.hpp index 685bc67eea..8dc0f3a107 100644 --- a/src/include/blockwise_4d_tensor_op.hip.hpp +++ b/src/include/blockwise_4d_tensor_op.hip.hpp @@ -340,11 +340,10 @@ struct BlockwiseChwnTensorCopyPadded constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize; const Float* p_src_tmp = - p_src + - src_desc.Get1dIndex(c_block_data_begin, - (ho_block_data_begin + h_block_pad_low) - h_global_pad_low, - (wo_block_data_begin + w_block_pad_low) - w_global_pad_low, - n_block_data_begin); + p_src + src_desc.Get1dIndex(c_block_data_begin, + (ho_block_data_begin + h_block_pad_low) - h_global_pad_low, + (wo_block_data_begin + w_block_pad_low) - w_global_pad_low, + n_block_data_begin); #if 0 if(get_thread_local_1d_id() == 0) diff --git a/src/include/blockwise_batched_gemm.hip.hpp b/src/include/blockwise_batched_gemm.hip.hpp index 30746fb82c..3ae67a2062 100644 --- a/src/include/blockwise_batched_gemm.hip.hpp +++ b/src/include/blockwise_batched_gemm.hip.hpp @@ -329,9 +329,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 { threadwise_matrix_copy( c_thread_sub_mtx, - p_c_thread + - c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster, - n_repeat * NPerLevel1Cluster), + p_c_thread + c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster, + n_repeat * NPerLevel1Cluster), c_block_mtx, p_c_block + c_block_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster, diff --git a/src/include/blockwise_direct_convolution.hip.hpp b/src/include/blockwise_direct_convolution.hip.hpp index 3aff3b7936..d731e2258a 100644 --- a/src/include/blockwise_direct_convolution.hip.hpp +++ b/src/include/blockwise_direct_convolution.hip.hpp @@ -93,11 +93,10 @@ __device__ void blockwise_direct_convolution(InBlockDesc, Float p_out_thread[out_thread_desc.GetElementSpace()]; threadwise_4d_tensor_copy(out_block_desc, - p_out_block + - out_block_desc.Get1dIndex(n_thread_data_begin, - k_thread_data_begin, - ho_thread_data_begin, - wo_thread_data_begin), + p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin, + k_thread_data_begin, + ho_thread_data_begin, + wo_thread_data_begin), out_thread_desc, p_out_thread, out_thread_desc.GetLengths()); @@ -108,11 +107,10 @@ __device__ void blockwise_direct_convolution(InBlockDesc, // threadwise convolution threadwise_direct_convolution_2( in_thread_block_desc, - p_in_block + - in_block_desc.Get1dIndex(n_thread_data_begin, - c_thread_data_begin, - hi_thread_data_begin, - wi_thread_data_begin), + p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin, + c_thread_data_begin, + hi_thread_data_begin, + wi_thread_data_begin), wei_thread_block_desc, p_wei_block + wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data_begin, 0, 0), @@ -124,11 +122,10 @@ __device__ void blockwise_direct_convolution(InBlockDesc, threadwise_4d_tensor_copy(out_thread_desc, p_out_thread, out_block_desc, - p_out_block + - out_block_desc.Get1dIndex(n_thread_data_begin, - k_thread_data_begin, - ho_thread_data_begin, - wo_thread_data_begin), + p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin, + k_thread_data_begin, + ho_thread_data_begin, + wo_thread_data_begin), out_thread_desc.GetLengths()); } } diff --git a/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp index 64490c765b..e6f41207d6 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp @@ -40,12 +40,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn const Float* const __restrict__ p_wei_global, Float* const __restrict__ p_out_global) const { - // NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N] - // for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N" - // if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock + // be careful of this assertion static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0"); - static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock, - "wrong!"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -172,16 +168,13 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn constexpr index_t max_align = mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead); - constexpr index_t in_block_space = - in_chwn_block_desc.GetElementSpace(Number{}); + constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number{}); constexpr index_t wei_block_space = wei_cyxk_block_desc.GetElementSpace(Number{}); - __shared__ Float - p_in_block[in_block_space]; - __shared__ Float - p_wei_block[wei_block_space]; + __shared__ Float p_in_block[in_block_space]; + __shared__ Float p_wei_block[wei_block_space]; // register Float p_out_thread[out_khwn_thread_desc.GetElementSpace()]; @@ -190,9 +183,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread); const Float* p_in_global_block_begin = - p_in_global + - in_chwn_global_desc.Get1dIndex( - 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); + p_in_global + in_chwn_global_desc.Get1dIndex( + 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); const Float* p_wei_global_block_begin = p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); @@ -269,26 +261,32 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; // this is for v2 GEMM - // output is a 8d tensor - if(NPerThread < NPerBlock && WoPerThread == 1) + // output is a 10d tensor + if(NPerThread <= NPerBlock) { - constexpr index_t N1_ = GemmNPerThreadSubC; - constexpr index_t W1_ = WoPerBlock / ((WoPerThread * NPerThread) / GemmNPerThreadSubC); - constexpr index_t K2_ = GemmMPerThreadSubC; - constexpr index_t K1_ = KPerBlock / KPerThread; + constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N1 = NPerBlock / N2; - constexpr auto out_8d_global_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC); + constexpr index_t W1 = WoPerBlock / W2; - constexpr auto out_8d_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; + + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); #if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) @@ -301,25 +299,21 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn } #endif - threadwise_8d_tensor_copy( - out_8d_thread_desc, + threadwise_10d_tensor_copy( + out_10d_thread_desc, p_out_thread, - out_8d_global_desc, + out_10d_global_desc, p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, ho_block_data_begin + ho_thread_data_begin, wo_block_data_begin + wo_thread_data_begin, n_block_data_begin + n_thread_data_begin), - out_8d_thread_desc.GetLengths(), + out_10d_thread_desc.GetLengths(), Number{}); } - else if(NPerThread == NPerBlock) - { - // not implemented yet - assert(false); - } else { + // no implemented yet assert(false); } #endif diff --git a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp index 1e2e365f5f..040ba1ec2f 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp @@ -31,11 +31,10 @@ template + index_t WeiBlockCopyDataPerRead, + index_t OutThreadCopyDataPerWrite> struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn { - __host__ __device__ constexpr GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn() {} - __device__ void Run(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global, Float* const __restrict__ p_out_global) const @@ -232,7 +231,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn { for(index_t x = 0; x < X; ++x) { -#if 0 +#if 1 blockwise_gemm.Run #elif 0 blockwise_gemm.Run_RegisterDoubleBuffer diff --git a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp index 554624bf48..623049d589 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp @@ -273,9 +273,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer #elif 0 blockwise_gemm.Run_asm #endif - (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_now + y * Wi + x, - p_out_thread); + (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_now + y * Wi + x, + p_out_thread); } } @@ -322,9 +322,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer #elif 0 blockwise_gemm.Run_asm #endif - (p_wei_block_double + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_double + y * Wi + x, - p_out_thread); + (p_wei_block_double + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_double + y * Wi + x, + p_out_thread); } } @@ -356,10 +356,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer #elif 0 blockwise_gemm.Run_asm #endif - (p_wei_block_double + wei_block_space + - wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_double + in_block_space + y * Wi + x, - p_out_thread); + (p_wei_block_double + wei_block_space + + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_double + in_block_space + y * Wi + x, + p_out_thread); } } } @@ -387,14 +387,13 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence{}); - threadwise_6d_tensor_copy( - out_6d_thread_desc, - p_out_thread, - out_6d_global_desc, - p_out_global + - out_kb_global_desc.Get1dIndex(k_thread_data_begin, b_thread_data_begin), - out_6d_thread_desc.GetLengths(), - Number{}); + threadwise_6d_tensor_copy(out_6d_thread_desc, + p_out_thread, + out_6d_global_desc, + p_out_global + out_kb_global_desc.Get1dIndex( + k_thread_data_begin, b_thread_data_begin), + out_6d_thread_desc.GetLengths(), + Number{}); } else #endif diff --git a/src/include/gridwise_direct_convolution_1.hip.hpp b/src/include/gridwise_direct_convolution_1.hip.hpp index 7723fb78b4..29c7e86b37 100644 --- a/src/include/gridwise_direct_convolution_1.hip.hpp +++ b/src/include/gridwise_direct_convolution_1.hip.hpp @@ -113,11 +113,10 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ c_block_work_begin += CPerBlock) { // copy input tensor to LDS - blockwise_in_copy.Run(p_in_global + - in_global_desc.Get1dIndex(n_block_work_begin, - c_block_work_begin, - hi_block_work_begin, - wi_block_work_begin), + blockwise_in_copy.Run(p_in_global + in_global_desc.Get1dIndex(n_block_work_begin, + c_block_work_begin, + hi_block_work_begin, + wi_block_work_begin), p_in_block); // copy weight tensor to LDS @@ -144,9 +143,9 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ } // copy output tensor from LDS to device mem - blockwise_out_copy.Run( - p_out_block, - p_out_global + - out_global_desc.Get1dIndex( - n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin)); + blockwise_out_copy.Run(p_out_block, + p_out_global + out_global_desc.Get1dIndex(n_block_work_begin, + k_block_work_begin, + ho_block_work_begin, + wo_block_work_begin)); } diff --git a/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp b/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp index b301fc1e52..3cb3216917 100644 --- a/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp +++ b/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp @@ -175,18 +175,16 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i c_block_data_begin += CPerBlock, __syncthreads()) { // copy input tensor to LDS - blockwise_in_copy.Run(p_in_global + - in_nchw_global_desc.Get1dIndex(n_block_data_begin, - c_block_data_begin, - hi_block_data_begin, - wi_block_data_begin), + blockwise_in_copy.Run(p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin, + c_block_data_begin, + hi_block_data_begin, + wi_block_data_begin), p_in_block); // copy weight tensor to LDS - blockwise_wei_copy.Run( - p_wei_global + - wei_kcyx_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0), - p_wei_block); + blockwise_wei_copy.Run(p_wei_global + wei_kcyx_global_desc.Get1dIndex( + k_block_data_begin, c_block_data_begin, 0, 0), + p_wei_block); __syncthreads(); @@ -196,11 +194,10 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i #if 1 threadwise_direct_convolution_2( in_nchw_thread_block_desc, - p_in_block + - in_nchw_block_desc.Get1dIndex(n_thread_data_begin, - c_thread_data, - hi_thread_data_begin, - wi_thread_data_begin), + p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin, + c_thread_data, + hi_thread_data_begin, + wi_thread_data_begin), wei_kcyx_thread_block_desc, p_wei_block + wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), @@ -209,11 +206,10 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i #elif 0 threadwise_direct_convolution_3( in_nchw_thread_block_desc, - p_in_block + - in_nchw_block_desc.Get1dIndex(n_thread_data_begin, - c_thread_data, - hi_thread_data_begin, - wi_thread_data_begin), + p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin, + c_thread_data, + hi_thread_data_begin, + wi_thread_data_begin), wei_kcyx_thread_block_desc, p_wei_block + wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), @@ -228,10 +224,9 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i out_nkhw_thread_desc, p_out_thread, out_nkhw_global_desc, - p_out_global + - out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), + p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin), out_nkhw_thread_desc.GetLengths()); } diff --git a/src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp b/src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp index 250253f2ff..4dafaa055e 100644 --- a/src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp +++ b/src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp @@ -198,10 +198,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( p_in_vec_block); // copy weight tensor to LDS - blockwise_wei_copy.Run( - p_wei_vec_global + - wei_kcyx_vec_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0), - p_wei_vec_block); + blockwise_wei_copy.Run(p_wei_vec_global + wei_kcyx_vec_global_desc.Get1dIndex( + k_block_data_begin, c_block_data_begin, 0, 0), + p_wei_vec_block); __syncthreads(); @@ -211,11 +210,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( #if 1 threadwise_direct_convolution_2( in_nchw_vec_thread_block_desc, - p_in_vec_block + - in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin, - c_thread_data, - hi_thread_data_begin, - wi_thread_data_begin), + p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin, + c_thread_data, + hi_thread_data_begin, + wi_thread_data_begin), wei_kcyx_vec_thread_block_desc, p_wei_vec_block + wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), @@ -224,11 +222,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( #elif 0 threadwise_direct_convolution_3( in_nchw_vec_thread_block_desc, - p_in_vec_block + - in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin, - c_thread_data, - hi_thread_data_begin, - wi_thread_data_begin), + p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin, + c_thread_data, + hi_thread_data_begin, + wi_thread_data_begin), wei_kcyx_vec_thread_block_desc, p_wei_vec_block + wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), @@ -243,10 +240,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( out_nkhw_thread_desc, p_out_thread, out_nkhw_global_desc, - p_out_global + - out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), + p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin), out_nkhw_thread_desc.GetLengths()); } diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp index f04a283fcf..fe1ee2191f 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp @@ -283,11 +283,10 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( out_hkwn_thread_desc, p_out_thread, out_khwn_global_desc, - p_out_global + - out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin), + p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), out_hkwn_thread_desc.GetLengths(), reorder_khwn_from_hkwn); } diff --git a/src/include/tensor.hpp b/src/include/tensor.hpp index 1ebfef0c5d..d0c785c16e 100644 --- a/src/include/tensor.hpp +++ b/src/include/tensor.hpp @@ -22,7 +22,8 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) return os; } -typedef enum { +typedef enum +{ Half = 0, Float = 1, } DataType_t; diff --git a/src/include/threadwise_nd_tensor_op.hip.hpp b/src/include/threadwise_nd_tensor_op.hip.hpp index 2138c9ec07..42e5d1660c 100644 --- a/src/include/threadwise_nd_tensor_op.hip.hpp +++ b/src/include/threadwise_nd_tensor_op.hip.hpp @@ -162,3 +162,118 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc, } } } + +// need to assume src and dst is aligned +template +__device__ void threadwise_10d_tensor_copy(SrcDesc, + const Float* __restrict__ p_src, + DstDesc, + Float* __restrict__ p_dst, + SrcOpLengths, + Number) +{ + using vector_t = typename vector_type::MemoryType; + + static_assert(SrcDesc{}.GetDimension() == 10 && DstDesc{}.GetDimension() == 10 && + SrcOpLengths::nDim == 10, + "wrong! should be 10 dimension"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + constexpr auto I9 = Number<9>{}; + + constexpr auto src_desc = SrcDesc{}; + constexpr auto dst_desc = DstDesc{}; + constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); + + static_assert(SrcDesc{}.GetStride(I9) == 1 && DstDesc{}.GetStride(I9) == 1, + "wrong! only support stride7 == 1!\n"); + + static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4, + "wrong! only support DataPerRead == 1, 2 or 4!\n"); + + static_assert(SrcDesc{}.GetStride(I8) % DataPerRead == 0 && + DstDesc{}.GetStride(I8) % DataPerRead == 0, + "wrong! src and dst stride should be multiple of DataPerRead to keep alignment"); + + constexpr index_t L9 = SrcOpLengths{}.Get(I9); + + static_assert(L9 % DataPerRead == 0, "wrong! L9 should be evenly divided by DataPerRead"); + + constexpr index_t nloop_d9 = L9 / DataPerRead; + +#pragma unroll + for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) + { +#pragma unroll + for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) + { +#pragma unroll + for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) + { +#pragma unroll + for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) + { +#pragma unroll + for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4) + { +#pragma unroll + for(index_t did5 = 0; did5 < ref_desc.GetLength(I5); ++did5) + { +#pragma unroll + for(index_t did6 = 0; did6 < ref_desc.GetLength(I6); ++did6) + { +#pragma unroll + for(index_t did7 = 0; did7 < ref_desc.GetLength(I7); ++did7) + { +#pragma unroll + for(index_t did8 = 0; did8 < ref_desc.GetLength(I8); ++did8) + { +#pragma unroll + for(index_t iloop_d9 = 0; iloop_d9 < nloop_d9; ++iloop_d9) + { + const index_t src_index = + src_desc.Get1dIndex(did0, + did1, + did2, + did3, + did4, + did5, + did6, + did7, + did8, + iloop_d9 * DataPerRead); + + const index_t dst_index = + dst_desc.Get1dIndex(did0, + did1, + did2, + did3, + did4, + did5, + did6, + did7, + did8, + iloop_d9 * DataPerRead); + + *(reinterpret_cast(p_dst + dst_index)) = + *(reinterpret_cast(p_src + + src_index)); + } + } + } + } + } + } + } + } + } + } +}