From 9aa77620b80324de2b3ada92b581426ea70efcdd Mon Sep 17 00:00:00 2001 From: coderfeli Date: Mon, 21 Apr 2025 05:43:57 +0000 Subject: [PATCH] fix perf --- .../fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 2bc374fbe9..e23f381470 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -448,13 +448,12 @@ struct BlockFmhaPipelineQRKSVSAsync #else if constexpr(kHasLogitsSoftCap) { - float scale_lo = scale_s * 0.6931472f; + float scale_lo = scale_s * 0.6931472f * logits_soft_cap_params.logits_soft_cap_rcp; + float logits_cap = log2e_v * logits_soft_cap_params.logits_soft_cap; tile_elementwise_inout( [&scale_lo, - &logits_cap = logits_soft_cap_params.logits_soft_cap, - &logits_cap_rev = logits_soft_cap_params.logits_soft_cap_rcp](auto& x) { - x = log2e_v * logits_cap * - tanh_fast(x * scale_lo * logits_cap_rev); + &logits_cap](auto& x) { + x = logits_cap * tanh_fast(x * scale_lo); }, s_acc); }