diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 6bdcb509b0..5fff2e2644 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -1380,6 +1380,32 @@ CK_TILE_DEVICE double exp(double x) return exp(x); }; +template +CK_TILE_DEVICE T tanh_fast(T x) +{ + return type_convert((exp_fast_exp2(2.0 * type_convert(x)) - 1.0) / (exp_fast_exp2(2.0 * type_convert(x)) + 1.0)); +}; + +template <> +CK_TILE_DEVICE float tanh_fast(float x) +{ + // return (exp_fast_exp2(2.0f * x) - 1.0f) / (exp_fast_exp2(2.0f * x) + 1.0f); + float e, r, s, t, d; + float a = x; + s = abs(a); + t = -log2e_v * 2.0f * s; + e = __builtin_amdgcn_exp2f(t); + d = e + 1.0f; + r = __builtin_amdgcn_rcpf(d); + r = e * (-r) + r; + if (s < 4.997253418e-3f) r = a; + union fipnr {float f; unsigned int i;}; + fipnr r_; r_.f = r; + fipnr a_; a_.f = a; + { r_.i = (r_.i|(a_.i&0x80000000)); r = r_.f; } + return r; +}; + template CK_TILE_DEVICE T log(T x) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index aa25d75176..cbb2fe5504 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -174,7 +174,7 @@ struct BlockFmhaPipelineQRKSVS kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - const float logits_cap = 0.3f; + const float logits_cap = 30.0f; const float logits_cap_scale = scale_s / (logits_cap * log2e_v); // K tile in LDS