MX GEMM - FP6 Support in GEMM MX v3 Pipeline (#2481)

* Add GEMM MX BF6 example

* Fix BF6 type_convert

* Add type_convert for bf16x6

* Add compare operator to f4x2_pk_t

* Update README for 67_gemm_microscaling

* Fix host tensor initialization with integer values for FP8
This commit is contained in:
Andriy Roshchenko
2025-07-11 13:07:05 -06:00
committed by GitHub
parent d239b91fd5
commit 518dc21ae8
11 changed files with 303 additions and 15 deletions

View File

@@ -1118,6 +1118,54 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
#endif
}
template <class FloatC>
__device__ static void Run(const bf6x16x2_t& reg_a,
const int32_t scale_a,
const bf6x16x2_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
using arg_type = int32x8_t;
arg_type arg_a{
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][0]),
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][1]),
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][2]),
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][0]),
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][1]),
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][2]),
0,
0};
arg_type arg_b{
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][0]),
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][1]),
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][2]),
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][0]),
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][1]),
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][2]),
0,
0};
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_a,
arg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
3, // blgp
OpselA, // OPSEL
scale_a,
OpselB, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a,
const int32_t scale_a,

View File

@@ -60,6 +60,17 @@ struct f4x2_pk_t
{
return (x0 << 4) | (x1 & 0b00001111);
}
// Compare operator
__host__ __device__ friend bool operator==(const f4x2_pk_t& lhs, const f4x2_pk_t& rhs)
{
return lhs.data == rhs.data;
}
__host__ __device__ friend bool operator!=(const f4x2_pk_t& lhs, const f4x2_pk_t& rhs)
{
return !(lhs == rhs);
}
};
template <typename BitType, index_t pk_size>

View File

@@ -2254,8 +2254,9 @@ using f6x16x2_t = typename vector_type<f6x16_pk_t, 2>::type;
using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
// bf6
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
using bf6x32_t = typename vector_type<bf6x32_pk_t, 1>::type;
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
using bf6x16x2_t = typename vector_type<bf6x16_pk_t, 2>::type;
using bf6x32_t = typename vector_type<bf6x32_pk_t, 1>::type;
// e8m0
using e8m0x4_bexp_t = typename vector_type<e8m0_bexp_t, 4>::type;

View File

@@ -2102,17 +2102,15 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
using array_type = uint8_t __attribute__((ext_vector_type(32)));
array_type uint8_array;
// collect the 6-bit values into an array
ck::static_for<0, 32, 1>{}([&](auto i) {
out.bf6_array[i] = utils::sat_convert_to_type<bf6_t>(in.float_array[i] / scale);
uint8_array[static_cast<index_t>(i)] =
utils::sat_convert_to_type<bf6_t>(in.float_array[i] / scale);
});
return out.bf6_vector;
return bf6x32_t{bf6x32_pk_t{uint8_array}};
#endif
}
@@ -2257,6 +2255,37 @@ inline __host__ __device__ bf6x32_pk_t type_convert<bf6x32_pk_t, float32_t>(floa
return static_cast<bf6x32_pk_t>(type_convert<bf6x32_t>(x));
}
template <>
inline __host__ __device__ bf6x16_t type_convert<bf6x16_t, float16_t>(float16_t x)
{
union
{
float16_t v16x2[2];
float32_t v32;
} in{{x, x}};
union
{
bf6x32_t v32;
bf6x16_t v16x2[2];
} out{};
#if CK_USE_SR_F6_CONVERSION
out.v32 = bf6_convert_sr(in.v32);
#else
out.v32 = bf6_convert_rne(in.v32);
#endif
return out.v16x2[0];
}
template <>
inline __host__ __device__ bf6x16_pk_t type_convert<bf6x16_pk_t, float16_t>(float16_t x)
{
return static_cast<bf6x16_pk_t>(type_convert<bf6x16_t>(x));
}
/**
* @brief Specializes the type conversion template for converting a bf6_t value to float.
*
@@ -2329,6 +2358,32 @@ inline __host__ __device__ float32_t type_convert<float32_t, bf6x32_t>(bf6x32_t
return out.float_vector;
#endif
}
template <>
inline __host__ __device__ float16_t type_convert<float16_t, bf6x16_t>(bf6x16_t x)
{
union
{
bf6x16_t v16x2[2];
bf6x32_t v32;
} in{{x, x}};
union
{
float16_t v16x2[2];
float32_t v32;
} out{};
out.v32 = type_convert<float32_t>(in.v32);
return out.v16x2[0];
}
template <>
inline __host__ __device__ float16_t type_convert<float16_t, bf6x16_pk_t>(bf6x16_pk_t x)
{
return type_convert<float16_t>(static_cast<bf6x16_t>(x));
}
#endif
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
template <typename Y, typename X, size_t NumElems>