Improve XDL to WMMA porting for grouped conv fwd (#3456)

Refactors the way the number of XDL (matrix multiply-accumulate) instructions per wave is calculated and used in the grouped convolution forward implementations, especially to better support WMMA (Wave Matrix Multiply-Accumulate) instructions and 16x16 tiles. 
The changes use MXdlPerWave instead of NXdlPerWave to increase number of waves per M dim.
This commit is contained in:
Bartłomiej Kocot
2025-12-19 23:58:51 +01:00
committed by GitHub
parent 2d9c962e2c
commit cbc8335964
13 changed files with 226 additions and 133 deletions

View File

@@ -33,7 +33,7 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
.with_thread_block(FwdThreadBlock_64_64x32x32)
.with_gemm_config(FwdGemmParams_Xdl_2x2_per_wave)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x16x1)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);

View File

@@ -28,7 +28,7 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_128x128x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)

View File

@@ -111,8 +111,8 @@ struct DefaultAlgorithm
.bk1 = 8,
.m_per_xdl = 16,
.n_per_xdl = 16,
.m_xdl_per_wave = 4,
.n_xdl_per_wave = 4};
.m_xdl_per_wave = 8,
.n_xdl_per_wave = 8};
ckb::test::TransferABC transfer{
.a =
@@ -188,7 +188,7 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
" ├─ Pipeline scheduler: INTRAWAVE\n"
" ├─ Warp Gemm parameters: \n"
" │ ├─ subtile size: 16×16\n"
" │ └─ Number of warp gemm iterations: 4×4\n"
" │ └─ Number of warp gemm iterations: 8×8\n"
" └─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"

View File

@@ -68,7 +68,7 @@ constexpr TransferABC FwdTransfer_4x64x1{
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.scalar_per_vector = 4},
},
};

View File

@@ -60,7 +60,7 @@ template <index_t BlockSize_,
index_t NPerXDL_,
index_t MXdlPerWave_,
bool IsWave64>
static constexpr auto GetNXdlPerWave2()
static constexpr auto GetXdlPerWave2()
{
constexpr index_t Waves = IsWave64 ? BlockSize_ / 64 : BlockSize_ / 32;
constexpr index_t MWaves = MPerBlock_ / (MXdlPerWave_ * MPerXDL_);
@@ -84,17 +84,33 @@ static constexpr auto GetNXdlPerWave2()
}
}
#define GET_NXDL_PER_WAVE_IMPL \
template <bool IsWave64> \
static constexpr auto GetNXdlPerWave() \
{ \
return GetNXdlPerWave2<BlockSize, \
MPerBlock, \
NPerBlock, \
MPerXDL, \
NPerXDL, \
MXdlPerWave, \
IsWave64>(); \
#define GET_NXDL_PER_WAVE_IMPL \
template <bool IsWave64> \
static constexpr auto GetNXdlPerWave() \
{ \
return GetXdlPerWave2<BlockSize, \
MPerBlock, \
NPerBlock, \
MPerXDL, \
NPerXDL, \
MXdlPerWave, \
IsWave64>(); \
}
#define GET_MXDL_PER_WAVE_IMPL \
template <bool IsWave64, \
index_t MPerXDLAligned = MPerXDL, \
index_t NPerXDLAligned = NPerXDL, \
index_t NXdlPerWaveAligned = NXdlPerWave> \
static constexpr auto GetMXdlPerWave() \
{ \
return GetXdlPerWave2<BlockSize, \
NPerBlock, \
MPerBlock, \
NPerXDLAligned, \
MPerXDLAligned, \
NXdlPerWaveAligned, \
IsWave64>(); \
}
template <index_t BlockSize_,
@@ -114,14 +130,14 @@ static constexpr auto GetWarpTileConfig()
constexpr auto NXdlPerWave =
IsWave64
? GetNXdlPerWave2<BlockSize_,
MPerBlock_,
NPerBlock_,
MPerXDL_,
NPerXDL_,
MXdlPerWave_,
true>()
: GetNXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
? GetXdlPerWave2<BlockSize_,
MPerBlock_,
NPerBlock_,
MPerXDL_,
NPerXDL_,
MXdlPerWave_,
true>()
: GetXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
if constexpr(IsWave64 == false && NXdlPerWave != 0)
{

View File

@@ -190,9 +190,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
using DeviceOp = DeviceBatchedGemmGemm_Xdl_CShuffle;
static constexpr auto MXdlPerWave64 =
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
static constexpr auto MXdlPerWave32 =
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};

View File

@@ -235,20 +235,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
{
using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle;
static constexpr auto Gemm0MXdlPerWave64 = GetNXdlPerWave2<BlockSize,
Gemm0NPerBlock,
Gemm0MPerBlock,
Gemm0NPerXdl,
Gemm0MPerXdl,
Gemm0NXdlPerWave,
true>();
static constexpr auto Gemm0MXdlPerWave32 = GetNXdlPerWave2<BlockSize,
Gemm0NPerBlock,
Gemm0MPerBlock,
Gemm0NPerXdl,
Gemm0MPerXdl,
Gemm0NXdlPerWave,
false>();
static constexpr auto Gemm0MXdlPerWave64 = GetXdlPerWave2<BlockSize,
Gemm0NPerBlock,
Gemm0MPerBlock,
Gemm0NPerXdl,
Gemm0MPerXdl,
Gemm0NXdlPerWave,
true>();
static constexpr auto Gemm0MXdlPerWave32 = GetXdlPerWave2<BlockSize,
Gemm0NPerBlock,
Gemm0MPerBlock,
Gemm0NPerXdl,
Gemm0MPerXdl,
Gemm0NXdlPerWave,
false>();
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();

View File

@@ -223,9 +223,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
MaskingSpec>
{
static constexpr auto MXdlPerWave64 =
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
static constexpr auto MXdlPerWave32 =
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");

View File

@@ -211,9 +211,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle;
static constexpr auto MXdlPerWave64 =
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
static constexpr auto MXdlPerWave32 =
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};

View File

@@ -325,9 +325,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
BComputeDataType>
{
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
GET_NXDL_PER_WAVE_IMPL
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
GET_MXDL_PER_WAVE_IMPL
// Force usage of 16x16 instruction for WMMA
static constexpr index_t Wave32MaxMNPerXDL = 16;
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
static constexpr auto MXdlPerWave32 =
GetMXdlPerWave<false,
Wave32MaxMNPerXDL,
Wave32MaxMNPerXDL,
NXdlPerWave*(NPerXDL / Wave32MaxMNPerXDL)>();
static_assert(NumGroupsToMerge >= 1);
@@ -486,35 +492,36 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
using GemmADataType = std::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
KPerBlock, AK1, BK1, MPerXDL_, NPerXDL_, MXdlPerWave_, NXdlPerWave*(NPerXDL / NPerXDL_), \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_), \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL_, \
NPerXDL_, MXdlPerWave_, NXdlPerWave*(NPerXDL / NPerXDL_), \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_), \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType, DoElementwiseBeforeCShuffle
@@ -523,7 +530,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \
NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \
MPerXDL, NXdlPerWave_, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
MPerXDL, NXdlPerWave, MXdlPerWave_, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
@@ -536,34 +543,35 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
BComputeDataType, DoElementwiseBeforeCShuffle
// Use appropriate gridwise gemm
template <index_t NXdlPerWave_>
template <index_t MXdlPerWave_, index_t MPerXDL_, index_t NPerXDL_>
using GridwiseGemmMultipleABDBase = GridwiseGemmMultipleABD_xdl_cshuffle<
CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
template <index_t NXdlPerWave_>
template <index_t MXdlPerWave_, index_t MPerXDL_, index_t NPerXDL_>
using GridwiseGemmMultipleDBase = GridwiseGemmMultipleD_xdl_cshuffle<
CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
template <index_t NXdlPerWave_>
template <index_t MXdlPerWave_>
using GridwiseGemmMultipleDCTransposeBase = GridwiseGemmMultipleD_xdl_cshuffle<
CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
#undef CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
using GridwiseGemm64 =
std::conditional_t<isMultiA || isMultiB,
GridwiseGemmMultipleABDBase<math::max(NXdlPerWave64, 1)>,
GridwiseGemmMultipleDBase<math::max(NXdlPerWave64, 1)>>;
using GridwiseGemm32 = std::conditional_t<isMultiA || isMultiB,
GridwiseGemmMultipleABDBase<NXdlPerWave32>,
GridwiseGemmMultipleDBase<NXdlPerWave32>>;
using GridwiseGemm64 = std::conditional_t<
isMultiA || isMultiB,
GridwiseGemmMultipleABDBase<math::max(MXdlPerWave64, 1), MPerXDL, NPerXDL>,
GridwiseGemmMultipleDBase<math::max(MXdlPerWave64, 1), MPerXDL, NPerXDL>>;
using GridwiseGemm32 = std::conditional_t<
isMultiA || isMultiB,
GridwiseGemmMultipleABDBase<MXdlPerWave32, Wave32MaxMNPerXDL, Wave32MaxMNPerXDL>,
GridwiseGemmMultipleDBase<MXdlPerWave32, Wave32MaxMNPerXDL, Wave32MaxMNPerXDL>>;
using GridwiseGemmCTranspose64 =
std::conditional_t<CTranspose,
GridwiseGemmMultipleDCTransposeBase<math::max(NXdlPerWave64, 1)>,
GridwiseGemmMultipleDCTransposeBase<math::max(MXdlPerWave64, 1)>,
GridwiseGemm64>;
using GridwiseGemmCTranspose32 =
std::conditional_t<CTranspose,
GridwiseGemmMultipleDCTransposeBase<NXdlPerWave32>,
GridwiseGemmMultipleDCTransposeBase<MXdlPerWave32>,
GridwiseGemm32>;
// If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers.
@@ -913,14 +921,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
if(get_warp_size() == 64)
{
if constexpr(NXdlPerWave64 > 0)
if constexpr(MXdlPerWave64 > 0)
{
InitGridDesc<GridwiseGemm64, GridwiseGemmCTranspose64>();
}
}
else
{
if constexpr(NXdlPerWave32 > 0)
if constexpr(MXdlPerWave32 > 0)
{
InitGridDesc<GridwiseGemm32, GridwiseGemmCTranspose32>();
}
@@ -1388,7 +1396,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
if(get_warp_size() == 64)
{
if constexpr(NXdlPerWave64 > 0)
if constexpr(MXdlPerWave64 > 0)
{
return RunImp<GridwiseGemm64, GridwiseGemmCTranspose64>(arg, stream_config);
}
@@ -1399,7 +1407,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
if constexpr(NXdlPerWave32 > 0)
if constexpr(MXdlPerWave32 > 0)
{
return RunImp<GridwiseGemm32, GridwiseGemmCTranspose32>(arg, stream_config);
}
@@ -1436,7 +1444,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
if(!ck::is_xdl_wmma_supported<AComputeDataType, BComputeDataType, MPerXDL, NPerXDL>())
if(!ck::is_xdl_wmma_supported<AComputeDataType,
BComputeDataType,
Wave32MaxMNPerXDL,
Wave32MaxMNPerXDL>())
{
return false;
}
@@ -1720,7 +1731,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// check Gridwise GEMM
if(get_warp_size() == 64)
{
if constexpr(NXdlPerWave64 > 0)
if constexpr(MXdlPerWave64 > 0)
{
if constexpr(isMultiA || isMultiB)
{
@@ -1759,7 +1770,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
else
{
if constexpr(NXdlPerWave32 > 0)
if constexpr(MXdlPerWave32 > 0)
{
if constexpr(isMultiA || isMultiB)
{
@@ -2047,8 +2058,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
<< "<"
str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
if(get_warp_size() != 64) {
str << "_WmmaPorted";
}
str << "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "

View File

@@ -400,9 +400,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
BComputeDataType>
{
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
GET_NXDL_PER_WAVE_IMPL
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
GET_MXDL_PER_WAVE_IMPL
// Force usage of 16x16 instruction for WMMA
static constexpr index_t Wave32MaxMNPerXDL = 16;
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
static constexpr auto MXdlPerWave32 =
GetMXdlPerWave<false,
Wave32MaxMNPerXDL,
Wave32MaxMNPerXDL,
NXdlPerWave*(NPerXDL / Wave32MaxMNPerXDL)>();
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
@@ -563,7 +569,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
: BBlockTransferSrcScalarPerVector;
// Use appropriate gridwise gemm
template <index_t NXdlPerWave_>
template <index_t MXdlPerWave_, index_t MPerXDL_, index_t NPerXDL_>
using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3<
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
@@ -585,10 +591,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave_,
MPerXDL_,
NPerXDL_,
MXdlPerWave_,
NXdlPerWave*(NPerXDL / NPerXDL_),
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
@@ -606,7 +612,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_),
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVectors,
BlkGemmPipeSched,
@@ -617,8 +623,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
BDataType,
DoElementwiseBeforeCShuffle,
DirectLoad>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(MXdlPerWave64, 1), MPerXDL, NPerXDL>;
using GridwiseGemm32 = GridwiseGemmBase<MXdlPerWave32, Wave32MaxMNPerXDL, Wave32MaxMNPerXDL>;
// #undef GridwiseGemmV3TemplateParams
@@ -1430,7 +1436,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return avg_time;
}
INVOKER_RUN_IMPL
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(get_warp_size() == 64)
{
if constexpr(MXdlPerWave64 > 0)
{
return RunImp<GridwiseGemm64>(arg, stream_config);
}
}
else
{
if constexpr(MXdlPerWave32 > 0)
{
return RunImp<GridwiseGemm32>(arg, stream_config);
}
}
return 0;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
@@ -1483,7 +1506,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
}
if(!ck::is_xdl_wmma_supported<AComputeDataType, BComputeDataType, MPerXDL, NPerXDL>())
if(!ck::is_xdl_wmma_supported<AComputeDataType,
BComputeDataType,
Wave32MaxMNPerXDL,
Wave32MaxMNPerXDL>())
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
@@ -1758,7 +1784,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
if(get_warp_size() == 64)
{
if constexpr(NXdlPerWave64 > 0)
if constexpr(MXdlPerWave64 > 0)
{
typename GridwiseGemm64::Argument gemm_arg{nullptr,
nullptr,
@@ -1780,7 +1806,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
else
{
if constexpr(NXdlPerWave32 > 0)
if constexpr(MXdlPerWave32 > 0)
{
typename GridwiseGemm32::Argument gemm_arg{nullptr,
nullptr,
@@ -2064,6 +2090,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// clang-format off
str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
if(get_warp_size() != 64) {
str << "_WmmaPorted";
}
if constexpr(DirectLoad) {
str << "_DirectLoad";
}

View File

@@ -206,9 +206,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
BComputeDataType>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
GET_NXDL_PER_WAVE_IMPL
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
GET_MXDL_PER_WAVE_IMPL
// Force usage of 16x16 instruction for WMMA
static constexpr index_t Wave32MaxMNPerXDL = 16;
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
static constexpr auto MXdlPerWave32 =
GetMXdlPerWave<false,
Wave32MaxMNPerXDL,
Wave32MaxMNPerXDL,
NXdlPerWave*(NPerXDL / Wave32MaxMNPerXDL)>();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t MaxGemmsNum = 32;
@@ -409,25 +415,26 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS \
ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL_, \
NPerXDL_, MXdlPerWave_, NXdlPerWave*(NPerXDL / NPerXDL_), \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_), \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
AComputeDataType, DoElementwiseBeforeCShuffle
// Use appropriate gridwise gemm
template <index_t NXdlPerWave_>
template <index_t MXdlPerWave_, index_t MPerXDL_, index_t NPerXDL_>
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS>;
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(MXdlPerWave64, 1), MPerXDL, NPerXDL>;
using GridwiseGemm32 = GridwiseGemmBase<MXdlPerWave32, Wave32MaxMNPerXDL, Wave32MaxMNPerXDL>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 =
@@ -607,7 +614,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
if(get_warp_size() == 64)
{
if constexpr(NXdlPerWave64 > 0)
if constexpr(MXdlPerWave64 > 0)
{
init_gemm_args<GridwiseGemm64>(a_grid_ptrs[i],
static_cast<const BDataType*>(p_b),
@@ -624,7 +631,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
}
else
{
if constexpr(NXdlPerWave32 > 0)
if constexpr(MXdlPerWave32 > 0)
{
init_gemm_args<GridwiseGemm32>(a_grid_ptrs[i],
static_cast<const BDataType*>(p_b),
@@ -769,7 +776,24 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
}
}
INVOKER_RUN_IMPL
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(get_warp_size() == 64)
{
if constexpr(MXdlPerWave64 > 0)
{
return RunImp<GridwiseGemm64>(arg, stream_config);
}
}
else
{
if constexpr(MXdlPerWave32 > 0)
{
return RunImp<GridwiseGemm32>(arg, stream_config);
}
}
return 0;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
@@ -822,7 +846,10 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
return false;
}
}
if(!ck::is_xdl_wmma_supported<AComputeDataType, BComputeDataType, MPerXDL, NPerXDL>())
if(!ck::is_xdl_wmma_supported<AComputeDataType,
BComputeDataType,
Wave32MaxMNPerXDL,
Wave32MaxMNPerXDL>())
{
return false;
}
@@ -1205,8 +1232,12 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
<< "<"
str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor";
if(get_warp_size() != 64) {
str << "_WmmaPorted";
}
str << "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "

View File

@@ -206,9 +206,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
MaskingSpec>
{
static constexpr auto MXdlPerWave64 =
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
static constexpr auto MXdlPerWave32 =
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");