Move the calling of mask.GetTileRangeAlongX() to the kernel

This commit is contained in:
Qianfeng Zhang
2026-03-28 14:19:22 +00:00
parent eefe426ef7
commit 423cc72bc4
5 changed files with 38 additions and 8 deletions

View File

@@ -889,10 +889,17 @@ struct HstuAttentionFwdKernel
};
}();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(i_m0,
number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kN0>{});
return HstuAttentionPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
seqlen_k_start,
seqlen_k_end,
mask,
kargs.scale_s,
kargs.scale_p,
@@ -917,10 +924,17 @@ struct HstuAttentionFwdKernel
};
}();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(i_m0,
number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kN0>{});
return HstuAttentionPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
seqlen_k_start,
seqlen_k_end,
mask,
kargs.scale_s,
kargs.scale_p,

View File

@@ -133,6 +133,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask& mask,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLu result
@@ -186,8 +188,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
auto q_tile = load_tile(q_dram_window);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -467,6 +467,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask mask,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLU result
@@ -482,6 +484,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
identity{},
identity{},
identity{},
seqlen_k_start,
seqlen_k_end,
mask,
scale_s,
scale_p,

View File

@@ -133,6 +133,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask& mask,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLu result
@@ -186,8 +188,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
auto q_tile = load_tile(q_dram_window);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -444,6 +444,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask mask,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLU result
@@ -459,6 +461,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
identity{},
identity{},
identity{},
seqlen_k_start,
seqlen_k_end,
mask,
scale_s,
scale_p,

View File

@@ -133,6 +133,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask& mask,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLu result
@@ -199,8 +201,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
auto q_tile = load_tile(q_dram_window);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -593,6 +593,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask mask,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLU result
@@ -608,6 +610,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
identity{},
identity{},
identity{},
seqlen_k_start,
seqlen_k_end,
mask,
scale_s,
scale_p,

View File

@@ -133,6 +133,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask& mask,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLu result
@@ -199,8 +201,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
auto q_tile = load_tile(q_dram_window);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -591,6 +591,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask mask,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLU result
@@ -606,6 +608,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
identity{},
identity{},
identity{},
seqlen_k_start,
seqlen_k_end,
mask,
scale_s,
scale_p,