MX GEMM - New GEMM pipeline for MX data types (#2059)

* Allow selection of mfma_scale instructions

* Read B tensor from LDS to VGPR in chunks of 16 in MFMA order

* Add constexpr and synchronize return type for `get_exponent_value`

* Pass scales by reference and add comments to `mfma_scale_f32_32x32x64`

* Add support for microscaling instructions in `XdlopsGemm`

* Fix `mfma_scale_f32_16x16x128f8f6f4` wrapper

* Remove software implementation of MX GEMM

* Make interface of `intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>` consistent with the other scale instruction

* Update README

* Updated CHANGELOG

* Remove unused static methods
This commit is contained in:
Andriy Roshchenko
2025-04-15 17:17:07 -06:00
committed by GitHub
parent d55c9cb313
commit 7106976a72
19 changed files with 1007 additions and 608 deletions

View File

@@ -159,16 +159,22 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma = false;
static constexpr auto is_scale_mfma = true;
//> KPack is at least the k_per_blk of selected mfma
//
// Should be a multiple of k_per_blk.
// TODO: Move this to blockwise pipeline base
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
@@ -1088,10 +1094,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
static_assert(KPerBlock % ScaleBlockSize == 0,
"KPerBlock should be multiple of ScaleBlockSize");
static_assert(KPerBlock / ScaleBlockSize == BlockwiseGemmPipe::KRepeat,
"Single call to xdlops_gemm::Run should process exactly ScaleBlockSize "
"elements in k dimension");
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
@@ -1476,61 +1478,63 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
static constexpr auto KPerThread = KPerBlock / K0PerXdlops;
// NXdlPerWave == NRepeat
// MXdlPerWave == MRepeat
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
// Initial thread mapping for MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MWaves=NWaves=2
// Initial thread mapping for:
// BlockSize = 256
// MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2
// For each [m0, n0] tile, there are 4 waves:
// tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0]
// tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1]
// tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0]
// tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1]
auto a_thread_offset_m =
MPerXdl * ((get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) / MWaves) +
mfma.selected_mfma.group_size *
((get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / MPerXdl);
auto a_thread_offset_k = KPerThread * (get_thread_local_1d_id() % MPerXdl) / MPerXdl;
// BlockSize = 128
// MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1
// For each [m0, n0] tile, there are 2 waves:
// tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0]
// tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0]
auto b_thread_offset_n =
get_thread_local_1d_id() % NPerXdl +
(get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) % NWaves * NPerXdl;
auto b_thread_offset_k = KPerThread * (get_thread_local_1d_id() % NPerXdl) / NPerXdl;
// TODO: Document initial thread mapping for more combinations of parameters
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
AScaleDataType,
AScaleDataType,
decltype(a_scale_grid_desc_am_ak), // SrcDesc
decltype(BlockwiseGemmPipe::a_scale_thread_desc_group), // DstDesc
Sequence<mfma.selected_mfma.group_size, 1>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
0, // SrcVectorDim
1, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true>(a_scale_grid_desc_am_ak,
make_multi_index(block_m_id * MPerBlock + a_thread_offset_m,
a_thread_offset_k / ScaleBlockSize));
const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
BScaleDataType,
BScaleDataType,
decltype(b_scale_grid_desc_bn_ak),
decltype(BlockwiseGemmPipe::b_scale_thread_desc),
Sequence<1, BlockwiseGemmPipe::KRepeat>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // SrcVectorDim
BlockwiseGemmPipe::KRepeat, // SrcScalarPerVector
1,
false>(b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n,
b_thread_offset_k / ScaleBlockSize));
static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) /
mfma.selected_mfma.num_threads_per_blk;
auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl;
auto a_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<AScaleDataType,
AScaleDataType,
decltype(a_scale_grid_desc_am_ak),
decltype(BlockwiseGemmPipe::a_scale_thread_desc_copy),
Sequence<1, 1>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // SrcVectorDim
1, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true>(
a_scale_grid_desc_am_ak,
make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, thread_offset_k));
auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl;
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleDataType,
BScaleDataType,
decltype(b_scale_grid_desc_bn_ak),
decltype(BlockwiseGemmPipe::b_scale_thread_desc_copy),
Sequence<1, 1>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // SrcVectorDim
1, // SrcScalarPerVector
1,
true>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,