From 6cacf6370f3b95812add7b26c258c9cdcc15df7d Mon Sep 17 00:00:00 2001 From: Lakhinder Walia <139581206+lakhinderwalia@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:24:51 -0800 Subject: [PATCH] fast_gelu: minor code reorg to enhance ref & gpu performance (#1162) [ROCm/composable_kernel commit: 1f306024d01ed4ebf66f226c882fdcaa7ae207a7] --- .../element/unary_element_wise_operation.hpp | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index db89a79723..70c72bf768 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -458,27 +458,29 @@ struct FastGelu template <> __host__ void operator()(float& y, const float& x) const { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = exp(-u); - const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); - - y = x * cdf; + // const float u = -2.f * x * (0.035677f * x * x + 0.797885f); + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u = x * (c1 * x * x + c2); + const float emu = exp(u); + y = x / (1.f + emu); } // device code, use lower precision "__expf" and "rcp" template <> __device__ void operator()(float& y, const float& x) const { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = __expf(-u); + // const float u = 2.f * x * (0.035677f * x * x + 0.797885f); + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u = x * (c1 * x * x + c2); + const float emu = __expf(u); #if !CK_WORKAROUND_SWDEV_383542 - const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f); + y = x * __frcp_rn(1.f + emu); #else - const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f); + y = x * __ocml_native_recip_f32(1.f + emu); #endif - - y = x * cdf; } template <>