mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Clip fp8 to +/-240 on all targets. (#1172)
* clip fp8 to +/-240 on all targets
* if inputs to fp8 conversion are +/-inf, they remain unaltered
* increase tolerance for test_elementwise_layernorm to prevent false errors
* change the input values for gemm examples to floats
* reduce gemm example float input values to prevent errors
* increase the tolerance for gemm examples
[ROCm/composable_kernel commit: d0c7b45150]
This commit is contained in:
@@ -107,11 +107,12 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
|
||||
template <>
|
||||
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
float max_fp8 = 240.0f;
|
||||
if(!std::isinf(x))
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
#if defined(__gfx94__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
@@ -144,7 +145,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 42;
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
@@ -156,7 +157,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
@@ -191,7 +192,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 42;
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
@@ -207,9 +208,10 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
|
||||
template <>
|
||||
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
if(!std::isinf(x))
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
|
||||
Reference in New Issue
Block a user