From 50b96745c68d17c3c03b4492d23867eb5e859aa7 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 17 Feb 2019 02:28:20 -0600 Subject: [PATCH] gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn use khwn for thread C data now --- ...icit_gemm_convolution_1_chwn_csrk_khwn.hpp | 2 +- ..._gemm_convolution_1_chwn_csrk_khwn.hip.hpp | 41 +++++++++---------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp index 3e92a157eb..246d331fb5 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp @@ -200,7 +200,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned WoPerThread = 1; constexpr unsigned BlockSize = 128; -#elif 1 +#elif 0 // for 1x1, 28x28 constexpr unsigned NPerBlock = 16; constexpr unsigned KPerBlock = 128; diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp index edde97b893..9d12f2b1ec 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp @@ -104,8 +104,8 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric Sequence{}, Number{}); // tensor view of threadwise output in register - constexpr auto out_hkwn_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); + constexpr auto out_khwn_thread_desc = + make_ConstantTensorDescriptor(Sequence{}); #if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) @@ -179,7 +179,9 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric Number{}); constexpr auto c_kxwn_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); #if 0 const auto blockwise_batch_gemm = @@ -192,7 +194,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric false, 0, in_chwn_block_desc.GetStride(I1), - out_hkwn_thread_desc.GetStride(I0), + out_khwn_thread_desc.GetStride(I1), HoPerBlock, HoPerThread, GemmKPerThreadLoop, @@ -205,7 +207,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric decltype(c_kxwn_thread_mtx_desc), 0, in_chwn_block_desc.GetStride(I1), - out_hkwn_thread_desc.GetStride(I0), + out_khwn_thread_desc.GetStride(I1), HoPerBlock, GemmMPerThreadSubC, GemmNPerThreadSubC, @@ -230,10 +232,10 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; // register - Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()]; + Float p_out_thread[out_khwn_thread_desc.GetElementSpace()]; // set threadwise output tensor to 0 - threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread); + threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread); const Float* p_in_global_block_begin = p_in_global + in_chwn_global_desc.Get1dIndex( @@ -275,33 +277,30 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric // convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N] #if 0 // for v1 batch-gemm - const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch; const unsigned k_thread_data_begin = c_thread_mtx_begin.row; + const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch; const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const unsigned n_thread_data_begin = c_thread_mtx_begin.col - wo_thread_data_begin * NPerBlock; - constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{}; - - threadwise_4d_tensor_copy_reorder_by_get_dst_from_src( - out_hkwn_thread_desc, + threadwise_4d_tensor_copy( + out_khwn_thread_desc, p_out_thread, out_khwn_global_desc, p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, ho_block_data_begin + ho_thread_data_begin, wo_block_data_begin + wo_thread_data_begin, n_block_data_begin + n_thread_data_begin), - out_hkwn_thread_desc.GetLengths(), - reorder_khwn_from_hkwn); + out_khwn_thread_desc.GetLengths()); #else - for(unsigned ho = 0; ho < out_hkwn_thread_desc.GetLength(I0); ++ho) + for(unsigned k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k) { - for(unsigned k = 0; k < out_hkwn_thread_desc.GetLength(I1); ++k) + for(unsigned ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho) { - for(unsigned wo = 0; wo < out_hkwn_thread_desc.GetLength(I2); ++wo) + for(unsigned wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo) { - for(unsigned n = 0; n < out_hkwn_thread_desc.GetLength(I3); ++n) + for(unsigned n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n) { - const unsigned b = out_hkwn_thread_desc.Get1dIndex(0, 0, wo, n); + const unsigned b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n); const auto c_thread_mtx_distance = blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b); @@ -312,13 +311,13 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric const unsigned b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col; const unsigned wo_thread = b_thread / NPerBlock; - const unsigned n_thread = b_thread - NPerBlock * wo_thread; + const unsigned n_thread = b_thread % NPerBlock; p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread, ho_block_data_begin + ho_thread, wo_block_data_begin + wo_thread, n_block_data_begin + n_thread)] = - p_out_thread[out_hkwn_thread_desc.Get1dIndex(ho, k, wo, n)]; + p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)]; } } }