mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
fast tanh
This commit is contained in:
@@ -1380,6 +1380,32 @@ CK_TILE_DEVICE double exp<double>(double x)
|
||||
return exp(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T tanh_fast(T x)
|
||||
{
|
||||
return type_convert<T>((exp_fast_exp2<T>(2.0 * type_convert<float>(x)) - 1.0) / (exp_fast_exp2<T>(2.0 * type_convert<float>(x)) + 1.0));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float tanh_fast<float>(float x)
|
||||
{
|
||||
// return (exp_fast_exp2<float>(2.0f * x) - 1.0f) / (exp_fast_exp2<float>(2.0f * x) + 1.0f);
|
||||
float e, r, s, t, d;
|
||||
float a = x;
|
||||
s = abs(a);
|
||||
t = -log2e_v<float> * 2.0f * s;
|
||||
e = __builtin_amdgcn_exp2f(t);
|
||||
d = e + 1.0f;
|
||||
r = __builtin_amdgcn_rcpf(d);
|
||||
r = e * (-r) + r;
|
||||
if (s < 4.997253418e-3f) r = a;
|
||||
union fipnr {float f; unsigned int i;};
|
||||
fipnr r_; r_.f = r;
|
||||
fipnr a_; a_.f = a;
|
||||
{ r_.i = (r_.i|(a_.i&0x80000000)); r = r_.f; }
|
||||
return r;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T log(T x)
|
||||
{
|
||||
|
||||
@@ -174,7 +174,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
const float logits_cap = 0.3f;
|
||||
const float logits_cap = 30.0f;
|
||||
const float logits_cap_scale = scale_s / (logits_cap * log2e_v<float>);
|
||||
|
||||
// K tile in LDS
|
||||
|
||||
Reference in New Issue
Block a user