diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index b989094c0e..dbac1f0c85 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -109,9 +109,6 @@ inline __host__ __device__ f8_t f8_convert_sr(float x) { constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); - float max_fp8 = 240.0f; - if(!std::isinf(x)) - x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); #if defined(__gfx94__) union { @@ -119,10 +116,15 @@ inline __host__ __device__ f8_t f8_convert_sr(float x) uint32_t i32val; uint8_t i8val[4]; // not endian independent } val; - val.fval = x; - uint32_t ival = 0; - ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos - val.i32val = ival; + val.fval = x; + uint32_t ival = 0; + const float max_fp8 = 240.0f; + // if x is not +/- infinity or nan + if((val.i32val & NumericUtils::nan_mask) != NumericUtils::Inf) + // clip float value + val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8); + ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; return val.i8val[0]; // little endian #else constexpr bool negative_zero_nan = true; @@ -166,10 +168,15 @@ inline __host__ __device__ bf8_t f8_convert_sr(float x) uint32_t i32val; uint8_t i8val[4]; // not endian independent } val; - val.fval = x; - uint32_t ival = 0; - ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos - val.i32val = ival; + val.fval = x; + uint32_t ival = 0; + const float max_bf8 = 57344.0f; + // if x is not +/- infinity or nan + if((val.i32val & NumericUtils::nan_mask) != NumericUtils::Inf) + // clip float value + val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8); + ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; return val.i8val[0]; // little endian #else constexpr bool negative_zero_nan = true; @@ -208,9 +215,6 @@ __host__ __device__ constexpr Y f8_convert_rne(X x); template <> inline __host__ __device__ f8_t f8_convert_rne(float x) { - float max_fp8 = 240.0f; - if(!std::isinf(x)) - x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); #if defined(__gfx94__) union { @@ -218,8 +222,13 @@ inline __host__ __device__ f8_t f8_convert_rne(float x) uint32_t i32val; uint8_t i8val[4]; // not endian independent } val; - val.fval = x; - uint32_t ival = 0; + val.fval = x; + uint32_t ival = 0; + const float max_fp8 = 240.0f; + // if x is not +/- infinity or nan + if((val.i32val & NumericUtils::nan_mask) != NumericUtils::Inf) + // clip float value + val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8); ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0 val.i32val = ival; return val.i8val[0]; @@ -263,8 +272,13 @@ inline __host__ __device__ bf8_t f8_convert_rne(float x) uint32_t i32val; uint8_t i8val[4]; // not endian independent } val; - val.fval = x; - uint32_t ival = 0; + val.fval = x; + uint32_t ival = 0; + const float max_bf8 = 57344.0f; + // if x is not +/- infinity or nan + if((val.i32val & NumericUtils::nan_mask) != NumericUtils::Inf) + // clip float value + val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8); ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0 val.i32val = ival; return val.i8val[0];