Implement conditional softmax rescale in trload with_softmax pipeline

This commit is contained in:
Qianfeng Zhang
2026-04-24 09:54:12 +00:00
parent 90e718f73d
commit 1f2e2a272e

View File

@@ -476,6 +476,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
}
else
{
// use the m_old[i] as the m-for-stablization if m_old[i] - m[i] >= -8.0f
// and still keep the m-for-stablization in m[]
if(m_old[i_idx] - m[i_idx] >= -8.0f)
m(i_idx) = m_old[i_idx];
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]);
@@ -497,7 +502,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
{
l(i_idx) = rowsum_p[i_idx];
}
else
else if(m[i_idx] > m_old[i_idx])
{
const auto tmp = f_exp(m_old[i_idx] - m[i_idx]);
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
@@ -506,6 +511,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
o_acc(i_j_idx) *= tmp;
});
}
else
l(i_idx) = l[i_idx] + rowsum_p[i_idx];
});
seqlen_k_curr += kN0;