From 2ea8d8313c2578390d4f1e868dc5d8362e2945f0 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 8 Dec 2025 04:54:22 +0000 Subject: [PATCH] Using explicit vgpr-saved partition_index with store_tile(lds_window, ...) --- ...mha_pipeline_qr_ks_vs_whole_k_prefetch.hpp | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index 8490d3d15e..23c84cba9e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -359,6 +359,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch statically_indexed_array v_tiles; + // provide partition_index for LDS tile window with so that warp_id is in vgpr + array partition_index{get_warp_id(), get_lane_id()}; + do { // STAGE 1, Gemm_0 ( S = Q@K ) @@ -369,9 +372,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration { static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile( - k_lds_write_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); + store_tile(k_lds_write_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); if constexpr(i_k1 < k1_loops - 1) { @@ -410,9 +413,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch else // the iteration is also the last iteration { static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile( - k_lds_write_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); + store_tile(k_lds_write_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); if constexpr(i_k1 < k1_loops - 1) { @@ -445,9 +448,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration { static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile( - k_lds_write_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); + store_tile(k_lds_write_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); if constexpr(i_k1 < NumPrefetchV) { @@ -482,9 +485,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch else // last iteration { static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile( - k_lds_write_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); + store_tile(k_lds_write_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); if constexpr(i_k1 < NumPrefetchV) { @@ -511,7 +514,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch { static_for<0, k1_loops, 1>{}([&](auto i_k1) { store_tile(k_lds_write_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[I0])); + tile_elementwise_in(k_element_func, k_tiles[I0]), + partition_index); __builtin_amdgcn_sched_barrier(0x00000001); @@ -678,7 +682,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch __builtin_amdgcn_s_barrier(); }; - store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile); + store_tile( + v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index); __builtin_amdgcn_sched_barrier(0x00000001); @@ -717,7 +722,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch tile_elementwise_in(v_element_func, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], - v_shuffled_tile); + v_shuffled_tile, + partition_index); }; });