mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Optimize bf16 conversion (#664)
* Add TypeConvert class and start refactoring * Refactor TypeConvert as a struct * Get back to template functions type_convert * Add a type_convert_bf16_rtn, set rtz as default * Clean up * Add UnaryConvertPrecision struct for high-precision workloads * Format * Update type_convert to UnaryConvert on threadwise level * Update UnaryConvertPrecision * Format * Fix chmod * Add a flag to pick converion method * Format * Remove the added flag * Merge elementwise op with type conversion * Move type_convert to elemwise op, update the op * Update type_convert_precision -> bf16_convert_rtn * Clean up * Update comments * Update the CK_WORKAROUND_DENORM_FIX flag handling * Update the unneeded op to work but warn user * Remove the message * Use a PassThrough instead of ConvertBF16RTN to calcaulate reference * Format * Add missing include
This commit is contained in:
@@ -56,6 +56,12 @@ struct PassThrough
|
||||
y = type_convert<bhalf_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
|
||||
{
|
||||
y = type_convert<bhalf_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
|
||||
{
|
||||
@@ -86,6 +92,23 @@ struct UnaryConvert
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertBF16RTN
|
||||
{
|
||||
// convert to bf16 using round to nearest (rtn)
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
{
|
||||
// check Y datatype
|
||||
static_assert(is_same<Y, bhalf_t>::value, "Data type is not supported by this operation!");
|
||||
|
||||
// check X datatype
|
||||
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = bf16_convert_rtn<Y>(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Scale
|
||||
{
|
||||
__host__ __device__ Scale(float scale) : scale_(scale) {}
|
||||
|
||||
Reference in New Issue
Block a user