diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 7fd15b2833..eb281af7bb 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -49,7 +49,7 @@ struct ProblemSizeStreamK final struct ExecutionConfig final { bool do_verification = true; - int init_method = 1; + int init_method = 2; bool time_kernel = false; }; diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 7be2539d90..49743a9c43 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -69,8 +69,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); break; default: - ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); - ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + ck::utils::FillUniformDistribution{-0.1f, 0.1f}(a_m_k); + ck::utils::FillUniformDistribution{-0.1f, 0.1f}(b_k_n); } Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); @@ -240,7 +240,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) #else c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + return ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 1e-1, 1e-1); #endif } diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 6bbff98312..b989094c0e 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -107,11 +107,12 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); template <> inline __host__ __device__ f8_t f8_convert_sr(float x) { - constexpr int seed = 42; + constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); + float max_fp8 = 240.0f; + if(!std::isinf(x)) + x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); #if defined(__gfx94__) - float max_fp8 = 240.0f; - x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); union { float fval; @@ -144,7 +145,7 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; - constexpr int seed = 42; + constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); return utils:: cast_to_f8( @@ -156,7 +157,7 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) template <> inline __host__ __device__ bf8_t f8_convert_sr(float x) { - constexpr int seed = 42; + constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); #if defined(__gfx94__) union @@ -191,7 +192,7 @@ inline __host__ __device__ bf8_t f8_convert_sr(half_t x) constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; - constexpr int seed = 42; + constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); return utils:: cast_to_f8( @@ -207,9 +208,10 @@ __host__ __device__ constexpr Y f8_convert_rne(X x); template <> inline __host__ __device__ f8_t f8_convert_rne(float x) { -#if defined(__gfx94__) float max_fp8 = 240.0f; - x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); + if(!std::isinf(x)) + x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); +#if defined(__gfx94__) union { float fval; diff --git a/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp b/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp index ae42919db6..220076465d 100644 --- a/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp @@ -233,7 +233,7 @@ bool profile_elementwise_layernorm_impl(int do_verification, y_dev.FromDevice(y.mData.data()); bool pass = - ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3); + ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 5e-3, 5e-3); if(do_log) {