mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK][CONV] Support NCHW in class DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 (#2459)
1. Port NCHW support from ConvFwd (#2375) to conv bwd data 2. Add new instance device_grouped_conv_bwd_data_xdl_f16_nchw_instances for nchw Co-authored-by: azhuang <anzhong.huang@amd.com>
This commit is contained in:
@@ -74,7 +74,10 @@ template <typename GridwiseGemm,
|
||||
typename CDEElementwiseOp,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename ComputePtrOffsetOfN,
|
||||
InMemoryDataOperationEnum OutElementOp>
|
||||
InMemoryDataOperationEnum OutElementOp,
|
||||
bool HasMainKBlockLoopInAllGemm,
|
||||
bool NoMainKBlockLoopInAllGemm,
|
||||
bool CTranspose>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -101,16 +104,21 @@ __global__ void
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))
|
||||
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))
|
||||
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
const long_index_t a_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
const long_index_t b_n_offset =
|
||||
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
|
||||
|
||||
const long_index_t e_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
|
||||
|
||||
@@ -141,11 +149,11 @@ __global__ void
|
||||
group_id = index_t((left + right) / 2);
|
||||
}
|
||||
|
||||
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
|
||||
if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm)
|
||||
{
|
||||
GridwiseGemm::template Run<true, OutElementOp>(
|
||||
GridwiseGemm::template Run<HasMainKBlockLoopInAllGemm, OutElementOp>(
|
||||
p_a_grid + a_batch_offset + a_n_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_b_grid + b_batch_offset + b_n_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
@@ -162,22 +170,44 @@ __global__ void
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<false, OutElementOp>(
|
||||
p_a_grid + a_batch_offset + a_n_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].block_2_ctile_map_,
|
||||
KBatch,
|
||||
k_idx);
|
||||
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
|
||||
{
|
||||
GridwiseGemm::template Run<true, OutElementOp>(
|
||||
p_a_grid + a_batch_offset + a_n_offset,
|
||||
p_b_grid + b_batch_offset + b_n_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].block_2_ctile_map_,
|
||||
KBatch,
|
||||
k_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<false, OutElementOp>(
|
||||
p_a_grid + a_batch_offset + a_n_offset,
|
||||
p_b_grid + b_batch_offset + b_n_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].block_2_ctile_map_,
|
||||
KBatch,
|
||||
k_idx);
|
||||
}
|
||||
}
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
@@ -278,7 +308,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
// implementation we can avoid copy data to workspace before kernel launch since number of
|
||||
// groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then
|
||||
// we run this kernel in the loop.
|
||||
static constexpr index_t MaxGroupedGemmGroupsNum = 32;
|
||||
static constexpr index_t MaxGroupedGemmGroupsNum =
|
||||
ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0
|
||||
? 1
|
||||
: 32;
|
||||
|
||||
using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
|
||||
|
||||
@@ -296,24 +330,40 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using ALayoutAfterTranspose =
|
||||
std::conditional_t<is_NGCHW_NGKHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NHWGK,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NDHWGK,
|
||||
ALayout>>;
|
||||
using BLayoutAfterTranspose =
|
||||
std::conditional_t<is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::GKYXC,
|
||||
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::GKZYXC,
|
||||
BLayout>>;
|
||||
using ELayoutAfterTranspose =
|
||||
std::conditional_t<is_NGCHW_NGKHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NHWGC,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NDHWGC,
|
||||
ELayout>>;
|
||||
static constexpr bool isATensorColMajor =
|
||||
(ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) &&
|
||||
(ABlockTransferSrcVectorDim == 1) &&
|
||||
(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>());
|
||||
|
||||
static constexpr bool NeedTransposeKernel =
|
||||
(isATensorColMajor == false) && (is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>());
|
||||
|
||||
static constexpr bool CTranspose =
|
||||
(NeedTransposeKernel == false) && (is_same_v<ELayout, tensor_layout::convolution::NGCHW> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NGCDHW>);
|
||||
|
||||
using ALayoutAfterTranspose = std::conditional_t<
|
||||
is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
|
||||
tensor_layout::convolution::NHWGK,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
|
||||
tensor_layout::convolution::NDHWGK,
|
||||
ALayout>>;
|
||||
using BLayoutAfterTranspose = std::conditional_t<
|
||||
is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
|
||||
tensor_layout::convolution::GKYXC,
|
||||
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>() &&
|
||||
NeedTransposeKernel,
|
||||
tensor_layout::convolution::GKZYXC,
|
||||
BLayout>>;
|
||||
using ELayoutAfterTranspose = std::conditional_t<
|
||||
is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
|
||||
tensor_layout::convolution::NHWGC,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
|
||||
tensor_layout::convolution::NDHWGC,
|
||||
ELayout>>;
|
||||
|
||||
using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
@@ -329,7 +379,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
ELayoutAfterTranspose,
|
||||
true, /*SplitConvN*/
|
||||
ABDataType,
|
||||
EDataType>;
|
||||
EDataType,
|
||||
1,
|
||||
index_t,
|
||||
CTranspose>;
|
||||
|
||||
static auto
|
||||
GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
|
||||
@@ -357,15 +410,25 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
DLayout,
|
||||
true, /*SplitConvN*/
|
||||
ABDataType,
|
||||
DDataType>;
|
||||
DDataType,
|
||||
1, /*index_t NumGroupsToMerge = 1,*/
|
||||
index_t, /* typename IndexType = */
|
||||
CTranspose>;
|
||||
return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N();
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
|
||||
|
||||
return make_tuple(
|
||||
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return make_tuple(
|
||||
b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1, ds_grid_desc_m_n, e_grid_desc_m_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(
|
||||
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
|
||||
}
|
||||
}
|
||||
|
||||
// GridwiseGemm
|
||||
@@ -383,13 +446,34 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
|
||||
|
||||
#define GridwiseGemmCTransposeTemplateParameters \
|
||||
ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
|
||||
BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \
|
||||
NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave, MXdlPerWave, \
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
|
||||
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
|
||||
BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
|
||||
BBlockLdsExtraN, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
|
||||
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
|
||||
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
|
||||
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
|
||||
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmMultiDTemplateParams>;
|
||||
using GridwiseGemmCTranspose = std::conditional_t<
|
||||
CTranspose,
|
||||
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>,
|
||||
GridwiseGemm>;
|
||||
|
||||
template <typename EGridDesc_M_N>
|
||||
static auto
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n)
|
||||
{
|
||||
return GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
|
||||
return GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n);
|
||||
}
|
||||
|
||||
template <typename Desc_K0_M_K1>
|
||||
@@ -419,13 +503,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{}));
|
||||
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}));
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
using Block2ETileMap =
|
||||
decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}));
|
||||
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
|
||||
|
||||
@@ -630,14 +715,17 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
sizeof(EDataType);
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
a_g_n_k_wos_lengths, a_g_n_k_wos_strides)
|
||||
: a_g_n_k_wos_strides;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides);
|
||||
NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
|
||||
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)
|
||||
: b_g_k_c_xs_strides;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides);
|
||||
NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
e_g_n_c_wis_lengths, e_g_n_c_wis_strides)
|
||||
: e_g_n_c_wis_strides;
|
||||
|
||||
// populate Ds pointer
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
@@ -737,12 +825,27 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
conv_N_per_block_ = conv_to_gemm_transform_.N_;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
|
||||
const auto a_grid_desc_ak0_m_ak1 = [&]() {
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
|
||||
}
|
||||
else
|
||||
{
|
||||
return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 = [&]() {
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
|
||||
}
|
||||
else
|
||||
{
|
||||
return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
|
||||
}
|
||||
}();
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
|
||||
// populate Ds desc
|
||||
@@ -764,7 +867,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
DLayout,
|
||||
true, /*SplitConvN*/
|
||||
ABDataType,
|
||||
DDataType>;
|
||||
DDataType,
|
||||
1,
|
||||
index_t,
|
||||
CTranspose>;
|
||||
ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
@@ -810,14 +916,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
const auto GemmK = a_grid_desc_m_k.GetLength(I1);
|
||||
const bool HasMainKBlockLoop =
|
||||
GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, k_batch_);
|
||||
GridwiseGemmCTranspose::CalculateHasMainKBlockLoop(GemmK, k_batch_);
|
||||
|
||||
gemm_kernel_args_[gemms_count_ /
|
||||
MaxGroupedGemmGroupsNum][gemms_count_ %
|
||||
MaxGroupedGemmGroupsNum] =
|
||||
GemmArgs{a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
GridwiseGemm::
|
||||
GridwiseGemmCTranspose::
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n),
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -851,8 +957,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_;
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
// Use not modified base strides
|
||||
a_in_transpose_desc_ =
|
||||
@@ -892,8 +997,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
|
||||
a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -908,8 +1012,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
|
||||
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -924,8 +1027,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
|
||||
e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -1030,24 +1132,25 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
const BDataType* p_b_grid = arg.p_b_grid_;
|
||||
EDataType* p_e_grid = arg.p_e_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
}
|
||||
}
|
||||
|
||||
for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size();
|
||||
gemm_set_id++)
|
||||
{
|
||||
@@ -1067,42 +1170,111 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
};
|
||||
|
||||
auto launch_kernel = [&]() {
|
||||
const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
MaxGroupedGemmGroupsNum,
|
||||
GemmArgs,
|
||||
AElementwiseOp,
|
||||
BElementwiseOp,
|
||||
CDEElementwiseOp,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
ElementOp>;
|
||||
bool has_loop_in_all_gemm = true;
|
||||
bool no_loop_in_all_gemm = true;
|
||||
for(auto i = 0; i < gemms_count_for_set; i++)
|
||||
{
|
||||
has_loop_in_all_gemm &= gemm_kernel_args[i].HasMainKBlockLoop_;
|
||||
no_loop_in_all_gemm &= !gemm_kernel_args[i].HasMainKBlockLoop_;
|
||||
}
|
||||
|
||||
return launch_and_time_kernel_with_preprocess(stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
gemm_kernel_args,
|
||||
gemms_count_for_set,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.k_batch_);
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop, auto no_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
constexpr bool no_main_loop = no_main_k_block_loop.value;
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
|
||||
GridwiseGemmCTranspose,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
MaxGroupedGemmGroupsNum,
|
||||
GemmArgs,
|
||||
BElementwiseOp,
|
||||
AElementwiseOp,
|
||||
CDEElementwiseOp,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
ElementOp,
|
||||
has_main_loop,
|
||||
no_main_loop,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_b_grid,
|
||||
p_a_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
gemm_kernel_args,
|
||||
gemms_count_for_set,
|
||||
arg.b_element_op_,
|
||||
arg.a_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.k_batch_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
MaxGroupedGemmGroupsNum,
|
||||
GemmArgs,
|
||||
AElementwiseOp,
|
||||
BElementwiseOp,
|
||||
CDEElementwiseOp,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
ElementOp,
|
||||
has_main_loop,
|
||||
no_main_loop,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
gemm_kernel_args,
|
||||
gemms_count_for_set,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.k_batch_);
|
||||
}
|
||||
};
|
||||
|
||||
ave_time += launch_kernel();
|
||||
if(has_loop_in_all_gemm)
|
||||
{
|
||||
ave_time += launch_kernel(integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(no_loop_in_all_gemm)
|
||||
{
|
||||
ave_time += launch_kernel(integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time += launch_kernel(integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
@@ -1116,9 +1288,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
// Transpose from NGKHW to NHWGK
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
EDataType* p_e_in_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
@@ -1208,8 +1380,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
|
||||
// Transpose from NHWGC to NGCHW
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
@@ -1284,10 +1455,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
}
|
||||
|
||||
const index_t ConvG = arg.b_g_k_c_xs_lengths_[0];
|
||||
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
|
||||
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
|
||||
|
||||
const index_t ConvG = arg.b_g_k_c_xs_lengths_[0];
|
||||
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
|
||||
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
// Specifialization
|
||||
if constexpr(ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
|
||||
@@ -1307,15 +1481,30 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGKHW> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGKDHW>)
|
||||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK> || NeedTransposeKernel)
|
||||
{
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if(is_same_v<ALayout, tensor_layout::convolution::NGKHW> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGKDHW>)
|
||||
{
|
||||
static_assert(NeedTransposeKernel == false);
|
||||
|
||||
if constexpr(ABlockTransferSrcScalarPerVector != 1)
|
||||
{
|
||||
if(ABlockTransferSrcVectorDim != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(output_spatial_acum % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
@@ -1351,10 +1540,20 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
is_same_v<DLayout, tensor_layout::convolution::GC> ||
|
||||
is_same_v<DLayout, tensor_layout::convolution::G_C>)
|
||||
{
|
||||
// vector load D matrix from global memory
|
||||
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
if(CTranspose == false)
|
||||
{
|
||||
ds_valid = false;
|
||||
// vector load D matrix from global memory
|
||||
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
ds_valid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
ds_valid = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -1376,10 +1575,20 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
is_same_v<ELayout, tensor_layout::convolution::NGCHW> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NGCDHW>)
|
||||
{
|
||||
// vector store C matrix into global memory
|
||||
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
if(CTranspose == false)
|
||||
{
|
||||
return false;
|
||||
// vector store C matrix into global memory
|
||||
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -1390,7 +1599,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
// Gridwise GEMM size
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
if(!GridwiseGemmCTranspose::CheckValidity(
|
||||
arg.a_grid_desc_m_k_container_[i],
|
||||
arg.b_grid_desc_n_k_container_[i],
|
||||
arg.ds_grid_desc_m_n_container_[i],
|
||||
@@ -1403,8 +1612,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
if((ConvG * ConvC) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user