mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Simulate TF32 with BF16x3 (#3142)
* tf32:bf16x3:use bf16x3 emulate tf32 gemm * change blockwiseGemm to demo bf16x3 * temp push * self review * self review * fix multi-device compile error * bug fix * code refactor * limit to gfx950 * enhance gemm gfx942 threshold * lower change from blockwise to warpwise * refact codes * refact codes * error fix * change threshold * bug fix * fix threshold error * change host reference implement to same as device * bug fix * bug fix * code refact * fix clang-format fail * code refine
This commit is contained in:
@@ -10,6 +10,25 @@ namespace ck {
|
||||
#define __gfx94__
|
||||
#endif
|
||||
|
||||
// Helper function to convert float vector to bf16 vectors (big and small parts)
|
||||
// This is used by both tf32 and xf32 implementations
|
||||
template <index_t VecSize>
|
||||
__device__ __forceinline__ void
|
||||
convert_float_to_bf16_pairs(const vector_type<float, VecSize>& reg_f32,
|
||||
vector_type<bhalf_t, VecSize>& reg_bf16_big,
|
||||
vector_type<bhalf_t, VecSize>& reg_bf16_small)
|
||||
{
|
||||
static_for<0, VecSize, 1>{}([&](auto k) {
|
||||
using IK = Number<k>;
|
||||
reg_bf16_big.template AsType<bhalf_t>()(k) =
|
||||
type_convert<bhalf_t, float>(reg_f32.template AsType<float>()[IK{}]);
|
||||
reg_bf16_small.template AsType<bhalf_t>()(k) = type_convert<bhalf_t, float>(
|
||||
reg_f32.template AsType<float>()[IK{}] -
|
||||
type_convert<float, bhalf_t>(reg_bf16_big.template AsType<bhalf_t>()[IK{}]));
|
||||
});
|
||||
}
|
||||
/* */
|
||||
|
||||
// fp32
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x1f32;
|
||||
@@ -1636,7 +1655,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
|
||||
}
|
||||
};
|
||||
|
||||
/******************* tf32 *************************************/
|
||||
/******************* tf32 on gfx942 *************************************/
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x8xf32;
|
||||
|
||||
@@ -1646,7 +1665,7 @@ struct intrin_mfma_f32_16x16x8xf32<16, 16>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
#if defined(__gfx942__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
#else
|
||||
@@ -1666,7 +1685,7 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
#if defined(__gfx942__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
#else
|
||||
@@ -1677,4 +1696,102 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32>
|
||||
}
|
||||
};
|
||||
|
||||
/******************* tf32/xf32 on gfx950 ********************************/
|
||||
/* bf16x3 simulate tf32/xf32: input/output/accumulator are all float; */
|
||||
/* step: */
|
||||
/* 1. separate one input to 2 bf16 registers: */
|
||||
/* in_bf16_big = f32_to_bf16(in_f32) */
|
||||
/* in_bf16_small = in_f32 - in_bf16_big */
|
||||
/* 2. run 3 xdlops gemm: the accumulator of each gemm is the same. */
|
||||
/* out_f32 = A_bf16_big * B_bf16_big */
|
||||
/* out_f32 += A_bf16_small * B_bf16_big */
|
||||
/* out_f32 += A_bf16_big * B_bf16_small */
|
||||
/************************************************************************/
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x32xf32;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x32xf32<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
using I0 = Number<0>;
|
||||
vector_type<float, 8> reg_a_v(reg_a);
|
||||
vector_type<float, 8> reg_b_v(reg_b);
|
||||
|
||||
vector_type<bhalf_t, 8> v_reg_a_bf16_big;
|
||||
vector_type<bhalf_t, 8> v_reg_a_bf16_small;
|
||||
vector_type<bhalf_t, 8> v_reg_b_bf16_big;
|
||||
vector_type<bhalf_t, 8> v_reg_b_bf16_small;
|
||||
|
||||
convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small);
|
||||
convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small);
|
||||
|
||||
// Run 3 times: big*big, small*big, big*small
|
||||
intrin_mfma_f32_16x16x32bf16<16, 16>::Run(
|
||||
v_reg_a_bf16_small.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
intrin_mfma_f32_16x16x32bf16<16, 16>::Run(
|
||||
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_small.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
intrin_mfma_f32_16x16x32bf16<16, 16>::Run(
|
||||
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif // defined(__gfx950__)
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x16xf32;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x16xf32<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
using I0 = Number<0>;
|
||||
vector_type<float, 8> reg_a_v(reg_a);
|
||||
vector_type<float, 8> reg_b_v(reg_b);
|
||||
|
||||
vector_type<bhalf_t, 8> v_reg_a_bf16_big;
|
||||
vector_type<bhalf_t, 8> v_reg_a_bf16_small;
|
||||
vector_type<bhalf_t, 8> v_reg_b_bf16_big;
|
||||
vector_type<bhalf_t, 8> v_reg_b_bf16_small;
|
||||
|
||||
convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small);
|
||||
convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small);
|
||||
|
||||
// Run 3 times: big*big, small*big, big*small
|
||||
intrin_mfma_f32_32x32x16bf16<32, 32>::Run(
|
||||
v_reg_a_bf16_small.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
intrin_mfma_f32_32x32x16bf16<32, 32>::Run(
|
||||
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_small.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
intrin_mfma_f32_32x32x16bf16<32, 32>::Run(
|
||||
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif // defined(__gfx950__)
|
||||
}
|
||||
};
|
||||
|
||||
/******************* tf32/xf32 on gfx950 end ************************************/
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user