diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index a3489bc8cc..3aae266e4c 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -1,8 +1,9 @@ #pragma once #include #include "device.hpp" -#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp" -#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp" +#include "gridwise_convolution_wrapper.hip.hpp" +#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp" +//#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp" template void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, @@ -272,7 +273,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, { constexpr auto gridwise_conv = #if 1 - gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn + GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn #else gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer #endif @@ -301,11 +302,12 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, WeiBlockCopyThreadPerDim0, WeiBlockCopyThreadPerDim1, InBlockCopyDataPerRead, - WeiBlockCopyDataPerRead>(); + WeiBlockCopyDataPerRead>{}; - float time = launch_kernel(gridwise_conv.Run, + float time = launch_kernel(run_gridwise_convolution, dim3(GridSize), dim3(BlockSize), + gridwise_conv, static_cast(in_chwn_device_buf.GetDeviceBuffer()), static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), static_cast(out_khwn_device_buf.GetDeviceBuffer())); diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp similarity index 95% rename from src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp rename to src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp index da689bc6b9..5f0d353465 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp @@ -34,10 +34,11 @@ template -class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn +struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn { - public: - __host__ __device__ static index_t GetSharedMemorySize() + __host__ __device__ constexpr GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn() {} + + __host__ __device__ constexpr index_t GetSharedMemoryUsage() const { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -46,7 +47,6 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn constexpr auto in_chwn_global_desc = InGlobalDesc{}; constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; - constexpr auto out_khwn_global_desc = OutGlobalDesc{}; constexpr index_t Hi = in_chwn_global_desc.GetLength(I1); constexpr index_t Wi = in_chwn_global_desc.GetLength(I2); @@ -64,10 +64,6 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( Sequence{}, Number{}); - // tensor view of threadwise output in register - constexpr auto out_kb_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - constexpr index_t max_align = mod_conv::max(InBlockCopyDataPerRead, WeiBlockCopyDataPerRead); @@ -81,9 +77,9 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn return (in_block_element_space + wei_block_element_space) * sizeof(Float); } - __global__ static void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) + __device__ void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) const { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; diff --git a/src/include/gridwise_convolution_wrapper.hip.hpp b/src/include/gridwise_convolution_wrapper.hip.hpp new file mode 100644 index 0000000000..e0abfda3b6 --- /dev/null +++ b/src/include/gridwise_convolution_wrapper.hip.hpp @@ -0,0 +1,10 @@ +#pragma once + +template +__global__ void run_gridwise_convolution(GridwiseConvolution, + const T* const __restrict__ p_in_global, + const T* const __restrict__ p_wei_global, + T* const __restrict__ p_out_global) +{ + GridwiseConvolution{}.Run(p_in_global, p_wei_global, p_out_global); +}