From 98f9b4a47bce83505e5372ee00bd77ff5e81a204 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 4 Dec 2025 10:25:16 +0000 Subject: [PATCH] Add prefetching whole next iteration K path in the pipeline --- ...mha_pipeline_qr_ks_vs_whole_k_prefetch.hpp | 249 ++++++++++++------ 1 file changed, 166 insertions(+), 83 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index 78a7f18342..5ceab925b6 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -179,8 +179,14 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch constexpr index_t k1_loops = kN0 / kK1; + static_assert(k1_loops >= 2, + "k1_loops >= 2 required due to pre-storing two v_tiles to Lds"); + constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); + constexpr bool kPreloadWholeNextIterationK = + Policy::template IsPreloadWholeNextIterationK(); + // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); @@ -233,12 +239,27 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch static_assert(k1_loops >= NumPrefetchK, "Check failed!"); // only prefetch two k tiles to save vgprs consumption - statically_indexed_array k_tiles; + auto k_tiles = [&]() { + if constexpr(kPreloadWholeNextIterationK) + return statically_indexed_array{}; + else + return statically_indexed_array{}; + }(); - static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) { - k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); - }); + if constexpr(kPreloadWholeNextIterationK) + { + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + k_tiles[i_k1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + }); + } + else + { + static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) { + k_tiles[i_k1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + }); + }; __builtin_amdgcn_sched_barrier(0); @@ -314,7 +335,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch {bias_origin.at(number<0>{}), seqlen_k_start}, Policy::template MakeBiasDramTileDistribution()); - // assuming no random values need be saved, this is try when this pipeline is called from + // assuming no random values need be saved, this is true when the pipeline is called from // xformers, since we have a separate kernel to generated randomm values auto null_randval_window = [&]() { if constexpr(kHasDropout) @@ -354,42 +375,84 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch do { // STAGE 1, Gemm_0 ( S = Q@K ) - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_write_windows[number{}], - k_tiles[number{}]); + if constexpr(kPreloadWholeNextIterationK) + { + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + store_tile(k_lds_write_windows[number{}], + k_tiles[i_k1]); - __builtin_amdgcn_sched_barrier(0x00000001); + __builtin_amdgcn_sched_barrier(0x00000001); - if constexpr(i_k1 < k1_loops - NumPrefetchK) - { - k_tiles[number{}] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); - } - else - { // load v_tiles used in current iteration - v_tiles[number{}] = load_tile(v_dram_window); + v_tiles[i_k1] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); - }; - __builtin_amdgcn_sched_barrier(0x00000001); + __builtin_amdgcn_sched_barrier(0x00000001); - block_sync_lds(); + block_sync_lds(); - // execute current unroll of gemm_0 - gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + // execute current unroll of gemm_0 + gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); - sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); - auto tmp_tile = cast_tile(sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); - set_slice_tile(pcomp_tile, - tmp_tile, - sequence<0, i_k1 * kK1>{}, - sequence{}); - }); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_k1 * kK1>{}, + sequence{}); + }); + } + else + { + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + store_tile(k_lds_write_windows[number{}], + k_tiles[number{}]); - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(i_k1 < k1_loops - NumPrefetchK) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + } + else + { + // load v_tiles used in current iteration + v_tiles[number{}] = + load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; + + __builtin_amdgcn_sched_barrier(0x00000001); + + block_sync_lds(); + + // execute current unroll of gemm_0 + gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + + auto tmp_tile = cast_tile(sacc_tile); + + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_k1 * kK1>{}, + sequence{}); + }); + } + + __builtin_amdgcn_sched_barrier(0x000000001); + + if constexpr(!kPreloadWholeNextIterationK) + { + static_for{}([&](auto i_k1) { + // load v_tiles used in current iteration + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }); + } const auto bias_tile = load_tile(bias_dram_window); // load bias tile @@ -445,32 +508,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch __builtin_amdgcn_sched_barrier(0x00000001); - using v_shuffled_tile_type = decltype(make_static_distributed_tensor( - Policy::template MakeShuffledVRegTileDistribution())); - - v_shuffled_tile_type v_shuffled_tile; - - shuffle_tile(v_shuffled_tile, v_tiles[number<0>{}]); - - // check whether first V-LdsBufer overlap with last K-LdsBuffer, - // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 - if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) - { - __builtin_amdgcn_s_barrier(); - }; - - store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile); - - __builtin_amdgcn_sched_barrier(0x00000001); - - static_for{}([&](auto i_k1) { - // load v_tiles used in current iteration - v_tiles[i_k1] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - }); - - __builtin_amdgcn_sched_barrier(0x00000001); - auto m_local = block_tile_reduce( pcomp_tile, sequence<1>{}, f_max, -numeric::infinity()); block_tile_reduce_sync(m_local, f_max, bool_constant{}); @@ -538,44 +575,90 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch seqlen_k_curr += kN0; + __builtin_amdgcn_sched_barrier(0x00000001); + auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); - // k1_loops >= 2 required - shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]); + __builtin_amdgcn_sched_barrier(0x00000001); - store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile); + using v_shuffled_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeShuffledVRegTileDistribution())); + + v_shuffled_tile_type v_shuffled_tile; + + shuffle_tile(v_shuffled_tile, v_tiles[number<0>{}]); + + // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + { + __builtin_amdgcn_s_barrier(); + }; + + store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile); __builtin_amdgcn_sched_barrier(0x00000001); // STAGE 3, Gemm_1 ( O = P@V ) - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - if constexpr(i_k1 < NumPrefetchK) - { + if constexpr(kPreloadWholeNextIterationK) + { + static_for<0, k1_loops, 1>{}([&](auto i_k1) { // load k_tiles used by next iteration k_tiles[i_k1] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); - }; - - __builtin_amdgcn_sched_barrier(0x00000001); - - block_sync_lds(); - - gemm_1( - o_acc, - get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); - - if constexpr(i_k1 < k1_loops - 2) - { - __builtin_amdgcn_sched_barrier(0x00000001); - - shuffle_tile(v_shuffled_tile, v_tiles[number{}]); - store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}], - v_shuffled_tile); __builtin_amdgcn_sched_barrier(0x00000001); - }; - }); + + block_sync_lds(); + + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); + + if constexpr(i_k1 < k1_loops - 1) + { + __builtin_amdgcn_sched_barrier(0x00000001); + + shuffle_tile(v_shuffled_tile, v_tiles[number{}]); + store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], + v_shuffled_tile); + + __builtin_amdgcn_sched_barrier(0x00000001); + }; + }); + } + else + { + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + if constexpr(i_k1 < NumPrefetchK) + { + // load k_tiles used by next iteration + k_tiles[i_k1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + }; + + __builtin_amdgcn_sched_barrier(0x00000001); + + block_sync_lds(); + + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); + + if constexpr(i_k1 < k1_loops - 1) + { + __builtin_amdgcn_sched_barrier(0x00000001); + + shuffle_tile(v_shuffled_tile, v_tiles[number{}]); + store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], + v_shuffled_tile); + + __builtin_amdgcn_sched_barrier(0x00000001); + }; + }); + } // check whether last V-LdsBuffer overlap with first K-LdsBuffer, // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4