[CK][CONV] Support NCHW in class DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle (#2375)

1. When conv spec is 1x1 stride1 pad0, nchw is equal with matrix A + column major, we only need minor change in conv transformer to support it.
2. when out is NKHW, it is equal with matrix C with column major. we need swap A & B to get best performance.
3. Add new instance device_grouped_conv_fwd_xdl_f16_nchw_instances for nchw.


[ROCm/composable_kernel commit: 1749c0409e]
This commit is contained in:
linqunAMD
2025-06-26 08:32:39 +08:00
committed by GitHub
parent 207baa02bb
commit c7c24bb10d
6 changed files with 552 additions and 137 deletions

View File

@@ -77,7 +77,8 @@ template <typename GridwiseGemm,
typename ComputePtrOffsetOfN,
bool HasMainKBlockLoop,
bool isMultiA,
bool isMultiB>
bool isMultiB,
bool CTranspose>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
@@ -171,17 +172,22 @@ __global__ void
}
else
{
const long_index_t a_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
CTranspose
? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t a_group_offset =
CTranspose
? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_n_offset =
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
p_as_grid + a_group_offset + a_n_offset,
p_bs_grid + b_group_offset,
p_bs_grid + b_group_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_group_offset + e_n_offset,
p_shared,
@@ -335,12 +341,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr bool isATensorColMajor =
(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) &&
(ABlockTransferSrcVectorDim == 1) && (NumGroupsToMerge == 1) &&
(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>());
static constexpr bool NeedTransposeKernel =
(isATensorColMajor == false) && (is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>());
static constexpr bool CTranspose = (NeedTransposeKernel == false) && (isMultiAB == false) &&
(is_same_v<ELayout, tensor_layout::convolution::NGKHW> ||
is_same_v<ELayout, tensor_layout::convolution::NGKDHW>);
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType,
NumGroupsToMerge>;
NumGroupsToMerge,
index_t,
CTranspose>;
static constexpr index_t ClusterLengthNPerBlock =
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
@@ -361,9 +383,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NHWGC,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGC, ALay>>;
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NDHWGC,
ALay>>;
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
@@ -379,9 +403,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::GKYXC,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::GKZYXC, BLay>>;
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::GKZYXC,
BLay>>;
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
@@ -397,17 +423,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NHWGK,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGK, ELay>>;
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NDHWGK,
ELay>>;
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
return out_gemmm_gemmn_desc;
if constexpr(CTranspose)
{
constexpr auto matrix_padder_trans =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{NPerBlock, MPerBlock, KPerBlock};
return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
}
else
{
return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
}
}
// Shape of Ds and E must be aligned. Strides can be different.
@@ -471,11 +504,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType, DoElementwiseBeforeCShuffle
#define GridwiseGemmCTransposeTemplateParameters \
GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \
NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \
MPerXDL, NXdlPerWave, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType, DoElementwiseBeforeCShuffle
// Use appropriate gridwise gemm
using GridwiseGemm = std::conditional_t<
isMultiA || isMultiB,
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmMultiABDTemplateParameters>,
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
using GridwiseGemmCTranspose = std::conditional_t<
CTranspose,
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>,
GridwiseGemm>;
// If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers.
using APointers =
@@ -497,15 +551,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
remove_cvref_t<decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(
EGridDesc_M_N{}))>;
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, NPerBlock>;
using NGCHWTransposeDescType =
@@ -612,16 +667,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
a_g_n_c_wis_strides_{NeedTransposeKernel
? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)
: a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
b_g_k_c_xs_strides_{NeedTransposeKernel
? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)
: b_g_k_c_xs_strides},
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
e_g_n_k_wos_strides_{NeedTransposeKernel
? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)
: e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
@@ -651,7 +712,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
block_2_etile_map_{
GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op},
@@ -783,24 +845,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
bool valid = false;
if constexpr(CTranspose)
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
valid = GridwiseGemmCTranspose::CheckValidity(b_grid_desc_n_k_,
a_grid_desc_m_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_);
}
else
{
valid = GridwiseGemmCTranspose::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_);
}
if(valid)
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_);
}
}
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
// Use not modified base strides
a_in_transpose_desc_ =
@@ -835,8 +907,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
std::size_t GetWorkspaceATensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -851,8 +922,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
std::size_t GetWorkspaceBTensorSizeBytes() const
{
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -867,8 +937,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
std::size_t GetWorkspaceETensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -1007,7 +1076,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop,
isMultiA,
isMultiB>;
isMultiB,
CTranspose>;
return launch_and_time_kernel(
stream_config,
@@ -1035,68 +1105,118 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
const ADataType* p_a_grid = arg.p_as_grid_.At(I0);
const BDataType* p_b_grid = arg.p_bs_grid_.At(I0);
EDataType* p_e_grid = arg.p_e_grid_;
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
else if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
else if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
}
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
GridwiseGemm,
const ADataType*,
const BDataType*,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop,
isMultiA,
isMultiB>;
if constexpr(CTranspose)
{
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
GridwiseGemmCTranspose,
const BDataType*,
const ADataType*,
typename GridwiseGemm::DsGridPointer,
EDataType,
BElementwiseOperation,
AElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop,
isMultiA,
isMultiB,
CTranspose>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
return launch_and_time_kernel(
stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_b_grid,
p_a_grid,
arg.p_ds_grid_,
p_e_grid,
arg.b_element_op_,
arg.a_element_op_,
arg.cde_element_op_,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_ak0_m_ak1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
}
else
{
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
GridwiseGemm,
const ADataType*,
const BDataType*,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop,
isMultiA,
isMultiB,
CTranspose>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
}
}
};
@@ -1114,8 +1234,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
float avg_time = 0.f;
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
const index_t a_grid_size =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
@@ -1166,8 +1285,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
avg_time += RunGemm(arg, stream_config);
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
@@ -1215,9 +1333,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
// check device
if(get_device_name() == "gfx908")
@@ -1310,7 +1430,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC> || is_same_v<ALayout, ctc::NGCW> ||
is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
NeedTransposeKernel)
{
// Check access per C
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
@@ -1326,6 +1446,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
}
else if constexpr(is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
{
static_assert(NeedTransposeKernel == false);
static_assert(NumGroupsToMerge == 1);
if constexpr(ABlockTransferSrcScalarPerVector != 1)
{
if(ABlockTransferSrcVectorDim != 1)
{
return false;
}
if(input_spatial_acum % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
}
else
{
return false;
@@ -1350,7 +1487,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
return false;
}
// check vector access of Ds
bool valid = true;
@@ -1396,8 +1532,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
});
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
@@ -1409,8 +1544,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return false;
}
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
@@ -1457,9 +1590,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ELayout, ctc::NDHWGK> || is_same_v<ELayout, ctc::NGKW> ||
is_same_v<ELayout, ctc::NGKHW> || is_same_v<ELayout, ctc::NGKDHW>)
{
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
if(CTranspose == false)
{
return false;
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
}
else
{
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
}
else
@@ -1483,11 +1629,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
if constexpr(CTranspose)
{
return GridwiseGemmCTranspose::CheckValidity(arg.b_grid_desc_n_k_,
arg.a_grid_desc_m_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
else
{
return GridwiseGemmCTranspose::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
}
}

View File

@@ -19,7 +19,8 @@ template <index_t NDimSpatial,
typename ADataType = float,
typename CDataType = float,
index_t NumGroupsToMerge = 1,
typename IndexType = index_t>
typename IndexType = index_t,
bool CTranspose = false>
struct TransformConvFwdToGemm
{
private:
@@ -1253,6 +1254,83 @@ struct TransformConvFwdToGemm
}
}
template <typename ALayout,
typename ck::enable_if<NDimSpatial == 1 &&
is_same_v<ALayout, tensor_layout::convolution::NGCW>,
bool>::type = false>
__host__ __device__ auto MakeADescriptor_M_K() const
{
static_assert(NumGroupsToMerge == 1);
static_assert(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0);
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(N_, Wo_)), make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
template <typename ALayout,
typename ck::enable_if<NDimSpatial == 2 &&
is_same_v<ALayout, tensor_layout::convolution::NGCHW>,
bool>::type = false>
__host__ __device__ auto MakeADescriptor_M_K() const
{
static_assert(NumGroupsToMerge == 1);
static_assert(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0);
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Ho_ * Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
template <typename ALayout,
typename ck::enable_if<NDimSpatial == 3 &&
is_same_v<ALayout, tensor_layout::convolution::NGCDHW>,
bool>::type = false>
__host__ __device__ auto MakeADescriptor_M_K() const
{
static_assert(NumGroupsToMerge == 1);
static_assert(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0);
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Do_ * Ho_ * Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
template <typename BLayout,
typename ck::enable_if<is_same_v<BLayout, tensor_layout::convolution::GKCX> ||
is_same_v<BLayout, tensor_layout::convolution::GKCYX> ||
is_same_v<BLayout, tensor_layout::convolution::GKCZYX>,
bool>::type = false>
__host__ __device__ auto MakeBDescriptor_N_K() const
{
static_assert(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 ||
ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0);
static_assert(NumGroupsToMerge == 1);
return make_naive_tensor_descriptor_packed(make_tuple(K_, C_));
}
template <typename BLayout,
typename ck::enable_if<is_same_v<BLayout, tensor_layout::convolution::GKXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
@@ -1338,8 +1416,16 @@ struct TransformConvFwdToGemm
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
if constexpr(CTranspose)
{
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
make_tuple(KStrideTensorC_, I0));
}
else
{
return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
}
}
template <
@@ -1350,8 +1436,16 @@ struct TransformConvFwdToGemm
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
if constexpr(CTranspose)
{
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_),
make_tuple(KStrideTensorC_, I0));
}
else
{
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
}
}
template <
@@ -1362,12 +1456,21 @@ struct TransformConvFwdToGemm
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
if constexpr(CTranspose)
{
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_),
make_tuple(KStrideTensorC_, I0));
}
else
{
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
}
}
template <typename CLayout,
index_t NDimSp = NDimSpatial,
index_t NDimSp = NDimSpatial,
typename ck::enable_if<NDimSp == 1 &&
(is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
@@ -1375,6 +1478,7 @@ struct TransformConvFwdToGemm
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
static_assert(CTranspose == false);
const IndexType NDoHoWo = N_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
{
@@ -1429,6 +1533,7 @@ struct TransformConvFwdToGemm
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
static_assert(CTranspose == false);
const IndexType NDoHoWo = N_ * Ho_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
{
@@ -1486,7 +1591,7 @@ struct TransformConvFwdToGemm
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
static_assert(CTranspose == false);
const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
{
@@ -1536,6 +1641,101 @@ struct TransformConvFwdToGemm
}
}
template <typename CLayout,
index_t NDimSp = NDimSpatial,
typename ck::enable_if<NDimSp == 1 &&
(is_same_v<CLayout, tensor_layout::convolution::GNKW> ||
is_same_v<CLayout, tensor_layout::convolution::NGKW>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
static_assert(NumGroupsToMerge == 1);
auto n_k_wo_desc = make_naive_tensor_descriptor(
make_tuple(N_, K_, Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1));
if constexpr(CTranspose)
{
return transform_tensor_descriptor(
n_k_wo_desc,
make_tuple(make_pass_through_transform(K_),
make_merge_transform(make_tuple(N_, Wo_))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(n_k_wo_desc,
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
template <typename CLayout,
index_t NDimSp = NDimSpatial,
typename ck::enable_if<NDimSp == 2 &&
(is_same_v<CLayout, tensor_layout::convolution::GNKHW> ||
is_same_v<CLayout, tensor_layout::convolution::NGKHW>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
static_assert(NumGroupsToMerge == 1);
auto n_k_howo_desc = make_naive_tensor_descriptor(
make_tuple(N_, K_, Ho_ * Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1));
if constexpr(CTranspose)
{
return transform_tensor_descriptor(
n_k_howo_desc,
make_tuple(make_pass_through_transform(K_),
make_merge_transform(make_tuple(N_, Ho_ * Wo_))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
n_k_howo_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
template <typename CLayout,
index_t NDimSp = NDimSpatial,
typename ck::enable_if<NDimSp == 3 &&
(is_same_v<CLayout, tensor_layout::convolution::GNKDHW> ||
is_same_v<CLayout, tensor_layout::convolution::NGKDHW>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
static_assert(NumGroupsToMerge == 1);
auto n_k_dohowo_desc = make_naive_tensor_descriptor(
make_tuple(N_, K_, Do_ * Ho_ * Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1));
if constexpr(CTranspose)
{
return transform_tensor_descriptor(
n_k_dohowo_desc,
make_tuple(make_pass_through_transform(K_),
make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
n_k_dohowo_desc,
make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
IndexType N_;
IndexType Di_, Hi_, Wi_;
IndexType Do_, Ho_, Wo_;

View File

@@ -179,6 +179,38 @@ using device_grouped_conv_fwd_xdl_f16_instances = std::tuple<
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_xdl_f16_nchw_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>,
// 32x32 instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
// 16x16 instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,

View File

@@ -53,6 +53,15 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances_shard(
ConvFwd1x1S1P0>,
Shards,
ShardIndex>{});
add_device_operation_instances(instances,
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f16_nchw_instances<2,
NGCHW,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwd1x1S1P0>,
Shards,
ShardIndex>{});
}
} // namespace ck::tensor_operation::device::instance

View File

@@ -31,6 +31,14 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_f16_nchw_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwd1x1S1P0>{});
}
} // namespace instance

View File

@@ -47,6 +47,15 @@ void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instances(
Empty_Tuple,
NGKDHW,
ConvFwd1x1S1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_f16_nchw_instances<3,
NGCDHW,
GKCZYX,
Empty_Tuple,
NGKDHW,
ConvFwd1x1S1P0>{});
}
} // namespace instance