mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Add fp8 @ bf8 gemm support and example (#933)
* Add f8 bf8 gemm example * Add element-wise ops * Add intrinsics * Update reference calculation * Add an additional type option for xdlops gemm * Fix build process * Add bf8 to buffer addressing * Update blockwise op, split typeA and typeB * Update for compatibility * Uppdate naming to f8->fp8 * Update naming * Format
This commit is contained in:
@@ -31,7 +31,9 @@ enum struct MfmaInstr
|
||||
mfma_i32_16x16x32i8,
|
||||
mfma_f64_16x16x4f64,
|
||||
mfma_f32_32x32x16f8f8,
|
||||
mfma_f32_16x16x32f8f8
|
||||
mfma_f32_16x16x32f8f8,
|
||||
mfma_f32_32x32x16f8bf8,
|
||||
mfma_f32_16x16x32f8bf8
|
||||
};
|
||||
|
||||
template <MfmaInstr instr>
|
||||
@@ -502,10 +504,62 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8bf8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
static constexpr index_t num_regs_per_blk = 16;
|
||||
static constexpr index_t num_threads_per_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 2;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 32;
|
||||
static constexpr index_t n_per_blk = 32;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x16f8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 4;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 4;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x32f8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename base_type,
|
||||
index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
typename additional_type = base_type>
|
||||
struct MfmaSelector
|
||||
{
|
||||
template <typename base_type_, index_t MPerXdlops_, index_t NPerXdlops_>
|
||||
template <typename base_type_,
|
||||
index_t MPerXdlops_,
|
||||
index_t NPerXdlops_,
|
||||
typename additional_type_ = base_type_>
|
||||
static constexpr auto GetMfma();
|
||||
|
||||
template <>
|
||||
@@ -656,7 +710,22 @@ struct MfmaSelector
|
||||
}
|
||||
#endif
|
||||
|
||||
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16f8bf8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32f8bf8;
|
||||
}
|
||||
#endif
|
||||
|
||||
static constexpr auto selected_mfma =
|
||||
mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{};
|
||||
|
||||
__host__ __device__ constexpr MfmaSelector()
|
||||
{
|
||||
@@ -703,7 +772,8 @@ template <typename base_type,
|
||||
index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
typename additional_type = base_type,
|
||||
bool TransposeC = false>
|
||||
struct XdlopsGemm
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -854,14 +924,18 @@ struct XdlopsGemm
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
|
||||
{
|
||||
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
|
||||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
|
||||
is_same<base_type, int8_t>::value
|
||||
static_assert(
|
||||
is_same<base_type, double>::value || is_same<base_type, float>::value ||
|
||||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
|
||||
is_same<base_type, int8_t>::value
|
||||
#if defined CK_ENABLE_FP8
|
||||
|| is_same<base_type, f8_t>::value
|
||||
|| is_same<base_type, f8_t>::value
|
||||
#endif
|
||||
,
|
||||
"base base_type must be double, float, half, bfloat16, and int8_t!");
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
|| (is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value)
|
||||
#endif
|
||||
,
|
||||
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
|
||||
|
||||
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
if constexpr(!TransposeC)
|
||||
@@ -957,7 +1031,7 @@ struct XdlopsGemm
|
||||
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
|
||||
}
|
||||
|
||||
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{};
|
||||
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops, additional_type>{};
|
||||
|
||||
static constexpr auto mfma_instr = mfma.selected_mfma;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user