diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index aa0c610e6d..acac1d09fa 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -451,7 +451,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 1 +#elif 0 // 3x3 filter, 28x28 image constexpr index_t N = 128; constexpr index_t C = 256; @@ -499,6 +499,18 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 1; constexpr index_t WPad = 1; +#elif 1 + // 5x5 filter, 20x86 image + constexpr index_t N = 16; + constexpr index_t C = 256; + constexpr index_t HI = 20; + constexpr index_t WI = 86; + constexpr index_t K = 512; + constexpr index_t Y = 5; + constexpr index_t X = 5; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; #elif 0 // 5x5 filter, 20x86 image, 1x1 padding constexpr index_t N = 16; @@ -535,7 +547,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 0 +#elif 10 // 1x1 filter, 14x14 image constexpr index_t N = 128; constexpr index_t C = 512; diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp index 7ce9039a20..ff1d024346 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp @@ -117,15 +117,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn // blockwise copy // input: format is [C, Hi, Wi, N] -#if 0 - const auto blockwise_in_copy = - Blockwise4dTensorCopy1{}; -#else const auto blockwise_in_copy = Blockwise4dTensorCopy3{}; -#endif // blockwise wei copy // format is [CPerBlock, X * KPerBlock] const auto blockwise_wei_copy = - Blockwise2dTensorCopy1{}); constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number{}); - __shared__ Float p_in_block[in_block_space]; - __shared__ Float p_wei_block[wei_block_space]; + // LDS double buffer + __shared__ Float p_in_block_double[2 * in_block_space]; + __shared__ Float p_wei_block_double[2 * wei_block_space]; // register Float p_out_thread[out_k_h_w_n_thread_desc.GetElementSpace()]; @@ -225,22 +229,102 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn const Float* p_wei_global_block_offset = p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin); - for(index_t c_block_data_begin = 0; c_block_data_begin < C; - c_block_data_begin += CPerBlock, - p_in_global_block_offset += - CPerBlock * in_c_h_w_n_global_desc.GetStride(I0), - p_wei_global_block_offset += - CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) + // LDS double buffer: preload data into LDS { - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); + Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); + blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); + + blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, + p_in_block_double); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, + p_wei_block_double); + } + + // LDS double buffer: main body + for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C; + c_block_data_begin += 2 * CPerBlock) + { +#pragma unroll + for(index_t iloop = 0; iloop < 2; ++iloop) + { + const bool even_loop = (iloop % 2 == 0); + + Float* p_in_block_now = + even_loop ? p_in_block_double : p_in_block_double + in_block_space; + Float* p_wei_block_now = + even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; + + Float* p_in_block_next = + even_loop ? p_in_block_double + in_block_space : p_in_block_double; + Float* p_wei_block_next = + even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; + + Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; + Float + p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + p_in_global_block_offset += + CPerBlock * in_c_h_w_n_global_desc.GetStride(I0); + p_wei_global_block_offset += + CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); + + // LDS double buffer: GEMM on current data + blockwise_batch_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); + + // LDS double buffer: store next data to LDS + blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, + p_in_block_next); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, + p_wei_block_next); + } + } + + // LDS double buffer: tail + { + Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + // even iteration + p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0); + p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0); __syncthreads(); - blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); + // LDS doubel buffer: load next data from device mem + blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); + // LDS double buffer: GEMM on current data + blockwise_batch_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); + + // LDS double buffer: store next data to LDS + blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, + p_in_block_double + in_block_space); + blockwise_wei_copy.RunStoreRegisterClipboard( + p_wei_register_clipboard, p_wei_block_double + wei_block_space); + + // odd iteration __syncthreads(); + + // LDS double buffer: GEMM on current data + blockwise_batch_gemm.Run(p_wei_block_double + wei_block_space, + p_in_block_double + in_block_space, + p_out_thread); } } }