mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
puting gridwise convolution into its own class
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp"
|
||||
#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
|
||||
#include "gridwise_convolution_wrapper.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp"
|
||||
//#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
|
||||
@@ -272,7 +273,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
|
||||
{
|
||||
constexpr auto gridwise_conv =
|
||||
#if 1
|
||||
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
|
||||
GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
|
||||
#else
|
||||
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
|
||||
#endif
|
||||
@@ -301,11 +302,12 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1,
|
||||
InBlockCopyDataPerRead,
|
||||
WeiBlockCopyDataPerRead>();
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
|
||||
float time = launch_kernel(gridwise_conv.Run,
|
||||
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
gridwise_conv,
|
||||
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
|
||||
|
||||
@@ -34,10 +34,11 @@ template <index_t GridSize,
|
||||
index_t WeiBlockCopyThreadPerDim1,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead>
|
||||
class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
|
||||
struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
|
||||
{
|
||||
public:
|
||||
__host__ __device__ static index_t GetSharedMemorySize()
|
||||
__host__ __device__ constexpr GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn() {}
|
||||
|
||||
__host__ __device__ constexpr index_t GetSharedMemoryUsage() const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -46,7 +47,6 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
|
||||
|
||||
constexpr auto in_chwn_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t Hi = in_chwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wi = in_chwn_global_desc.GetLength(I2);
|
||||
@@ -64,10 +64,6 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
|
||||
constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, Y, X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_kb_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerThread, BPerThread>{});
|
||||
|
||||
constexpr index_t max_align =
|
||||
mod_conv::max(InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
|
||||
|
||||
@@ -81,9 +77,9 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
|
||||
return (in_block_element_space + wei_block_element_space) * sizeof(Float);
|
||||
}
|
||||
|
||||
__global__ static void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global)
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
10
src/include/gridwise_convolution_wrapper.hip.hpp
Normal file
10
src/include/gridwise_convolution_wrapper.hip.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
template <class GridwiseConvolution, class T>
|
||||
__global__ void run_gridwise_convolution(GridwiseConvolution,
|
||||
const T* const __restrict__ p_in_global,
|
||||
const T* const __restrict__ p_wei_global,
|
||||
T* const __restrict__ p_out_global)
|
||||
{
|
||||
GridwiseConvolution{}.Run(p_in_global, p_wei_global, p_out_global);
|
||||
}
|
||||
Reference in New Issue
Block a user