mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK][CONV] Support NCHW in class DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle (#2375)
1. When conv spec is 1x1 stride1 pad0, nchw is equal with matrix A + column major, we only need minor change in conv transformer to support it.
2. when out is NKHW, it is equal with matrix C with column major. we need swap A & B to get best performance.
3. Add new instance device_grouped_conv_fwd_xdl_f16_nchw_instances for nchw.
[ROCm/composable_kernel commit: 1749c0409e]
This commit is contained in:
@@ -77,7 +77,8 @@ template <typename GridwiseGemm,
|
||||
typename ComputePtrOffsetOfN,
|
||||
bool HasMainKBlockLoop,
|
||||
bool isMultiA,
|
||||
bool isMultiB>
|
||||
bool isMultiB,
|
||||
bool CTranspose>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -171,17 +172,22 @@ __global__ void
|
||||
}
|
||||
else
|
||||
{
|
||||
const long_index_t a_group_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_group_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
|
||||
|
||||
CTranspose
|
||||
? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx))
|
||||
: amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
|
||||
const long_index_t a_group_offset =
|
||||
CTranspose
|
||||
? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx))
|
||||
: amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_n_offset =
|
||||
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
|
||||
const long_index_t a_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
|
||||
p_as_grid + a_group_offset + a_n_offset,
|
||||
p_bs_grid + b_group_offset,
|
||||
p_bs_grid + b_group_offset + b_n_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_group_offset + e_n_offset,
|
||||
p_shared,
|
||||
@@ -335,12 +341,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
static constexpr bool isATensorColMajor =
|
||||
(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) &&
|
||||
(ABlockTransferSrcVectorDim == 1) && (NumGroupsToMerge == 1) &&
|
||||
(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>());
|
||||
|
||||
static constexpr bool NeedTransposeKernel =
|
||||
(isATensorColMajor == false) && (is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>());
|
||||
|
||||
static constexpr bool CTranspose = (NeedTransposeKernel == false) && (isMultiAB == false) &&
|
||||
(is_same_v<ELayout, tensor_layout::convolution::NGKHW> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NGKDHW>);
|
||||
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
ConvForwardSpecialization,
|
||||
true /*SplitN*/,
|
||||
ADataType,
|
||||
EDataType,
|
||||
NumGroupsToMerge>;
|
||||
NumGroupsToMerge,
|
||||
index_t,
|
||||
CTranspose>;
|
||||
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
@@ -361,9 +383,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::NHWGC,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGC, ALay>>;
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::NDHWGC,
|
||||
ALay>>;
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
|
||||
@@ -379,9 +403,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::GKYXC,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::GKZYXC, BLay>>;
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::GKZYXC,
|
||||
BLay>>;
|
||||
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
|
||||
@@ -397,17 +423,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::NHWGK,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGK, ELay>>;
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::NDHWGK,
|
||||
ELay>>;
|
||||
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
|
||||
return out_gemmm_gemmn_desc;
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
constexpr auto matrix_padder_trans =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{NPerBlock, MPerBlock, KPerBlock};
|
||||
return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
}
|
||||
}
|
||||
|
||||
// Shape of Ds and E must be aligned. Strides can be different.
|
||||
@@ -471,11 +504,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType, DoElementwiseBeforeCShuffle
|
||||
|
||||
#define GridwiseGemmCTransposeTemplateParameters \
|
||||
GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \
|
||||
NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \
|
||||
MPerXDL, NXdlPerWave, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
|
||||
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
|
||||
ABlockLdsExtraM, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType, DoElementwiseBeforeCShuffle
|
||||
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm = std::conditional_t<
|
||||
isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmMultiABDTemplateParameters>,
|
||||
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
|
||||
using GridwiseGemmCTranspose = std::conditional_t<
|
||||
CTranspose,
|
||||
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>,
|
||||
GridwiseGemm>;
|
||||
|
||||
// If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers.
|
||||
using APointers =
|
||||
@@ -497,15 +551,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
remove_cvref_t<decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(
|
||||
EGridDesc_M_N{}))>;
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, NPerBlock>;
|
||||
|
||||
using NGCHWTransposeDescType =
|
||||
@@ -612,16 +667,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
|
||||
a_g_n_c_wis_strides_{NeedTransposeKernel
|
||||
? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)
|
||||
: a_g_n_c_wis_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
|
||||
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
|
||||
b_g_k_c_xs_strides_{NeedTransposeKernel
|
||||
? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
|
||||
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)
|
||||
: b_g_k_c_xs_strides},
|
||||
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
|
||||
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
|
||||
e_g_n_k_wos_strides_{NeedTransposeKernel
|
||||
? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)
|
||||
: e_g_n_k_wos_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
@@ -651,7 +712,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
block_2_etile_map_{
|
||||
GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
compute_ptr_offset_of_groups_{},
|
||||
compute_ptr_offset_of_n_{},
|
||||
a_element_op_{a_element_op},
|
||||
@@ -783,24 +845,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
bool valid = false;
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
valid = GridwiseGemmCTranspose::CheckValidity(b_grid_desc_n_k_,
|
||||
a_grid_desc_m_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
valid = GridwiseGemmCTranspose::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_);
|
||||
}
|
||||
if(valid)
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose::
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose::
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
// Use not modified base strides
|
||||
a_in_transpose_desc_ =
|
||||
@@ -835,8 +907,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
|
||||
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -851,8 +922,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
|
||||
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -867,8 +937,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
|
||||
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -1007,7 +1076,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB>;
|
||||
isMultiB,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
@@ -1035,68 +1105,118 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
const ADataType* p_a_grid = arg.p_as_grid_.At(I0);
|
||||
const BDataType* p_b_grid = arg.p_bs_grid_.At(I0);
|
||||
EDataType* p_e_grid = arg.p_e_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() +
|
||||
arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
else if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() +
|
||||
arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() +
|
||||
arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
else if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() +
|
||||
arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
}
|
||||
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
const ADataType*,
|
||||
const BDataType*,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB>;
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
GridwiseGemmCTranspose,
|
||||
const BDataType*,
|
||||
const ADataType*,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
BElementwiseOperation,
|
||||
AElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_b_grid,
|
||||
p_a_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.b_element_op_,
|
||||
arg.a_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
const ADataType*,
|
||||
const BDataType*,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1114,8 +1234,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
float avg_time = 0.f;
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const index_t a_grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
@@ -1166,8 +1285,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
avg_time += RunGemm(arg, stream_config);
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
@@ -1215,9 +1333,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
|
||||
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
|
||||
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
|
||||
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
|
||||
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
|
||||
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
|
||||
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
// check device
|
||||
if(get_device_name() == "gfx908")
|
||||
@@ -1310,7 +1430,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
|
||||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
|
||||
is_same_v<ALayout, ctc::NDHWGC> || is_same_v<ALayout, ctc::NGCW> ||
|
||||
is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
|
||||
NeedTransposeKernel)
|
||||
{
|
||||
// Check access per C
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
|
||||
@@ -1326,6 +1446,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
|
||||
{
|
||||
static_assert(NeedTransposeKernel == false);
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
|
||||
if constexpr(ABlockTransferSrcScalarPerVector != 1)
|
||||
{
|
||||
if(ABlockTransferSrcVectorDim != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(input_spatial_acum % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
@@ -1350,7 +1487,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of Ds
|
||||
bool valid = true;
|
||||
|
||||
@@ -1396,8 +1532,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
@@ -1409,8 +1544,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
@@ -1457,9 +1590,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
is_same_v<ELayout, ctc::NDHWGK> || is_same_v<ELayout, ctc::NGKW> ||
|
||||
is_same_v<ELayout, ctc::NGKHW> || is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
if(CTranspose == false)
|
||||
{
|
||||
return false;
|
||||
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -1483,11 +1629,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return GridwiseGemmCTranspose::CheckValidity(arg.b_grid_desc_n_k_,
|
||||
arg.a_grid_desc_m_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return GridwiseGemmCTranspose::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,8 @@ template <index_t NDimSpatial,
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
index_t NumGroupsToMerge = 1,
|
||||
typename IndexType = index_t>
|
||||
typename IndexType = index_t,
|
||||
bool CTranspose = false>
|
||||
struct TransformConvFwdToGemm
|
||||
{
|
||||
private:
|
||||
@@ -1253,6 +1254,83 @@ struct TransformConvFwdToGemm
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename ck::enable_if<NDimSpatial == 1 &&
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGCW>,
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeADescriptor_M_K() const
|
||||
{
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
static_assert(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0);
|
||||
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Wo_)), make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename ck::enable_if<NDimSpatial == 2 &&
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGCHW>,
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeADescriptor_M_K() const
|
||||
{
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
static_assert(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0);
|
||||
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Ho_ * Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename ck::enable_if<NDimSpatial == 3 &&
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGCDHW>,
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeADescriptor_M_K() const
|
||||
{
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
static_assert(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0);
|
||||
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Do_ * Ho_ * Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <typename BLayout,
|
||||
typename ck::enable_if<is_same_v<BLayout, tensor_layout::convolution::GKCX> ||
|
||||
is_same_v<BLayout, tensor_layout::convolution::GKCYX> ||
|
||||
is_same_v<BLayout, tensor_layout::convolution::GKCZYX>,
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeBDescriptor_N_K() const
|
||||
{
|
||||
static_assert(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 ||
|
||||
ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Pad0);
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, C_));
|
||||
}
|
||||
|
||||
template <typename BLayout,
|
||||
typename ck::enable_if<is_same_v<BLayout, tensor_layout::convolution::GKXC> ||
|
||||
is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
|
||||
@@ -1338,8 +1416,16 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
|
||||
make_tuple(KStrideTensorC_, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -1350,8 +1436,16 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_),
|
||||
make_tuple(KStrideTensorC_, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -1362,12 +1456,21 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_),
|
||||
make_tuple(KStrideTensorC_, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename ck::enable_if<NDimSp == 1 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
|
||||
@@ -1375,6 +1478,7 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
static_assert(CTranspose == false);
|
||||
const IndexType NDoHoWo = N_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
@@ -1429,6 +1533,7 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
static_assert(CTranspose == false);
|
||||
const IndexType NDoHoWo = N_ * Ho_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
@@ -1486,7 +1591,7 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
|
||||
static_assert(CTranspose == false);
|
||||
const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
@@ -1536,6 +1641,101 @@ struct TransformConvFwdToGemm
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename ck::enable_if<NDimSp == 1 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::GNKW> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NGKW>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
auto n_k_wo_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, K_, Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1));
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
n_k_wo_desc,
|
||||
make_tuple(make_pass_through_transform(K_),
|
||||
make_merge_transform(make_tuple(N_, Wo_))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(n_k_wo_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename ck::enable_if<NDimSp == 2 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::GNKHW> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NGKHW>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
auto n_k_howo_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, K_, Ho_ * Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1));
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
n_k_howo_desc,
|
||||
make_tuple(make_pass_through_transform(K_),
|
||||
make_merge_transform(make_tuple(N_, Ho_ * Wo_))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
n_k_howo_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename ck::enable_if<NDimSp == 3 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::GNKDHW> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NGKDHW>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
auto n_k_dohowo_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, K_, Do_ * Ho_ * Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1));
|
||||
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
n_k_dohowo_desc,
|
||||
make_tuple(make_pass_through_transform(K_),
|
||||
make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
n_k_dohowo_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
IndexType N_;
|
||||
IndexType Di_, Hi_, Wi_;
|
||||
IndexType Do_, Ho_, Wo_;
|
||||
|
||||
@@ -179,6 +179,38 @@ using device_grouped_conv_fwd_xdl_f16_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec,
|
||||
typename DsDataTypes = Tuple<>,
|
||||
typename OutElementOp = PassThrough>
|
||||
using device_grouped_conv_fwd_xdl_f16_nchw_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>,
|
||||
// 32x32 instance
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
// 16x16 instance
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -53,6 +53,15 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances_shard(
|
||||
ConvFwd1x1S1P0>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
add_device_operation_instances(instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f16_nchw_instances<2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwd1x1S1P0>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace ck::tensor_operation::device::instance
|
||||
|
||||
@@ -31,6 +31,14 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f16_nchw_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwd1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -47,6 +47,15 @@ void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instances(
|
||||
Empty_Tuple,
|
||||
NGKDHW,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f16_nchw_instances<3,
|
||||
NGCDHW,
|
||||
GKCZYX,
|
||||
Empty_Tuple,
|
||||
NGKDHW,
|
||||
ConvFwd1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
Reference in New Issue
Block a user