Update to only pre-load one v_tile during Gemm0 loop

This commit is contained in:
Qianfeng Zhang
2025-12-22 08:41:54 +00:00
parent db5c12db89
commit 57cf989f63
2 changed files with 69 additions and 61 deletions

View File

@@ -373,14 +373,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
move_tile_window(k_dram_window, {kN0Sub, 0});
};
if constexpr(i_n0 < NumPrefetchV)
{
v_tiles[i_n0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
if constexpr(i_n0 == n0_loops - 1)
{
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
// prefetch all k_tiles for next iteration
static_for<0, n0_loops, 1>{}([&](auto ii_n0) {
k_tiles[number<ii_n0>{}] = load_tile(k_dram_window);
@@ -413,9 +410,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
move_tile_window(k_dram_window, {kN0Sub, 0});
};
if constexpr(i_n0 < NumPrefetchV)
if constexpr(i_n0 == n0_loops - 1)
{
v_tiles[i_n0] = load_tile(v_dram_window);
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
@@ -441,9 +438,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
if constexpr(i_n0 < NumPrefetchV)
if constexpr(i_n0 == 0)
{
v_tiles[i_n0] = load_tile(v_dram_window);
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
@@ -470,9 +467,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
if constexpr(i_n0 < NumPrefetchV)
if constexpr(i_n0 == 0)
{
v_tiles[i_n0] = load_tile(v_dram_window);
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
@@ -507,10 +504,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
if constexpr(i_n0 == n0_loops - 1)
{
static_for<0, NumPrefetchV, 1>{}([&](auto i_k1) {
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
});
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -593,6 +588,31 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
__builtin_amdgcn_sched_barrier(0);
auto v_shuffled_tile = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegTileDistribution<Problem>());
shuffle_tile(v_shuffled_tile, tile_elementwise_in(v_element_func, v_tiles[I0]));
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
store_tile(
v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) {
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
});
__builtin_amdgcn_sched_barrier(0);
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
@@ -653,22 +673,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
__builtin_amdgcn_sched_barrier(0x00000001);
auto v_shuffled_tile = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegTileDistribution<Problem>());
shuffle_tile(v_shuffled_tile, tile_elementwise_in(v_element_func, v_tiles[I0]));
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
store_tile(
v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
__builtin_amdgcn_sched_barrier(0x00000001);

View File

@@ -377,14 +377,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
move_tile_window(k_dram_window, {kN0Sub, 0});
};
if constexpr(i_n0 < NumPrefetchV)
{
v_tiles[i_n0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
if constexpr(i_n0 == n0_loops - 1)
{
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
// prefetch all k_tiles for next iteration
static_for<0, n0_loops, 1>{}([&](auto ii_n0) {
k_tiles[number<ii_n0>{}] = load_tile(k_dram_window);
@@ -417,9 +414,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
move_tile_window(k_dram_window, {kN0Sub, 0});
};
if constexpr(i_n0 < NumPrefetchV)
if constexpr(i_n0 == n0_loops - 1)
{
v_tiles[i_n0] = load_tile(v_dram_window);
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
@@ -445,9 +442,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
if constexpr(i_n0 < NumPrefetchV)
if constexpr(i_n0 == 0)
{
v_tiles[i_n0] = load_tile(v_dram_window);
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
@@ -474,9 +471,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
if constexpr(i_n0 < NumPrefetchV)
if constexpr(i_n0 == 0)
{
v_tiles[i_n0] = load_tile(v_dram_window);
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
@@ -511,10 +508,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
if constexpr(i_n0 == n0_loops - 1)
{
static_for<0, NumPrefetchV, 1>{}([&](auto i_k1) {
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
});
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -597,6 +592,28 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
__builtin_amdgcn_sched_barrier(0);
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[I0]),
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) {
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
});
__builtin_amdgcn_sched_barrier(0);
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
@@ -657,19 +674,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
__builtin_amdgcn_sched_barrier(0x00000001);
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[I0]),
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
__builtin_amdgcn_sched_barrier(0x00000001);