From 86d8d72008b80c1a5547faa6d0c5318eaecd0dfc Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 22 May 2026 15:08:30 +0000 Subject: [PATCH] Use partition_index parameter for all get_x_indices_from_distributed_indices() calls --- .../hstu_attention_with_softmax_fwd_pipeline.hpp | 4 +++- .../hstu_attention_with_softmax_fwd_trload_pipeline.hpp | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index 535c0ced81..c282f4946e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -438,7 +438,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( - pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + pcomp_tile.get_tile_distribution(), + make_tuple(idx0, idx1), + partition_index); const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); constexpr auto i_j_idx = make_tuple(idx0, idx1); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 3a0584b9f9..52201a74df 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -447,7 +447,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( - pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + pcomp_tile.get_tile_distribution(), + make_tuple(idx0, idx1), + partition_index); const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); constexpr auto i_j_idx = make_tuple(idx0, idx1);