mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
MX GEMM - New GEMM pipeline for MX data types (#2059)
* Allow selection of mfma_scale instructions * Read B tensor from LDS to VGPR in chunks of 16 in MFMA order * Add constexpr and synchronize return type for `get_exponent_value` * Pass scales by reference and add comments to `mfma_scale_f32_32x32x64` * Add support for microscaling instructions in `XdlopsGemm` * Fix `mfma_scale_f32_16x16x128f8f6f4` wrapper * Remove software implementation of MX GEMM * Make interface of `intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>` consistent with the other scale instruction * Update README * Updated CHANGELOG * Remove unused static methods
This commit is contained in:
committed by
GitHub
parent
d55c9cb313
commit
7106976a72
@@ -845,15 +845,24 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
|
||||
static constexpr bool is_k_reduction = true; // ???
|
||||
// clang-format on
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
class FloatA,
|
||||
class ScaleA,
|
||||
class FloatB,
|
||||
class ScaleB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a,
|
||||
const int32_t scale_a,
|
||||
const ScaleA& scale_a,
|
||||
const FloatB& b,
|
||||
const int32_t scale_b,
|
||||
const ScaleB& scale_b,
|
||||
FloatC& reg_c) const
|
||||
{
|
||||
static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this point.");
|
||||
static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at this point.");
|
||||
|
||||
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(
|
||||
a, scale_a, b, scale_b, reg_c);
|
||||
a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -874,15 +883,24 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
|
||||
static constexpr bool is_k_reduction = true; // ???
|
||||
// clang-format on
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
class FloatA,
|
||||
class ScaleA,
|
||||
class FloatB,
|
||||
class ScaleB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a,
|
||||
const int32_t scale_a,
|
||||
const ScaleA& scale_a,
|
||||
const FloatB& b,
|
||||
const int32_t scale_b,
|
||||
const ScaleB& scale_b,
|
||||
FloatC& reg_c) const
|
||||
{
|
||||
static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this point.");
|
||||
static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at this point.");
|
||||
|
||||
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(
|
||||
a, scale_a, b, scale_b, reg_c);
|
||||
a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -890,14 +908,16 @@ template <typename base_type,
|
||||
index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
typename additional_type = base_type,
|
||||
bool is_single_rate_mfma = false>
|
||||
bool is_single_rate_mfma = false,
|
||||
bool is_scale_mfma = false>
|
||||
struct MfmaSelector
|
||||
{
|
||||
template <typename base_type_,
|
||||
index_t MPerXdlops_,
|
||||
index_t NPerXdlops_,
|
||||
typename additional_type_ = base_type_,
|
||||
bool is_single_rate_mfma_ = false>
|
||||
bool is_single_rate_mfma_ = false,
|
||||
bool is_scale_mfma_ = false>
|
||||
static constexpr auto GetMfma();
|
||||
|
||||
template <>
|
||||
@@ -1103,12 +1123,24 @@ struct MfmaSelector
|
||||
return MfmaInstr::mfma_f32_32x32x16f8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 32, 32, f8_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32f8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 32, 32>()
|
||||
{
|
||||
@@ -1145,8 +1177,12 @@ struct MfmaSelector
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8f8;
|
||||
}
|
||||
|
||||
static constexpr auto selected_mfma = mfma_type<
|
||||
GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type, is_single_rate_mfma>()>{};
|
||||
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type,
|
||||
MPerXdlops,
|
||||
NPerXdlops,
|
||||
additional_type,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>()>{};
|
||||
|
||||
__host__ __device__ constexpr MfmaSelector()
|
||||
{
|
||||
@@ -1194,7 +1230,8 @@ template <typename base_type,
|
||||
index_t NPerXdlops,
|
||||
index_t KPack,
|
||||
typename additional_type = base_type,
|
||||
bool TransposeC = false>
|
||||
bool TransposeC = false,
|
||||
bool is_scale_mfma = false>
|
||||
struct XdlopsGemm
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -1225,7 +1262,7 @@ struct XdlopsGemm
|
||||
MPerXdlops == 64,
|
||||
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
|
||||
|
||||
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
|
||||
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
|
||||
}
|
||||
|
||||
// XDL output supporting C = A * B
|
||||
@@ -1368,6 +1405,27 @@ struct XdlopsGemm
|
||||
});
|
||||
}
|
||||
|
||||
template <class FloatA, class ScaleA, class FloatB, class ScaleB, class FloatC>
|
||||
__device__ void Run(const FloatA& p_a_wave,
|
||||
const ScaleA& a_scale_thread,
|
||||
const FloatB& p_b_wave,
|
||||
const ScaleB& b_scale_thread,
|
||||
FloatC& p_c_thread) const
|
||||
{
|
||||
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
if constexpr(!TransposeC)
|
||||
{
|
||||
mfma_instr.template run<MPerXdlops, NPerXdlops>(
|
||||
p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
mfma_instr.template run<MPerXdlops, NPerXdlops>(
|
||||
p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; }
|
||||
|
||||
__device__ static auto GetBlkIdx()
|
||||
@@ -1455,7 +1513,8 @@ struct XdlopsGemm
|
||||
KPack <= 4) ||
|
||||
(is_same<base_type, int8_t>::value && KPack <= 8))
|
||||
? true
|
||||
: false > {};
|
||||
: false,
|
||||
is_scale_mfma > {};
|
||||
|
||||
static constexpr auto mfma_instr = mfma.selected_mfma;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user