Add prefetching whole next iteration K path in the pipeline

This commit is contained in:
Qianfeng Zhang
2025-12-04 10:25:16 +00:00
parent 5fada1ce99
commit 98f9b4a47b

View File

@@ -179,8 +179,14 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
constexpr index_t k1_loops = kN0 / kK1;
static_assert(k1_loops >= 2,
"k1_loops >= 2 required due to pre-storing two v_tiles to Lds");
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
constexpr bool kPreloadWholeNextIterationK =
Policy::template IsPreloadWholeNextIterationK<Problem>();
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
@@ -233,12 +239,27 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
static_assert(k1_loops >= NumPrefetchK, "Check failed!");
// only prefetch two k tiles to save vgprs consumption
statically_indexed_array<k_tile_type, NumPrefetchK> k_tiles;
auto k_tiles = [&]() {
if constexpr(kPreloadWholeNextIterationK)
return statically_indexed_array<k_tile_type, k1_loops>{};
else
return statically_indexed_array<k_tile_type, NumPrefetchK>{};
}();
static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) {
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
});
if constexpr(kPreloadWholeNextIterationK)
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
});
}
else
{
static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) {
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
});
};
__builtin_amdgcn_sched_barrier(0);
@@ -314,7 +335,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
{bias_origin.at(number<0>{}), seqlen_k_start},
Policy::template MakeBiasDramTileDistribution<Problem>());
// assuming no random values need be saved, this is try when this pipeline is called from
// assuming no random values need be saved, this is true when the pipeline is called from
// xformers, since we have a separate kernel to generated randomm values
auto null_randval_window = [&]() {
if constexpr(kHasDropout)
@@ -354,42 +375,84 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
do
{
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
k_tiles[number<i_k1 % NumPrefetchK>{}]);
if constexpr(kPreloadWholeNextIterationK)
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
k_tiles[i_k1]);
__builtin_amdgcn_sched_barrier(0x00000001);
__builtin_amdgcn_sched_barrier(0x00000001);
if constexpr(i_k1 < k1_loops - NumPrefetchK)
{
k_tiles[number<i_k1 % NumPrefetchK>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
}
else
{
// load v_tiles used in current iteration
v_tiles[number<i_k1 - (k1_loops - NumPrefetchK)>{}] = load_tile(v_dram_window);
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
__builtin_amdgcn_sched_barrier(0x00000001);
__builtin_amdgcn_sched_barrier(0x00000001);
block_sync_lds();
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_k1 * kK1>{},
sequence<kM0, (i_k1 + 1) * kK1>{});
});
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_k1 * kK1>{},
sequence<kM0, (i_k1 + 1) * kK1>{});
});
}
else
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
k_tiles[number<i_k1 % NumPrefetchK>{}]);
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier(0x00000001);
if constexpr(i_k1 < k1_loops - NumPrefetchK)
{
k_tiles[number<i_k1 % NumPrefetchK>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
}
else
{
// load v_tiles used in current iteration
v_tiles[number<i_k1 - (k1_loops - NumPrefetchK)>{}] =
load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
__builtin_amdgcn_sched_barrier(0x00000001);
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_k1 * kK1>{},
sequence<kM0, (i_k1 + 1) * kK1>{});
});
}
__builtin_amdgcn_sched_barrier(0x000000001);
if constexpr(!kPreloadWholeNextIterationK)
{
static_for<NumPrefetchK, k1_loops, 1>{}([&](auto i_k1) {
// load v_tiles used in current iteration
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
});
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
@@ -445,32 +508,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
__builtin_amdgcn_sched_barrier(0x00000001);
using v_shuffled_tile_type = decltype(make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegTileDistribution<Problem>()));
v_shuffled_tile_type v_shuffled_tile;
shuffle_tile(v_shuffled_tile, v_tiles[number<0>{}]);
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile);
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<NumPrefetchK, k1_loops, 1>{}([&](auto i_k1) {
// load v_tiles used in current iteration
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
});
__builtin_amdgcn_sched_barrier(0x00000001);
auto m_local = block_tile_reduce<CompDataType>(
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
@@ -538,44 +575,90 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
seqlen_k_curr += kN0;
__builtin_amdgcn_sched_barrier(0x00000001);
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
// k1_loops >= 2 required
shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]);
__builtin_amdgcn_sched_barrier(0x00000001);
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile);
using v_shuffled_tile_type = decltype(make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegTileDistribution<Problem>()));
v_shuffled_tile_type v_shuffled_tile;
shuffle_tile(v_shuffled_tile, v_tiles[number<0>{}]);
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile);
__builtin_amdgcn_sched_barrier(0x00000001);
// STAGE 3, Gemm_1 ( O = P@V )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
if constexpr(i_k1 < NumPrefetchK)
{
if constexpr(kPreloadWholeNextIterationK)
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
// load k_tiles used by next iteration
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
};
__builtin_amdgcn_sched_barrier(0x00000001);
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
if constexpr(i_k1 < k1_loops - 2)
{
__builtin_amdgcn_sched_barrier(0x00000001);
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 2>{}]);
store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}],
v_shuffled_tile);
__builtin_amdgcn_sched_barrier(0x00000001);
};
});
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
if constexpr(i_k1 < k1_loops - 1)
{
__builtin_amdgcn_sched_barrier(0x00000001);
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 1>{}]);
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
v_shuffled_tile);
__builtin_amdgcn_sched_barrier(0x00000001);
};
});
}
else
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
if constexpr(i_k1 < NumPrefetchK)
{
// load k_tiles used by next iteration
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
};
__builtin_amdgcn_sched_barrier(0x00000001);
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
if constexpr(i_k1 < k1_loops - 1)
{
__builtin_amdgcn_sched_barrier(0x00000001);
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 1>{}]);
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
v_shuffled_tile);
__builtin_amdgcn_sched_barrier(0x00000001);
};
});
}
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4