mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Simulate TF32 with BF16x3 (#3142)
* tf32:bf16x3:use bf16x3 emulate tf32 gemm * change blockwiseGemm to demo bf16x3 * temp push * self review * self review * fix multi-device compile error * bug fix * code refactor * limit to gfx950 * enhance gemm gfx942 threshold * lower change from blockwise to warpwise * refact codes * refact codes * error fix * change threshold * bug fix * fix threshold error * change host reference implement to same as device * bug fix * bug fix * code refact * fix clang-format fail * code refine
This commit is contained in:
@@ -80,8 +80,10 @@ enum struct MfmaInstr
|
||||
mfma_f32_16x16x128f8f6f4,
|
||||
mfma_scale_f32_32x32x64f8f6f4,
|
||||
mfma_scale_f32_16x16x128f8f6f4,
|
||||
mfma_f32_16x16x8xf32, // tf32
|
||||
mfma_f32_32x32x4xf32,
|
||||
mfma_f32_16x16x8xf32, // tf32 on gfx942
|
||||
mfma_f32_32x32x4xf32, // tf32 on gfx942
|
||||
mfma_f32_16x16x32xf32, // bf16x3 simulate tf32 on gfx950
|
||||
mfma_f32_32x32x16xf32, // bf16x3 simulate tf32 on gfx950
|
||||
// gfx11
|
||||
wmma_f32_16x16x16_f16,
|
||||
wmma_f32_16x16x16_bf16,
|
||||
@@ -1015,6 +1017,51 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4xf32>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x16xf32>
|
||||
{
|
||||
// gfx950 specific: use bf16x3 simulate tf32
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
static constexpr index_t num_regs_per_blk = 16;
|
||||
static constexpr index_t num_threads_per_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 2;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 32;
|
||||
static constexpr index_t n_per_blk = 32;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x16xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x32xf32>
|
||||
{
|
||||
// gfx950 specific: use bf16x3 simulate tf32
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 4;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 4;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x32xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
// gfx11
|
||||
struct mfma_type_gfx11_base
|
||||
{
|
||||
@@ -1275,12 +1322,14 @@ struct MfmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<tf32_t, 32, 32>()
|
||||
constexpr auto GetMfma<tf32_t, 32, 32, tf32_t>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_32x32x16xf32;
|
||||
#elif defined(__gfx942__)
|
||||
return MfmaInstr::mfma_f32_32x32x4xf32;
|
||||
#else
|
||||
@@ -1289,12 +1338,14 @@ struct MfmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<tf32_t, 16, 16>()
|
||||
constexpr auto GetMfma<tf32_t, 16, 16, tf32_t>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x32xf32;
|
||||
#elif defined(__gfx942__)
|
||||
return MfmaInstr::mfma_f32_16x16x8xf32;
|
||||
#else
|
||||
@@ -2185,6 +2236,10 @@ struct XdlopsGemm
|
||||
(is_same<base_type, int8_t>::value && KPack <= 8) ||
|
||||
((is_same<base_type, f8_t>::value || is_same<base_type, bf8_t>::value) && KPack < 32) ||
|
||||
is_same<additional_type, pk_i4_t>::value)
|
||||
#if defined(__gfx950__)
|
||||
// tf32 on gfx950 is implemented as bf16x3, so it should be treated as bf16.
|
||||
|| (is_same<base_type, tf32_t>::value && KPack <= 4)
|
||||
#endif
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto mfma = MfmaSelector<base_type,
|
||||
|
||||
Reference in New Issue
Block a user