[ck][gfx12] support contraction on gfx12 (#3421)

* support contraction on gfx12

* increase tolerance for gfx11 in example contraction

the precsion of gfx11 wmma is less than others.
This commit is contained in:
linqunAMD
2025-12-15 23:16:01 +08:00
committed by GitHub
parent 6d7299ff78
commit 7e93eed878
5 changed files with 125 additions and 19 deletions

View File

@@ -166,11 +166,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
{
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
GET_NXDL_PER_WAVE_IMPL
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto WarpTileConfig64 = GetWarpTileConfig<BlockSize,
MPerBlock,
NPerBlock,
MPerXDL,
NPerXDL,
MXdlPerWave,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
true>();
static constexpr auto WarpTileConfig32 = GetWarpTileConfig<BlockSize,
MPerBlock,
NPerBlock,
MPerXDL,
NPerXDL,
MXdlPerWave,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
false>();
static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3);
static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3);
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
@@ -321,7 +337,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
// GridwiseGemm
template <index_t NXdlPerWave_>
template <typename WarpTileConfig>
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
BDataType,
@@ -340,10 +356,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave_,
WarpTileConfig::At(0),
WarpTileConfig::At(1),
WarpTileConfig::At(2),
WarpTileConfig::At(3),
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
@@ -360,13 +376,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
WarpTileConfig::At(4),
WarpTileConfig::At(5),
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
using GridwiseGemm64 = GridwiseGemmBase<decltype(WarpTileConfig64)>;
using GridwiseGemm32 = GridwiseGemmBase<decltype(WarpTileConfig32)>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 =
@@ -588,7 +604,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_wmma_supported<ComputeDataType, ComputeDataType, MPerXDL, NPerXDL>())
if(!ck::is_xdl_wmma_supported<ComputeDataType,
ComputeDataType,
MPerXDL,
NPerXDL,
WarpTileConfig32.At(0),
WarpTileConfig32.At(1)>())
{
return false;
}
@@ -783,6 +804,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< AK1 << ", "
<< BK1 << ", "
<< ABlockTransferSrcVectorDim << ", "