mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Refactor f8_t, add bf8_t (#792)
* Refactor f8_t to add bf8_t * Add check_err impl for f8_t * Update fp8 test * Format * Revert the fix * Update vector_type implementation * Add bf8 test * Add bf8, use BitInt types * Add bf8 conversion methods * Update type_convert for fp8/bf8 * Add check_err fp8/bf8 support * Add subnorm fp8 tests * Add subnorm bf8 tests * Fix conversion * Add bf8 cmake bindings * Add macros to enable build with disabled fp8/bf8 * Remove is_native method * Update flag combination for mixed precision instances * Add more flag checks * Add another flag to a client example * Add type traits, decouple f8/bf8 casting * Clean up * Decouple fp8 and bf8 flags * Remove more redundant flags * Remove leftover comments
This commit is contained in:
@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
|
||||
}
|
||||
};
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
|
||||
{
|
||||
@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
|
||||
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
|
||||
struct MfmaSelector
|
||||
@@ -640,6 +642,7 @@ struct MfmaSelector
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 32, 32>()
|
||||
{
|
||||
@@ -651,6 +654,7 @@ struct MfmaSelector
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32f8f8;
|
||||
}
|
||||
#endif
|
||||
|
||||
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
|
||||
|
||||
@@ -852,7 +856,11 @@ struct XdlopsGemm
|
||||
{
|
||||
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
|
||||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
|
||||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value,
|
||||
is_same<base_type, int8_t>::value
|
||||
#if defined CK_ENABLE_FP8
|
||||
|| is_same<base_type, f8_t>::value
|
||||
#endif
|
||||
,
|
||||
"base base_type must be double, float, half, bfloat16, and int8_t!");
|
||||
|
||||
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
|
||||
Reference in New Issue
Block a user