mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Add fp8 GEMM and an example for it (#767)
* Add fp8 xdl gemm * Add example * Use int8 intrinsics for buffer load/store * Format * Update cmakelists
This commit is contained in:
@@ -29,7 +29,9 @@ enum struct MfmaInstr
|
||||
mfma_i32_16x16x16i8,
|
||||
mfma_i32_32x32x16i8,
|
||||
mfma_i32_16x16x32i8,
|
||||
mfma_f64_16x16x4f64
|
||||
mfma_f64_16x16x4f64,
|
||||
mfma_f32_32x32x16f8f8,
|
||||
mfma_f32_16x16x32f8f8
|
||||
};
|
||||
|
||||
template <MfmaInstr instr>
|
||||
@@ -454,6 +456,50 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
static constexpr index_t num_regs_per_blk = 16;
|
||||
static constexpr index_t num_threads_per_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 2;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 32;
|
||||
static constexpr index_t n_per_blk = 32;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x16f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 4;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 4;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
|
||||
struct MfmaSelector
|
||||
{
|
||||
@@ -594,6 +640,18 @@ struct MfmaSelector
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16f8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32f8f8;
|
||||
}
|
||||
|
||||
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
|
||||
|
||||
__host__ __device__ constexpr MfmaSelector()
|
||||
@@ -794,7 +852,7 @@ struct XdlopsGemm
|
||||
{
|
||||
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
|
||||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
|
||||
is_same<base_type, int8_t>::value,
|
||||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value,
|
||||
"base base_type must be double, float, half, bfloat16, and int8_t!");
|
||||
|
||||
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
|
||||
Reference in New Issue
Block a user