mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Simulate TF32 with BF16x3 (#3142)
* tf32:bf16x3:use bf16x3 emulate tf32 gemm * change blockwiseGemm to demo bf16x3 * temp push * self review * self review * fix multi-device compile error * bug fix * code refactor * limit to gfx950 * enhance gemm gfx942 threshold * lower change from blockwise to warpwise * refact codes * refact codes * error fix * change threshold * bug fix * fix threshold error * change host reference implement to same as device * bug fix * bug fix * code refact * fix clang-format fail * code refine
This commit is contained in:
@@ -129,7 +129,10 @@ inline bool is_wmma_supported()
|
||||
return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported();
|
||||
}
|
||||
|
||||
inline bool is_tf32_supported() { return (ck::get_device_name() == "gfx942") ? true : false; }
|
||||
inline bool is_tf32_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -168,8 +168,8 @@ typename std::enable_if<
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double rtol = 1e-5,
|
||||
double atol = 3e-5)
|
||||
double rtol = 5e-4,
|
||||
double atol = 5e-4)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
|
||||
@@ -94,7 +94,8 @@ template <typename ALayout,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
typename CElementwiseOperation,
|
||||
typename ComputeDataType = ADataType>
|
||||
struct DeviceGroupedGemm : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
@@ -134,7 +134,8 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename ComputeDataType = ADataType>
|
||||
struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
@@ -145,7 +146,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
CDEElementwiseOperation,
|
||||
ComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemm_Xdl;
|
||||
GET_NXDL_PER_WAVE_IMPL
|
||||
@@ -233,8 +235,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
// GridwiseGemm
|
||||
template <index_t NXdlPerWave_>
|
||||
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
|
||||
@@ -80,8 +80,10 @@ enum struct MfmaInstr
|
||||
mfma_f32_16x16x128f8f6f4,
|
||||
mfma_scale_f32_32x32x64f8f6f4,
|
||||
mfma_scale_f32_16x16x128f8f6f4,
|
||||
mfma_f32_16x16x8xf32, // tf32
|
||||
mfma_f32_32x32x4xf32,
|
||||
mfma_f32_16x16x8xf32, // tf32 on gfx942
|
||||
mfma_f32_32x32x4xf32, // tf32 on gfx942
|
||||
mfma_f32_16x16x32xf32, // bf16x3 simulate tf32 on gfx950
|
||||
mfma_f32_32x32x16xf32, // bf16x3 simulate tf32 on gfx950
|
||||
// gfx11
|
||||
wmma_f32_16x16x16_f16,
|
||||
wmma_f32_16x16x16_bf16,
|
||||
@@ -1015,6 +1017,51 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4xf32>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x16xf32>
|
||||
{
|
||||
// gfx950 specific: use bf16x3 simulate tf32
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
static constexpr index_t num_regs_per_blk = 16;
|
||||
static constexpr index_t num_threads_per_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 2;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 32;
|
||||
static constexpr index_t n_per_blk = 32;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x16xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x32xf32>
|
||||
{
|
||||
// gfx950 specific: use bf16x3 simulate tf32
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 4;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 4;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x32xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
// gfx11
|
||||
struct mfma_type_gfx11_base
|
||||
{
|
||||
@@ -1275,12 +1322,14 @@ struct MfmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<tf32_t, 32, 32>()
|
||||
constexpr auto GetMfma<tf32_t, 32, 32, tf32_t>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_32x32x16xf32;
|
||||
#elif defined(__gfx942__)
|
||||
return MfmaInstr::mfma_f32_32x32x4xf32;
|
||||
#else
|
||||
@@ -1289,12 +1338,14 @@ struct MfmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<tf32_t, 16, 16>()
|
||||
constexpr auto GetMfma<tf32_t, 16, 16, tf32_t>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x32xf32;
|
||||
#elif defined(__gfx942__)
|
||||
return MfmaInstr::mfma_f32_16x16x8xf32;
|
||||
#else
|
||||
@@ -2185,6 +2236,10 @@ struct XdlopsGemm
|
||||
(is_same<base_type, int8_t>::value && KPack <= 8) ||
|
||||
((is_same<base_type, f8_t>::value || is_same<base_type, bf8_t>::value) && KPack < 32) ||
|
||||
is_same<additional_type, pk_i4_t>::value)
|
||||
#if defined(__gfx950__)
|
||||
// tf32 on gfx950 is implemented as bf16x3, so it should be treated as bf16.
|
||||
|| (is_same<base_type, tf32_t>::value && KPack <= 4)
|
||||
#endif
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto mfma = MfmaSelector<base_type,
|
||||
|
||||
@@ -10,6 +10,25 @@ namespace ck {
|
||||
#define __gfx94__
|
||||
#endif
|
||||
|
||||
// Helper function to convert float vector to bf16 vectors (big and small parts)
|
||||
// This is used by both tf32 and xf32 implementations
|
||||
template <index_t VecSize>
|
||||
__device__ __forceinline__ void
|
||||
convert_float_to_bf16_pairs(const vector_type<float, VecSize>& reg_f32,
|
||||
vector_type<bhalf_t, VecSize>& reg_bf16_big,
|
||||
vector_type<bhalf_t, VecSize>& reg_bf16_small)
|
||||
{
|
||||
static_for<0, VecSize, 1>{}([&](auto k) {
|
||||
using IK = Number<k>;
|
||||
reg_bf16_big.template AsType<bhalf_t>()(k) =
|
||||
type_convert<bhalf_t, float>(reg_f32.template AsType<float>()[IK{}]);
|
||||
reg_bf16_small.template AsType<bhalf_t>()(k) = type_convert<bhalf_t, float>(
|
||||
reg_f32.template AsType<float>()[IK{}] -
|
||||
type_convert<float, bhalf_t>(reg_bf16_big.template AsType<bhalf_t>()[IK{}]));
|
||||
});
|
||||
}
|
||||
/* */
|
||||
|
||||
// fp32
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x1f32;
|
||||
@@ -1636,7 +1655,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
|
||||
}
|
||||
};
|
||||
|
||||
/******************* tf32 *************************************/
|
||||
/******************* tf32 on gfx942 *************************************/
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x8xf32;
|
||||
|
||||
@@ -1646,7 +1665,7 @@ struct intrin_mfma_f32_16x16x8xf32<16, 16>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
#if defined(__gfx942__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
#else
|
||||
@@ -1666,7 +1685,7 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
#if defined(__gfx942__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
#else
|
||||
@@ -1677,4 +1696,102 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32>
|
||||
}
|
||||
};
|
||||
|
||||
/******************* tf32/xf32 on gfx950 ********************************/
|
||||
/* bf16x3 simulate tf32/xf32: input/output/accumulator are all float; */
|
||||
/* step: */
|
||||
/* 1. separate one input to 2 bf16 registers: */
|
||||
/* in_bf16_big = f32_to_bf16(in_f32) */
|
||||
/* in_bf16_small = in_f32 - in_bf16_big */
|
||||
/* 2. run 3 xdlops gemm: the accumulator of each gemm is the same. */
|
||||
/* out_f32 = A_bf16_big * B_bf16_big */
|
||||
/* out_f32 += A_bf16_small * B_bf16_big */
|
||||
/* out_f32 += A_bf16_big * B_bf16_small */
|
||||
/************************************************************************/
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x32xf32;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x32xf32<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
using I0 = Number<0>;
|
||||
vector_type<float, 8> reg_a_v(reg_a);
|
||||
vector_type<float, 8> reg_b_v(reg_b);
|
||||
|
||||
vector_type<bhalf_t, 8> v_reg_a_bf16_big;
|
||||
vector_type<bhalf_t, 8> v_reg_a_bf16_small;
|
||||
vector_type<bhalf_t, 8> v_reg_b_bf16_big;
|
||||
vector_type<bhalf_t, 8> v_reg_b_bf16_small;
|
||||
|
||||
convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small);
|
||||
convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small);
|
||||
|
||||
// Run 3 times: big*big, small*big, big*small
|
||||
intrin_mfma_f32_16x16x32bf16<16, 16>::Run(
|
||||
v_reg_a_bf16_small.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
intrin_mfma_f32_16x16x32bf16<16, 16>::Run(
|
||||
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_small.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
intrin_mfma_f32_16x16x32bf16<16, 16>::Run(
|
||||
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif // defined(__gfx950__)
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x16xf32;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x16xf32<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
using I0 = Number<0>;
|
||||
vector_type<float, 8> reg_a_v(reg_a);
|
||||
vector_type<float, 8> reg_b_v(reg_b);
|
||||
|
||||
vector_type<bhalf_t, 8> v_reg_a_bf16_big;
|
||||
vector_type<bhalf_t, 8> v_reg_a_bf16_small;
|
||||
vector_type<bhalf_t, 8> v_reg_b_bf16_big;
|
||||
vector_type<bhalf_t, 8> v_reg_b_bf16_small;
|
||||
|
||||
convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small);
|
||||
convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small);
|
||||
|
||||
// Run 3 times: big*big, small*big, big*small
|
||||
intrin_mfma_f32_32x32x16bf16<32, 32>::Run(
|
||||
v_reg_a_bf16_small.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
intrin_mfma_f32_32x32x16bf16<32, 32>::Run(
|
||||
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_small.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
intrin_mfma_f32_32x32x16bf16<32, 32>::Run(
|
||||
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif // defined(__gfx950__)
|
||||
}
|
||||
};
|
||||
|
||||
/******************* tf32/xf32 on gfx950 end ************************************/
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user