MX GEMM - New GEMM pipeline for MX data types (#2059)

* Allow selection of mfma_scale instructions

* Read B tensor from LDS to VGPR in chunks of 16 in MFMA order

* Add constexpr and synchronize return type for `get_exponent_value`

* Pass scales by reference and add comments to `mfma_scale_f32_32x32x64`

* Add support for microscaling instructions in `XdlopsGemm`

* Fix `mfma_scale_f32_16x16x128f8f6f4` wrapper

* Remove software implementation of MX GEMM

* Make interface of `intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>` consistent with the other scale instruction

* Update README

* Updated CHANGELOG

* Remove unused static methods
This commit is contained in:
Andriy Roshchenko
2025-04-15 17:17:07 -06:00
committed by GitHub
parent d55c9cb313
commit 7106976a72
19 changed files with 1007 additions and 608 deletions

View File

@@ -845,15 +845,24 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
template <index_t MPerXdlops,
index_t NPerXdlops,
class FloatA,
class ScaleA,
class FloatB,
class ScaleB,
class FloatC>
__device__ void run(const FloatA& a,
const int32_t scale_a,
const ScaleA& scale_a,
const FloatB& b,
const int32_t scale_b,
const ScaleB& scale_b,
FloatC& reg_c) const
{
static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this point.");
static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at this point.");
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(
a, scale_a, b, scale_b, reg_c);
a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
}
};
@@ -874,15 +883,24 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
template <index_t MPerXdlops,
index_t NPerXdlops,
class FloatA,
class ScaleA,
class FloatB,
class ScaleB,
class FloatC>
__device__ void run(const FloatA& a,
const int32_t scale_a,
const ScaleA& scale_a,
const FloatB& b,
const int32_t scale_b,
const ScaleB& scale_b,
FloatC& reg_c) const
{
static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this point.");
static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at this point.");
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(
a, scale_a, b, scale_b, reg_c);
a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
}
};
@@ -890,14 +908,16 @@ template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
typename additional_type = base_type,
bool is_single_rate_mfma = false>
bool is_single_rate_mfma = false,
bool is_scale_mfma = false>
struct MfmaSelector
{
template <typename base_type_,
index_t MPerXdlops_,
index_t NPerXdlops_,
typename additional_type_ = base_type_,
bool is_single_rate_mfma_ = false>
bool is_single_rate_mfma_ = false,
bool is_scale_mfma_ = false>
static constexpr auto GetMfma();
template <>
@@ -1103,12 +1123,24 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_32x32x16f8f8;
}
template <>
constexpr auto GetMfma<f8_t, 32, 32, f8_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
}
template <>
constexpr auto GetMfma<f8_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x32f8f8;
}
template <>
constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
}
template <>
constexpr auto GetMfma<bf8_t, 32, 32>()
{
@@ -1145,8 +1177,12 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_16x16x32bf8f8;
}
static constexpr auto selected_mfma = mfma_type<
GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type, is_single_rate_mfma>()>{};
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type,
MPerXdlops,
NPerXdlops,
additional_type,
is_single_rate_mfma,
is_scale_mfma>()>{};
__host__ __device__ constexpr MfmaSelector()
{
@@ -1194,7 +1230,8 @@ template <typename base_type,
index_t NPerXdlops,
index_t KPack,
typename additional_type = base_type,
bool TransposeC = false>
bool TransposeC = false,
bool is_scale_mfma = false>
struct XdlopsGemm
{
static constexpr auto I0 = Number<0>{};
@@ -1225,7 +1262,7 @@ struct XdlopsGemm
MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
}
// XDL output supporting C = A * B
@@ -1368,6 +1405,27 @@ struct XdlopsGemm
});
}
template <class FloatA, class ScaleA, class FloatB, class ScaleB, class FloatC>
__device__ void Run(const FloatA& p_a_wave,
const ScaleA& a_scale_thread,
const FloatB& p_b_wave,
const ScaleB& b_scale_thread,
FloatC& p_c_thread) const
{
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
if constexpr(!TransposeC)
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(
p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
}
else
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(
p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
}
});
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; }
__device__ static auto GetBlkIdx()
@@ -1455,7 +1513,8 @@ struct XdlopsGemm
KPack <= 4) ||
(is_same<base_type, int8_t>::value && KPack <= 8))
? true
: false > {};
: false,
is_scale_mfma > {};
static constexpr auto mfma_instr = mfma.selected_mfma;