mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
MX GEMM - FP6 Support in GEMM MX v3 Pipeline (#2481)
* Add GEMM MX BF6 example * Fix BF6 type_convert * Add type_convert for bf16x6 * Add compare operator to f4x2_pk_t * Update README for 67_gemm_microscaling * Fix host tensor initialization with integer values for FP8
This commit is contained in:
committed by
GitHub
parent
d239b91fd5
commit
518dc21ae8
@@ -1118,6 +1118,54 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf6x16x2_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
const bf6x16x2_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
using arg_type = int32x8_t;
|
||||
arg_type arg_a{
|
||||
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][0]),
|
||||
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][1]),
|
||||
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][2]),
|
||||
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][0]),
|
||||
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][1]),
|
||||
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][2]),
|
||||
0,
|
||||
0};
|
||||
arg_type arg_b{
|
||||
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][0]),
|
||||
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][1]),
|
||||
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][2]),
|
||||
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][0]),
|
||||
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][1]),
|
||||
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][2]),
|
||||
0,
|
||||
0};
|
||||
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_a,
|
||||
arg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
3, // blgp
|
||||
OpselA, // OPSEL
|
||||
scale_a,
|
||||
OpselB, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
|
||||
@@ -60,6 +60,17 @@ struct f4x2_pk_t
|
||||
{
|
||||
return (x0 << 4) | (x1 & 0b00001111);
|
||||
}
|
||||
|
||||
// Compare operator
|
||||
__host__ __device__ friend bool operator==(const f4x2_pk_t& lhs, const f4x2_pk_t& rhs)
|
||||
{
|
||||
return lhs.data == rhs.data;
|
||||
}
|
||||
|
||||
__host__ __device__ friend bool operator!=(const f4x2_pk_t& lhs, const f4x2_pk_t& rhs)
|
||||
{
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BitType, index_t pk_size>
|
||||
|
||||
@@ -2254,8 +2254,9 @@ using f6x16x2_t = typename vector_type<f6x16_pk_t, 2>::type;
|
||||
using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
|
||||
|
||||
// bf6
|
||||
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
|
||||
using bf6x32_t = typename vector_type<bf6x32_pk_t, 1>::type;
|
||||
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
|
||||
using bf6x16x2_t = typename vector_type<bf6x16_pk_t, 2>::type;
|
||||
using bf6x32_t = typename vector_type<bf6x32_pk_t, 1>::type;
|
||||
|
||||
// e8m0
|
||||
using e8m0x4_bexp_t = typename vector_type<e8m0_bexp_t, 4>::type;
|
||||
|
||||
@@ -2102,17 +2102,15 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1
|
||||
float float_array[32];
|
||||
} in{x};
|
||||
|
||||
union
|
||||
{
|
||||
bf6x32_t bf6_vector;
|
||||
bf6_t bf6_array[32];
|
||||
} out{};
|
||||
using array_type = uint8_t __attribute__((ext_vector_type(32)));
|
||||
array_type uint8_array;
|
||||
|
||||
// collect the 6-bit values into an array
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
out.bf6_array[i] = utils::sat_convert_to_type<bf6_t>(in.float_array[i] / scale);
|
||||
uint8_array[static_cast<index_t>(i)] =
|
||||
utils::sat_convert_to_type<bf6_t>(in.float_array[i] / scale);
|
||||
});
|
||||
|
||||
return out.bf6_vector;
|
||||
return bf6x32_t{bf6x32_pk_t{uint8_array}};
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -2257,6 +2255,37 @@ inline __host__ __device__ bf6x32_pk_t type_convert<bf6x32_pk_t, float32_t>(floa
|
||||
return static_cast<bf6x32_pk_t>(type_convert<bf6x32_t>(x));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ bf6x16_t type_convert<bf6x16_t, float16_t>(float16_t x)
|
||||
{
|
||||
|
||||
union
|
||||
{
|
||||
float16_t v16x2[2];
|
||||
float32_t v32;
|
||||
} in{{x, x}};
|
||||
|
||||
union
|
||||
{
|
||||
bf6x32_t v32;
|
||||
bf6x16_t v16x2[2];
|
||||
} out{};
|
||||
|
||||
#if CK_USE_SR_F6_CONVERSION
|
||||
out.v32 = bf6_convert_sr(in.v32);
|
||||
#else
|
||||
out.v32 = bf6_convert_rne(in.v32);
|
||||
#endif
|
||||
|
||||
return out.v16x2[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ bf6x16_pk_t type_convert<bf6x16_pk_t, float16_t>(float16_t x)
|
||||
{
|
||||
return static_cast<bf6x16_pk_t>(type_convert<bf6x16_t>(x));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Specializes the type conversion template for converting a bf6_t value to float.
|
||||
*
|
||||
@@ -2329,6 +2358,32 @@ inline __host__ __device__ float32_t type_convert<float32_t, bf6x32_t>(bf6x32_t
|
||||
return out.float_vector;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ float16_t type_convert<float16_t, bf6x16_t>(bf6x16_t x)
|
||||
{
|
||||
union
|
||||
{
|
||||
bf6x16_t v16x2[2];
|
||||
bf6x32_t v32;
|
||||
} in{{x, x}};
|
||||
|
||||
union
|
||||
{
|
||||
float16_t v16x2[2];
|
||||
float32_t v32;
|
||||
} out{};
|
||||
|
||||
out.v32 = type_convert<float32_t>(in.v32);
|
||||
return out.v16x2[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ float16_t type_convert<float16_t, bf6x16_pk_t>(bf6x16_pk_t x)
|
||||
{
|
||||
return type_convert<float16_t>(static_cast<bf6x16_t>(x));
|
||||
}
|
||||
|
||||
#endif
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
template <typename Y, typename X, size_t NumElems>
|
||||
|
||||
Reference in New Issue
Block a user