Use partition_index parameter for all get_x_indices_from_distributed_indices() calls

This commit is contained in:
Qianfeng Zhang
2026-05-22 15:08:30 +00:00
parent 65992be728
commit 86d8d72008
2 changed files with 6 additions and 2 deletions

View File

@@ -438,7 +438,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
pcomp_tile.get_tile_distribution(),
make_tuple(idx0, idx1),
partition_index);
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);

View File

@@ -447,7 +447,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
pcomp_tile.get_tile_distribution(),
make_tuple(idx0, idx1),
partition_index);
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);