diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index 497aa3e9c1..bf7cdc8c5a 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -220,7 +220,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t BlockSize = 128; -#elif 0 +#elif 1 // 1x1, 14x14, Vega 20, hack CPerBlock = 1 constexpr index_t BPerBlock = 64; constexpr index_t KPerBlock = 128; @@ -306,6 +306,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, float time = launch_kernel(gridwise_conv.Run, dim3(GridSize), dim3(BlockSize), + gridwise_conv.GetSharedMemoryUsage(), static_cast(in_chwn_device_buf.GetDeviceBuffer()), static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), static_cast(out_khwn_device_buf.GetDeviceBuffer())); diff --git a/src/include/device.hpp b/src/include/device.hpp index eec7dd5395..066866858b 100644 --- a/src/include/device.hpp +++ b/src/include/device.hpp @@ -29,14 +29,14 @@ struct KernelTimer }; template -float launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, Args... args) +float launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { KernelTimer timer; #if DEVICE_BACKEND_HIP timer.Start(); - hipLaunchKernelGGL(kernel, grid_dim, block_dim, 0, 0, args...); + hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, 0, args...); timer.End(); @@ -47,7 +47,7 @@ float launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, Args... args) timer.Start(); - cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, 0, 0); + cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, 0); timer.End(); diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp index da689bc6b9..fd223dfad4 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp @@ -37,7 +37,7 @@ template {}; constexpr auto I1 = Number<1>{}; @@ -46,7 +46,6 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn constexpr auto in_chwn_global_desc = InGlobalDesc{}; constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; - constexpr auto out_khwn_global_desc = OutGlobalDesc{}; constexpr index_t Hi = in_chwn_global_desc.GetLength(I1); constexpr index_t Wi = in_chwn_global_desc.GetLength(I2); @@ -56,29 +55,59 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1); - // tensor view of blockwise input and weight + // tensor view of blockwise input // be careful of alignment constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( Sequence{}, Number{}); - constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // tensor view of threadwise output in register - constexpr auto out_kb_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - + // LDS: be careful of alignment constexpr index_t max_align = mod_conv::max(InBlockCopyDataPerRead, WeiBlockCopyDataPerRead); + return in_cb_block_desc.GetElementSpace(Number{}); + } + + __host__ __device__ constexpr index_t GetWeightBlockElementSpace() const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_chwn_global_desc = InGlobalDesc{}; + constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; + + constexpr index_t Hi = in_chwn_global_desc.GetLength(I1); + constexpr index_t Wi = in_chwn_global_desc.GetLength(I2); + + constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1); + constexpr index_t X = wei_cyxk_global_desc.GetLength(I2); + + constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1); + + // tensor view of blockwise weight + // be careful of alignment + constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + // LDS: be careful of alignment - constexpr index_t in_block_element_space = - in_cb_block_desc.GetElementSpace(Number{}); + constexpr index_t max_align = + mod_conv::max(InBlockCopyDataPerRead, WeiBlockCopyDataPerRead); - constexpr index_t wei_block_element_space = - wei_cyxk_block_desc.GetElementSpace(Number{}); + return wei_cyxk_block_desc.GetElementSpace(Number{}); + } - return (in_block_element_space + wei_block_element_space) * sizeof(Float); + __host__ __device__ constexpr index_t GetSharedMemoryUsage() const + { + + return (GetInputBlockElementSpace() + GetWeightBlockElementSpace()) * sizeof(Float); + } + + __device__ constexpr static Float* GetSharedMemoryBegin() + { + extern __shared__ Float s[]; + + return s; } __global__ static void Run(const Float* const __restrict__ p_in_global, @@ -251,8 +280,8 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn constexpr index_t wei_block_element_space = wei_cyxk_block_desc.GetElementSpace(Number{}); - __shared__ Float p_in_block[in_block_element_space]; - __shared__ Float p_wei_block[wei_block_element_space]; + Float* const p_in_block = GetSharedMemoryBegin(); + Float* const p_wei_block = GetSharedMemoryBegin() + in_block_element_space; const Float* p_in_global_block_offset = p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); @@ -288,7 +317,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn blockwise_gemm.Run #elif 1 blockwise_gemm.Run_RegisterDoubleBuffer -#elif 0 +#elif 1 blockwise_gemm.Run_asm #endif (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),