Optimized GEMMs for MX FP4/8 (#2294)

Adds V3 GEMM pipeline for MX FP4 and MX FP8 
Adds V3 GEMM pipeline for MX FP4 with preshuffling
Adds MXFP4 GEMM tests (#2275)
Adds MXFP4 GEMM examples
Adds MXFP4 GEMMs to ckProfiler




Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>
Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: OscarXu <huaiguxu@amd.com>
Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: Ding, Yi <yi.ding@amd.com>
Co-authored-by: feifei14119 <feiw@amd.com>
Co-authored-by: Lin, Qun <qlin@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
This commit is contained in:
Andriy Roshchenko
2025-06-05 13:54:15 -06:00
committed by GitHub
parent 233e274077
commit 00247e3c29
83 changed files with 8193 additions and 2165 deletions

View File

@@ -8,6 +8,35 @@
#include "ck/utility/amd_xdlops.hpp"
namespace ck {
/**
* @brief Define matrix data types that have hardware support for MX GEMMs
*/
template <typename T>
static constexpr bool is_scale_mfma_data_type()
{
using U = element_type_t<T>;
return is_same_v<U, f8_ocp_t> || is_same_v<U, bf8_ocp_t> || is_same_v<U, f6_t> ||
is_same_v<U, bf6_t> || is_same_v<U, f4_t>;
}
/**
* @brief Define scale data types that have hardware support for MX GEMMs
*/
template <typename T>
static constexpr bool is_scale_mfma_scale_type()
{
return is_same_v<T, e8m0_bexp_t>;
}
/**
* @brief Combination of data types that have hardware support for MX GEMMs
*/
template <typename ADataType, typename BDataType, typename AScaleDataType, typename BScaleDataType>
static constexpr bool scale_mfma_hw_support()
{
return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
}
enum struct MfmaInstr
{
@@ -847,6 +876,8 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t OpselA,
index_t OpselB,
class FloatA,
class ScaleA,
class FloatB,
@@ -858,11 +889,9 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
const ScaleB& scale_b,
FloatC& reg_c) const
{
static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this point.");
static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at this point.");
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(
a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
}
};
@@ -885,6 +914,8 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t OpselA,
index_t OpselB,
class FloatA,
class ScaleA,
class FloatB,
@@ -896,11 +927,9 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
const ScaleB& scale_b,
FloatC& reg_c) const
{
static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this point.");
static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at this point.");
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(
a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
}
};
@@ -1117,7 +1146,7 @@ struct MfmaSelector
#endif
}
// Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
// Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
// See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3
// TODO: explore optimization opportunity by using new mfma instructions on gfx950
template <>
@@ -1153,6 +1182,16 @@ struct MfmaSelector
{
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
}
template <>
constexpr auto GetMfma<f4_t, 32, 32, f4_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
}
template <>
constexpr auto GetMfma<f4_t, 16, 16, f4_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
}
template <>
constexpr auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
@@ -1290,10 +1329,10 @@ struct MfmaSelector
#endif
}
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type,
static constexpr auto selected_mfma = mfma_type<GetMfma<element_type_t<base_type>,
MPerXdlops,
NPerXdlops,
additional_type,
element_type_t<additional_type>,
is_single_rate_mfma,
is_scale_mfma>()>{};
@@ -1375,7 +1414,8 @@ struct XdlopsGemm
MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
static_assert(KPack * 2 % mfma_instr.k_per_blk == 0,
"KPack should be a multiple of k_per_blk");
}
// XDL output supporting C = A * B
@@ -1413,6 +1453,49 @@ struct XdlopsGemm
Sequence<7>{}));
}
// XDL output supporting C = A * B
// M3_N3 -> M3_M4_M5_N3
template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_pass_through_transform(M2),
make_pass_through_transform(N2),
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
Number<mfma_instr.num_input_blks>{},
Number<mfma_instr.group_size>{})),
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6, 7, 8>{},
Sequence<9>{}));
}
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
template <typename CDesc_M0_N0_M1_N1_M2_N2>
@@ -1518,7 +1601,13 @@ struct XdlopsGemm
});
}
template <class FloatA, class ScaleA, class FloatB, class ScaleB, class FloatC>
template <index_t OpselA,
index_t OpselB,
class FloatA,
class ScaleA,
class FloatB,
class ScaleB,
class FloatC>
__device__ void Run(const FloatA& p_a_wave,
const ScaleA& a_scale_thread,
const FloatB& p_b_wave,
@@ -1528,12 +1617,12 @@ struct XdlopsGemm
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
if constexpr(!TransposeC)
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(
mfma_instr.template run<MPerXdlops, NPerXdlops, OpselA, OpselB>(
p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
}
else
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(
mfma_instr.template run<MPerXdlops, NPerXdlops, OpselB, OpselA>(
p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
}
});