[CK][CONV] Support NCHW in class DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle (#2375)

1. When conv spec is 1x1 stride1 pad0, nchw is equal with matrix A + column major, we only need minor change in conv transformer to support it.
2. when out is NKHW, it is equal with matrix C with column major. we need swap A & B to get best performance.
3. Add new instance device_grouped_conv_fwd_xdl_f16_nchw_instances for nchw.
This commit is contained in:
linqunAMD
2025-06-26 08:32:39 +08:00
committed by GitHub
parent a14753b86f
commit 1749c0409e
6 changed files with 552 additions and 137 deletions

View File

@@ -77,7 +77,8 @@ template <typename GridwiseGemm,
typename ComputePtrOffsetOfN,
bool HasMainKBlockLoop,
bool isMultiA,
bool isMultiB>
bool isMultiB,
bool CTranspose>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
@@ -171,17 +172,22 @@ __global__ void
}
else
{
const long_index_t a_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
CTranspose
? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t a_group_offset =
CTranspose
? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_n_offset =
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
p_as_grid + a_group_offset + a_n_offset,
p_bs_grid + b_group_offset,
p_bs_grid + b_group_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_group_offset + e_n_offset,
p_shared,
@@ -335,12 +341,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr bool isATensorColMajor =
(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) &&
(ABlockTransferSrcVectorDim == 1) && (NumGroupsToMerge == 1) &&
(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>());
static constexpr bool NeedTransposeKernel =
(isATensorColMajor == false) && (is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>());
static constexpr bool CTranspose = (NeedTransposeKernel == false) && (isMultiAB == false) &&
(is_same_v<ELayout, tensor_layout::convolution::NGKHW> ||
is_same_v<ELayout, tensor_layout::convolution::NGKDHW>);
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType,
NumGroupsToMerge>;
NumGroupsToMerge,
index_t,
CTranspose>;
static constexpr index_t ClusterLengthNPerBlock =
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
@@ -361,9 +383,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NHWGC,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGC, ALay>>;
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NDHWGC,
ALay>>;
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
@@ -379,9 +403,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::GKYXC,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::GKZYXC, BLay>>;
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::GKZYXC,
BLay>>;
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
@@ -397,17 +423,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NHWGK,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGK, ELay>>;
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NDHWGK,
ELay>>;
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
return out_gemmm_gemmn_desc;
if constexpr(CTranspose)
{
constexpr auto matrix_padder_trans =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{NPerBlock, MPerBlock, KPerBlock};
return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
}
else
{
return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
}
}
// Shape of Ds and E must be aligned. Strides can be different.
@@ -471,11 +504,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType, DoElementwiseBeforeCShuffle
#define GridwiseGemmCTransposeTemplateParameters \
GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \
NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \
MPerXDL, NXdlPerWave, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType, DoElementwiseBeforeCShuffle
// Use appropriate gridwise gemm
using GridwiseGemm = std::conditional_t<
isMultiA || isMultiB,
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmMultiABDTemplateParameters>,
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
using GridwiseGemmCTranspose = std::conditional_t<
CTranspose,
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>,
GridwiseGemm>;
// If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers.
using APointers =
@@ -497,15 +551,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
remove_cvref_t<decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(
EGridDesc_M_N{}))>;
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, NPerBlock>;
using NGCHWTransposeDescType =
@@ -612,16 +667,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
a_g_n_c_wis_strides_{NeedTransposeKernel
? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)
: a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
b_g_k_c_xs_strides_{NeedTransposeKernel
? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)
: b_g_k_c_xs_strides},
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
e_g_n_k_wos_strides_{NeedTransposeKernel
? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)
: e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
@@ -651,7 +712,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
block_2_etile_map_{
GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op},
@@ -783,24 +845,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
bool valid = false;
if constexpr(CTranspose)
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
valid = GridwiseGemmCTranspose::CheckValidity(b_grid_desc_n_k_,
a_grid_desc_m_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_);
}
else
{
valid = GridwiseGemmCTranspose::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_);
}
if(valid)
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_);
}
}
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
// Use not modified base strides
a_in_transpose_desc_ =
@@ -835,8 +907,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
std::size_t GetWorkspaceATensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -851,8 +922,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
std::size_t GetWorkspaceBTensorSizeBytes() const
{
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -867,8 +937,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
std::size_t GetWorkspaceETensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -1007,7 +1076,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop,
isMultiA,
isMultiB>;
isMultiB,
CTranspose>;
return launch_and_time_kernel(
stream_config,
@@ -1035,68 +1105,118 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
const ADataType* p_a_grid = arg.p_as_grid_.At(I0);
const BDataType* p_b_grid = arg.p_bs_grid_.At(I0);
EDataType* p_e_grid = arg.p_e_grid_;
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
else if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
else if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
}
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
GridwiseGemm,
const ADataType*,
const BDataType*,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop,
isMultiA,
isMultiB>;
if constexpr(CTranspose)
{
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
GridwiseGemmCTranspose,
const BDataType*,
const ADataType*,
typename GridwiseGemm::DsGridPointer,
EDataType,
BElementwiseOperation,
AElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop,
isMultiA,
isMultiB,
CTranspose>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
return launch_and_time_kernel(
stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_b_grid,
p_a_grid,
arg.p_ds_grid_,
p_e_grid,
arg.b_element_op_,
arg.a_element_op_,
arg.cde_element_op_,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_ak0_m_ak1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
}
else
{
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
GridwiseGemm,
const ADataType*,
const BDataType*,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop,
isMultiA,
isMultiB,
CTranspose>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
}
}
};
@@ -1114,8 +1234,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
float avg_time = 0.f;
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
const index_t a_grid_size =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
@@ -1166,8 +1285,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
avg_time += RunGemm(arg, stream_config);
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
@@ -1215,9 +1333,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
// check device
if(get_device_name() == "gfx908")
@@ -1310,7 +1430,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC> || is_same_v<ALayout, ctc::NGCW> ||
is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
NeedTransposeKernel)
{
// Check access per C
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
@@ -1326,6 +1446,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
}
else if constexpr(is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
{
static_assert(NeedTransposeKernel == false);
static_assert(NumGroupsToMerge == 1);
if constexpr(ABlockTransferSrcScalarPerVector != 1)
{
if(ABlockTransferSrcVectorDim != 1)
{
return false;
}
if(input_spatial_acum % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
}
else
{
return false;
@@ -1350,7 +1487,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
return false;
}
// check vector access of Ds
bool valid = true;
@@ -1396,8 +1532,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
});
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
@@ -1409,8 +1544,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return false;
}
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
@@ -1457,9 +1590,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ELayout, ctc::NDHWGK> || is_same_v<ELayout, ctc::NGKW> ||
is_same_v<ELayout, ctc::NGKHW> || is_same_v<ELayout, ctc::NGKDHW>)
{
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
if(CTranspose == false)
{
return false;
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
}
else
{
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
}
else
@@ -1483,11 +1629,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
if constexpr(CTranspose)
{
return GridwiseGemmCTranspose::CheckValidity(arg.b_grid_desc_n_k_,
arg.a_grid_desc_m_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
else
{
return GridwiseGemmCTranspose::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
}
}