Update in whole_k_prefetch_trload pipeline to prefetch two k_tile for next iteration in the non-whole-k-perfetch path

This commit is contained in:
Qianfeng Zhang
2025-12-23 10:23:26 +00:00
parent 489e2554ea
commit f5b4d5dc26

View File

@@ -198,6 +198,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
constexpr bool kPreloadWholeNextIterationK =
Policy::template IsPreloadWholeNextIterationK<Problem>();
// This path prefetches two k_tiles for next iteration, so it has the opportunity to
// prefetch two v_tiles during Gemm0
if constexpr(!kPreloadWholeNextIterationK)
{
static_assert(NumPrefetchV >= 2);
};
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
@@ -243,17 +250,22 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
using k_tile_type = decltype(load_tile(k_dram_window));
// only prefetch two k tiles to save vgprs consumption
auto k_tiles = [&]() {
if constexpr(kPreloadWholeNextIterationK)
return statically_indexed_array<k_tile_type, n0_loops>{};
else
return statically_indexed_array<k_tile_type, 1>{};
return statically_indexed_array<k_tile_type, 2>{};
}();
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
if constexpr(!kPreloadWholeNextIterationK)
{
k_tiles[I1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
__builtin_amdgcn_sched_barrier(0x00000001);
// provide partition_index for LDS tile window with so that warp_id is in vgpr
@@ -495,20 +507,20 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[I0]),
tile_elementwise_in(k_element_func, k_tiles[number<i_n0 % 2>{}]),
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
if constexpr(i_n0 < n0_loops - 1)
if constexpr(i_n0 < n0_loops - 2)
{
k_tiles[I0] = load_tile(k_dram_window);
k_tiles[number<i_n0 % 2>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
if constexpr(i_n0 == n0_loops - 1)
if constexpr(i_n0 >= n0_loops - 2)
{
v_tiles[I0] = load_tile(v_dram_window);
v_tiles[number<i_n0 - (n0_loops - 2)>{}] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
@@ -607,10 +619,20 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) {
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
});
if constexpr(kPreloadWholeNextIterationK)
{
static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) {
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
});
}
else
{
static_for<2, NumPrefetchV, 1>{}([&](auto i_k1) {
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
});
};
__builtin_amdgcn_sched_barrier(0);
@@ -698,6 +720,18 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
}
};
if constexpr(i_k1 == k1_loops - NumPrefetchV + 1)
{
if constexpr(!kPreloadWholeNextIterationK)
{
if(seqlen_k_curr < seqlen_k_end)
{
k_tiles[I1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
}
};
block_sync_lds();
gemm_1(
o_acc,