diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index bcc85876a9..a6324bf6ee 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -373,14 +373,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch move_tile_window(k_dram_window, {kN0Sub, 0}); }; - if constexpr(i_n0 < NumPrefetchV) - { - v_tiles[i_n0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - }; - if constexpr(i_n0 == n0_loops - 1) { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + // prefetch all k_tiles for next iteration static_for<0, n0_loops, 1>{}([&](auto ii_n0) { k_tiles[number{}] = load_tile(k_dram_window); @@ -413,9 +410,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch move_tile_window(k_dram_window, {kN0Sub, 0}); }; - if constexpr(i_n0 < NumPrefetchV) + if constexpr(i_n0 == n0_loops - 1) { - v_tiles[i_n0] = load_tile(v_dram_window); + v_tiles[I0] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); }; @@ -441,9 +438,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch tile_elementwise_in(k_element_func, k_tiles[number{}]), partition_index); - if constexpr(i_n0 < NumPrefetchV) + if constexpr(i_n0 == 0) { - v_tiles[i_n0] = load_tile(v_dram_window); + v_tiles[I0] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); }; @@ -470,9 +467,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch tile_elementwise_in(k_element_func, k_tiles[number{}]), partition_index); - if constexpr(i_n0 < NumPrefetchV) + if constexpr(i_n0 == 0) { - v_tiles[i_n0] = load_tile(v_dram_window); + v_tiles[I0] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); }; @@ -507,10 +504,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch if constexpr(i_n0 == n0_loops - 1) { - static_for<0, NumPrefetchV, 1>{}([&](auto i_k1) { - v_tiles[i_k1] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - }); + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); }; __builtin_amdgcn_sched_barrier(0x00000001); @@ -593,6 +588,31 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch tile_elementwise_inout( [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + __builtin_amdgcn_sched_barrier(0); + + auto v_shuffled_tile = make_static_distributed_tensor( + Policy::template MakeShuffledVRegTileDistribution()); + shuffle_tile(v_shuffled_tile, tile_elementwise_in(v_element_func, v_tiles[I0])); + + // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + { + __builtin_amdgcn_s_barrier(); + }; + + store_tile( + v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index); + + __builtin_amdgcn_sched_barrier(0x00000001); + + static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) { + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }); + + __builtin_amdgcn_sched_barrier(0); + constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); @@ -653,22 +673,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch __builtin_amdgcn_sched_barrier(0x00000001); - auto v_shuffled_tile = make_static_distributed_tensor( - Policy::template MakeShuffledVRegTileDistribution()); - shuffle_tile(v_shuffled_tile, tile_elementwise_in(v_element_func, v_tiles[I0])); - - // check whether first V-LdsBufer overlap with last K-LdsBuffer, - // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 - if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) - { - __builtin_amdgcn_s_barrier(); - }; - - store_tile( - v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index); - - __builtin_amdgcn_sched_barrier(0x00000001); - auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); __builtin_amdgcn_sched_barrier(0x00000001); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp index df3f567a7c..a6056bcc99 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp @@ -377,14 +377,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad move_tile_window(k_dram_window, {kN0Sub, 0}); }; - if constexpr(i_n0 < NumPrefetchV) - { - v_tiles[i_n0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {kK1, 0}); - }; - if constexpr(i_n0 == n0_loops - 1) { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + // prefetch all k_tiles for next iteration static_for<0, n0_loops, 1>{}([&](auto ii_n0) { k_tiles[number{}] = load_tile(k_dram_window); @@ -417,9 +414,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad move_tile_window(k_dram_window, {kN0Sub, 0}); }; - if constexpr(i_n0 < NumPrefetchV) + if constexpr(i_n0 == n0_loops - 1) { - v_tiles[i_n0] = load_tile(v_dram_window); + v_tiles[I0] = load_tile(v_dram_window); move_tile_window(v_dram_window, {kK1, 0}); }; @@ -445,9 +442,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad tile_elementwise_in(k_element_func, k_tiles[number{}]), partition_index); - if constexpr(i_n0 < NumPrefetchV) + if constexpr(i_n0 == 0) { - v_tiles[i_n0] = load_tile(v_dram_window); + v_tiles[I0] = load_tile(v_dram_window); move_tile_window(v_dram_window, {kK1, 0}); }; @@ -474,9 +471,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad tile_elementwise_in(k_element_func, k_tiles[number{}]), partition_index); - if constexpr(i_n0 < NumPrefetchV) + if constexpr(i_n0 == 0) { - v_tiles[i_n0] = load_tile(v_dram_window); + v_tiles[I0] = load_tile(v_dram_window); move_tile_window(v_dram_window, {kK1, 0}); }; @@ -511,10 +508,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad if constexpr(i_n0 == n0_loops - 1) { - static_for<0, NumPrefetchV, 1>{}([&](auto i_k1) { - v_tiles[i_k1] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {kK1, 0}); - }); + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); }; __builtin_amdgcn_sched_barrier(0x00000001); @@ -597,6 +592,28 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad tile_elementwise_inout( [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + __builtin_amdgcn_sched_barrier(0); + + // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + { + __builtin_amdgcn_s_barrier(); + }; + + store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_tiles[I0]), + partition_index); + + __builtin_amdgcn_sched_barrier(0x00000001); + + static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) { + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }); + + __builtin_amdgcn_sched_barrier(0); + constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); @@ -657,19 +674,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad __builtin_amdgcn_sched_barrier(0x00000001); - // check whether first V-LdsBufer overlap with last K-LdsBuffer, - // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 - if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) - { - __builtin_amdgcn_s_barrier(); - }; - - store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_tiles[I0]), - partition_index); - - __builtin_amdgcn_sched_barrier(0x00000001); - auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); __builtin_amdgcn_sched_barrier(0x00000001);