mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Optimized GEMMs for MX FP4/8 (#2294)
Adds V3 GEMM pipeline for MX FP4 and MX FP8 Adds V3 GEMM pipeline for MX FP4 with preshuffling Adds MXFP4 GEMM tests (#2275) Adds MXFP4 GEMM examples Adds MXFP4 GEMMs to ckProfiler Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com> Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com> Co-authored-by: OscarXu <huaiguxu@amd.com> Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: Ding, Yi <yi.ding@amd.com> Co-authored-by: feifei14119 <feiw@amd.com> Co-authored-by: Lin, Qun <qlin@amd.com> Co-authored-by: joye <joye@amd.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
233e274077
commit
00247e3c29
@@ -8,6 +8,35 @@
|
||||
#include "ck/utility/amd_xdlops.hpp"
|
||||
|
||||
namespace ck {
|
||||
/**
|
||||
* @brief Define matrix data types that have hardware support for MX GEMMs
|
||||
*/
|
||||
template <typename T>
|
||||
static constexpr bool is_scale_mfma_data_type()
|
||||
{
|
||||
using U = element_type_t<T>;
|
||||
return is_same_v<U, f8_ocp_t> || is_same_v<U, bf8_ocp_t> || is_same_v<U, f6_t> ||
|
||||
is_same_v<U, bf6_t> || is_same_v<U, f4_t>;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Define scale data types that have hardware support for MX GEMMs
|
||||
*/
|
||||
template <typename T>
|
||||
static constexpr bool is_scale_mfma_scale_type()
|
||||
{
|
||||
return is_same_v<T, e8m0_bexp_t>;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Combination of data types that have hardware support for MX GEMMs
|
||||
*/
|
||||
template <typename ADataType, typename BDataType, typename AScaleDataType, typename BScaleDataType>
|
||||
static constexpr bool scale_mfma_hw_support()
|
||||
{
|
||||
return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
|
||||
is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
|
||||
}
|
||||
|
||||
enum struct MfmaInstr
|
||||
{
|
||||
@@ -847,6 +876,8 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t OpselA,
|
||||
index_t OpselB,
|
||||
class FloatA,
|
||||
class ScaleA,
|
||||
class FloatB,
|
||||
@@ -858,11 +889,9 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
|
||||
const ScaleB& scale_b,
|
||||
FloatC& reg_c) const
|
||||
{
|
||||
static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this point.");
|
||||
static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at this point.");
|
||||
|
||||
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(
|
||||
a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
|
||||
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
|
||||
a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -885,6 +914,8 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t OpselA,
|
||||
index_t OpselB,
|
||||
class FloatA,
|
||||
class ScaleA,
|
||||
class FloatB,
|
||||
@@ -896,11 +927,9 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
|
||||
const ScaleB& scale_b,
|
||||
FloatC& reg_c) const
|
||||
{
|
||||
static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this point.");
|
||||
static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at this point.");
|
||||
|
||||
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(
|
||||
a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
|
||||
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
|
||||
a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1117,7 +1146,7 @@ struct MfmaSelector
|
||||
#endif
|
||||
}
|
||||
|
||||
// Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
|
||||
// Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
|
||||
// See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3
|
||||
// TODO: explore optimization opportunity by using new mfma instructions on gfx950
|
||||
template <>
|
||||
@@ -1153,6 +1182,16 @@ struct MfmaSelector
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<f4_t, 32, 32, f4_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<f4_t, 16, 16, f4_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
|
||||
@@ -1290,10 +1329,10 @@ struct MfmaSelector
|
||||
#endif
|
||||
}
|
||||
|
||||
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type,
|
||||
static constexpr auto selected_mfma = mfma_type<GetMfma<element_type_t<base_type>,
|
||||
MPerXdlops,
|
||||
NPerXdlops,
|
||||
additional_type,
|
||||
element_type_t<additional_type>,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>()>{};
|
||||
|
||||
@@ -1375,7 +1414,8 @@ struct XdlopsGemm
|
||||
MPerXdlops == 64,
|
||||
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
|
||||
|
||||
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
|
||||
static_assert(KPack * 2 % mfma_instr.k_per_blk == 0,
|
||||
"KPack should be a multiple of k_per_blk");
|
||||
}
|
||||
|
||||
// XDL output supporting C = A * B
|
||||
@@ -1413,6 +1453,49 @@ struct XdlopsGemm
|
||||
Sequence<7>{}));
|
||||
}
|
||||
|
||||
// XDL output supporting C = A * B
|
||||
// M3_N3 -> M3_M4_M5_N3
|
||||
template <typename CDesc_M0_N0_M1_N1_M2_N2>
|
||||
__host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
|
||||
const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
|
||||
{
|
||||
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
|
||||
const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_m0_n0_m1_n1_m2_n2,
|
||||
make_tuple(make_pass_through_transform(M0),
|
||||
make_pass_through_transform(N0),
|
||||
make_pass_through_transform(M1),
|
||||
make_pass_through_transform(N1),
|
||||
make_pass_through_transform(M2),
|
||||
make_pass_through_transform(N2),
|
||||
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
|
||||
Number<mfma_instr.num_input_blks>{},
|
||||
Number<mfma_instr.group_size>{})),
|
||||
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6>{},
|
||||
Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6, 7, 8>{},
|
||||
Sequence<9>{}));
|
||||
}
|
||||
|
||||
// transposed XDL output supporting C' = B' * A'
|
||||
// M2_N2 -> M2_N2_N3_N4
|
||||
template <typename CDesc_M0_N0_M1_N1_M2_N2>
|
||||
@@ -1518,7 +1601,13 @@ struct XdlopsGemm
|
||||
});
|
||||
}
|
||||
|
||||
template <class FloatA, class ScaleA, class FloatB, class ScaleB, class FloatC>
|
||||
template <index_t OpselA,
|
||||
index_t OpselB,
|
||||
class FloatA,
|
||||
class ScaleA,
|
||||
class FloatB,
|
||||
class ScaleB,
|
||||
class FloatC>
|
||||
__device__ void Run(const FloatA& p_a_wave,
|
||||
const ScaleA& a_scale_thread,
|
||||
const FloatB& p_b_wave,
|
||||
@@ -1528,12 +1617,12 @@ struct XdlopsGemm
|
||||
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
if constexpr(!TransposeC)
|
||||
{
|
||||
mfma_instr.template run<MPerXdlops, NPerXdlops>(
|
||||
mfma_instr.template run<MPerXdlops, NPerXdlops, OpselA, OpselB>(
|
||||
p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
mfma_instr.template run<MPerXdlops, NPerXdlops>(
|
||||
mfma_instr.template run<MPerXdlops, NPerXdlops, OpselB, OpselA>(
|
||||
p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
|
||||
}
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user