mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[CK][CONV] Support NCHW in class DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle (#2375)
1. When conv spec is 1x1 stride1 pad0, nchw is equal with matrix A + column major, we only need minor change in conv transformer to support it. 2. when out is NKHW, it is equal with matrix C with column major. we need swap A & B to get best performance. 3. Add new instance device_grouped_conv_fwd_xdl_f16_nchw_instances for nchw.
This commit is contained in:
@@ -77,7 +77,8 @@ template <typename GridwiseGemm,
|
||||
typename ComputePtrOffsetOfN,
|
||||
bool HasMainKBlockLoop,
|
||||
bool isMultiA,
|
||||
bool isMultiB>
|
||||
bool isMultiB,
|
||||
bool CTranspose>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -171,17 +172,22 @@ __global__ void
|
||||
}
|
||||
else
|
||||
{
|
||||
const long_index_t a_group_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_group_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
|
||||
|
||||
CTranspose
|
||||
? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx))
|
||||
: amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
|
||||
const long_index_t a_group_offset =
|
||||
CTranspose
|
||||
? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx))
|
||||
: amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_n_offset =
|
||||
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
|
||||
const long_index_t a_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
|
||||
p_as_grid + a_group_offset + a_n_offset,
|
||||
p_bs_grid + b_group_offset,
|
||||
p_bs_grid + b_group_offset + b_n_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_group_offset + e_n_offset,
|
||||
p_shared,
|
||||
@@ -335,12 +341,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
static constexpr bool isATensorColMajor =
|
||||
(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) &&
|
||||
(ABlockTransferSrcVectorDim == 1) && (NumGroupsToMerge == 1) &&
|
||||
(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>());
|
||||
|
||||
static constexpr bool NeedTransposeKernel =
|
||||
(isATensorColMajor == false) && (is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>());
|
||||
|
||||
static constexpr bool CTranspose = (NeedTransposeKernel == false) && (isMultiAB == false) &&
|
||||
(is_same_v<ELayout, tensor_layout::convolution::NGKHW> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NGKDHW>);
|
||||
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
ConvForwardSpecialization,
|
||||
true /*SplitN*/,
|
||||
ADataType,
|
||||
EDataType,
|
||||
NumGroupsToMerge>;
|
||||
NumGroupsToMerge,
|
||||
index_t,
|
||||
CTranspose>;
|
||||
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
@@ -361,9 +383,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::NHWGC,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGC, ALay>>;
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::NDHWGC,
|
||||
ALay>>;
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
|
||||
@@ -379,9 +403,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::GKYXC,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::GKZYXC, BLay>>;
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::GKZYXC,
|
||||
BLay>>;
|
||||
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
|
||||
@@ -397,17 +423,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::NHWGK,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGK, ELay>>;
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::NDHWGK,
|
||||
ELay>>;
|
||||
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
|
||||
return out_gemmm_gemmn_desc;
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
constexpr auto matrix_padder_trans =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{NPerBlock, MPerBlock, KPerBlock};
|
||||
return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
}
|
||||
}
|
||||
|
||||
// Shape of Ds and E must be aligned. Strides can be different.
|
||||
@@ -471,11 +504,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType, DoElementwiseBeforeCShuffle
|
||||
|
||||
#define GridwiseGemmCTransposeTemplateParameters \
|
||||
GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \
|
||||
NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \
|
||||
MPerXDL, NXdlPerWave, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
|
||||
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
|
||||
ABlockLdsExtraM, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType, DoElementwiseBeforeCShuffle
|
||||
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm = std::conditional_t<
|
||||
isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmMultiABDTemplateParameters>,
|
||||
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
|
||||
using GridwiseGemmCTranspose = std::conditional_t<
|
||||
CTranspose,
|
||||
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>,
|
||||
GridwiseGemm>;
|
||||
|
||||
// If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers.
|
||||
using APointers =
|
||||
@@ -497,15 +551,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
remove_cvref_t<decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(
|
||||
EGridDesc_M_N{}))>;
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, NPerBlock>;
|
||||
|
||||
using NGCHWTransposeDescType =
|
||||
@@ -612,16 +667,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
|
||||
a_g_n_c_wis_strides_{NeedTransposeKernel
|
||||
? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)
|
||||
: a_g_n_c_wis_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
|
||||
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
|
||||
b_g_k_c_xs_strides_{NeedTransposeKernel
|
||||
? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
|
||||
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)
|
||||
: b_g_k_c_xs_strides},
|
||||
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
|
||||
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
|
||||
e_g_n_k_wos_strides_{NeedTransposeKernel
|
||||
? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)
|
||||
: e_g_n_k_wos_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
@@ -651,7 +712,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
block_2_etile_map_{
|
||||
GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
compute_ptr_offset_of_groups_{},
|
||||
compute_ptr_offset_of_n_{},
|
||||
a_element_op_{a_element_op},
|
||||
@@ -783,24 +845,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
bool valid = false;
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
valid = GridwiseGemmCTranspose::CheckValidity(b_grid_desc_n_k_,
|
||||
a_grid_desc_m_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
valid = GridwiseGemmCTranspose::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_);
|
||||
}
|
||||
if(valid)
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose::
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose::
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
// Use not modified base strides
|
||||
a_in_transpose_desc_ =
|
||||
@@ -835,8 +907,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
|
||||
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -851,8 +922,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
|
||||
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -867,8 +937,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
|
||||
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -1007,7 +1076,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB>;
|
||||
isMultiB,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
@@ -1035,68 +1105,118 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
const ADataType* p_a_grid = arg.p_as_grid_.At(I0);
|
||||
const BDataType* p_b_grid = arg.p_bs_grid_.At(I0);
|
||||
EDataType* p_e_grid = arg.p_e_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() +
|
||||
arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
else if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() +
|
||||
arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() +
|
||||
arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
else if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() +
|
||||
arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
}
|
||||
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
const ADataType*,
|
||||
const BDataType*,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB>;
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
GridwiseGemmCTranspose,
|
||||
const BDataType*,
|
||||
const ADataType*,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
BElementwiseOperation,
|
||||
AElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_b_grid,
|
||||
p_a_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.b_element_op_,
|
||||
arg.a_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
const ADataType*,
|
||||
const BDataType*,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1114,8 +1234,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
float avg_time = 0.f;
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const index_t a_grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
@@ -1166,8 +1285,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
avg_time += RunGemm(arg, stream_config);
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
@@ -1215,9 +1333,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
|
||||
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
|
||||
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
|
||||
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
|
||||
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
|
||||
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
|
||||
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
// check device
|
||||
if(get_device_name() == "gfx908")
|
||||
@@ -1310,7 +1430,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
|
||||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
|
||||
is_same_v<ALayout, ctc::NDHWGC> || is_same_v<ALayout, ctc::NGCW> ||
|
||||
is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
|
||||
NeedTransposeKernel)
|
||||
{
|
||||
// Check access per C
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
|
||||
@@ -1326,6 +1446,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
|
||||
{
|
||||
static_assert(NeedTransposeKernel == false);
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
|
||||
if constexpr(ABlockTransferSrcScalarPerVector != 1)
|
||||
{
|
||||
if(ABlockTransferSrcVectorDim != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(input_spatial_acum % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
@@ -1350,7 +1487,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of Ds
|
||||
bool valid = true;
|
||||
|
||||
@@ -1396,8 +1532,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
@@ -1409,8 +1544,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
@@ -1457,9 +1590,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
is_same_v<ELayout, ctc::NDHWGK> || is_same_v<ELayout, ctc::NGKW> ||
|
||||
is_same_v<ELayout, ctc::NGKHW> || is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
if(CTranspose == false)
|
||||
{
|
||||
return false;
|
||||
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -1483,11 +1629,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return GridwiseGemmCTranspose::CheckValidity(arg.b_grid_desc_n_k_,
|
||||
arg.a_grid_desc_m_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return GridwiseGemmCTranspose::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user