diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index f3ba54a35c..e1bef4715a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -434,7 +434,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch do { // STAGE 1, Gemm_0 ( S = Q@K ) - if constexpr(kPreloadWholeNextIterationK) + if constexpr(kPreloadWholeNextIterationK) // used when kM0 = 64 { if(seqlen_k_curr == seqlen_k_start) // at first iteration { @@ -525,18 +525,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch { v_tiles[i_n0] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); - - // prefetch k_tile for next iteration - k_tiles[i_n0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kN0Sub, 0}); }; - // prefetch other k_tiles for next iteration - if constexpr(i_n0 >= NumPrefetchV) - { - k_tiles[i_n0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kN0Sub, 0}); - }; + // prefetch k_tile for next iteration + k_tiles[i_n0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); block_sync_lds(); gemm_0(sacc_tile, @@ -579,7 +572,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch }; } } - else // only preload one unroll of K for next iteration + else // only preload one unroll of K for next iteration, used when kM0=128 { static_for<0, n0_loops, 1>{}([&](auto i_n0) { store_tile(k_lds_write_windows[number{}], @@ -765,7 +758,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch if(seqlen_k_curr < seqlen_k_end) { k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kN0, 0}); + move_tile_window(k_dram_window, {kN0Sub, 0}); }; }