Implement conditional softmax rescale in non-trload with_softmax pipeline

This commit is contained in:
Qianfeng Zhang
2026-04-24 09:22:44 +00:00
parent d099819657
commit 90e718f73d

View File

@@ -470,6 +470,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
}
else
{
// use the m_old[i] as the m-for-stablization if m_old[i] - m[i] >= -8.0f
// and still keep the m-for-stablization in m[]
if(m_old[i_idx] - m[i_idx] >= -8.0f)
m(i_idx) = m_old[i_idx];
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]);
@@ -491,7 +496,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
{
l(i_idx) = rowsum_p[i_idx];
}
else
else if(m[i_idx] > m_old[i_idx])
{
const auto tmp = f_exp(m_old[i_idx] - m[i_idx]);
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
@@ -500,6 +505,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
o_acc(i_j_idx) *= tmp;
});
}
else
l(i_idx) = l[i_idx] + rowsum_p[i_idx];
});
seqlen_k_curr += kN0;