diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index db89a79723..70c72bf768 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -458,27 +458,29 @@ struct FastGelu template <> __host__ void operator()(float& y, const float& x) const { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = exp(-u); - const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); - - y = x * cdf; + // const float u = -2.f * x * (0.035677f * x * x + 0.797885f); + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u = x * (c1 * x * x + c2); + const float emu = exp(u); + y = x / (1.f + emu); } // device code, use lower precision "__expf" and "rcp" template <> __device__ void operator()(float& y, const float& x) const { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = __expf(-u); + // const float u = 2.f * x * (0.035677f * x * x + 0.797885f); + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u = x * (c1 * x * x + c2); + const float emu = __expf(u); #if !CK_WORKAROUND_SWDEV_383542 - const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f); + y = x * __frcp_rn(1.f + emu); #else - const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f); + y = x * __ocml_native_recip_f32(1.f + emu); #endif - - y = x * cdf; } template <>