[CK][CONV] Support NCHW in class DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 (#2459)

1. Port NCHW support from ConvFwd (#2375) to conv bwd data
2. Add new instance device_grouped_conv_bwd_data_xdl_f16_nchw_instances for nchw

Co-authored-by: azhuang <anzhong.huang@amd.com>
This commit is contained in:
linqunAMD
2025-07-17 08:19:57 +08:00
committed by GitHub
parent 6e76b82059
commit fbd9f32abe
6 changed files with 509 additions and 156 deletions

View File

@@ -74,7 +74,10 @@ template <typename GridwiseGemm,
typename CDEElementwiseOp,
typename ComputePtrOffsetOfBatch,
typename ComputePtrOffsetOfN,
InMemoryDataOperationEnum OutElementOp>
InMemoryDataOperationEnum OutElementOp,
bool HasMainKBlockLoopInAllGemm,
bool NoMainKBlockLoopInAllGemm,
bool CTranspose>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
@@ -101,16 +104,21 @@ __global__ void
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t b_n_offset =
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
@@ -141,11 +149,11 @@ __global__ void
group_id = index_t((left + right) / 2);
}
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm)
{
GridwiseGemm::template Run<true, OutElementOp>(
GridwiseGemm::template Run<HasMainKBlockLoopInAllGemm, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset,
p_b_grid + b_batch_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
@@ -162,22 +170,44 @@ __global__ void
}
else
{
GridwiseGemm::template Run<false, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_,
KBatch,
k_idx);
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
{
GridwiseGemm::template Run<true, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_,
KBatch,
k_idx);
}
else
{
GridwiseGemm::template Run<false, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_,
KBatch,
k_idx);
}
}
#else
ignore = p_a_grid;
@@ -278,7 +308,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// implementation we can avoid copy data to workspace before kernel launch since number of
// groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then
// we run this kernel in the loop.
static constexpr index_t MaxGroupedGemmGroupsNum = 32;
static constexpr index_t MaxGroupedGemmGroupsNum =
ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0
? 1
: 32;
using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
@@ -296,24 +330,40 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ALayoutAfterTranspose =
std::conditional_t<is_NGCHW_NGKHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NHWGK,
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NDHWGK,
ALayout>>;
using BLayoutAfterTranspose =
std::conditional_t<is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::GKYXC,
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::GKZYXC,
BLayout>>;
using ELayoutAfterTranspose =
std::conditional_t<is_NGCHW_NGKHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NHWGC,
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NDHWGC,
ELayout>>;
static constexpr bool isATensorColMajor =
(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) &&
(ABlockTransferSrcVectorDim == 1) &&
(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>());
static constexpr bool NeedTransposeKernel =
(isATensorColMajor == false) && (is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>());
static constexpr bool CTranspose =
(NeedTransposeKernel == false) && (is_same_v<ELayout, tensor_layout::convolution::NGCHW> ||
is_same_v<ELayout, tensor_layout::convolution::NGCDHW>);
using ALayoutAfterTranspose = std::conditional_t<
is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
tensor_layout::convolution::NHWGK,
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
tensor_layout::convolution::NDHWGK,
ALayout>>;
using BLayoutAfterTranspose = std::conditional_t<
is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
tensor_layout::convolution::GKYXC,
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>() &&
NeedTransposeKernel,
tensor_layout::convolution::GKZYXC,
BLayout>>;
using ELayoutAfterTranspose = std::conditional_t<
is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
tensor_layout::convolution::NHWGC,
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
tensor_layout::convolution::NDHWGC,
ELayout>>;
using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
ConvBackwardDataSpecialization,
@@ -329,7 +379,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
ELayoutAfterTranspose,
true, /*SplitConvN*/
ABDataType,
EDataType>;
EDataType,
1,
index_t,
CTranspose>;
static auto
GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
@@ -357,15 +410,25 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DLayout,
true, /*SplitConvN*/
ABDataType,
DDataType>;
DDataType,
1, /*index_t NumGroupsToMerge = 1,*/
index_t, /* typename IndexType = */
CTranspose>;
return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N();
},
Number<NumDTensor>{});
const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
return make_tuple(
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
if constexpr(CTranspose)
{
return make_tuple(
b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1, ds_grid_desc_m_n, e_grid_desc_m_n);
}
else
{
return make_tuple(
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
}
}
// GridwiseGemm
@@ -383,13 +446,34 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
#define GridwiseGemmCTransposeTemplateParameters \
ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \
NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave, MXdlPerWave, \
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
BBlockLdsExtraN, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmMultiDTemplateParams>;
using GridwiseGemmCTranspose = std::conditional_t<
CTranspose,
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>,
GridwiseGemm>;
template <typename EGridDesc_M_N>
static auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n)
{
return GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
return GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n);
}
template <typename Desc_K0_M_K1>
@@ -419,13 +503,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{}));
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}));
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
// block-to-e-tile map
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
using Block2ETileMap =
decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}));
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
@@ -630,14 +715,17 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
sizeof(EDataType);
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
a_g_n_k_wos_lengths, a_g_n_k_wos_strides)
: a_g_n_k_wos_strides;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)
: b_g_k_c_xs_strides;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(e_g_n_c_wis_lengths,
e_g_n_c_wis_strides);
NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
e_g_n_c_wis_lengths, e_g_n_c_wis_strides)
: e_g_n_c_wis_strides;
// populate Ds pointer
static_for<0, NumDTensor, 1>{}([&](auto i) {
@@ -737,12 +825,27 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
conv_N_per_block_ = conv_to_gemm_transform_.N_;
const auto a_grid_desc_ak0_m_ak1 =
conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
const auto b_grid_desc_bk0_n_bk1 =
conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
const auto a_grid_desc_ak0_m_ak1 = [&]() {
if constexpr(CTranspose)
{
return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
}
else
{
return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
}
}();
const auto b_grid_desc_bk0_n_bk1 = [&]() {
if constexpr(CTranspose)
{
return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
}
else
{
return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
}
}();
DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc
@@ -764,7 +867,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DLayout,
true, /*SplitConvN*/
ABDataType,
DDataType>;
DDataType,
1,
index_t,
CTranspose>;
ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides_transposed,
@@ -810,14 +916,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const auto GemmK = a_grid_desc_m_k.GetLength(I1);
const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, k_batch_);
GridwiseGemmCTranspose::CalculateHasMainKBlockLoop(GemmK, k_batch_);
gemm_kernel_args_[gemms_count_ /
MaxGroupedGemmGroupsNum][gemms_count_ %
MaxGroupedGemmGroupsNum] =
GemmArgs{a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
GridwiseGemm::
GridwiseGemmCTranspose::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n),
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
@@ -851,8 +957,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_;
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
// Use not modified base strides
a_in_transpose_desc_ =
@@ -892,8 +997,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std::size_t GetWorkspaceATensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -908,8 +1012,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std::size_t GetWorkspaceBTensorSizeBytes() const
{
if constexpr(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -924,8 +1027,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std::size_t GetWorkspaceETensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -1030,24 +1132,25 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const ADataType* p_a_grid = arg.p_a_grid_;
const BDataType* p_b_grid = arg.p_b_grid_;
EDataType* p_e_grid = arg.p_e_grid_;
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid =
type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid =
type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
if constexpr(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
{
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
if constexpr(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
{
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
}
}
for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size();
gemm_set_id++)
{
@@ -1067,42 +1170,111 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
};
auto launch_kernel = [&]() {
const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
MaxGroupedGemmGroupsNum,
GemmArgs,
AElementwiseOp,
BElementwiseOp,
CDEElementwiseOp,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
ElementOp>;
bool has_loop_in_all_gemm = true;
bool no_loop_in_all_gemm = true;
for(auto i = 0; i < gemms_count_for_set; i++)
{
has_loop_in_all_gemm &= gemm_kernel_args[i].HasMainKBlockLoop_;
no_loop_in_all_gemm &= !gemm_kernel_args[i].HasMainKBlockLoop_;
}
return launch_and_time_kernel_with_preprocess(stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
gemm_kernel_args,
gemms_count_for_set,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_,
arg.k_batch_);
auto launch_kernel = [&](auto has_main_k_block_loop, auto no_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool no_main_loop = no_main_k_block_loop.value;
if constexpr(CTranspose)
{
const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
GridwiseGemmCTranspose,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
MaxGroupedGemmGroupsNum,
GemmArgs,
BElementwiseOp,
AElementwiseOp,
CDEElementwiseOp,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
ElementOp,
has_main_loop,
no_main_loop,
CTranspose>;
return launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_b_grid,
p_a_grid,
arg.p_ds_grid_,
p_e_grid,
gemm_kernel_args,
gemms_count_for_set,
arg.b_element_op_,
arg.a_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_,
arg.k_batch_);
}
else
{
const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
MaxGroupedGemmGroupsNum,
GemmArgs,
AElementwiseOp,
BElementwiseOp,
CDEElementwiseOp,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
ElementOp,
has_main_loop,
no_main_loop,
CTranspose>;
return launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
gemm_kernel_args,
gemms_count_for_set,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_,
arg.k_batch_);
}
};
ave_time += launch_kernel();
if(has_loop_in_all_gemm)
{
ave_time += launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(no_loop_in_all_gemm)
{
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
return ave_time;
@@ -1116,9 +1288,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
{
arg.Print();
}
// Transpose from NGKHW to NHWGK
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
EDataType* p_e_in_grid =
type_convert<EDataType*>(arg.p_workspace_) +
@@ -1208,8 +1380,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
// Transpose from NHWGC to NGCHW
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
@@ -1284,10 +1455,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
}
const index_t ConvG = arg.b_g_k_c_xs_lengths_[0];
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
const index_t ConvG = arg.b_g_k_c_xs_lengths_[0];
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
// Specifialization
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
@@ -1307,15 +1481,30 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NGKHW> ||
is_same_v<ALayout, tensor_layout::convolution::NGKDHW>)
is_same_v<ALayout, tensor_layout::convolution::NDHWGK> || NeedTransposeKernel)
{
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else if(is_same_v<ALayout, tensor_layout::convolution::NGKHW> ||
is_same_v<ALayout, tensor_layout::convolution::NGKDHW>)
{
static_assert(NeedTransposeKernel == false);
if constexpr(ABlockTransferSrcScalarPerVector != 1)
{
if(ABlockTransferSrcVectorDim != 1)
{
return false;
}
if(output_spatial_acum % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
}
else
{
return false;
@@ -1351,10 +1540,20 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
is_same_v<DLayout, tensor_layout::convolution::GC> ||
is_same_v<DLayout, tensor_layout::convolution::G_C>)
{
// vector load D matrix from global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
if(CTranspose == false)
{
ds_valid = false;
// vector load D matrix from global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
ds_valid = false;
}
}
else
{
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
ds_valid = false;
}
}
}
else
@@ -1376,10 +1575,20 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
is_same_v<ELayout, tensor_layout::convolution::NGCHW> ||
is_same_v<ELayout, tensor_layout::convolution::NGCDHW>)
{
// vector store C matrix into global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
if(CTranspose == false)
{
return false;
// vector store C matrix into global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
}
else
{
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
}
else
@@ -1390,7 +1599,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// Gridwise GEMM size
for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(
if(!GridwiseGemmCTranspose::CheckValidity(
arg.a_grid_desc_m_k_container_[i],
arg.b_grid_desc_n_k_container_[i],
arg.ds_grid_desc_m_n_container_[i],
@@ -1403,8 +1612,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
}
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
if((ConvG * ConvC) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{