From 90e718f73dca51bae01dacf4d270ed54b95d2940 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 24 Apr 2026 09:22:44 +0000 Subject: [PATCH] Implement conditional softmax rescale in non-trload with_softmax pipeline --- .../hstu_attention_with_softmax_fwd_pipeline.hpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index fb4cee3648..26c1164d32 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -470,6 +470,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS } else { + // use the m_old[i] as the m-for-stablization if m_old[i] - m[i] >= -8.0f + // and still keep the m-for-stablization in m[] + if(m_old[i_idx] - m[i_idx] >= -8.0f) + m(i_idx) = m_old[i_idx]; + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]); @@ -491,7 +496,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS { l(i_idx) = rowsum_p[i_idx]; } - else + else if(m[i_idx] > m_old[i_idx]) { const auto tmp = f_exp(m_old[i_idx] - m[i_idx]); l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; @@ -500,6 +505,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS o_acc(i_j_idx) *= tmp; }); } + else + l(i_idx) = l[i_idx] + rowsum_p[i_idx]; }); seqlen_k_curr += kN0;