mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
Refactor for MIOpen integration (#4)
Refactor, so can bring multi-index transformation and padding support into MIOpen
[ROCm/composable_kernel commit: 52c3fe05be]
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#ifndef CONV_COMMON_HPP
|
||||
#define CONV_COMMON_HPP
|
||||
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <class InDesc, class WeiDesc>
|
||||
|
||||
@@ -3,14 +3,17 @@
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_convolution_kernel_wrapper.hpp"
|
||||
#include "convolution_common.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
|
||||
template <class T,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations>
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
@@ -19,6 +22,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -28,9 +33,12 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_desc = InDesc{};
|
||||
constexpr auto wei_kcyx_desc = WeiDesc{};
|
||||
constexpr auto out_nkhw_desc = OutDesc{};
|
||||
constexpr auto in_nchw_desc =
|
||||
make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
|
||||
constexpr auto wei_kcyx_desc =
|
||||
make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
|
||||
constexpr auto out_nkhw_desc =
|
||||
make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides());
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLength(I0);
|
||||
constexpr index_t K = out_nkhw_desc.GetLength(I1);
|
||||
@@ -47,7 +55,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, blockwise-GEMM 128x128, each thread hold 64 data
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 16;
|
||||
@@ -84,7 +92,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// BlockSize = 64, blockwise-GEMM 64x64, each thread hold 64 data
|
||||
// BlockSize = 64, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t BPerBlock = 8;
|
||||
@@ -120,7 +128,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// BlockSize = 256, blockwise-GEMM 64x128, each thread hold 32 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -174,11 +182,15 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
ConvolutionDirection::Forward,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
|
||||
@@ -3,27 +3,23 @@
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_convolution_kernel_wrapper.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp"
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
ck::index_t nrepeat)
|
||||
template <class T,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations>
|
||||
void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -32,12 +28,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc,
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_desc =
|
||||
make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
|
||||
constexpr auto wei_kcyx_desc =
|
||||
make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
|
||||
constexpr auto out_nkhw_desc =
|
||||
make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides());
|
||||
constexpr auto in_nchw_desc = InDesc{};
|
||||
constexpr auto wei_kcyx_desc = WeiDesc{};
|
||||
constexpr auto out_nkhw_desc = OutDesc{};
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLength(I0);
|
||||
constexpr index_t K = out_nkhw_desc.GetLength(I1);
|
||||
@@ -54,7 +47,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc,
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
// BlockSize = 256, blockwise-GEMM 128x128, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 16;
|
||||
@@ -91,7 +84,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc,
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// BlockSize = 64, each thread hold 64 data
|
||||
// BlockSize = 64, blockwise-GEMM 64x64, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t BPerBlock = 8;
|
||||
@@ -127,7 +120,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc,
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// BlockSize = 256, blockwise-GEMM 64x128, each thread hold 32 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -177,48 +170,44 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc,
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
constexpr auto gridwise_conv =
|
||||
#if 0
|
||||
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
|
||||
#else
|
||||
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer
|
||||
#endif
|
||||
<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
GemmNRepeat,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>{};
|
||||
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
ConvolutionDirection::Forward,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
GemmNRepeat,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
@@ -5,14 +5,14 @@
|
||||
#include "gridwise_convolution_kernel_wrapper.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <class T,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations>
|
||||
class ConvDilations,
|
||||
class LeftPads,
|
||||
class RightPads>
|
||||
void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
@@ -21,8 +21,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
@@ -164,7 +168,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
constexpr auto gridwise_conv =
|
||||
#if 0
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded
|
||||
#else
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
#endif
|
||||
@@ -176,6 +180,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
|
||||
@@ -3,30 +3,26 @@
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_convolution_kernel_wrapper.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <class T,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class LeftPads,
|
||||
class RightPads>
|
||||
void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
ck::index_t nrepeat)
|
||||
class ConvDilations>
|
||||
void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
@@ -168,9 +164,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded(InDesc,
|
||||
|
||||
constexpr auto gridwise_conv =
|
||||
#if 0
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
#else
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated
|
||||
#endif
|
||||
<GridSize,
|
||||
BlockSize,
|
||||
@@ -180,8 +176,6 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded(InDesc,
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
#include "tensor.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <class TConstTensorDesc>
|
||||
|
||||
Reference in New Issue
Block a user