diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp index e3ef73332b..f74634f1d2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp @@ -198,6 +198,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad constexpr bool kPreloadWholeNextIterationK = Policy::template IsPreloadWholeNextIterationK(); + // This path prefetches two k_tiles for next iteration, so it has the opportunity to + // prefetch two v_tiles during Gemm0 + if constexpr(!kPreloadWholeNextIterationK) + { + static_assert(NumPrefetchV >= 2); + }; + // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); @@ -243,17 +250,22 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad using k_tile_type = decltype(load_tile(k_dram_window)); - // only prefetch two k tiles to save vgprs consumption auto k_tiles = [&]() { if constexpr(kPreloadWholeNextIterationK) return statically_indexed_array{}; else - return statically_indexed_array{}; + return statically_indexed_array{}; }(); k_tiles[I0] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kN0Sub, 0}); + if constexpr(!kPreloadWholeNextIterationK) + { + k_tiles[I1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + __builtin_amdgcn_sched_barrier(0x00000001); // provide partition_index for LDS tile window with so that warp_id is in vgpr @@ -495,20 +507,20 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad { static_for<0, n0_loops, 1>{}([&](auto i_n0) { store_tile(k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[I0]), + tile_elementwise_in(k_element_func, k_tiles[number{}]), partition_index); __builtin_amdgcn_sched_barrier(0x00000001); - if constexpr(i_n0 < n0_loops - 1) + if constexpr(i_n0 < n0_loops - 2) { - k_tiles[I0] = load_tile(k_dram_window); + k_tiles[number{}] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kN0Sub, 0}); }; - if constexpr(i_n0 == n0_loops - 1) + if constexpr(i_n0 >= n0_loops - 2) { - v_tiles[I0] = load_tile(v_dram_window); + v_tiles[number{}] = load_tile(v_dram_window); move_tile_window(v_dram_window, {kK1, 0}); }; @@ -607,10 +619,20 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad __builtin_amdgcn_sched_barrier(0x00000001); - static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) { - v_tiles[i_k1] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {kK1, 0}); - }); + if constexpr(kPreloadWholeNextIterationK) + { + static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) { + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }); + } + else + { + static_for<2, NumPrefetchV, 1>{}([&](auto i_k1) { + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }); + }; __builtin_amdgcn_sched_barrier(0); @@ -698,6 +720,18 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad } }; + if constexpr(i_k1 == k1_loops - NumPrefetchV + 1) + { + if constexpr(!kPreloadWholeNextIterationK) + { + if(seqlen_k_curr < seqlen_k_end) + { + k_tiles[I1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + } + }; + block_sync_lds(); gemm_1( o_acc,