From f0f3b65e2a7894d5a5c4e1e09afe103290c77fca Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 27 Feb 2024 12:31:05 -0800 Subject: [PATCH] Clip fp8 to +/-240 on all targets. (#1172) * clip fp8 to +/-240 on all targets * if inputs to fp8 conversion are +/-inf, they remain unaltered * increase tolerance for test_elementwise_layernorm to prevent false errors * change the input values for gemm examples to floats * reduce gemm example float input values to prevent errors * increase the tolerance for gemm examples [ROCm/composable_kernel commit: d0c7b45150695c2dff205c4ddc9cce2a2e6a2950] --- example/01_gemm/common.hpp | 2 +- example/01_gemm/run_gemm_example.inc | 7 ++++--- include/ck/utility/type_convert.hpp | 18 ++++++++++-------- .../profile_elementwise_layernorm_impl.hpp | 2 +- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 7fd15b2833..eb281af7bb 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -49,7 +49,7 @@ struct ProblemSizeStreamK final struct ExecutionConfig final { bool do_verification = true; - int init_method = 1; + int init_method = 2; bool time_kernel = false; }; diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 7be2539d90..49743a9c43 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -69,8 +69,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); break; default: - ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); - ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + ck::utils::FillUniformDistribution{-0.1f, 0.1f}(a_m_k); + ck::utils::FillUniformDistribution{-0.1f, 0.1f}(b_k_n); } Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); @@ -240,7 +240,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) #else c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + return ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 1e-1, 1e-1); #endif } diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 6bbff98312..b989094c0e 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -107,11 +107,12 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); template <> inline __host__ __device__ f8_t f8_convert_sr(float x) { - constexpr int seed = 42; + constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); + float max_fp8 = 240.0f; + if(!std::isinf(x)) + x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); #if defined(__gfx94__) - float max_fp8 = 240.0f; - x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); union { float fval; @@ -144,7 +145,7 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; - constexpr int seed = 42; + constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); return utils:: cast_to_f8( @@ -156,7 +157,7 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) template <> inline __host__ __device__ bf8_t f8_convert_sr(float x) { - constexpr int seed = 42; + constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); #if defined(__gfx94__) union @@ -191,7 +192,7 @@ inline __host__ __device__ bf8_t f8_convert_sr(half_t x) constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; - constexpr int seed = 42; + constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); return utils:: cast_to_f8( @@ -207,9 +208,10 @@ __host__ __device__ constexpr Y f8_convert_rne(X x); template <> inline __host__ __device__ f8_t f8_convert_rne(float x) { -#if defined(__gfx94__) float max_fp8 = 240.0f; - x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); + if(!std::isinf(x)) + x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); +#if defined(__gfx94__) union { float fval; diff --git a/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp b/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp index ae42919db6..220076465d 100644 --- a/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp @@ -233,7 +233,7 @@ bool profile_elementwise_layernorm_impl(int do_verification, y_dev.FromDevice(y.mData.data()); bool pass = - ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3); + ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 5e-3, 5e-3); if(do_log) {