mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Mark low probability branch as unlikely in the softmax pipelines
This commit is contained in:
@@ -461,7 +461,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
if(__builtin_expect(m[i_idx] == -numeric<CompDataType>::infinity(), 0)) // unlikely
|
||||
{
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
@@ -492,7 +492,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
if(__builtin_expect(m[i_idx] == -numeric<CompDataType>::infinity(), 0)) // unlikely
|
||||
{
|
||||
l(i_idx) = rowsum_p[i_idx];
|
||||
}
|
||||
@@ -576,8 +576,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
if(__builtin_expect(m[i_idx] == -numeric<CompDataType>::infinity(), 0)) // unlikely
|
||||
{
|
||||
o_acc(i_j_idx) = 0.0f;
|
||||
}
|
||||
else
|
||||
o_acc(i_j_idx) *= 1.0f / l[i_idx];
|
||||
});
|
||||
|
||||
@@ -467,7 +467,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
if(__builtin_expect(m[i_idx] == -numeric<CompDataType>::infinity(), 0)) // unlikely
|
||||
{
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
@@ -498,7 +498,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
if(__builtin_expect(m[i_idx] == -numeric<CompDataType>::infinity(), 0)) // unlikely
|
||||
{
|
||||
l(i_idx) = rowsum_p[i_idx];
|
||||
}
|
||||
@@ -582,8 +582,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
if(__builtin_expect(m[i_idx] == -numeric<CompDataType>::infinity(), 0)) // unlikely
|
||||
{
|
||||
o_acc(i_j_idx) = 0.0f;
|
||||
}
|
||||
else
|
||||
o_acc(i_j_idx) *= 1.0f / l[i_idx];
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user