fast tanh

This commit is contained in:
Bernard
2025-04-16 09:51:18 +00:00
parent 44584c831a
commit 895666fe2d
2 changed files with 27 additions and 1 deletions

View File

@@ -1380,6 +1380,32 @@ CK_TILE_DEVICE double exp<double>(double x)
return exp(x);
};
template <typename T>
CK_TILE_DEVICE T tanh_fast(T x)
{
return type_convert<T>((exp_fast_exp2<T>(2.0 * type_convert<float>(x)) - 1.0) / (exp_fast_exp2<T>(2.0 * type_convert<float>(x)) + 1.0));
};
template <>
CK_TILE_DEVICE float tanh_fast<float>(float x)
{
// return (exp_fast_exp2<float>(2.0f * x) - 1.0f) / (exp_fast_exp2<float>(2.0f * x) + 1.0f);
float e, r, s, t, d;
float a = x;
s = abs(a);
t = -log2e_v<float> * 2.0f * s;
e = __builtin_amdgcn_exp2f(t);
d = e + 1.0f;
r = __builtin_amdgcn_rcpf(d);
r = e * (-r) + r;
if (s < 4.997253418e-3f) r = a;
union fipnr {float f; unsigned int i;};
fipnr r_; r_.f = r;
fipnr a_; a_.f = a;
{ r_.i = (r_.i|(a_.i&0x80000000)); r = r_.f; }
return r;
};
template <typename T>
CK_TILE_DEVICE T log(T x)
{

View File

@@ -174,7 +174,7 @@ struct BlockFmhaPipelineQRKSVS
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
const float logits_cap = 0.3f;
const float logits_cap = 30.0f;
const float logits_cap_scale = scale_s / (logits_cap * log2e_v<float>);
// K tile in LDS