From 25521a7e0616012b783a9344e16b61ffe08a2d35 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 5 Dec 2025 15:58:33 +0000 Subject: [PATCH] Switch the codes based on the iteration index (first/intermediate/last) --- ...mha_pipeline_qr_ks_vs_whole_k_prefetch.hpp | 321 +++++++++++------- 1 file changed, 194 insertions(+), 127 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index 5ceab925b6..d159b550a4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -154,9 +154,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch void* smem_ptr, DropoutType& dropout) const { - ignore = q_element_func; - ignore = k_element_func; - // xformers path does not require the pipeline to output random values for host // verification, since a separate kernel is used to generate random values ignore = randval_dram_block_window_tmp; @@ -177,6 +174,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + constexpr index_t k1_loops = kN0 / kK1; static_assert(k1_loops >= 2, @@ -184,6 +184,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); + constexpr index_t NumPrefetchV = 2; + static_assert(k1_loops >= NumPrefetchV, "Check failed!"); + constexpr bool kPreloadWholeNextIterationK = Policy::template IsPreloadWholeNextIterationK(); @@ -218,10 +221,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch q_dram_block_window_tmp.get_window_origin(), Policy::template MakeQRegTileDistribution()); - auto q_tile = load_tile(q_dram_window); - - __builtin_amdgcn_sched_barrier(0); - const auto q_origin = q_dram_window.get_window_origin(); const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); @@ -234,34 +233,22 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch using k_tile_type = decltype(load_tile(k_dram_window)); - constexpr index_t NumPrefetchK = 2; - - static_assert(k1_loops >= NumPrefetchK, "Check failed!"); - // only prefetch two k tiles to save vgprs consumption auto k_tiles = [&]() { if constexpr(kPreloadWholeNextIterationK) return statically_indexed_array{}; else - return statically_indexed_array{}; + return statically_indexed_array{}; }(); - if constexpr(kPreloadWholeNextIterationK) - { - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); - }); - } - else - { - static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) { - k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); - }); - }; + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(0x00000001); + + auto q_tile = load_tile(q_dram_window); + + __builtin_amdgcn_sched_barrier(0x00000001); // K tile in LDS KDataType* k_lds_ptr = static_cast(smem_ptr); @@ -377,51 +364,167 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch // STAGE 1, Gemm_0 ( S = Q@K ) if constexpr(kPreloadWholeNextIterationK) { - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_write_windows[number{}], - k_tiles[i_k1]); + if(seqlen_k_curr == seqlen_k_start) // at first iteration + { + if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration + { + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + store_tile( + k_lds_write_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}])); - __builtin_amdgcn_sched_barrier(0x00000001); + if constexpr(i_k1 < k1_loops - 1) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + } + else + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); - // load v_tiles used in current iteration - v_tiles[i_k1] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); + // prefetch all k_tiles for next iteration + static_for<0, k1_loops, 1>{}([&](auto ii_k1) { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + }); + } - __builtin_amdgcn_sched_barrier(0x00000001); + block_sync_lds(); + gemm_0(sacc_tile, + q_tile, + k_lds_read_windows[number{}]); - block_sync_lds(); + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_k1 * kK1>{}, + sequence{}); + }); + } + else // the iteration is also the last iteration + { + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + store_tile( + k_lds_write_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}])); - // execute current unroll of gemm_0 - gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + if constexpr(i_k1 < k1_loops - 1) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + } + else + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + } - sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + block_sync_lds(); + gemm_0(sacc_tile, + q_tile, + k_lds_read_windows[number{}]); - auto tmp_tile = cast_tile(sacc_tile); + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_k1 * kK1>{}, + sequence{}); + }); + }; + } + else // at intermediate and last iteration + { + if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration + { + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + store_tile( + k_lds_write_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}])); - set_slice_tile(pcomp_tile, - tmp_tile, - sequence<0, i_k1 * kK1>{}, - sequence{}); - }); + if constexpr(i_k1 == 0) + { + // prefetch first v_tile + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; + + // prefetch first two k_tiles for next iteration + if constexpr(i_k1 == 1) + { + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + + k_tiles[I1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + }; + + // prefetch other k_tiles for next iteration + if constexpr(i_k1 >= 2) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + }; + + block_sync_lds(); + gemm_0(sacc_tile, + q_tile, + k_lds_read_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_k1 * kK1>{}, + sequence{}); + }); + } + else // last iteration + { + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + store_tile( + k_lds_write_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}])); + + if constexpr(i_k1 == 0) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; + + block_sync_lds(); + gemm_0(sacc_tile, + q_tile, + k_lds_read_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_k1 * kK1>{}, + sequence{}); + }); + }; + } } - else + else // only preload one unroll of K for next iteration { static_for<0, k1_loops, 1>{}([&](auto i_k1) { store_tile(k_lds_write_windows[number{}], - k_tiles[number{}]); + tile_elementwise_in(k_element_func, k_tiles[I0])); __builtin_amdgcn_sched_barrier(0x00000001); - if constexpr(i_k1 < k1_loops - NumPrefetchK) + if constexpr(i_k1 < k1_loops - 1) { - k_tiles[number{}] = load_tile(k_dram_window); + k_tiles[I0] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); } else { - // load v_tiles used in current iteration - v_tiles[number{}] = - load_tile(v_dram_window); + v_tiles[number{}] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); }; @@ -429,13 +532,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch block_sync_lds(); - // execute current unroll of gemm_0 gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); - sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); - + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); auto tmp_tile = cast_tile(sacc_tile); - set_slice_tile(pcomp_tile, tmp_tile, sequence<0, i_k1 * kK1>{}, @@ -445,14 +545,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch __builtin_amdgcn_sched_barrier(0x000000001); - if constexpr(!kPreloadWholeNextIterationK) - { - static_for{}([&](auto i_k1) { - // load v_tiles used in current iteration - v_tiles[i_k1] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - }); - } + static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) { + v_tiles[i_buf] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }); const auto bias_tile = load_tile(bias_dram_window); // load bias tile @@ -577,16 +673,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch __builtin_amdgcn_sched_barrier(0x00000001); - auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); - - __builtin_amdgcn_sched_barrier(0x00000001); - - using v_shuffled_tile_type = decltype(make_static_distributed_tensor( - Policy::template MakeShuffledVRegTileDistribution())); - - v_shuffled_tile_type v_shuffled_tile; - - shuffle_tile(v_shuffled_tile, v_tiles[number<0>{}]); + auto v_shuffled_tile = make_static_distributed_tensor( + Policy::template MakeShuffledVRegTileDistribution()); + shuffle_tile(v_shuffled_tile, tile_elementwise_in(v_element_func, v_tiles[I0])); // check whether first V-LdsBufer overlap with last K-LdsBuffer, // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 @@ -599,66 +688,44 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch __builtin_amdgcn_sched_barrier(0x00000001); - // STAGE 3, Gemm_1 ( O = P@V ) - if constexpr(kPreloadWholeNextIterationK) + auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(!kPreloadWholeNextIterationK) { - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - // load k_tiles used by next iteration - k_tiles[i_k1] = load_tile(k_dram_window); + if(seqlen_k_curr < seqlen_k_end) + { + k_tiles[I0] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); - - __builtin_amdgcn_sched_barrier(0x00000001); - - block_sync_lds(); - - gemm_1(o_acc, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); - - if constexpr(i_k1 < k1_loops - 1) - { - __builtin_amdgcn_sched_barrier(0x00000001); - - shuffle_tile(v_shuffled_tile, v_tiles[number{}]); - store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], - v_shuffled_tile); - - __builtin_amdgcn_sched_barrier(0x00000001); - }; - }); + }; } - else - { - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - if constexpr(i_k1 < NumPrefetchK) - { - // load k_tiles used by next iteration - k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); - }; - __builtin_amdgcn_sched_barrier(0x00000001); + __builtin_amdgcn_sched_barrier(0x00000001); - block_sync_lds(); + // STAGE 3, Gemm_1 ( O = P@V ) + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + if constexpr(i_k1 < k1_loops - NumPrefetchV) + { + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; - gemm_1(o_acc, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); - if constexpr(i_k1 < k1_loops - 1) - { - __builtin_amdgcn_sched_barrier(0x00000001); - - shuffle_tile(v_shuffled_tile, v_tiles[number{}]); - store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], - v_shuffled_tile); - - __builtin_amdgcn_sched_barrier(0x00000001); - }; - }); - } + if constexpr(i_k1 < k1_loops - 1) + { + shuffle_tile(v_shuffled_tile, + tile_elementwise_in(v_element_func, + v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); + store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], + v_shuffled_tile); + }; + }); // check whether last V-LdsBuffer overlap with first K-LdsBuffer, // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4