Refine the interleaving in the loop of Gemm0

This commit is contained in:
Qianfeng Zhang
2025-12-07 12:24:45 +00:00
parent 5722f8afbc
commit 044f554bf7
2 changed files with 41 additions and 36 deletions

View File

@@ -179,12 +179,12 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
constexpr index_t k1_loops = kN0 / kK1;
static_assert(k1_loops >= 2,
"k1_loops >= 2 required due to pre-storing two v_tiles to Lds");
// usually kN0 is 128, kK1 is 32/16
static_assert(k1_loops >= 2, "k1_loops >= 2 required to use this pipeline");
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
constexpr index_t NumPrefetchV = 2;
constexpr index_t NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
static_assert(k1_loops >= NumPrefetchV, "Check failed!");
constexpr bool kPreloadWholeNextIterationK =
@@ -377,18 +377,22 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
{
k_tiles[number<i_k1 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
}
else
{
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
if constexpr(i_k1 < NumPrefetchV)
{
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
if constexpr(i_k1 == k1_loops - 1)
{
// prefetch all k_tiles for next iteration
static_for<0, k1_loops, 1>{}([&](auto ii_k1) {
k_tiles[number<ii_k1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
});
}
};
block_sync_lds();
gemm_0(sacc_tile,
@@ -414,12 +418,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
{
k_tiles[number<i_k1 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
}
else
};
if constexpr(i_k1 < NumPrefetchV)
{
v_tiles[I0] = load_tile(v_dram_window);
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
}
};
block_sync_lds();
gemm_0(sacc_tile,
@@ -444,27 +449,20 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
if constexpr(i_k1 == 0)
if constexpr(i_k1 < NumPrefetchV)
{
// prefetch first v_tile
v_tiles[I0] = load_tile(v_dram_window);
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
// prefetch first two k_tiles for next iteration
if constexpr(i_k1 == 1)
{
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
k_tiles[I1] = load_tile(k_dram_window);
// prefetch k_tile for next iteration
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
};
// prefetch other k_tiles for next iteration
if constexpr(i_k1 >= 2)
if constexpr(i_k1 >= NumPrefetchV)
{
k_tiles[number<i_k1>{}] = load_tile(k_dram_window);
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
};
@@ -488,9 +486,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
if constexpr(i_k1 == 0)
if constexpr(i_k1 < NumPrefetchV)
{
v_tiles[I0] = load_tile(v_dram_window);
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
@@ -521,10 +519,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
{
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
}
else
};
if constexpr(i_k1 < NumPrefetchV)
{
v_tiles[number<i_k1 - (k1_loops - 1)>{}] = load_tile(v_dram_window);
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
@@ -545,11 +544,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
__builtin_amdgcn_sched_barrier(0x000000001);
static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
v_tiles[i_buf] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
});
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
// STAGE 2, scale_s, add bias, mask, softmax

View File

@@ -23,6 +23,17 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
return Problem::BlockFmhaShape::kM0 <= 64;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumPrefetchV()
{
constexpr index_t k1_loops = Problem::BlockFmhaShape::kN0 / Problem::BlockFmhaShape::kK1;
// usually kN0 is 128, kK1 is 32/16
static_assert(k1_loops >= 2, "Check failed!");
return 2;
};
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetNumKVLdsBuffers()
{