diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp index a6056bcc99..e3ef73332b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp @@ -678,17 +678,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad __builtin_amdgcn_sched_barrier(0x00000001); - if constexpr(!kPreloadWholeNextIterationK) - { - if(seqlen_k_curr < seqlen_k_end) - { - k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kN0Sub, 0}); - }; - } - - __builtin_amdgcn_sched_barrier(0x00000001); - // STAGE 3, Gemm_1 ( O = P@V ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { if constexpr(i_k1 < k1_loops - NumPrefetchV) @@ -697,6 +686,18 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad move_tile_window(v_dram_window, {kK1, 0}); }; + if constexpr(i_k1 == k1_loops - NumPrefetchV) + { + if constexpr(!kPreloadWholeNextIterationK) + { + if(seqlen_k_curr < seqlen_k_end) + { + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + } + }; + block_sync_lds(); gemm_1( o_acc,