mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
fast_gelu: minor code reorg to enhance ref & gpu performance (#1162)
This commit is contained in:
@@ -458,27 +458,29 @@ struct FastGelu
|
||||
template <>
|
||||
__host__ void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
|
||||
const float emu = exp(-u);
|
||||
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
|
||||
|
||||
y = x * cdf;
|
||||
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
|
||||
const float c1 = -2.0 * 0.035677f;
|
||||
const float c2 = -2.0 * 0.797885f;
|
||||
const float u = x * (c1 * x * x + c2);
|
||||
const float emu = exp(u);
|
||||
y = x / (1.f + emu);
|
||||
}
|
||||
|
||||
// device code, use lower precision "__expf" and "rcp"
|
||||
template <>
|
||||
__device__ void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
|
||||
const float emu = __expf(-u);
|
||||
// const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
|
||||
const float c1 = -2.0 * 0.035677f;
|
||||
const float c2 = -2.0 * 0.797885f;
|
||||
const float u = x * (c1 * x * x + c2);
|
||||
const float emu = __expf(u);
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_383542
|
||||
const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f);
|
||||
y = x * __frcp_rn(1.f + emu);
|
||||
#else
|
||||
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f);
|
||||
y = x * __ocml_native_recip_f32(1.f + emu);
|
||||
#endif
|
||||
|
||||
y = x * cdf;
|
||||
}
|
||||
|
||||
template <>
|
||||
|
||||
Reference in New Issue
Block a user