Pass partition_index to get_x_indices_from_distributed_indices() to reduce calls of __builtin_amdgcn_readfirstlane()

This commit is contained in:
Qianfeng Zhang
2026-02-21 14:46:31 +00:00
parent f2a555dac7
commit 2be2c3cd11
5 changed files with 14 additions and 5 deletions

View File

@@ -519,7 +519,8 @@ struct HstuAttentionFwdKernel
seqlen_in_first_split = kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen;
index_t num_tile_in_first_split =
ck_tile::integer_divide_ceil(seqlen_in_first_split, HstuAttentionPipeline::kM0);
__builtin_amdgcn_readfirstlane(ck_tile::integer_divide_ceil(
seqlen_in_first_split, HstuAttentionPipeline::kM0));
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);

View File

@@ -361,7 +361,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
pcomp_tile.get_tile_distribution(),
make_tuple(idx0, idx1),
partition_index);
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});

View File

@@ -363,7 +363,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
pcomp_tile.get_tile_distribution(),
make_tuple(idx0, idx1),
partition_index);
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});

View File

@@ -392,7 +392,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
pcomp_tile.get_tile_distribution(),
make_tuple(idx0, idx1),
partition_index);
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});

View File

@@ -394,7 +394,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
pcomp_tile.get_tile_distribution(),
make_tuple(idx0, idx1),
partition_index);
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});