MX GEMM - New GEMM pipeline for MX data types (#2059)

* Allow selection of mfma_scale instructions

* Read B tensor from LDS to VGPR in chunks of 16 in MFMA order

* Add constexpr and synchronize return type for `get_exponent_value`

* Pass scales by reference and add comments to `mfma_scale_f32_32x32x64`

* Add support for microscaling instructions in `XdlopsGemm`

* Fix `mfma_scale_f32_16x16x128f8f6f4` wrapper

* Remove software implementation of MX GEMM

* Make interface of `intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>` consistent with the other scale instruction

* Update README

* Updated CHANGELOG

* Remove unused static methods
This commit is contained in:
Andriy Roshchenko
2025-04-15 17:17:07 -06:00
committed by GitHub
parent d55c9cb313
commit 7106976a72
19 changed files with 1007 additions and 608 deletions

View File

@@ -30,48 +30,69 @@ enum class MFMA_F8F6F4
};
template <typename AFragT, typename BFragT, typename AccumFragT, int32_t BLOCK_M, int32_t BLOCK_N>
template <int32_t BLOCK_M, int32_t BLOCK_N>
struct mfma_type_selector;
template <typename AFragT, typename BFragT, typename AccumFragT>
struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
template <>
struct mfma_type_selector<16, 16>
{
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
template <typename AFragT, typename BFragT, typename AccumFragT>
__device__ static void run(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>{};
op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
}
__device__ void operator()(AFragT const& fragA,
const int32_t scale_a,
BFragT const& fragB,
const int32_t scale_b,
AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>{};
op.template run<16, 16, AFragT, BFragT, AccumFragT>(
fragA, scale_a, fragB, scale_b, fragAcc);
op.template run<16, 16>(fragA, fragB, fragAcc);
}
};
template <typename AFragT, typename BFragT, typename AccumFragT>
struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
template <>
struct mfma_type_selector<32, 32>
{
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
template <typename AFragT, typename BFragT, typename AccumFragT>
__device__ static void run(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>{};
op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
op.template run<32, 32>(fragA, fragB, fragAcc);
}
};
__device__ void operator()(AFragT const& fragA,
const int32_t scale_a,
template <int32_t BLOCK_M, int32_t BLOCK_N>
struct mfma_scale_type_selector;
template <>
struct mfma_scale_type_selector<16, 16>
{
template <typename AFragT,
typename AScaleFragT,
typename BFragT,
typename BScaleFragT,
typename AccumFragT>
__device__ static void run(AFragT const& fragA,
AScaleFragT const& scale_a,
BFragT const& fragB,
const int32_t scale_b,
BScaleFragT const& scale_b,
AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>{};
op.template run<16, 16>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc);
}
};
template <>
struct mfma_scale_type_selector<32, 32>
{
template <typename AFragT,
typename AScaleFragT,
typename BFragT,
typename BScaleFragT,
typename AccumFragT>
__device__ static void run(AFragT const& fragA,
AScaleFragT const& scale_a,
BFragT const& fragB,
BScaleFragT const& scale_b,
AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>{};
op.template run<32, 32, AFragT, BFragT, AccumFragT>(
fragA, scale_a, fragB, scale_b, fragAcc);
op.template run<32, 32>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc);
}
};
@@ -334,8 +355,7 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr,
// BLOCK_K / BLOCK_X is a stride in xA matrix
auto startOffset = row_major(startCoord2D, BLOCK_K / BLOCK_X);
// obtain 8-bit exponent
fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF;
fragX = scale_ptr[startOffset];
return load_A_row_major<AType, AFragT, BLOCK_M, BLOCK_K>(input_ptr);
}
@@ -502,7 +522,7 @@ __device__ BFragT load_mx_B_col_major(BType const* input_ptr,
auto startOffset = col_major(startCoord2D, BLOCK_K / BLOCK_X);
// obtain 8-bit exponent
fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF;
fragX = scale_ptr[startOffset];
return load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(input_ptr);
}
@@ -773,7 +793,8 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
// Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(fragA, fragB, fragAcc);
using mfma = mfma_type_selector<BLOCK_M, BLOCK_N>;
mfma::template run<>(fragA, fragB, fragAcc);
for(int i = 0; i < vectorSize(fragC); ++i)
{
@@ -805,29 +826,34 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using ScaleFragT = int32_t;
using AScaleFragT = vector_type<ScaleType, 1>::type;
using BScaleFragT = vector_type<ScaleType, 1>::type;
// Create frags
auto fragA = AFragT{};
auto fragB = BFragT{};
auto fragC = CFragT{};
auto fragAcc = AccumFragT{0};
auto fragXa = ScaleFragT{0};
auto fragXb = ScaleFragT{0};
auto fragXa = AScaleFragT{};
auto fragXb = BScaleFragT{};
// Load the inputs.
// A = col major, BLOCK_M x BLOCK_K
fragA = load_mx_A_row_major<AType, AFragT, ScaleType, ScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
fragA = load_mx_A_row_major<AType, AFragT, ScaleType, AScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
a, xa, fragXa);
// B = col major, BLOCK_K x BLOCK_N
fragB = load_mx_B_col_major<BType, BFragT, ScaleType, ScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
fragB = load_mx_B_col_major<BType, BFragT, ScaleType, BScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
b, xb, fragXb);
// Scaled Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(
fragA, fragXa, fragB, fragXb, fragAcc);
using mfma = mfma_scale_type_selector<BLOCK_M, BLOCK_N>;
mfma::template run<>(fragA,
fragXa.template AsType<ScaleType>(),
fragB,
fragXb.template AsType<ScaleType>(),
fragAcc);
for(int i = 0; i < vectorSize(fragC); ++i)
{