Refactor f8_t, add bf8_t (#792)

* Refactor f8_t to add bf8_t

* Add check_err impl for f8_t

* Update fp8 test

* Format

* Revert the fix

* Update vector_type implementation

* Add bf8 test

* Add bf8, use BitInt types

* Add bf8 conversion methods

* Update type_convert for fp8/bf8

* Add check_err fp8/bf8 support

* Add subnorm fp8 tests

* Add subnorm bf8 tests

* Fix conversion

* Add bf8 cmake bindings

* Add macros to enable build with disabled fp8/bf8

* Remove is_native method

* Update flag combination for mixed precision instances

* Add more flag checks

* Add another flag to a client example

* Add type traits, decouple f8/bf8 casting

* Clean up

* Decouple fp8 and bf8 flags

* Remove more redundant flags

* Remove leftover comments

[ROCm/composable_kernel commit: 62d4af7449]
This commit is contained in:
Rostyslav Geyyer
2023-09-12 17:04:27 -05:00
committed by GitHub
parent e885110f62
commit 2e227b8581
23 changed files with 739 additions and 172 deletions

View File

@@ -89,6 +89,7 @@ struct PassThrough
}
#endif
#if defined CK_ENABLE_FP8
template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{
@@ -118,6 +119,7 @@ struct PassThrough
{
y = type_convert<f8_t>(x);
}
#endif
};
struct UnaryConvert
@@ -146,6 +148,7 @@ struct ConvertBF16RTN
}
};
#if defined CK_ENABLE_FP8
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
@@ -162,6 +165,7 @@ struct ConvertF8SR
y = f8_convert_sr<Y>(x);
}
};
#endif
struct Scale
{

View File

@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
};
#if defined CK_ENABLE_FP8
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{
@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
#endif
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector
@@ -640,6 +642,7 @@ struct MfmaSelector
}
#endif
#if defined CK_ENABLE_FP8
template <>
static constexpr auto GetMfma<f8_t, 32, 32>()
{
@@ -651,6 +654,7 @@ struct MfmaSelector
{
return MfmaInstr::mfma_f32_16x16x32f8f8;
}
#endif
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
@@ -852,7 +856,11 @@ struct XdlopsGemm
{
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value,
is_same<base_type, int8_t>::value
#if defined CK_ENABLE_FP8
|| is_same<base_type, f8_t>::value
#endif
,
"base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {