From 238331dbb5730b20bc2a147fb6d5c5bc706b052a Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 16 Apr 2025 14:38:14 +0000 Subject: [PATCH] slightly better --- include/ck_tile/core/numeric/math.hpp | 36 +++++++++++-------- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 8 ++--- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 5fff2e2644..8ed918d471 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -1383,27 +1383,33 @@ CK_TILE_DEVICE double exp(double x) template CK_TILE_DEVICE T tanh_fast(T x) { - return type_convert((exp_fast_exp2(2.0 * type_convert(x)) - 1.0) / (exp_fast_exp2(2.0 * type_convert(x)) + 1.0)); + return x; + // return type_convert((exp_fast_exp2(2.0 * type_convert(x)) - 1.0) / (exp_fast_exp2(2.0 * type_convert(x)) + 1.0)); }; template <> CK_TILE_DEVICE float tanh_fast(float x) { // return (exp_fast_exp2(2.0f * x) - 1.0f) / (exp_fast_exp2(2.0f * x) + 1.0f); - float e, r, s, t, d; - float a = x; - s = abs(a); - t = -log2e_v * 2.0f * s; - e = __builtin_amdgcn_exp2f(t); - d = e + 1.0f; - r = __builtin_amdgcn_rcpf(d); - r = e * (-r) + r; - if (s < 4.997253418e-3f) r = a; - union fipnr {float f; unsigned int i;}; - fipnr r_; r_.f = r; - fipnr a_; a_.f = a; - { r_.i = (r_.i|(a_.i&0x80000000)); r = r_.f; } - return r; + // float a = __builtin_amdgcn_exp2f(2.88539f * x); + // float b = __builtin_amdgcn_rcpf(a + 1.0f); + // return 1.0f - 2.0f * b; + // float x2 = x * x; + // float x5 = x * 0.13333f - 0.33333f; + // float x3 = x2 * x5 + 1.0f; + // return x3 * x; + // 1.442695 + // float a = 2.0f * log2e_v * x; + // a = __builtin_amdgcn_exp2f(a); + // a = __builtin_amdgcn_rcpf(a + 1.0f); + // a = 2 * a; + // a = 1 - a; + // return a; + float a = __builtin_amdgcn_exp2f(2.88539f * x); + float b = __builtin_amdgcn_rcpf(a + 1.0f); + return 1.0f - 2.0f * b; + + }; template diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 8c832e5352..2af044d539 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -175,7 +175,7 @@ struct BlockFmhaPipelineQRKSVS "wrong!"); const float logits_cap = 30.0f; - const float logits_cap_scale = scale_s / (logits_cap * log2e_v); + const float logits_cap_rev = __builtin_amdgcn_rcpf(logits_cap); // K tile in LDS KDataType* k_lds_ptr = static_cast(static_cast( @@ -427,11 +427,11 @@ struct BlockFmhaPipelineQRKSVS else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + float scale_lo = scale_s * 0.6931472f; // #if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = (x * scale_s) * __builtin_amdgcn_rcpf(log2e_v<>); }, s_acc); tile_elementwise_inout( - [&logits_cap_scale, &logits_cap](auto& x) { - x = log2e_v * logits_cap * tanh_fast(x * __builtin_amdgcn_rcpf(logits_cap)); + [&scale_lo, &logits_cap, &logits_cap_rev](auto& x) { + x = log2e_v * logits_cap * tanh_fast(x * scale_lo * logits_cap_rev); }, s_acc );