Simulate TF32 with BF16x3 (#3142)

* tf32:bf16x3:use bf16x3 emulate tf32 gemm

* change blockwiseGemm to demo bf16x3

* temp push

* self review

* self review

* fix multi-device compile error

* bug fix

* code refactor

* limit to gfx950

* enhance gemm gfx942 threshold

* lower change from blockwise to warpwise

* refact codes

* refact codes

* error fix

* change threshold

* bug fix

* fix threshold error

* change host reference implement to same as device

* bug fix

* bug fix

* code refact

* fix clang-format fail

* code refine
This commit is contained in:
yinglu
2025-11-14 08:21:09 +08:00
committed by GitHub
parent f2cfc6b94e
commit 2a73eb3bc0
16 changed files with 419 additions and 49 deletions

View File

@@ -129,7 +129,10 @@ inline bool is_wmma_supported()
return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported();
}
inline bool is_tf32_supported() { return (ck::get_device_name() == "gfx942") ? true : false; }
inline bool is_tf32_supported()
{
return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
}
} // namespace ck
#endif

View File

@@ -168,8 +168,8 @@ typename std::enable_if<
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-5,
double atol = 3e-5)
double rtol = 5e-4,
double atol = 5e-4)
{
if(out.size() != ref.size())
{

View File

@@ -94,7 +94,8 @@ template <typename ALayout,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
typename CElementwiseOperation,
typename ComputeDataType = ADataType>
struct DeviceGroupedGemm : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();

View File

@@ -134,7 +134,8 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename ComputeDataType = ADataType>
struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
@@ -145,7 +146,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
CDEElementwiseOperation,
ComputeDataType>
{
using DeviceOp = DeviceGroupedGemm_Xdl;
GET_NXDL_PER_WAVE_IMPL
@@ -233,8 +235,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
using ComputeDataType = ADataType;
// GridwiseGemm
template <index_t NXdlPerWave_>
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<

View File

@@ -80,8 +80,10 @@ enum struct MfmaInstr
mfma_f32_16x16x128f8f6f4,
mfma_scale_f32_32x32x64f8f6f4,
mfma_scale_f32_16x16x128f8f6f4,
mfma_f32_16x16x8xf32, // tf32
mfma_f32_32x32x4xf32,
mfma_f32_16x16x8xf32, // tf32 on gfx942
mfma_f32_32x32x4xf32, // tf32 on gfx942
mfma_f32_16x16x32xf32, // bf16x3 simulate tf32 on gfx950
mfma_f32_32x32x16xf32, // bf16x3 simulate tf32 on gfx950
// gfx11
wmma_f32_16x16x16_f16,
wmma_f32_16x16x16_bf16,
@@ -1015,6 +1017,51 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4xf32>
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16xf32>
{
// gfx950 specific: use bf16x3 simulate tf32
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32xf32>
{
// gfx950 specific: use bf16x3 simulate tf32
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
// gfx11
struct mfma_type_gfx11_base
{
@@ -1275,12 +1322,14 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<tf32_t, 32, 32>()
constexpr auto GetMfma<tf32_t, 32, 32, tf32_t>()
{
#if defined(__gfx12__)
return MfmaInstr::wmma_unsupport_16x16_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#elif defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x16xf32;
#elif defined(__gfx942__)
return MfmaInstr::mfma_f32_32x32x4xf32;
#else
@@ -1289,12 +1338,14 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<tf32_t, 16, 16>()
constexpr auto GetMfma<tf32_t, 16, 16, tf32_t>()
{
#if defined(__gfx12__)
return MfmaInstr::wmma_unsupport_16x16_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#elif defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32xf32;
#elif defined(__gfx942__)
return MfmaInstr::mfma_f32_16x16x8xf32;
#else
@@ -2185,6 +2236,10 @@ struct XdlopsGemm
(is_same<base_type, int8_t>::value && KPack <= 8) ||
((is_same<base_type, f8_t>::value || is_same<base_type, bf8_t>::value) && KPack < 32) ||
is_same<additional_type, pk_i4_t>::value)
#if defined(__gfx950__)
// tf32 on gfx950 is implemented as bf16x3, so it should be treated as bf16.
|| (is_same<base_type, tf32_t>::value && KPack <= 4)
#endif
? true
: false;
static constexpr auto mfma = MfmaSelector<base_type,

View File

@@ -10,6 +10,25 @@ namespace ck {
#define __gfx94__
#endif
// Helper function to convert float vector to bf16 vectors (big and small parts)
// This is used by both tf32 and xf32 implementations
template <index_t VecSize>
__device__ __forceinline__ void
convert_float_to_bf16_pairs(const vector_type<float, VecSize>& reg_f32,
vector_type<bhalf_t, VecSize>& reg_bf16_big,
vector_type<bhalf_t, VecSize>& reg_bf16_small)
{
static_for<0, VecSize, 1>{}([&](auto k) {
using IK = Number<k>;
reg_bf16_big.template AsType<bhalf_t>()(k) =
type_convert<bhalf_t, float>(reg_f32.template AsType<float>()[IK{}]);
reg_bf16_small.template AsType<bhalf_t>()(k) = type_convert<bhalf_t, float>(
reg_f32.template AsType<float>()[IK{}] -
type_convert<float, bhalf_t>(reg_bf16_big.template AsType<bhalf_t>()[IK{}]));
});
}
/* */
// fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x1f32;
@@ -1636,7 +1655,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
}
};
/******************* tf32 *************************************/
/******************* tf32 on gfx942 *************************************/
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x8xf32;
@@ -1646,7 +1665,7 @@ struct intrin_mfma_f32_16x16x8xf32<16, 16>
template <class FloatC>
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx94__)
#if defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
#else
@@ -1666,7 +1685,7 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32>
template <class FloatC>
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx94__)
#if defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
#else
@@ -1677,4 +1696,102 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32>
}
};
/******************* tf32/xf32 on gfx950 ********************************/
/* bf16x3 simulate tf32/xf32: input/output/accumulator are all float; */
/* step: */
/* 1. separate one input to 2 bf16 registers: */
/* in_bf16_big = f32_to_bf16(in_f32) */
/* in_bf16_small = in_f32 - in_bf16_big */
/* 2. run 3 xdlops gemm: the accumulator of each gemm is the same. */
/* out_f32 = A_bf16_big * B_bf16_big */
/* out_f32 += A_bf16_small * B_bf16_big */
/* out_f32 += A_bf16_big * B_bf16_small */
/************************************************************************/
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32xf32;
template <>
struct intrin_mfma_f32_16x16x32xf32<16, 16>
{
template <class FloatC>
__device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
using I0 = Number<0>;
vector_type<float, 8> reg_a_v(reg_a);
vector_type<float, 8> reg_b_v(reg_b);
vector_type<bhalf_t, 8> v_reg_a_bf16_big;
vector_type<bhalf_t, 8> v_reg_a_bf16_small;
vector_type<bhalf_t, 8> v_reg_b_bf16_big;
vector_type<bhalf_t, 8> v_reg_b_bf16_small;
convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small);
convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small);
// Run 3 times: big*big, small*big, big*small
intrin_mfma_f32_16x16x32bf16<16, 16>::Run(
v_reg_a_bf16_small.template AsType<bhalf8_t>()[I0{}],
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
reg_c);
intrin_mfma_f32_16x16x32bf16<16, 16>::Run(
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
v_reg_b_bf16_small.template AsType<bhalf8_t>()[I0{}],
reg_c);
intrin_mfma_f32_16x16x32bf16<16, 16>::Run(
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
reg_c);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16xf32;
template <>
struct intrin_mfma_f32_32x32x16xf32<32, 32>
{
template <class FloatC>
__device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
using I0 = Number<0>;
vector_type<float, 8> reg_a_v(reg_a);
vector_type<float, 8> reg_b_v(reg_b);
vector_type<bhalf_t, 8> v_reg_a_bf16_big;
vector_type<bhalf_t, 8> v_reg_a_bf16_small;
vector_type<bhalf_t, 8> v_reg_b_bf16_big;
vector_type<bhalf_t, 8> v_reg_b_bf16_small;
convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small);
convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small);
// Run 3 times: big*big, small*big, big*small
intrin_mfma_f32_32x32x16bf16<32, 32>::Run(
v_reg_a_bf16_small.template AsType<bhalf8_t>()[I0{}],
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
reg_c);
intrin_mfma_f32_32x32x16bf16<32, 32>::Run(
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
v_reg_b_bf16_small.template AsType<bhalf8_t>()[I0{}],
reg_c);
intrin_mfma_f32_32x32x16bf16<32, 32>::Run(
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
reg_c);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
/******************* tf32/xf32 on gfx950 end ************************************/
} // namespace ck