mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Improve XDL to WMMA porting for grouped conv fwd (#3456)
Refactors the way the number of XDL (matrix multiply-accumulate) instructions per wave is calculated and used in the grouped convolution forward implementations, especially to better support WMMA (Wave Matrix Multiply-Accumulate) instructions and 16x16 tiles. The changes use MXdlPerWave instead of NXdlPerWave to increase number of waves per M dim.
This commit is contained in:
@@ -33,7 +33,7 @@ TEST(FwdConvInstances,
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_64_64x32x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x2_per_wave)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
|
||||
|
||||
@@ -28,7 +28,7 @@ TEST(FwdConvInstances,
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
|
||||
@@ -111,8 +111,8 @@ struct DefaultAlgorithm
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 16,
|
||||
.n_per_xdl = 16,
|
||||
.m_xdl_per_wave = 4,
|
||||
.n_xdl_per_wave = 4};
|
||||
.m_xdl_per_wave = 8,
|
||||
.n_xdl_per_wave = 8};
|
||||
|
||||
ckb::test::TransferABC transfer{
|
||||
.a =
|
||||
@@ -188,7 +188,7 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
|
||||
" ├─ Pipeline scheduler: INTRAWAVE\n"
|
||||
" ├─ Warp Gemm parameters: \n"
|
||||
" │ ├─ subtile size: 16×16\n"
|
||||
" │ └─ Number of warp gemm iterations: 4×4\n"
|
||||
" │ └─ Number of warp gemm iterations: 8×8\n"
|
||||
" └─ Memory access:\n"
|
||||
" ├─ A Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
|
||||
@@ -68,7 +68,7 @@ constexpr TransferABC FwdTransfer_4x64x1{
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.scalar_per_vector = 4},
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ template <index_t BlockSize_,
|
||||
index_t NPerXDL_,
|
||||
index_t MXdlPerWave_,
|
||||
bool IsWave64>
|
||||
static constexpr auto GetNXdlPerWave2()
|
||||
static constexpr auto GetXdlPerWave2()
|
||||
{
|
||||
constexpr index_t Waves = IsWave64 ? BlockSize_ / 64 : BlockSize_ / 32;
|
||||
constexpr index_t MWaves = MPerBlock_ / (MXdlPerWave_ * MPerXDL_);
|
||||
@@ -84,17 +84,33 @@ static constexpr auto GetNXdlPerWave2()
|
||||
}
|
||||
}
|
||||
|
||||
#define GET_NXDL_PER_WAVE_IMPL \
|
||||
template <bool IsWave64> \
|
||||
static constexpr auto GetNXdlPerWave() \
|
||||
{ \
|
||||
return GetNXdlPerWave2<BlockSize, \
|
||||
MPerBlock, \
|
||||
NPerBlock, \
|
||||
MPerXDL, \
|
||||
NPerXDL, \
|
||||
MXdlPerWave, \
|
||||
IsWave64>(); \
|
||||
#define GET_NXDL_PER_WAVE_IMPL \
|
||||
template <bool IsWave64> \
|
||||
static constexpr auto GetNXdlPerWave() \
|
||||
{ \
|
||||
return GetXdlPerWave2<BlockSize, \
|
||||
MPerBlock, \
|
||||
NPerBlock, \
|
||||
MPerXDL, \
|
||||
NPerXDL, \
|
||||
MXdlPerWave, \
|
||||
IsWave64>(); \
|
||||
}
|
||||
|
||||
#define GET_MXDL_PER_WAVE_IMPL \
|
||||
template <bool IsWave64, \
|
||||
index_t MPerXDLAligned = MPerXDL, \
|
||||
index_t NPerXDLAligned = NPerXDL, \
|
||||
index_t NXdlPerWaveAligned = NXdlPerWave> \
|
||||
static constexpr auto GetMXdlPerWave() \
|
||||
{ \
|
||||
return GetXdlPerWave2<BlockSize, \
|
||||
NPerBlock, \
|
||||
MPerBlock, \
|
||||
NPerXDLAligned, \
|
||||
MPerXDLAligned, \
|
||||
NXdlPerWaveAligned, \
|
||||
IsWave64>(); \
|
||||
}
|
||||
|
||||
template <index_t BlockSize_,
|
||||
@@ -114,14 +130,14 @@ static constexpr auto GetWarpTileConfig()
|
||||
|
||||
constexpr auto NXdlPerWave =
|
||||
IsWave64
|
||||
? GetNXdlPerWave2<BlockSize_,
|
||||
MPerBlock_,
|
||||
NPerBlock_,
|
||||
MPerXDL_,
|
||||
NPerXDL_,
|
||||
MXdlPerWave_,
|
||||
true>()
|
||||
: GetNXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
|
||||
? GetXdlPerWave2<BlockSize_,
|
||||
MPerBlock_,
|
||||
NPerBlock_,
|
||||
MPerXDL_,
|
||||
NPerXDL_,
|
||||
MXdlPerWave_,
|
||||
true>()
|
||||
: GetXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
|
||||
|
||||
if constexpr(IsWave64 == false && NXdlPerWave != 0)
|
||||
{
|
||||
|
||||
@@ -190,9 +190,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
using DeviceOp = DeviceBatchedGemmGemm_Xdl_CShuffle;
|
||||
|
||||
static constexpr auto MXdlPerWave64 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
@@ -235,20 +235,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle;
|
||||
|
||||
static constexpr auto Gemm0MXdlPerWave64 = GetNXdlPerWave2<BlockSize,
|
||||
Gemm0NPerBlock,
|
||||
Gemm0MPerBlock,
|
||||
Gemm0NPerXdl,
|
||||
Gemm0MPerXdl,
|
||||
Gemm0NXdlPerWave,
|
||||
true>();
|
||||
static constexpr auto Gemm0MXdlPerWave32 = GetNXdlPerWave2<BlockSize,
|
||||
Gemm0NPerBlock,
|
||||
Gemm0MPerBlock,
|
||||
Gemm0NPerXdl,
|
||||
Gemm0MPerXdl,
|
||||
Gemm0NXdlPerWave,
|
||||
false>();
|
||||
static constexpr auto Gemm0MXdlPerWave64 = GetXdlPerWave2<BlockSize,
|
||||
Gemm0NPerBlock,
|
||||
Gemm0MPerBlock,
|
||||
Gemm0NPerXdl,
|
||||
Gemm0MPerXdl,
|
||||
Gemm0NXdlPerWave,
|
||||
true>();
|
||||
static constexpr auto Gemm0MXdlPerWave32 = GetXdlPerWave2<BlockSize,
|
||||
Gemm0NPerBlock,
|
||||
Gemm0MPerBlock,
|
||||
Gemm0NPerXdl,
|
||||
Gemm0MPerXdl,
|
||||
Gemm0NXdlPerWave,
|
||||
false>();
|
||||
|
||||
static constexpr index_t NumD0Tensor = D0sDataType::Size();
|
||||
static constexpr index_t NumD1Tensor = D1sDataType::Size();
|
||||
|
||||
@@ -223,9 +223,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
MaskingSpec>
|
||||
{
|
||||
static constexpr auto MXdlPerWave64 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
|
||||
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
|
||||
"Number of dimension must be greater than 0");
|
||||
|
||||
@@ -211,9 +211,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
|
||||
using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle;
|
||||
static constexpr auto MXdlPerWave64 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
@@ -325,9 +325,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
BComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
|
||||
GET_NXDL_PER_WAVE_IMPL
|
||||
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
|
||||
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
|
||||
GET_MXDL_PER_WAVE_IMPL
|
||||
// Force usage of 16x16 instruction for WMMA
|
||||
static constexpr index_t Wave32MaxMNPerXDL = 16;
|
||||
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetMXdlPerWave<false,
|
||||
Wave32MaxMNPerXDL,
|
||||
Wave32MaxMNPerXDL,
|
||||
NXdlPerWave*(NPerXDL / Wave32MaxMNPerXDL)>();
|
||||
|
||||
static_assert(NumGroupsToMerge >= 1);
|
||||
|
||||
@@ -486,35 +492,36 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
using GemmADataType = std::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
|
||||
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
|
||||
|
||||
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
|
||||
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
|
||||
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
|
||||
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
|
||||
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
|
||||
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
|
||||
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
|
||||
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
|
||||
KPerBlock, AK1, BK1, MPerXDL_, NPerXDL_, MXdlPerWave_, NXdlPerWave*(NPerXDL / NPerXDL_), \
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
|
||||
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
|
||||
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_), \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType
|
||||
|
||||
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
|
||||
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
|
||||
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
|
||||
NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
|
||||
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
|
||||
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
|
||||
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
|
||||
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
|
||||
BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
|
||||
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL_, \
|
||||
NPerXDL_, MXdlPerWave_, NXdlPerWave*(NPerXDL / NPerXDL_), \
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
|
||||
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
|
||||
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_), \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType, DoElementwiseBeforeCShuffle
|
||||
@@ -523,7 +530,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \
|
||||
NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \
|
||||
MPerXDL, NXdlPerWave_, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
MPerXDL, NXdlPerWave, MXdlPerWave_, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
@@ -536,34 +543,35 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
BComputeDataType, DoElementwiseBeforeCShuffle
|
||||
|
||||
// Use appropriate gridwise gemm
|
||||
template <index_t NXdlPerWave_>
|
||||
template <index_t MXdlPerWave_, index_t MPerXDL_, index_t NPerXDL_>
|
||||
using GridwiseGemmMultipleABDBase = GridwiseGemmMultipleABD_xdl_cshuffle<
|
||||
CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
|
||||
template <index_t NXdlPerWave_>
|
||||
template <index_t MXdlPerWave_, index_t MPerXDL_, index_t NPerXDL_>
|
||||
using GridwiseGemmMultipleDBase = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
|
||||
template <index_t NXdlPerWave_>
|
||||
template <index_t MXdlPerWave_>
|
||||
using GridwiseGemmMultipleDCTransposeBase = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
|
||||
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
|
||||
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
|
||||
#undef CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
|
||||
|
||||
using GridwiseGemm64 =
|
||||
std::conditional_t<isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABDBase<math::max(NXdlPerWave64, 1)>,
|
||||
GridwiseGemmMultipleDBase<math::max(NXdlPerWave64, 1)>>;
|
||||
using GridwiseGemm32 = std::conditional_t<isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABDBase<NXdlPerWave32>,
|
||||
GridwiseGemmMultipleDBase<NXdlPerWave32>>;
|
||||
using GridwiseGemm64 = std::conditional_t<
|
||||
isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABDBase<math::max(MXdlPerWave64, 1), MPerXDL, NPerXDL>,
|
||||
GridwiseGemmMultipleDBase<math::max(MXdlPerWave64, 1), MPerXDL, NPerXDL>>;
|
||||
using GridwiseGemm32 = std::conditional_t<
|
||||
isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABDBase<MXdlPerWave32, Wave32MaxMNPerXDL, Wave32MaxMNPerXDL>,
|
||||
GridwiseGemmMultipleDBase<MXdlPerWave32, Wave32MaxMNPerXDL, Wave32MaxMNPerXDL>>;
|
||||
|
||||
using GridwiseGemmCTranspose64 =
|
||||
std::conditional_t<CTranspose,
|
||||
GridwiseGemmMultipleDCTransposeBase<math::max(NXdlPerWave64, 1)>,
|
||||
GridwiseGemmMultipleDCTransposeBase<math::max(MXdlPerWave64, 1)>,
|
||||
GridwiseGemm64>;
|
||||
using GridwiseGemmCTranspose32 =
|
||||
std::conditional_t<CTranspose,
|
||||
GridwiseGemmMultipleDCTransposeBase<NXdlPerWave32>,
|
||||
GridwiseGemmMultipleDCTransposeBase<MXdlPerWave32>,
|
||||
GridwiseGemm32>;
|
||||
|
||||
// If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers.
|
||||
@@ -913,14 +921,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
if constexpr(MXdlPerWave64 > 0)
|
||||
{
|
||||
InitGridDesc<GridwiseGemm64, GridwiseGemmCTranspose64>();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
if constexpr(MXdlPerWave32 > 0)
|
||||
{
|
||||
InitGridDesc<GridwiseGemm32, GridwiseGemmCTranspose32>();
|
||||
}
|
||||
@@ -1388,7 +1396,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
if constexpr(MXdlPerWave64 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm64, GridwiseGemmCTranspose64>(arg, stream_config);
|
||||
}
|
||||
@@ -1399,7 +1407,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
if constexpr(MXdlPerWave32 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm32, GridwiseGemmCTranspose32>(arg, stream_config);
|
||||
}
|
||||
@@ -1436,7 +1444,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
if(!ck::is_xdl_wmma_supported<AComputeDataType, BComputeDataType, MPerXDL, NPerXDL>())
|
||||
if(!ck::is_xdl_wmma_supported<AComputeDataType,
|
||||
BComputeDataType,
|
||||
Wave32MaxMNPerXDL,
|
||||
Wave32MaxMNPerXDL>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1720,7 +1731,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// check Gridwise GEMM
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
if constexpr(MXdlPerWave64 > 0)
|
||||
{
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
@@ -1759,7 +1770,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
else
|
||||
{
|
||||
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
if constexpr(MXdlPerWave32 > 0)
|
||||
{
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
@@ -2047,8 +2058,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
|
||||
<< "<"
|
||||
str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
|
||||
|
||||
if(get_warp_size() != 64) {
|
||||
str << "_WmmaPorted";
|
||||
}
|
||||
|
||||
str << "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
|
||||
@@ -400,9 +400,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
BComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
GET_NXDL_PER_WAVE_IMPL
|
||||
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
|
||||
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
|
||||
GET_MXDL_PER_WAVE_IMPL
|
||||
// Force usage of 16x16 instruction for WMMA
|
||||
static constexpr index_t Wave32MaxMNPerXDL = 16;
|
||||
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetMXdlPerWave<false,
|
||||
Wave32MaxMNPerXDL,
|
||||
Wave32MaxMNPerXDL,
|
||||
NXdlPerWave*(NPerXDL / Wave32MaxMNPerXDL)>();
|
||||
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
@@ -563,7 +569,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
: BBlockTransferSrcScalarPerVector;
|
||||
|
||||
// Use appropriate gridwise gemm
|
||||
template <index_t NXdlPerWave_>
|
||||
template <index_t MXdlPerWave_, index_t MPerXDL_, index_t NPerXDL_>
|
||||
using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3<
|
||||
tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
@@ -585,10 +591,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave_,
|
||||
MPerXDL_,
|
||||
NPerXDL_,
|
||||
MXdlPerWave_,
|
||||
NXdlPerWave*(NPerXDL / NPerXDL_),
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -606,7 +612,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_),
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
@@ -617,8 +623,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
BDataType,
|
||||
DoElementwiseBeforeCShuffle,
|
||||
DirectLoad>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(MXdlPerWave64, 1), MPerXDL, NPerXDL>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<MXdlPerWave32, Wave32MaxMNPerXDL, Wave32MaxMNPerXDL>;
|
||||
|
||||
// #undef GridwiseGemmV3TemplateParams
|
||||
|
||||
@@ -1430,7 +1436,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
return avg_time;
|
||||
}
|
||||
|
||||
INVOKER_RUN_IMPL
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(MXdlPerWave64 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm64>(arg, stream_config);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(MXdlPerWave32 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm32>(arg, stream_config);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
@@ -1483,7 +1506,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
}
|
||||
|
||||
if(!ck::is_xdl_wmma_supported<AComputeDataType, BComputeDataType, MPerXDL, NPerXDL>())
|
||||
if(!ck::is_xdl_wmma_supported<AComputeDataType,
|
||||
BComputeDataType,
|
||||
Wave32MaxMNPerXDL,
|
||||
Wave32MaxMNPerXDL>())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
@@ -1758,7 +1784,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
if constexpr(MXdlPerWave64 > 0)
|
||||
{
|
||||
typename GridwiseGemm64::Argument gemm_arg{nullptr,
|
||||
nullptr,
|
||||
@@ -1780,7 +1806,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
if constexpr(MXdlPerWave32 > 0)
|
||||
{
|
||||
typename GridwiseGemm32::Argument gemm_arg{nullptr,
|
||||
nullptr,
|
||||
@@ -2064,6 +2090,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
|
||||
|
||||
if(get_warp_size() != 64) {
|
||||
str << "_WmmaPorted";
|
||||
}
|
||||
|
||||
if constexpr(DirectLoad) {
|
||||
str << "_DirectLoad";
|
||||
}
|
||||
|
||||
@@ -206,9 +206,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
BComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
|
||||
GET_NXDL_PER_WAVE_IMPL
|
||||
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
|
||||
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
|
||||
GET_MXDL_PER_WAVE_IMPL
|
||||
// Force usage of 16x16 instruction for WMMA
|
||||
static constexpr index_t Wave32MaxMNPerXDL = 16;
|
||||
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetMXdlPerWave<false,
|
||||
Wave32MaxMNPerXDL,
|
||||
Wave32MaxMNPerXDL,
|
||||
NXdlPerWave*(NPerXDL / Wave32MaxMNPerXDL)>();
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr index_t MaxGemmsNum = 32;
|
||||
@@ -409,25 +415,26 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS \
|
||||
ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
|
||||
AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
|
||||
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
|
||||
NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
|
||||
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
|
||||
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
|
||||
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
|
||||
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
|
||||
BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
|
||||
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL_, \
|
||||
NPerXDL_, MXdlPerWave_, NXdlPerWave*(NPerXDL / NPerXDL_), \
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
|
||||
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
|
||||
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_), \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
AComputeDataType, DoElementwiseBeforeCShuffle
|
||||
// Use appropriate gridwise gemm
|
||||
template <index_t NXdlPerWave_>
|
||||
template <index_t MXdlPerWave_, index_t MPerXDL_, index_t NPerXDL_>
|
||||
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS>;
|
||||
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(MXdlPerWave64, 1), MPerXDL, NPerXDL>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<MXdlPerWave32, Wave32MaxMNPerXDL, Wave32MaxMNPerXDL>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
@@ -607,7 +614,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
if constexpr(MXdlPerWave64 > 0)
|
||||
{
|
||||
init_gemm_args<GridwiseGemm64>(a_grid_ptrs[i],
|
||||
static_cast<const BDataType*>(p_b),
|
||||
@@ -624,7 +631,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
if constexpr(MXdlPerWave32 > 0)
|
||||
{
|
||||
init_gemm_args<GridwiseGemm32>(a_grid_ptrs[i],
|
||||
static_cast<const BDataType*>(p_b),
|
||||
@@ -769,7 +776,24 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
}
|
||||
}
|
||||
|
||||
INVOKER_RUN_IMPL
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(MXdlPerWave64 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm64>(arg, stream_config);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(MXdlPerWave32 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm32>(arg, stream_config);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
@@ -822,7 +846,10 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if(!ck::is_xdl_wmma_supported<AComputeDataType, BComputeDataType, MPerXDL, NPerXDL>())
|
||||
if(!ck::is_xdl_wmma_supported<AComputeDataType,
|
||||
BComputeDataType,
|
||||
Wave32MaxMNPerXDL,
|
||||
Wave32MaxMNPerXDL>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1205,8 +1232,12 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
|
||||
<< "<"
|
||||
str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor";
|
||||
if(get_warp_size() != 64) {
|
||||
str << "_WmmaPorted";
|
||||
}
|
||||
|
||||
str << "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
|
||||
@@ -206,9 +206,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
MaskingSpec>
|
||||
{
|
||||
static constexpr auto MXdlPerWave64 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
|
||||
|
||||
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
|
||||
"Number of dimension must be greater than 0");
|
||||
|
||||
Reference in New Issue
Block a user