mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Implement conditional softmax rescale in non-trload with_softmax pipeline
This commit is contained in:
@@ -470,6 +470,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
// use the m_old[i] as the m-for-stablization if m_old[i] - m[i] >= -8.0f
|
||||
// and still keep the m-for-stablization in m[]
|
||||
if(m_old[i_idx] - m[i_idx] >= -8.0f)
|
||||
m(i_idx) = m_old[i_idx];
|
||||
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]);
|
||||
@@ -491,7 +496,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
{
|
||||
l(i_idx) = rowsum_p[i_idx];
|
||||
}
|
||||
else
|
||||
else if(m[i_idx] > m_old[i_idx])
|
||||
{
|
||||
const auto tmp = f_exp(m_old[i_idx] - m[i_idx]);
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
@@ -500,6 +505,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
}
|
||||
else
|
||||
l(i_idx) = l[i_idx] + rowsum_p[i_idx];
|
||||
});
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
Reference in New Issue
Block a user