From d0803f263dbf1888ddab1db63e76ac815f9d8c64 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 27 Apr 2026 06:39:32 +0000 Subject: [PATCH] Mark low probability branch as unlikely in the softmax pipelines --- .../hstu_attention_with_softmax_fwd_pipeline.hpp | 8 +++++--- .../hstu_attention_with_softmax_fwd_trload_pipeline.hpp | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index a5e5830180..2590e1b5ad 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -461,7 +461,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - if(m[i_idx] == -numeric::infinity()) + if(__builtin_expect(m[i_idx] == -numeric::infinity(), 0)) // unlikely { sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -492,7 +492,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - if(m[i_idx] == -numeric::infinity()) + if(__builtin_expect(m[i_idx] == -numeric::infinity(), 0)) // unlikely { l(i_idx) = rowsum_p[i_idx]; } @@ -576,8 +576,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); - if(m[i_idx] == -numeric::infinity()) + if(__builtin_expect(m[i_idx] == -numeric::infinity(), 0)) // unlikely + { o_acc(i_j_idx) = 0.0f; + } else o_acc(i_j_idx) *= 1.0f / l[i_idx]; }); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 2b4cd4b4b4..0ed95270a6 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -467,7 +467,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - if(m[i_idx] == -numeric::infinity()) + if(__builtin_expect(m[i_idx] == -numeric::infinity(), 0)) // unlikely { sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -498,7 +498,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - if(m[i_idx] == -numeric::infinity()) + if(__builtin_expect(m[i_idx] == -numeric::infinity(), 0)) // unlikely { l(i_idx) = rowsum_p[i_idx]; } @@ -582,8 +582,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); - if(m[i_idx] == -numeric::infinity()) + if(__builtin_expect(m[i_idx] == -numeric::infinity(), 0)) // unlikely + { o_acc(i_j_idx) = 0.0f; + } else o_acc(i_j_idx) *= 1.0f / l[i_idx]; });