mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Update clipping for fp8/bf8 conversion (#1182)
* Update clipping for fp8 conversion
* Add clipping for bf8 conversion
* Format
[ROCm/composable_kernel commit: acfb339238]
This commit is contained in:
@@ -109,9 +109,6 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
float max_fp8 = 240.0f;
|
||||
if(!std::isinf(x))
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
{
|
||||
@@ -119,10 +116,15 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
const float max_fp8 = 240.0f;
|
||||
// if x is not +/- infinity or nan
|
||||
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
|
||||
// clip float value
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
|
||||
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
return val.i8val[0]; // little endian
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
@@ -166,10 +168,15 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
const float max_bf8 = 57344.0f;
|
||||
// if x is not +/- infinity or nan
|
||||
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
|
||||
// clip float value
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8);
|
||||
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
return val.i8val[0]; // little endian
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
@@ -208,9 +215,6 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
|
||||
template <>
|
||||
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
|
||||
{
|
||||
float max_fp8 = 240.0f;
|
||||
if(!std::isinf(x))
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
{
|
||||
@@ -218,8 +222,13 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
const float max_fp8 = 240.0f;
|
||||
// if x is not +/- infinity or nan
|
||||
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
|
||||
// clip float value
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
@@ -263,8 +272,13 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
const float max_bf8 = 57344.0f;
|
||||
// if x is not +/- infinity or nan
|
||||
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
|
||||
// clip float value
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8);
|
||||
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
|
||||
Reference in New Issue
Block a user