[CK][CONV] Support NCHW in class DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 (#2459)

1. Port NCHW support from ConvFwd (#2375) to conv bwd data
2. Add new instance device_grouped_conv_bwd_data_xdl_f16_nchw_instances for nchw

Co-authored-by: azhuang <anzhong.huang@amd.com>

[ROCm/composable_kernel commit: fbd9f32abe]
This commit is contained in:
linqunAMD
2025-07-17 08:19:57 +08:00
committed by GitHub
parent 348dec0d0c
commit 7b5207d652
6 changed files with 509 additions and 156 deletions

View File

@@ -92,7 +92,7 @@ inline bool parse_cmd_args(int argc,
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_params = ck::utils::conv::parse_conv_param(
num_dim_spatial, threshold_to_catch_partial_args, argv);
num_dim_spatial, threshold_to_catch_partial_args + 1, argv);
}
else
{

View File

@@ -74,7 +74,10 @@ template <typename GridwiseGemm,
typename CDEElementwiseOp,
typename ComputePtrOffsetOfBatch,
typename ComputePtrOffsetOfN,
InMemoryDataOperationEnum OutElementOp>
InMemoryDataOperationEnum OutElementOp,
bool HasMainKBlockLoopInAllGemm,
bool NoMainKBlockLoopInAllGemm,
bool CTranspose>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
@@ -101,16 +104,21 @@ __global__ void
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t b_n_offset =
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
@@ -141,11 +149,11 @@ __global__ void
group_id = index_t((left + right) / 2);
}
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm)
{
GridwiseGemm::template Run<true, OutElementOp>(
GridwiseGemm::template Run<HasMainKBlockLoopInAllGemm, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset,
p_b_grid + b_batch_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
@@ -162,22 +170,44 @@ __global__ void
}
else
{
GridwiseGemm::template Run<false, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_,
KBatch,
k_idx);
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
{
GridwiseGemm::template Run<true, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_,
KBatch,
k_idx);
}
else
{
GridwiseGemm::template Run<false, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_,
KBatch,
k_idx);
}
}
#else
ignore = p_a_grid;
@@ -278,7 +308,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// implementation we can avoid copy data to workspace before kernel launch since number of
// groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then
// we run this kernel in the loop.
static constexpr index_t MaxGroupedGemmGroupsNum = 32;
static constexpr index_t MaxGroupedGemmGroupsNum =
ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0
? 1
: 32;
using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
@@ -296,24 +330,40 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ALayoutAfterTranspose =
std::conditional_t<is_NGCHW_NGKHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NHWGK,
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NDHWGK,
ALayout>>;
using BLayoutAfterTranspose =
std::conditional_t<is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::GKYXC,
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::GKZYXC,
BLayout>>;
using ELayoutAfterTranspose =
std::conditional_t<is_NGCHW_NGKHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NHWGC,
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NDHWGC,
ELayout>>;
static constexpr bool isATensorColMajor =
(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) &&
(ABlockTransferSrcVectorDim == 1) &&
(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>());
static constexpr bool NeedTransposeKernel =
(isATensorColMajor == false) && (is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>());
static constexpr bool CTranspose =
(NeedTransposeKernel == false) && (is_same_v<ELayout, tensor_layout::convolution::NGCHW> ||
is_same_v<ELayout, tensor_layout::convolution::NGCDHW>);
using ALayoutAfterTranspose = std::conditional_t<
is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
tensor_layout::convolution::NHWGK,
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
tensor_layout::convolution::NDHWGK,
ALayout>>;
using BLayoutAfterTranspose = std::conditional_t<
is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
tensor_layout::convolution::GKYXC,
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>() &&
NeedTransposeKernel,
tensor_layout::convolution::GKZYXC,
BLayout>>;
using ELayoutAfterTranspose = std::conditional_t<
is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
tensor_layout::convolution::NHWGC,
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
tensor_layout::convolution::NDHWGC,
ELayout>>;
using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
ConvBackwardDataSpecialization,
@@ -329,7 +379,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
ELayoutAfterTranspose,
true, /*SplitConvN*/
ABDataType,
EDataType>;
EDataType,
1,
index_t,
CTranspose>;
static auto
GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
@@ -357,15 +410,25 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DLayout,
true, /*SplitConvN*/
ABDataType,
DDataType>;
DDataType,
1, /*index_t NumGroupsToMerge = 1,*/
index_t, /* typename IndexType = */
CTranspose>;
return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N();
},
Number<NumDTensor>{});
const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
return make_tuple(
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
if constexpr(CTranspose)
{
return make_tuple(
b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1, ds_grid_desc_m_n, e_grid_desc_m_n);
}
else
{
return make_tuple(
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
}
}
// GridwiseGemm
@@ -383,13 +446,34 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
#define GridwiseGemmCTransposeTemplateParameters \
ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \
NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave, MXdlPerWave, \
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
BBlockLdsExtraN, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmMultiDTemplateParams>;
using GridwiseGemmCTranspose = std::conditional_t<
CTranspose,
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>,
GridwiseGemm>;
template <typename EGridDesc_M_N>
static auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n)
{
return GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
return GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n);
}
template <typename Desc_K0_M_K1>
@@ -419,13 +503,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{}));
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}));
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
// block-to-e-tile map
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
using Block2ETileMap =
decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}));
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
@@ -630,14 +715,17 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
sizeof(EDataType);
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
a_g_n_k_wos_lengths, a_g_n_k_wos_strides)
: a_g_n_k_wos_strides;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)
: b_g_k_c_xs_strides;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(e_g_n_c_wis_lengths,
e_g_n_c_wis_strides);
NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
e_g_n_c_wis_lengths, e_g_n_c_wis_strides)
: e_g_n_c_wis_strides;
// populate Ds pointer
static_for<0, NumDTensor, 1>{}([&](auto i) {
@@ -737,12 +825,27 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
conv_N_per_block_ = conv_to_gemm_transform_.N_;
const auto a_grid_desc_ak0_m_ak1 =
conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
const auto b_grid_desc_bk0_n_bk1 =
conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
const auto a_grid_desc_ak0_m_ak1 = [&]() {
if constexpr(CTranspose)
{
return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
}
else
{
return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
}
}();
const auto b_grid_desc_bk0_n_bk1 = [&]() {
if constexpr(CTranspose)
{
return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
}
else
{
return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
}
}();
DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc
@@ -764,7 +867,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DLayout,
true, /*SplitConvN*/
ABDataType,
DDataType>;
DDataType,
1,
index_t,
CTranspose>;
ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides_transposed,
@@ -810,14 +916,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const auto GemmK = a_grid_desc_m_k.GetLength(I1);
const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, k_batch_);
GridwiseGemmCTranspose::CalculateHasMainKBlockLoop(GemmK, k_batch_);
gemm_kernel_args_[gemms_count_ /
MaxGroupedGemmGroupsNum][gemms_count_ %
MaxGroupedGemmGroupsNum] =
GemmArgs{a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
GridwiseGemm::
GridwiseGemmCTranspose::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n),
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
@@ -851,8 +957,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_;
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
// Use not modified base strides
a_in_transpose_desc_ =
@@ -892,8 +997,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std::size_t GetWorkspaceATensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -908,8 +1012,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std::size_t GetWorkspaceBTensorSizeBytes() const
{
if constexpr(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -924,8 +1027,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std::size_t GetWorkspaceETensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -1030,24 +1132,25 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const ADataType* p_a_grid = arg.p_a_grid_;
const BDataType* p_b_grid = arg.p_b_grid_;
EDataType* p_e_grid = arg.p_e_grid_;
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid =
type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid =
type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
if constexpr(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
{
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
if constexpr(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
{
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
}
}
for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size();
gemm_set_id++)
{
@@ -1067,42 +1170,111 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
};
auto launch_kernel = [&]() {
const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
MaxGroupedGemmGroupsNum,
GemmArgs,
AElementwiseOp,
BElementwiseOp,
CDEElementwiseOp,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
ElementOp>;
bool has_loop_in_all_gemm = true;
bool no_loop_in_all_gemm = true;
for(auto i = 0; i < gemms_count_for_set; i++)
{
has_loop_in_all_gemm &= gemm_kernel_args[i].HasMainKBlockLoop_;
no_loop_in_all_gemm &= !gemm_kernel_args[i].HasMainKBlockLoop_;
}
return launch_and_time_kernel_with_preprocess(stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
gemm_kernel_args,
gemms_count_for_set,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_,
arg.k_batch_);
auto launch_kernel = [&](auto has_main_k_block_loop, auto no_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool no_main_loop = no_main_k_block_loop.value;
if constexpr(CTranspose)
{
const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
GridwiseGemmCTranspose,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
MaxGroupedGemmGroupsNum,
GemmArgs,
BElementwiseOp,
AElementwiseOp,
CDEElementwiseOp,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
ElementOp,
has_main_loop,
no_main_loop,
CTranspose>;
return launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_b_grid,
p_a_grid,
arg.p_ds_grid_,
p_e_grid,
gemm_kernel_args,
gemms_count_for_set,
arg.b_element_op_,
arg.a_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_,
arg.k_batch_);
}
else
{
const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
MaxGroupedGemmGroupsNum,
GemmArgs,
AElementwiseOp,
BElementwiseOp,
CDEElementwiseOp,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
ElementOp,
has_main_loop,
no_main_loop,
CTranspose>;
return launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
gemm_kernel_args,
gemms_count_for_set,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_,
arg.k_batch_);
}
};
ave_time += launch_kernel();
if(has_loop_in_all_gemm)
{
ave_time += launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(no_loop_in_all_gemm)
{
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
return ave_time;
@@ -1116,9 +1288,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
{
arg.Print();
}
// Transpose from NGKHW to NHWGK
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
EDataType* p_e_in_grid =
type_convert<EDataType*>(arg.p_workspace_) +
@@ -1208,8 +1380,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
// Transpose from NHWGC to NGCHW
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
@@ -1284,10 +1455,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
}
const index_t ConvG = arg.b_g_k_c_xs_lengths_[0];
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
const index_t ConvG = arg.b_g_k_c_xs_lengths_[0];
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
// Specifialization
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
@@ -1307,15 +1481,30 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NGKHW> ||
is_same_v<ALayout, tensor_layout::convolution::NGKDHW>)
is_same_v<ALayout, tensor_layout::convolution::NDHWGK> || NeedTransposeKernel)
{
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else if(is_same_v<ALayout, tensor_layout::convolution::NGKHW> ||
is_same_v<ALayout, tensor_layout::convolution::NGKDHW>)
{
static_assert(NeedTransposeKernel == false);
if constexpr(ABlockTransferSrcScalarPerVector != 1)
{
if(ABlockTransferSrcVectorDim != 1)
{
return false;
}
if(output_spatial_acum % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
}
else
{
return false;
@@ -1351,10 +1540,20 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
is_same_v<DLayout, tensor_layout::convolution::GC> ||
is_same_v<DLayout, tensor_layout::convolution::G_C>)
{
// vector load D matrix from global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
if(CTranspose == false)
{
ds_valid = false;
// vector load D matrix from global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
ds_valid = false;
}
}
else
{
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
ds_valid = false;
}
}
}
else
@@ -1376,10 +1575,20 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
is_same_v<ELayout, tensor_layout::convolution::NGCHW> ||
is_same_v<ELayout, tensor_layout::convolution::NGCDHW>)
{
// vector store C matrix into global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
if(CTranspose == false)
{
return false;
// vector store C matrix into global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
}
else
{
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
}
else
@@ -1390,7 +1599,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// Gridwise GEMM size
for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(
if(!GridwiseGemmCTranspose::CheckValidity(
arg.a_grid_desc_m_k_container_[i],
arg.b_grid_desc_n_k_container_[i],
arg.ds_grid_desc_m_n_container_[i],
@@ -1403,8 +1612,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
}
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(NeedTransposeKernel)
{
if((ConvG * ConvC) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{

View File

@@ -30,7 +30,8 @@ template <
typename ADataType = float,
typename CDataType = float,
index_t NumGroupsToMerge = 1,
typename IndexType = index_t>
typename IndexType = index_t,
bool CTranspose = false>
struct TransformConvBwdDataToGemm_v1
{
private:
@@ -555,6 +556,41 @@ struct TransformConvBwdDataToGemm_v1
return make_naive_tensor_descriptor_packed(make_tuple(N_, Do_, Ho_, Wo_, K_));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::NGKHW>)
{
// assume packed
static_assert(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0);
const auto out_gemm_raw_grid_desc = make_naive_tensor_descriptor(
make_tuple(N_, Ho_ * Wo_, K_), make_tuple(NStrideTensorA_, I1, KStrideTensorA_));
return transform_tensor_descriptor(
out_gemm_raw_grid_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::NGKDHW>)
{
// assume packed
static_assert(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0);
const auto out_gemm_raw_grid_desc =
make_naive_tensor_descriptor(make_tuple(N_, Do_ * Ho_ * Wo_, K_),
make_tuple(NStrideTensorA_, I1, KStrideTensorA_));
return transform_tensor_descriptor(
out_gemm_raw_grid_desc,
make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
@@ -608,7 +644,9 @@ struct TransformConvBwdDataToGemm_v1
(is_same_v<ALayout_, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout_, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout_, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout_, tensor_layout::convolution::NDHWGK>),
is_same_v<ALayout_, tensor_layout::convolution::NDHWGK> ||
is_same_v<ALayout_, tensor_layout::convolution::NGKHW> ||
is_same_v<ALayout_, tensor_layout::convolution::NGKDHW>),
bool>::type = false>
__host__ __device__ auto MakeADescriptor_AK0_M_AK1() const
{
@@ -848,16 +886,16 @@ struct TransformConvBwdDataToGemm_v1
}
}
template <typename BLayout_ = BLayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<BLayout_, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout_, tensor_layout::convolution::GKZYXC>),
bool>::type = false>
template <
typename BLayout_ = BLayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<BLayout_, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout_, tensor_layout::convolution::GKZYXC> ||
is_same_v<BLayout_, tensor_layout::convolution::GKCYX> ||
is_same_v<BLayout_, tensor_layout::convolution::GKCZYX>),
bool>::type = false>
__host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const
{
// assume packed
// k_y_x_c for 2d or k_z_y_x_c for 3d
const auto wei_grid_desc = MakeWeiGridDesc();
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
@@ -886,6 +924,12 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
// assume packed
// k_y_x_c for 2d or k_z_y_x_c for 3d
static_assert(is_same_v<BLayout_, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout_, tensor_layout::convolution::GKZYXC>);
const auto wei_grid_desc = MakeWeiGridDesc();
// GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_);
const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
@@ -1059,6 +1103,7 @@ struct TransformConvBwdDataToGemm_v1
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
static_assert(CTranspose == false);
// assume strided
// n_hi_wi_c for 2d n_di_hi_wi_c for 3d
const auto in_grid_desc = MakeInGridDesc();
@@ -1314,6 +1359,48 @@ struct TransformConvBwdDataToGemm_v1
}
}
template <typename CLayout_ = CLayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<CLayout_, tensor_layout::convolution::NGCHW> ||
is_same_v<CLayout_, tensor_layout::convolution::NGCDHW>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
const auto in_grid_desc = make_naive_tensor_descriptor(
make_tuple(N_, C_, Di_ * Hi_ * Wi_), make_tuple(NStrideTensorC_, CStrideTensorC_, I1));
static_assert(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0);
if constexpr(CTranspose)
{
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(C_),
make_merge_transform(make_tuple(N_, Di_ * Hi_ * Wi_))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmNPerBlock, GemmMPerBlock),
Sequence<DoPadGemmN, DoPadGemmM>{});
}
else
{
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_merge_transform(make_tuple(N_, Di_ * Hi_ * Wi_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{});
}
}
// for input bias
template <typename CLayout_ = CLayout,
typename std::enable_if<NDimSpatial == 2 &&
@@ -1326,14 +1413,26 @@ struct TransformConvBwdDataToGemm_v1
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const auto in_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
if constexpr(CTranspose)
{
const auto in_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(C_, N_ * Ho_ * Wo_), make_tuple(I1, I0));
return in_gemmm_gemmn_grid_desc;
return in_gemmm_gemmn_grid_desc;
}
else
{
const auto in_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(N_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
return in_gemmm_gemmn_grid_desc;
}
}
else
{
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
static_assert(CTranspose == false);
// only work on HTilde and WTilde that contribute to non-padding area of input
// tensor
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
const auto IWTildeSliceBegin = math::integer_divide_floor(

View File

@@ -112,6 +112,36 @@ using device_grouped_conv_bwd_data_xdl_f16_instances =
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_f16_nchw_instances =
std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 8, 1, 32>, 2>
// clang-format on
>;
// bf16_bf16_f32_bf16
template <index_t NDimSpatial,
typename ALayout,

View File

@@ -32,6 +32,14 @@ void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances(
Empty_Tuple,
NGCHW,
ConvBwdDataDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_nchw_instances<2,
NGKHW,
GKCYX,
Empty_Tuple,
NGCHW,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance

View File

@@ -32,6 +32,14 @@ void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances(
Empty_Tuple,
NGCDHW,
ConvBwdDataDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_nchw_instances<3,
NGKDHW,
GKCZYX,
Empty_Tuple,
NGCDHW,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance