From bd0098afb382eb9fa1519bbe03f5a288aeb13b4e Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sat, 6 Apr 2019 18:40:54 -0500 Subject: [PATCH] use dedicated threadwise_copy for 1x1, perf at 80% --- ...icit_gemm_convolution_2_chwn_cyxk_khwn.hpp | 6 +- ...on_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp | 22 +++---- ...2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp | 66 ++++++++++++++----- 3 files changed, 64 insertions(+), 30 deletions(-) diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index d8e45bd3fe..8bd57049e2 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -192,7 +192,6 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t BlockSize = 256; #elif 0 - // 1x1, 14x14, Vega 20, disable lds_double_buffer, enable register double buffer // 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer constexpr index_t BPerBlock = 64; constexpr index_t KPerBlock = 128; @@ -245,6 +244,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t InBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t OutThreadCopyDataPerWrite = 4; + constexpr index_t BlockSize = 256; #endif @@ -295,7 +296,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, WeiBlockCopyThreadPerDim0, WeiBlockCopyThreadPerDim1, InBlockCopyDataPerRead, - WeiBlockCopyDataPerRead>{}; + WeiBlockCopyDataPerRead, + OutThreadCopyDataPerWrite>{}; float time = launch_kernel(run_gridwise_convolution, dim3(GridSize), diff --git a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp index 64d3c03970..1e2e365f5f 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp @@ -206,11 +206,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), __syncthreads()) { -// load data -#if 1 - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); -#elif 0 + // load data Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; @@ -219,9 +215,13 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, p_wei_register_clipboard); - +#if 1 blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block); blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block); +#else + vmcnt(0); + blockwise_in_copy.RunStoreRegisterClipboard_asm(p_in_register_clipboard, p_in_block); + blockwise_wei_copy.RunStoreRegisterClipboard_asm(p_wei_register_clipboard, p_wei_block); #endif __syncthreads(); @@ -232,16 +232,16 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn { for(index_t x = 0; x < X; ++x) { -#if 1 +#if 0 blockwise_gemm.Run #elif 0 blockwise_gemm.Run_RegisterDoubleBuffer -#elif 0 +#elif 1 blockwise_gemm.Run_asm #endif - (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block + y * Wi + x, - p_out_thread); + (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block + y * Wi + x, + p_out_thread); } } } diff --git a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp index 7c867b620f..234f750aca 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp @@ -5,6 +5,7 @@ #include "blockwise_4d_tensor_op.hip.hpp" #include "blockwise_2d_tensor_op.hip.hpp" #include "threadwise_2d_tensor_op.hip.hpp" +#include "threadwise_nd_tensor_op.hip.hpp" #include "blockwise_gemm.hip.hpp" // define B = flatten(N, Hi, Wi) @@ -31,7 +32,8 @@ template + index_t WeiBlockCopyDataPerRead, + index_t OutThreadCopyDataPerWrite> struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer { __device__ void Run(const Float* const __restrict__ p_in_global, @@ -369,25 +371,55 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; - for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) +#if 1 + if(Y == 1 && X == 1) + { // pure 1x1 conv + constexpr index_t K2_ = GemmMPerThreadSubC; + constexpr index_t K1_ = KPerBlock / KPerThread; + constexpr index_t B2_ = GemmNPerThreadSubC; + constexpr index_t B1_ = BPerBlock / BPerThread; + + constexpr auto out_6d_global_desc = make_ConstantTensorDescriptor( + Sequence{}); + + constexpr auto out_6d_thread_desc = make_ConstantTensorDescriptor( + Sequence{}); + + constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence{}); + + threadwise_6d_tensor_copy( + out_6d_thread_desc, + p_out_thread, + out_6d_global_desc, + p_out_global + + out_kb_global_desc.Get1dIndex(k_thread_data_begin, b_thread_data_begin), + out_6d_thread_desc.GetLengths(), + Number{}); + } + else +#endif { - for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) + for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) { - const auto c_thread_mtx_distance = - blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); - - index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row; - index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col; - - index_t h_data = b_data / (Wi * N); - index_t itmp = b_data - h_data * (Wi * N); - index_t w_data = itmp / N; - index_t n_data = itmp - w_data * N; - - if(n_data < N && h_data < Ho && w_data < Wo) + for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) { - p_out_global[out_khwn_global_desc.Get1dIndex(k_data, h_data, w_data, n_data)] = - p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; + const auto c_thread_mtx_distance = + blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); + + index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row; + index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col; + + index_t h_data = b_data / (Wi * N); + index_t itmp = b_data - h_data * (Wi * N); + index_t w_data = itmp / N; + index_t n_data = itmp - w_data * N; + + if(n_data < N && h_data < Ho && w_data < Wo) + { + p_out_global[out_khwn_global_desc.Get1dIndex( + k_data, h_data, w_data, n_data)] = + p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; + } } } }