From 59a1c6464fc86078d14179f35b2e0d1ee7ca7253 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Thu, 18 Jul 2024 00:15:05 +0800 Subject: [PATCH] Replace the using of __expf by __ocml_exp_f32 to work-around the test_softmax_rank4 failure (#1394) [ROCm/composable_kernel commit: ee768148f0701262e17787067b965e4d5a850d89] --- .../gpu/element/unary_element_wise_operation.hpp | 6 +++--- include/ck/utility/math_v2.hpp | 4 ++-- include/ck_tile/core/numeric/bfloat16.hpp | 5 ++++- include/ck_tile/core/numeric/float8.hpp | 4 ++-- include/ck_tile/core/numeric/half.hpp | 2 +- include/ck_tile/core/numeric/math.hpp | 2 +- 6 files changed, 13 insertions(+), 10 deletions(-) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index c9ca883744..bf4a1c800f 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -431,7 +431,7 @@ struct Relu // https://paperswithcode.com/method/gelu // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) // host code use higher accuracy "exp" and "div" -// gpu code use lower accuracy "__expf" and "rcp" function +// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function struct FastGelu { template @@ -451,7 +451,7 @@ struct FastGelu y = x / (1.f + emu); } - // device code, use lower precision "__expf" and "rcp" + // device code, use lower precision "__ocml_exp_f32" and "rcp" template <> __device__ void operator()(float& y, const float& x) const { @@ -459,7 +459,7 @@ struct FastGelu const float c1 = -2.0 * 0.035677f; const float c2 = -2.0 * 0.797885f; const float u = x * (c1 * x * x + c2); - const float emu = __expf(u); + const float emu = __ocml_exp_f32(u); y = x * ck::math::rcp(1.f + emu); } diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index 2b921cdc7c..d961cdb198 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -839,7 +839,7 @@ inline __device__ T rcp(T x) template inline __device__ T exp(T x) { - return ck::type_convert(__expf(ck::type_convert(x))); + return ck::type_convert(__ocml_exp_f32(ck::type_convert(x))); }; template <> @@ -851,7 +851,7 @@ inline __device__ half_t exp(half_t x) template <> inline __device__ float exp(float x) { - return __expf(x); + return __ocml_exp_f32(x); }; template <> diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 071387163a..4fdf8f9dae 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -331,7 +331,10 @@ bfloat16_t sqrt(bfloat16_t x) }; CK_TILE_DEVICE -bfloat16_t exp(bfloat16_t x) { return static_cast(__expf(static_cast(x))); }; +bfloat16_t exp(bfloat16_t x) +{ + return static_cast(__ocml_exp_f32(static_cast(x))); +}; CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x) { return static_cast(exp2f(static_cast(x))); }; diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index 56ca44e720..b3b1a1f3fb 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -835,7 +835,7 @@ CK_TILE_DEVICE fp8_t sqrt(fp8_t x) { return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); }; CK_TILE_DEVICE -fp8_t exp(fp8_t x) { return static_cast(__expf(static_cast(x))); }; +fp8_t exp(fp8_t x) { return static_cast(__ocml_exp_f32(static_cast(x))); }; CK_TILE_DEVICE fp8_t exp2(fp8_t x) { return static_cast(exp2f(static_cast(x))); }; @@ -860,7 +860,7 @@ CK_TILE_DEVICE bf8_t sqrt(bf8_t x) { return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); }; CK_TILE_DEVICE -bf8_t exp(bf8_t x) { return static_cast(__expf(static_cast(x))); }; +bf8_t exp(bf8_t x) { return static_cast(__ocml_exp_f32(static_cast(x))); }; CK_TILE_DEVICE bf8_t exp2(bf8_t x) { return static_cast(exp2f(static_cast(x))); }; diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index 752145f711..acb6eb6c3e 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -374,7 +374,7 @@ half_t sqrt(half_t x) }; CK_TILE_DEVICE -half_t exp(half_t x) { return static_cast(__expf(static_cast(x))); }; +half_t exp(half_t x) { return static_cast(__ocml_exp_f32(static_cast(x))); }; CK_TILE_DEVICE half_t exp2(half_t x) { return static_cast(exp2f(static_cast(x))); }; diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index d4984363da..9970bb3693 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -519,7 +519,7 @@ CK_TILE_DEVICE double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; CK_TILE_DEVICE -float exp(float x) { return __expf(x); }; +float exp(float x) { return __ocml_exp_f32(x); }; CK_TILE_HOST float exp(float x) { return std::expf(x); }