From c075d3f7d91079d28340cda89d51e15117493968 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 8 Apr 2019 12:02:56 -0500 Subject: [PATCH] add more assertion --- ...icit_gemm_convolution_1_chwn_cyxk_khwn.hpp | 31 ----- driver/driver.hip.cpp | 6 +- src/include/ConstantTensorDescriptor.hip.hpp | 73 +++++++++- src/include/blockwise_batched_gemm.hip.hpp | 2 +- ...on_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp | 125 +++++++++--------- 5 files changed, 136 insertions(+), 101 deletions(-) diff --git a/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp index c6323581fd..83962a6b3a 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp @@ -277,37 +277,6 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t OutThreadCopyDataPerWrite = 2; - constexpr index_t BlockSize = 128; -#elif 1 - // for 1x1, 14x14, Pascal, try - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 1; - constexpr index_t WoPerBlock = 4; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - - constexpr index_t InBlockCopy_ThreadPerDimC = 8; - constexpr index_t InBlockCopy_ThreadPerDimH = 1; - constexpr index_t InBlockCopy_ThreadPerDimW = 4; - constexpr index_t InBlockCopy_ThreadPerDimN = 4; - constexpr index_t InBlockCopyDataPerRead = 4; - - constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr index_t OutThreadCopyDataPerWrite = 4; - constexpr index_t BlockSize = 128; #endif diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index 0ea091e607..5061ab85fe 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -580,7 +580,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 1 +#elif 0 // 1x1 filter, 14x14 image, C = 2048 constexpr index_t N = 128; constexpr index_t C = 2048; @@ -661,9 +661,9 @@ int main(int argc, char* argv[]) device_direct_convolution_2_nchw_kcyx_nkhw #elif 0 device_direct_convolution_2_vectorized_nchw_kcyx_nkhw -#elif 0 - device_implicit_gemm_convolution_1_chwn_cyxk_khwn #elif 1 + device_implicit_gemm_convolution_1_chwn_cyxk_khwn +#elif 0 device_implicit_gemm_convolution_2_chwn_cyxk_khwn #endif (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); diff --git a/src/include/ConstantTensorDescriptor.hip.hpp b/src/include/ConstantTensorDescriptor.hip.hpp index 2a0b430566..411e46f83f 100644 --- a/src/include/ConstantTensorDescriptor.hip.hpp +++ b/src/include/ConstantTensorDescriptor.hip.hpp @@ -229,7 +229,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) constexpr auto desc = TDesc{}; constexpr index_t ndim = desc.GetDimension(); - static_assert(ndim >= 2 && ndim <= 8, "wrong!"); + static_assert(ndim >= 2 && ndim <= 10, "wrong!"); if(ndim == 2) { @@ -369,4 +369,75 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) desc.GetStride(I6), desc.GetStride(I7)); } + else if(ndim == 9) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u %u}\n", + s, + desc.GetDimension(), + desc.GetLength(I0), + desc.GetLength(I1), + desc.GetLength(I2), + desc.GetLength(I3), + desc.GetLength(I4), + desc.GetLength(I5), + desc.GetLength(I6), + desc.GetLength(I7), + desc.GetLength(I8), + desc.GetStride(I0), + desc.GetStride(I1), + desc.GetStride(I2), + desc.GetStride(I3), + desc.GetStride(I4), + desc.GetStride(I5), + desc.GetStride(I6), + desc.GetStride(I7), + desc.GetStride(I8)); + } + else if(ndim == 10) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + constexpr auto I9 = Number<9>{}; + + printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u %u %u}\n", + s, + desc.GetDimension(), + desc.GetLength(I0), + desc.GetLength(I1), + desc.GetLength(I2), + desc.GetLength(I3), + desc.GetLength(I4), + desc.GetLength(I5), + desc.GetLength(I6), + desc.GetLength(I7), + desc.GetLength(I8), + desc.GetLength(I9), + desc.GetStride(I0), + desc.GetStride(I1), + desc.GetStride(I2), + desc.GetStride(I3), + desc.GetStride(I4), + desc.GetStride(I5), + desc.GetStride(I6), + desc.GetStride(I7), + desc.GetStride(I8), + desc.GetStride(I9)); + } } diff --git a/src/include/blockwise_batched_gemm.hip.hpp b/src/include/blockwise_batched_gemm.hip.hpp index 30746fb82c..645f159eb6 100644 --- a/src/include/blockwise_batched_gemm.hip.hpp +++ b/src/include/blockwise_batched_gemm.hip.hpp @@ -138,7 +138,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; } - // this should be optimized away if input is known + // this should be optimized away because input will be known at compile time __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c) { diff --git a/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp index c976f4d8b2..02a9745211 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp @@ -41,7 +41,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn Float* const __restrict__ p_out_global) const { // be careful of this assertion - static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0"); + static_assert(NPerThread <= NPerBlock && NPerBlock % NPerThread == 0, + "wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -66,6 +67,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn constexpr index_t WiPerBlock = WoPerBlock + X - 1; // divide block work: [K, Ho, Wo, N] + static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, + "wrong! cannot evenly divide work for workgroup "); + constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; @@ -218,39 +222,39 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn // output: register to global mem, #if 0 - const auto c_thread_mtx_begin = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + const auto c_thread_mtx_begin = + blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k) - { - for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho) + for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k) { - for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo) + for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho) { - for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n) + for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo) { - const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n); + for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n) + { + const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n); - const auto c_thread_mtx_distance = - blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b); + const auto c_thread_mtx_distance = + blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b); - const index_t ho_thread = - c_thread_mtx_begin.batch + c_thread_mtx_distance.batch; - const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row; - const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col; + const index_t ho_thread = + c_thread_mtx_begin.batch + c_thread_mtx_distance.batch; + const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row; + const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col; - const index_t wo_thread = b_thread / NPerBlock; - const index_t n_thread = b_thread % NPerBlock; + const index_t wo_thread = b_thread / NPerBlock; + const index_t n_thread = b_thread % NPerBlock; - p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread, - ho_block_data_begin + ho_thread, - wo_block_data_begin + wo_thread, - n_block_data_begin + n_thread)] = - p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)]; + p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread, + ho_block_data_begin + ho_thread, + wo_block_data_begin + wo_thread, + n_block_data_begin + n_thread)] = + p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)]; + } } } } - } #elif 1 const auto c_thread_mtx_begin = blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); @@ -261,63 +265,54 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn const index_t n_thread_data_begin = c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; - // this is for v2 GEMM // output is a 10d tensor - if(NPerThread <= NPerBlock) - { - constexpr index_t N2 = GemmNPerThreadSubC; - constexpr index_t N1 = NPerBlock / N2; + constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N1 = NPerBlock / N2; - constexpr index_t W2 = - (GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC); - constexpr index_t W1 = WoPerBlock / W2; + constexpr index_t W2 = + (GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC); + constexpr index_t W1 = WoPerBlock / W2; - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; + constexpr index_t K2 = GemmMPerThreadSubC; + constexpr index_t K1 = KPerBlock / KPerThread; - constexpr auto out_10d_global_desc = - make_ConstantTensorDescriptor(Sequence{}); + constexpr auto out_10d_global_desc = + make_ConstantTensorDescriptor(Sequence{}); - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); #if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) { print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc"); - print_ConstantTensorDescriptor(out_8d_thread_desc, "out_8d_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc"); - print_ConstantTensorDescriptor(out_8d_global_desc, "out_8d_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); } #endif - threadwise_10d_tensor_copy( - out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + - out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin), - out_10d_thread_desc.GetLengths(), - Number{}); - } - else - { - // no implemented yet - assert(false); - } + threadwise_10d_tensor_copy( + out_10d_thread_desc, + p_out_thread, + out_10d_global_desc, + p_out_global + + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin, + n_block_data_begin + n_thread_data_begin), + out_10d_thread_desc.GetLengths(), + Number{}); #endif } };