Simulate TF32 with BF16x3 (#3142)

* tf32:bf16x3:use bf16x3 emulate tf32 gemm

* change blockwiseGemm to demo bf16x3

* temp push

* self review

* self review

* fix multi-device compile error

* bug fix

* code refactor

* limit to gfx950

* enhance gemm gfx942 threshold

* lower change from blockwise to warpwise

* refact codes

* refact codes

* error fix

* change threshold

* bug fix

* fix threshold error

* change host reference implement to same as device

* bug fix

* bug fix

* code refact

* fix clang-format fail

* code refine
This commit is contained in:
yinglu
2025-11-14 08:21:09 +08:00
committed by GitHub
parent f2cfc6b94e
commit 2a73eb3bc0
16 changed files with 419 additions and 49 deletions

View File

@@ -80,8 +80,10 @@ enum struct MfmaInstr
mfma_f32_16x16x128f8f6f4,
mfma_scale_f32_32x32x64f8f6f4,
mfma_scale_f32_16x16x128f8f6f4,
mfma_f32_16x16x8xf32, // tf32
mfma_f32_32x32x4xf32,
mfma_f32_16x16x8xf32, // tf32 on gfx942
mfma_f32_32x32x4xf32, // tf32 on gfx942
mfma_f32_16x16x32xf32, // bf16x3 simulate tf32 on gfx950
mfma_f32_32x32x16xf32, // bf16x3 simulate tf32 on gfx950
// gfx11
wmma_f32_16x16x16_f16,
wmma_f32_16x16x16_bf16,
@@ -1015,6 +1017,51 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4xf32>
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16xf32>
{
// gfx950 specific: use bf16x3 simulate tf32
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32xf32>
{
// gfx950 specific: use bf16x3 simulate tf32
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
// gfx11
struct mfma_type_gfx11_base
{
@@ -1275,12 +1322,14 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<tf32_t, 32, 32>()
constexpr auto GetMfma<tf32_t, 32, 32, tf32_t>()
{
#if defined(__gfx12__)
return MfmaInstr::wmma_unsupport_16x16_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#elif defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x16xf32;
#elif defined(__gfx942__)
return MfmaInstr::mfma_f32_32x32x4xf32;
#else
@@ -1289,12 +1338,14 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<tf32_t, 16, 16>()
constexpr auto GetMfma<tf32_t, 16, 16, tf32_t>()
{
#if defined(__gfx12__)
return MfmaInstr::wmma_unsupport_16x16_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#elif defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32xf32;
#elif defined(__gfx942__)
return MfmaInstr::mfma_f32_16x16x8xf32;
#else
@@ -2185,6 +2236,10 @@ struct XdlopsGemm
(is_same<base_type, int8_t>::value && KPack <= 8) ||
((is_same<base_type, f8_t>::value || is_same<base_type, bf8_t>::value) && KPack < 32) ||
is_same<additional_type, pk_i4_t>::value)
#if defined(__gfx950__)
// tf32 on gfx950 is implemented as bf16x3, so it should be treated as bf16.
|| (is_same<base_type, tf32_t>::value && KPack <= 4)
#endif
? true
: false;
static constexpr auto mfma = MfmaSelector<base_type,