Switch the codes based on the iteration index (first/intermediate/last)

This commit is contained in:
Qianfeng Zhang
2025-12-05 15:58:33 +00:00
parent c32949b285
commit 25521a7e06

View File

@@ -154,9 +154,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
void* smem_ptr,
DropoutType& dropout) const
{
ignore = q_element_func;
ignore = k_element_func;
// xformers path does not require the pipeline to output random values for host
// verification, since a separate kernel is used to generate random values
ignore = randval_dram_block_window_tmp;
@@ -177,6 +174,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr index_t k1_loops = kN0 / kK1;
static_assert(k1_loops >= 2,
@@ -184,6 +184,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
constexpr index_t NumPrefetchV = 2;
static_assert(k1_loops >= NumPrefetchV, "Check failed!");
constexpr bool kPreloadWholeNextIterationK =
Policy::template IsPreloadWholeNextIterationK<Problem>();
@@ -218,10 +221,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
auto q_tile = load_tile(q_dram_window);
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
@@ -234,34 +233,22 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
using k_tile_type = decltype(load_tile(k_dram_window));
constexpr index_t NumPrefetchK = 2;
static_assert(k1_loops >= NumPrefetchK, "Check failed!");
// only prefetch two k tiles to save vgprs consumption
auto k_tiles = [&]() {
if constexpr(kPreloadWholeNextIterationK)
return statically_indexed_array<k_tile_type, k1_loops>{};
else
return statically_indexed_array<k_tile_type, NumPrefetchK>{};
return statically_indexed_array<k_tile_type, 1>{};
}();
if constexpr(kPreloadWholeNextIterationK)
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
});
}
else
{
static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) {
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
});
};
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier(0x00000001);
auto q_tile = load_tile(q_dram_window);
__builtin_amdgcn_sched_barrier(0x00000001);
// K tile in LDS
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
@@ -377,51 +364,167 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
// STAGE 1, Gemm_0 ( S = Q@K )
if constexpr(kPreloadWholeNextIterationK)
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
k_tiles[i_k1]);
if(seqlen_k_curr == seqlen_k_start) // at first iteration
{
if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
__builtin_amdgcn_sched_barrier(0x00000001);
if constexpr(i_k1 < k1_loops - 1)
{
k_tiles[number<i_k1 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
}
else
{
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
// load v_tiles used in current iteration
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
// prefetch all k_tiles for next iteration
static_for<0, k1_loops, 1>{}([&](auto ii_k1) {
k_tiles[number<ii_k1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
});
}
__builtin_amdgcn_sched_barrier(0x00000001);
block_sync_lds();
gemm_0(sacc_tile,
q_tile,
k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
block_sync_lds();
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_k1 * kK1>{},
sequence<kM0, (i_k1 + 1) * kK1>{});
});
}
else // the iteration is also the last iteration
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
if constexpr(i_k1 < k1_loops - 1)
{
k_tiles[number<i_k1 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
}
else
{
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
}
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
block_sync_lds();
gemm_0(sacc_tile,
q_tile,
k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_k1 * kK1>{},
sequence<kM0, (i_k1 + 1) * kK1>{});
});
};
}
else // at intermediate and last iteration
{
if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_k1 * kK1>{},
sequence<kM0, (i_k1 + 1) * kK1>{});
});
if constexpr(i_k1 == 0)
{
// prefetch first v_tile
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
// prefetch first two k_tiles for next iteration
if constexpr(i_k1 == 1)
{
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
k_tiles[I1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
};
// prefetch other k_tiles for next iteration
if constexpr(i_k1 >= 2)
{
k_tiles[number<i_k1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
};
block_sync_lds();
gemm_0(sacc_tile,
q_tile,
k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_k1 * kK1>{},
sequence<kM0, (i_k1 + 1) * kK1>{});
});
}
else // last iteration
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
if constexpr(i_k1 == 0)
{
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
block_sync_lds();
gemm_0(sacc_tile,
q_tile,
k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_k1 * kK1>{},
sequence<kM0, (i_k1 + 1) * kK1>{});
});
};
}
}
else
else // only preload one unroll of K for next iteration
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
k_tiles[number<i_k1 % NumPrefetchK>{}]);
tile_elementwise_in(k_element_func, k_tiles[I0]));
__builtin_amdgcn_sched_barrier(0x00000001);
if constexpr(i_k1 < k1_loops - NumPrefetchK)
if constexpr(i_k1 < k1_loops - 1)
{
k_tiles[number<i_k1 % NumPrefetchK>{}] = load_tile(k_dram_window);
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
}
else
{
// load v_tiles used in current iteration
v_tiles[number<i_k1 - (k1_loops - NumPrefetchK)>{}] =
load_tile(v_dram_window);
v_tiles[number<i_k1 - (k1_loops - 1)>{}] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
@@ -429,13 +532,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_k1 * kK1>{},
@@ -445,14 +545,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
__builtin_amdgcn_sched_barrier(0x000000001);
if constexpr(!kPreloadWholeNextIterationK)
{
static_for<NumPrefetchK, k1_loops, 1>{}([&](auto i_k1) {
// load v_tiles used in current iteration
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
});
}
static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
v_tiles[i_buf] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
});
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
@@ -577,16 +673,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
__builtin_amdgcn_sched_barrier(0x00000001);
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
__builtin_amdgcn_sched_barrier(0x00000001);
using v_shuffled_tile_type = decltype(make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegTileDistribution<Problem>()));
v_shuffled_tile_type v_shuffled_tile;
shuffle_tile(v_shuffled_tile, v_tiles[number<0>{}]);
auto v_shuffled_tile = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegTileDistribution<Problem>());
shuffle_tile(v_shuffled_tile, tile_elementwise_in(v_element_func, v_tiles[I0]));
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
@@ -599,66 +688,44 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
__builtin_amdgcn_sched_barrier(0x00000001);
// STAGE 3, Gemm_1 ( O = P@V )
if constexpr(kPreloadWholeNextIterationK)
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
__builtin_amdgcn_sched_barrier(0x00000001);
if constexpr(!kPreloadWholeNextIterationK)
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
// load k_tiles used by next iteration
k_tiles[i_k1] = load_tile(k_dram_window);
if(seqlen_k_curr < seqlen_k_end)
{
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
__builtin_amdgcn_sched_barrier(0x00000001);
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
if constexpr(i_k1 < k1_loops - 1)
{
__builtin_amdgcn_sched_barrier(0x00000001);
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 1>{}]);
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
v_shuffled_tile);
__builtin_amdgcn_sched_barrier(0x00000001);
};
});
};
}
else
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
if constexpr(i_k1 < NumPrefetchK)
{
// load k_tiles used by next iteration
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
};
__builtin_amdgcn_sched_barrier(0x00000001);
__builtin_amdgcn_sched_barrier(0x00000001);
block_sync_lds();
// STAGE 3, Gemm_1 ( O = P@V )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
if constexpr(i_k1 < k1_loops - NumPrefetchV)
{
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
if constexpr(i_k1 < k1_loops - 1)
{
__builtin_amdgcn_sched_barrier(0x00000001);
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 1>{}]);
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
v_shuffled_tile);
__builtin_amdgcn_sched_barrier(0x00000001);
};
});
}
if constexpr(i_k1 < k1_loops - 1)
{
shuffle_tile(v_shuffled_tile,
tile_elementwise_in(v_element_func,
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]));
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
v_shuffled_tile);
};
});
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4