From 1f2e2a272e4d72372d3a63b0be9a928e816455ef Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 24 Apr 2026 09:54:12 +0000 Subject: [PATCH] Implement conditional softmax rescale in trload with_softmax pipeline --- .../hstu_attention_with_softmax_fwd_trload_pipeline.hpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index ee0be7b7b4..2dd879231b 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -476,6 +476,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad } else { + // use the m_old[i] as the m-for-stablization if m_old[i] - m[i] >= -8.0f + // and still keep the m-for-stablization in m[] + if(m_old[i_idx] - m[i_idx] >= -8.0f) + m(i_idx) = m_old[i_idx]; + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]); @@ -497,7 +502,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad { l(i_idx) = rowsum_p[i_idx]; } - else + else if(m[i_idx] > m_old[i_idx]) { const auto tmp = f_exp(m_old[i_idx] - m[i_idx]); l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; @@ -506,6 +511,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad o_acc(i_j_idx) *= tmp; }); } + else + l(i_idx) = l[i_idx] + rowsum_p[i_idx]; }); seqlen_k_curr += kN0;