From 044f554bf7596ef4ff66fb893ad586b84a70907c Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sun, 7 Dec 2025 12:24:45 +0000 Subject: [PATCH] Refine the interleaving in the loop of Gemm0 --- ...mha_pipeline_qr_ks_vs_whole_k_prefetch.hpp | 66 +++++++++---------- ..._ks_vs_whole_k_prefetch_default_policy.hpp | 11 ++++ 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index d159b550a4..8490d3d15e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -179,12 +179,12 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch constexpr index_t k1_loops = kN0 / kK1; - static_assert(k1_loops >= 2, - "k1_loops >= 2 required due to pre-storing two v_tiles to Lds"); + // usually kN0 is 128, kK1 is 32/16 + static_assert(k1_loops >= 2, "k1_loops >= 2 required to use this pipeline"); constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); - constexpr index_t NumPrefetchV = 2; + constexpr index_t NumPrefetchV = Policy::template GetNumPrefetchV(); static_assert(k1_loops >= NumPrefetchV, "Check failed!"); constexpr bool kPreloadWholeNextIterationK = @@ -377,18 +377,22 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch { k_tiles[number{}] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); - } - else - { - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); + }; + if constexpr(i_k1 < NumPrefetchV) + { + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; + + if constexpr(i_k1 == k1_loops - 1) + { // prefetch all k_tiles for next iteration static_for<0, k1_loops, 1>{}([&](auto ii_k1) { k_tiles[number{}] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); }); - } + }; block_sync_lds(); gemm_0(sacc_tile, @@ -414,12 +418,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch { k_tiles[number{}] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); - } - else + }; + + if constexpr(i_k1 < NumPrefetchV) { - v_tiles[I0] = load_tile(v_dram_window); + v_tiles[i_k1] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); - } + }; block_sync_lds(); gemm_0(sacc_tile, @@ -444,27 +449,20 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch k_lds_write_windows[number{}], tile_elementwise_in(k_element_func, k_tiles[number{}])); - if constexpr(i_k1 == 0) + if constexpr(i_k1 < NumPrefetchV) { - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); + v_tiles[i_k1] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); - }; - // prefetch first two k_tiles for next iteration - if constexpr(i_k1 == 1) - { - k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); - - k_tiles[I1] = load_tile(k_dram_window); + // prefetch k_tile for next iteration + k_tiles[i_k1] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); }; // prefetch other k_tiles for next iteration - if constexpr(i_k1 >= 2) + if constexpr(i_k1 >= NumPrefetchV) { - k_tiles[number{}] = load_tile(k_dram_window); + k_tiles[i_k1] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); }; @@ -488,9 +486,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch k_lds_write_windows[number{}], tile_elementwise_in(k_element_func, k_tiles[number{}])); - if constexpr(i_k1 == 0) + if constexpr(i_k1 < NumPrefetchV) { - v_tiles[I0] = load_tile(v_dram_window); + v_tiles[i_k1] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); }; @@ -521,10 +519,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch { k_tiles[I0] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); - } - else + }; + + if constexpr(i_k1 < NumPrefetchV) { - v_tiles[number{}] = load_tile(v_dram_window); + v_tiles[i_k1] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); }; @@ -545,11 +544,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch __builtin_amdgcn_sched_barrier(0x000000001); - static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) { - v_tiles[i_buf] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - }); - const auto bias_tile = load_tile(bias_dram_window); // load bias tile // STAGE 2, scale_s, add bias, mask, softmax diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index f15ac2f08d..0ac65ec094 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -23,6 +23,17 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy return Problem::BlockFmhaShape::kM0 <= 64; }; + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumPrefetchV() + { + constexpr index_t k1_loops = Problem::BlockFmhaShape::kN0 / Problem::BlockFmhaShape::kK1; + + // usually kN0 is 128, kK1 is 32/16 + static_assert(k1_loops >= 2, "Check failed!"); + + return 2; + }; + template CK_TILE_DEVICE static constexpr auto GetNumKVLdsBuffers() {