mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Refactor f8_t, add bf8_t (#792)
* Refactor f8_t to add bf8_t * Add check_err impl for f8_t * Update fp8 test * Format * Revert the fix * Update vector_type implementation * Add bf8 test * Add bf8, use BitInt types * Add bf8 conversion methods * Update type_convert for fp8/bf8 * Add check_err fp8/bf8 support * Add subnorm fp8 tests * Add subnorm bf8 tests * Fix conversion * Add bf8 cmake bindings * Add macros to enable build with disabled fp8/bf8 * Remove is_native method * Update flag combination for mixed precision instances * Add more flag checks * Add another flag to a client example * Add type traits, decouple f8/bf8 casting * Clean up * Decouple fp8 and bf8 flags * Remove more redundant flags * Remove leftover comments
This commit is contained in:
@@ -89,6 +89,7 @@ struct PassThrough
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
|
||||
{
|
||||
@@ -118,6 +119,7 @@ struct PassThrough
|
||||
{
|
||||
y = type_convert<f8_t>(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
struct UnaryConvert
|
||||
@@ -146,6 +148,7 @@ struct ConvertBF16RTN
|
||||
}
|
||||
};
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
struct ConvertF8SR
|
||||
{
|
||||
// convert to fp8 using stochastic rounding (SR)
|
||||
@@ -162,6 +165,7 @@ struct ConvertF8SR
|
||||
y = f8_convert_sr<Y>(x);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
struct Scale
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user