diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index a6324bf6ee..7ea53b9346 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -677,17 +677,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch __builtin_amdgcn_sched_barrier(0x00000001); - if constexpr(!kPreloadWholeNextIterationK) - { - if(seqlen_k_curr < seqlen_k_end) - { - k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kN0Sub, 0}); - }; - } - - __builtin_amdgcn_sched_barrier(0x00000001); - // STAGE 3, Gemm_1 ( O = P@V ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { if constexpr(i_k1 < k1_loops - NumPrefetchV) @@ -696,6 +685,18 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch move_tile_window(v_dram_window, {0, kK1}); }; + if constexpr(i_k1 == k1_loops - NumPrefetchV) + { + if constexpr(!kPreloadWholeNextIterationK) + { + if(seqlen_k_curr < seqlen_k_end) + { + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + } + }; + block_sync_lds(); gemm_1( o_acc,