From d2c653f177e2295f1000dadbe18fc37120855503 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Fri, 18 Apr 2025 10:45:53 +0000 Subject: [PATCH] fix bug --- .../fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 8f32915cff..c7fb168e43 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -193,6 +193,7 @@ struct BlockFmhaPipelineQRKSVSAsync constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); const float logits_cap = 30.0f; + const float logits_cap_rev = 0.0333333f; // const float logits_cap_scale = scale_s * rcp(logits_cap * log2e_v); // K tile in LDS @@ -439,10 +440,11 @@ struct BlockFmhaPipelineQRKSVSAsync { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); // #if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + float scale_lo = scale_s * 0.6931472f; +// #if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout( - [&logits_cap](auto& x) { - x = log2e_v * logits_cap * tanh_fast(x * rcp(logits_cap)); + [&scale_lo, &logits_cap, &logits_cap_rev](auto& x) { + x = log2e_v * logits_cap * tanh_fast(x * scale_lo * logits_cap_rev); }, s_acc );