From 1a8f2f21fbb219fbe80baa9e743afd794c10e7ff Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Wed, 15 Oct 2025 09:24:44 +0000 Subject: [PATCH] Move scaling by attn_scale to inside the main-loop --- .../18_hstu_attention/hstu_attention_fwd_pipeline.hpp | 6 +++--- .../ck_tile/18_hstu_attention/reference_hstu_attention.hpp | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index dab1620fb7..fc3a2e8bd0 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -450,6 +450,9 @@ struct HstuAttentionFwdPipelineQRKSVS tile_elementwise_inout(f_silu, pcomp_tile); + tile_elementwise_inout([&](auto& x) { x = x * type_convert(scale_p); }, + pcomp_tile); + seqlen_k_curr += kN0; if constexpr(kHasDropout) @@ -507,9 +510,6 @@ struct HstuAttentionFwdPipelineQRKSVS }; } while(seqlen_k_curr < seqlen_k_end); - tile_elementwise_inout([&](auto& x) { x = x * type_convert(scale_p); }, - o_acc); - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); return o_acc; diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index 08f561620a..4b35ef3884 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -191,7 +191,7 @@ struct reference_hstu_attention // SiLu element-wise for(CompDataType& elem : locals) - elem = silu(elem); + elem = silu(elem) * ck_tile::type_convert(scale_p); // second Gemm for(int k = 0; k < hdim_v; k++) @@ -221,8 +221,6 @@ struct reference_hstu_attention }; }; - dot_prod = dot_prod * ck_tile::type_convert(scale_p); - if constexpr(kIsJagged) o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) = ck_tile::type_convert(dot_prod);