From 409ec3b56e723c9eb613fa144657ae4c35bfabd1 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 15 Dec 2025 09:54:49 +0000 Subject: [PATCH] Fix move_tile_window(k_dram_window, ..) step in the pipeline --- ...mha_pipeline_qr_ks_vs_whole_k_prefetch.hpp | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index f3ba54a35c..e1bef4715a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -434,7 +434,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch do { // STAGE 1, Gemm_0 ( S = Q@K ) - if constexpr(kPreloadWholeNextIterationK) + if constexpr(kPreloadWholeNextIterationK) // used when kM0 = 64 { if(seqlen_k_curr == seqlen_k_start) // at first iteration { @@ -525,18 +525,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch { v_tiles[i_n0] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); - - // prefetch k_tile for next iteration - k_tiles[i_n0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kN0Sub, 0}); }; - // prefetch other k_tiles for next iteration - if constexpr(i_n0 >= NumPrefetchV) - { - k_tiles[i_n0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kN0Sub, 0}); - }; + // prefetch k_tile for next iteration + k_tiles[i_n0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); block_sync_lds(); gemm_0(sacc_tile, @@ -579,7 +572,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch }; } } - else // only preload one unroll of K for next iteration + else // only preload one unroll of K for next iteration, used when kM0=128 { static_for<0, n0_loops, 1>{}([&](auto i_n0) { store_tile(k_lds_write_windows[number{}], @@ -765,7 +758,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch if(seqlen_k_curr < seqlen_k_end) { k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kN0, 0}); + move_tile_window(k_dram_window, {kN0Sub, 0}); }; }