From 423cc72bc4ac03e04e9efd60ed64aab4659ab495 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 28 Mar 2026 14:19:22 +0000 Subject: [PATCH] Move the calling of mask.GetTileRangeAlongX() to the kernel --- .../hstu_attention_fwd_kernel.hpp | 14 ++++++++++++++ .../hstu_attention_no_softmax_fwd_pipeline.hpp | 8 ++++++-- ...tu_attention_no_softmax_fwd_trload_pipeline.hpp | 8 ++++++-- .../hstu_attention_with_softmax_fwd_pipeline.hpp | 8 ++++++-- ..._attention_with_softmax_fwd_trload_pipeline.hpp | 8 ++++++-- 5 files changed, 38 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index bf01aaf150..aa2bee442e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -889,10 +889,17 @@ struct HstuAttentionFwdKernel }; }(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(i_m0, + number{}, + number{}); + return HstuAttentionPipeline{}(q_dram_window, k_dram_window, v_dram_window, bias_dram_window, + seqlen_k_start, + seqlen_k_end, mask, kargs.scale_s, kargs.scale_p, @@ -917,10 +924,17 @@ struct HstuAttentionFwdKernel }; }(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(i_m0, + number{}, + number{}); + return HstuAttentionPipeline{}(q_dram_window, k_dram_window, v_dram_window, bias_dram_window, + seqlen_k_start, + seqlen_k_end, mask, kargs.scale_s, kargs.scale_p, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp index e706693b85..14cdae8103 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp @@ -133,6 +133,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, + index_t seqlen_k_start, + index_t seqlen_k_end, HstuMask& mask, float scale_s, // scaling value exerted on the immediate Q@K result float scale_p, // scaling value exerted on the SiLu result @@ -186,8 +188,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS auto q_tile = load_tile(q_dram_window); const auto q_origin = q_dram_window.get_window_origin(); - const auto [seqlen_k_start, seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), @@ -467,6 +467,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + index_t seqlen_k_start, + index_t seqlen_k_end, HstuMask mask, float scale_s, // scaling value exerted on the immediate Q@K result float scale_p, // scaling value exerted on the SiLU result @@ -482,6 +484,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS identity{}, identity{}, identity{}, + seqlen_k_start, + seqlen_k_end, mask, scale_s, scale_p, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp index 10e02aaf36..a385443af8 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -133,6 +133,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, + index_t seqlen_k_start, + index_t seqlen_k_end, HstuMask& mask, float scale_s, // scaling value exerted on the immediate Q@K result float scale_p, // scaling value exerted on the SiLu result @@ -186,8 +188,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad auto q_tile = load_tile(q_dram_window); const auto q_origin = q_dram_window.get_window_origin(); - const auto [seqlen_k_start, seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), @@ -444,6 +444,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + index_t seqlen_k_start, + index_t seqlen_k_end, HstuMask mask, float scale_s, // scaling value exerted on the immediate Q@K result float scale_p, // scaling value exerted on the SiLU result @@ -459,6 +461,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad identity{}, identity{}, identity{}, + seqlen_k_start, + seqlen_k_end, mask, scale_s, scale_p, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index c77948e54c..049d08ea41 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -133,6 +133,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, + index_t seqlen_k_start, + index_t seqlen_k_end, HstuMask& mask, float scale_s, // scaling value exerted on the immediate Q@K result float scale_p, // scaling value exerted on the SiLu result @@ -199,8 +201,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS auto q_tile = load_tile(q_dram_window); const auto q_origin = q_dram_window.get_window_origin(); - const auto [seqlen_k_start, seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), @@ -593,6 +593,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + index_t seqlen_k_start, + index_t seqlen_k_end, HstuMask mask, float scale_s, // scaling value exerted on the immediate Q@K result float scale_p, // scaling value exerted on the SiLU result @@ -608,6 +610,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS identity{}, identity{}, identity{}, + seqlen_k_start, + seqlen_k_end, mask, scale_s, scale_p, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 1b3ed7ad71..2306f28b5a 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -133,6 +133,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, + index_t seqlen_k_start, + index_t seqlen_k_end, HstuMask& mask, float scale_s, // scaling value exerted on the immediate Q@K result float scale_p, // scaling value exerted on the SiLu result @@ -199,8 +201,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad auto q_tile = load_tile(q_dram_window); const auto q_origin = q_dram_window.get_window_origin(); - const auto [seqlen_k_start, seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), @@ -591,6 +591,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + index_t seqlen_k_start, + index_t seqlen_k_end, HstuMask mask, float scale_s, // scaling value exerted on the immediate Q@K result float scale_p, // scaling value exerted on the SiLU result @@ -606,6 +608,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad identity{}, identity{}, identity{}, + seqlen_k_start, + seqlen_k_end, mask, scale_s, scale_p,