mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Improve XDL to WMMA porting for grouped conv fwd (#3456)
Refactors the way the number of XDL (matrix multiply-accumulate) instructions per wave is calculated and used in the grouped convolution forward implementations, especially to better support WMMA (Wave Matrix Multiply-Accumulate) instructions and 16x16 tiles. The changes use MXdlPerWave instead of NXdlPerWave to increase number of waves per M dim.
This commit is contained in:
@@ -190,9 +190,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
using DeviceOp = DeviceBatchedGemmGemm_Xdl_CShuffle;
|
||||
|
||||
static constexpr auto MXdlPerWave64 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
@@ -235,20 +235,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle;
|
||||
|
||||
static constexpr auto Gemm0MXdlPerWave64 = GetNXdlPerWave2<BlockSize,
|
||||
Gemm0NPerBlock,
|
||||
Gemm0MPerBlock,
|
||||
Gemm0NPerXdl,
|
||||
Gemm0MPerXdl,
|
||||
Gemm0NXdlPerWave,
|
||||
true>();
|
||||
static constexpr auto Gemm0MXdlPerWave32 = GetNXdlPerWave2<BlockSize,
|
||||
Gemm0NPerBlock,
|
||||
Gemm0MPerBlock,
|
||||
Gemm0NPerXdl,
|
||||
Gemm0MPerXdl,
|
||||
Gemm0NXdlPerWave,
|
||||
false>();
|
||||
static constexpr auto Gemm0MXdlPerWave64 = GetXdlPerWave2<BlockSize,
|
||||
Gemm0NPerBlock,
|
||||
Gemm0MPerBlock,
|
||||
Gemm0NPerXdl,
|
||||
Gemm0MPerXdl,
|
||||
Gemm0NXdlPerWave,
|
||||
true>();
|
||||
static constexpr auto Gemm0MXdlPerWave32 = GetXdlPerWave2<BlockSize,
|
||||
Gemm0NPerBlock,
|
||||
Gemm0MPerBlock,
|
||||
Gemm0NPerXdl,
|
||||
Gemm0MPerXdl,
|
||||
Gemm0NXdlPerWave,
|
||||
false>();
|
||||
|
||||
static constexpr index_t NumD0Tensor = D0sDataType::Size();
|
||||
static constexpr index_t NumD1Tensor = D1sDataType::Size();
|
||||
|
||||
@@ -223,9 +223,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
MaskingSpec>
|
||||
{
|
||||
static constexpr auto MXdlPerWave64 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
|
||||
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
|
||||
"Number of dimension must be greater than 0");
|
||||
|
||||
@@ -211,9 +211,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
|
||||
using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle;
|
||||
static constexpr auto MXdlPerWave64 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
@@ -325,9 +325,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
BComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
|
||||
GET_NXDL_PER_WAVE_IMPL
|
||||
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
|
||||
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
|
||||
GET_MXDL_PER_WAVE_IMPL
|
||||
// Force usage of 16x16 instruction for WMMA
|
||||
static constexpr index_t Wave32MaxMNPerXDL = 16;
|
||||
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetMXdlPerWave<false,
|
||||
Wave32MaxMNPerXDL,
|
||||
Wave32MaxMNPerXDL,
|
||||
NXdlPerWave*(NPerXDL / Wave32MaxMNPerXDL)>();
|
||||
|
||||
static_assert(NumGroupsToMerge >= 1);
|
||||
|
||||
@@ -486,35 +492,36 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
using GemmADataType = std::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
|
||||
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
|
||||
|
||||
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
|
||||
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
|
||||
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
|
||||
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
|
||||
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
|
||||
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
|
||||
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
|
||||
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
|
||||
KPerBlock, AK1, BK1, MPerXDL_, NPerXDL_, MXdlPerWave_, NXdlPerWave*(NPerXDL / NPerXDL_), \
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
|
||||
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
|
||||
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_), \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType
|
||||
|
||||
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
|
||||
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
|
||||
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
|
||||
NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
|
||||
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
|
||||
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
|
||||
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
|
||||
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
|
||||
BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
|
||||
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL_, \
|
||||
NPerXDL_, MXdlPerWave_, NXdlPerWave*(NPerXDL / NPerXDL_), \
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
|
||||
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
|
||||
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_), \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType, DoElementwiseBeforeCShuffle
|
||||
@@ -523,7 +530,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \
|
||||
NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \
|
||||
MPerXDL, NXdlPerWave_, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
MPerXDL, NXdlPerWave, MXdlPerWave_, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
@@ -536,34 +543,35 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
BComputeDataType, DoElementwiseBeforeCShuffle
|
||||
|
||||
// Use appropriate gridwise gemm
|
||||
template <index_t NXdlPerWave_>
|
||||
template <index_t MXdlPerWave_, index_t MPerXDL_, index_t NPerXDL_>
|
||||
using GridwiseGemmMultipleABDBase = GridwiseGemmMultipleABD_xdl_cshuffle<
|
||||
CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
|
||||
template <index_t NXdlPerWave_>
|
||||
template <index_t MXdlPerWave_, index_t MPerXDL_, index_t NPerXDL_>
|
||||
using GridwiseGemmMultipleDBase = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
|
||||
template <index_t NXdlPerWave_>
|
||||
template <index_t MXdlPerWave_>
|
||||
using GridwiseGemmMultipleDCTransposeBase = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
|
||||
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
|
||||
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
|
||||
#undef CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
|
||||
|
||||
using GridwiseGemm64 =
|
||||
std::conditional_t<isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABDBase<math::max(NXdlPerWave64, 1)>,
|
||||
GridwiseGemmMultipleDBase<math::max(NXdlPerWave64, 1)>>;
|
||||
using GridwiseGemm32 = std::conditional_t<isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABDBase<NXdlPerWave32>,
|
||||
GridwiseGemmMultipleDBase<NXdlPerWave32>>;
|
||||
using GridwiseGemm64 = std::conditional_t<
|
||||
isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABDBase<math::max(MXdlPerWave64, 1), MPerXDL, NPerXDL>,
|
||||
GridwiseGemmMultipleDBase<math::max(MXdlPerWave64, 1), MPerXDL, NPerXDL>>;
|
||||
using GridwiseGemm32 = std::conditional_t<
|
||||
isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABDBase<MXdlPerWave32, Wave32MaxMNPerXDL, Wave32MaxMNPerXDL>,
|
||||
GridwiseGemmMultipleDBase<MXdlPerWave32, Wave32MaxMNPerXDL, Wave32MaxMNPerXDL>>;
|
||||
|
||||
using GridwiseGemmCTranspose64 =
|
||||
std::conditional_t<CTranspose,
|
||||
GridwiseGemmMultipleDCTransposeBase<math::max(NXdlPerWave64, 1)>,
|
||||
GridwiseGemmMultipleDCTransposeBase<math::max(MXdlPerWave64, 1)>,
|
||||
GridwiseGemm64>;
|
||||
using GridwiseGemmCTranspose32 =
|
||||
std::conditional_t<CTranspose,
|
||||
GridwiseGemmMultipleDCTransposeBase<NXdlPerWave32>,
|
||||
GridwiseGemmMultipleDCTransposeBase<MXdlPerWave32>,
|
||||
GridwiseGemm32>;
|
||||
|
||||
// If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers.
|
||||
@@ -913,14 +921,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
if constexpr(MXdlPerWave64 > 0)
|
||||
{
|
||||
InitGridDesc<GridwiseGemm64, GridwiseGemmCTranspose64>();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
if constexpr(MXdlPerWave32 > 0)
|
||||
{
|
||||
InitGridDesc<GridwiseGemm32, GridwiseGemmCTranspose32>();
|
||||
}
|
||||
@@ -1388,7 +1396,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
if constexpr(MXdlPerWave64 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm64, GridwiseGemmCTranspose64>(arg, stream_config);
|
||||
}
|
||||
@@ -1399,7 +1407,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
if constexpr(MXdlPerWave32 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm32, GridwiseGemmCTranspose32>(arg, stream_config);
|
||||
}
|
||||
@@ -1436,7 +1444,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
if(!ck::is_xdl_wmma_supported<AComputeDataType, BComputeDataType, MPerXDL, NPerXDL>())
|
||||
if(!ck::is_xdl_wmma_supported<AComputeDataType,
|
||||
BComputeDataType,
|
||||
Wave32MaxMNPerXDL,
|
||||
Wave32MaxMNPerXDL>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1720,7 +1731,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// check Gridwise GEMM
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
if constexpr(MXdlPerWave64 > 0)
|
||||
{
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
@@ -1759,7 +1770,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
else
|
||||
{
|
||||
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
if constexpr(MXdlPerWave32 > 0)
|
||||
{
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
@@ -2047,8 +2058,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
|
||||
<< "<"
|
||||
str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
|
||||
|
||||
if(get_warp_size() != 64) {
|
||||
str << "_WmmaPorted";
|
||||
}
|
||||
|
||||
str << "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
|
||||
@@ -400,9 +400,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
BComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
GET_NXDL_PER_WAVE_IMPL
|
||||
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
|
||||
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
|
||||
GET_MXDL_PER_WAVE_IMPL
|
||||
// Force usage of 16x16 instruction for WMMA
|
||||
static constexpr index_t Wave32MaxMNPerXDL = 16;
|
||||
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetMXdlPerWave<false,
|
||||
Wave32MaxMNPerXDL,
|
||||
Wave32MaxMNPerXDL,
|
||||
NXdlPerWave*(NPerXDL / Wave32MaxMNPerXDL)>();
|
||||
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
@@ -563,7 +569,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
: BBlockTransferSrcScalarPerVector;
|
||||
|
||||
// Use appropriate gridwise gemm
|
||||
template <index_t NXdlPerWave_>
|
||||
template <index_t MXdlPerWave_, index_t MPerXDL_, index_t NPerXDL_>
|
||||
using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3<
|
||||
tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
@@ -585,10 +591,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave_,
|
||||
MPerXDL_,
|
||||
NPerXDL_,
|
||||
MXdlPerWave_,
|
||||
NXdlPerWave*(NPerXDL / NPerXDL_),
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -606,7 +612,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_),
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
@@ -617,8 +623,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
BDataType,
|
||||
DoElementwiseBeforeCShuffle,
|
||||
DirectLoad>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(MXdlPerWave64, 1), MPerXDL, NPerXDL>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<MXdlPerWave32, Wave32MaxMNPerXDL, Wave32MaxMNPerXDL>;
|
||||
|
||||
// #undef GridwiseGemmV3TemplateParams
|
||||
|
||||
@@ -1430,7 +1436,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
return avg_time;
|
||||
}
|
||||
|
||||
INVOKER_RUN_IMPL
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(MXdlPerWave64 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm64>(arg, stream_config);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(MXdlPerWave32 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm32>(arg, stream_config);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
@@ -1483,7 +1506,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
}
|
||||
|
||||
if(!ck::is_xdl_wmma_supported<AComputeDataType, BComputeDataType, MPerXDL, NPerXDL>())
|
||||
if(!ck::is_xdl_wmma_supported<AComputeDataType,
|
||||
BComputeDataType,
|
||||
Wave32MaxMNPerXDL,
|
||||
Wave32MaxMNPerXDL>())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
@@ -1758,7 +1784,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
if constexpr(MXdlPerWave64 > 0)
|
||||
{
|
||||
typename GridwiseGemm64::Argument gemm_arg{nullptr,
|
||||
nullptr,
|
||||
@@ -1780,7 +1806,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
if constexpr(MXdlPerWave32 > 0)
|
||||
{
|
||||
typename GridwiseGemm32::Argument gemm_arg{nullptr,
|
||||
nullptr,
|
||||
@@ -2064,6 +2090,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
|
||||
|
||||
if(get_warp_size() != 64) {
|
||||
str << "_WmmaPorted";
|
||||
}
|
||||
|
||||
if constexpr(DirectLoad) {
|
||||
str << "_DirectLoad";
|
||||
}
|
||||
|
||||
@@ -206,9 +206,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
BComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
|
||||
GET_NXDL_PER_WAVE_IMPL
|
||||
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
|
||||
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
|
||||
GET_MXDL_PER_WAVE_IMPL
|
||||
// Force usage of 16x16 instruction for WMMA
|
||||
static constexpr index_t Wave32MaxMNPerXDL = 16;
|
||||
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetMXdlPerWave<false,
|
||||
Wave32MaxMNPerXDL,
|
||||
Wave32MaxMNPerXDL,
|
||||
NXdlPerWave*(NPerXDL / Wave32MaxMNPerXDL)>();
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr index_t MaxGemmsNum = 32;
|
||||
@@ -409,25 +415,26 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS \
|
||||
ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
|
||||
AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
|
||||
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
|
||||
NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
|
||||
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
|
||||
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
|
||||
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
|
||||
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
|
||||
BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
|
||||
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL_, \
|
||||
NPerXDL_, MXdlPerWave_, NXdlPerWave*(NPerXDL / NPerXDL_), \
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
|
||||
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
|
||||
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_), \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
AComputeDataType, DoElementwiseBeforeCShuffle
|
||||
// Use appropriate gridwise gemm
|
||||
template <index_t NXdlPerWave_>
|
||||
template <index_t MXdlPerWave_, index_t MPerXDL_, index_t NPerXDL_>
|
||||
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS>;
|
||||
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(MXdlPerWave64, 1), MPerXDL, NPerXDL>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<MXdlPerWave32, Wave32MaxMNPerXDL, Wave32MaxMNPerXDL>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
@@ -607,7 +614,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
if constexpr(MXdlPerWave64 > 0)
|
||||
{
|
||||
init_gemm_args<GridwiseGemm64>(a_grid_ptrs[i],
|
||||
static_cast<const BDataType*>(p_b),
|
||||
@@ -624,7 +631,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
if constexpr(MXdlPerWave32 > 0)
|
||||
{
|
||||
init_gemm_args<GridwiseGemm32>(a_grid_ptrs[i],
|
||||
static_cast<const BDataType*>(p_b),
|
||||
@@ -769,7 +776,24 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
}
|
||||
}
|
||||
|
||||
INVOKER_RUN_IMPL
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(MXdlPerWave64 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm64>(arg, stream_config);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(MXdlPerWave32 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm32>(arg, stream_config);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
@@ -822,7 +846,10 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if(!ck::is_xdl_wmma_supported<AComputeDataType, BComputeDataType, MPerXDL, NPerXDL>())
|
||||
if(!ck::is_xdl_wmma_supported<AComputeDataType,
|
||||
BComputeDataType,
|
||||
Wave32MaxMNPerXDL,
|
||||
Wave32MaxMNPerXDL>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1205,8 +1232,12 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
|
||||
<< "<"
|
||||
str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor";
|
||||
if(get_warp_size() != 64) {
|
||||
str << "_WmmaPorted";
|
||||
}
|
||||
|
||||
str << "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
|
||||
@@ -206,9 +206,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
MaskingSpec>
|
||||
{
|
||||
static constexpr auto MXdlPerWave64 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
|
||||
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
|
||||
"Number of dimension must be greater than 0");
|
||||
|
||||
Reference in New Issue
Block a user