diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index 438408ca7e..d8e45bd3fe 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -193,6 +193,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t BlockSize = 256; #elif 0 // 1x1, 14x14, Vega 20, disable lds_double_buffer, enable register double buffer + // 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer constexpr index_t BPerBlock = 64; constexpr index_t KPerBlock = 128; constexpr index_t CPerBlock = 8; @@ -266,7 +267,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, for(index_t i = 0; i < nrepeat; ++i) { constexpr auto gridwise_conv = -#if 1 +#if 0 GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn #else GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index a83e4082c7..0ea091e607 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -580,7 +580,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 0 +#elif 1 // 1x1 filter, 14x14 image, C = 2048 constexpr index_t N = 128; constexpr index_t C = 2048; diff --git a/src/include/data_type.hip.hpp b/src/include/data_type.hip.hpp index 54bed9ec5a..ee44eebbe7 100644 --- a/src/include/data_type.hip.hpp +++ b/src/include/data_type.hip.hpp @@ -15,9 +15,13 @@ struct vector_type template <> struct vector_type { -#if 1 +#if DEVICE_BACKEND_HIP + // For some reason, HIP compiler need this definition to generate optimal load and store instruction typedef float MemoryType __attribute__((ext_vector_type(2))); -#else +#elif DEVICE_BACKEND_CUDA + // For some reason, CUDA need this definition to, otherwise + // compiler won't generate optimal load and store instruction, and + // kernel would produce wrong result, indicating the compiler fail to generate correct instruction, using MemoryType = float2; #endif @@ -38,9 +42,13 @@ struct vector_type template <> struct vector_type { -#if 1 +#if DEVICE_BACKEND_HIP + // For some reason, HIP compiler need this definition to generate optimal load and store instruction typedef float MemoryType __attribute__((ext_vector_type(4))); -#else +#elif DEVICE_BACKEND_CUDA + // For some reason, CUDA need this definition to, otherwise + // compiler won't generate optimal load and store instruction, and + // kernel would produce wrong result, indicating the compiler fail to generate correct instruction, using MemoryType = float4; #endif }; diff --git a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp index b919036df8..f1cd81b32b 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp @@ -204,8 +204,18 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer // preload data into LDS { #if 1 - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_double); - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_double); + Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); + + blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, + p_in_block_double); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, + p_wei_block_double); #elif 0 Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; @@ -363,9 +373,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer #elif 1 blockwise_gemm.Run_asm #endif - (p_wei_block_double + in_block_space + + (p_wei_block_double + wei_block_space + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_double + wei_block_space + y * Wi + x, + p_in_block_double + in_block_space + y * Wi + x, p_out_thread); } }