Improve the softmax+trload pipeline by using kN0=64 and prefetch only two k tiles

This commit is contained in:
Qianfeng Zhang
2025-11-05 16:23:05 +00:00
parent d190af2ef5
commit 54cd431f16
2 changed files with 74 additions and 20 deletions

View File

@@ -243,7 +243,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<64>
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<128>
{
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
using type = ck_tile::sequence<128, 64, 128, 16, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};

View File

@@ -219,9 +219,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
using k_tile_type = decltype(load_tile(k_dram_window));
statically_indexed_array<k_tile_type, k1_loops> k_tiles;
constexpr index_t NumPrefetchK = 2;
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
static_assert(k1_loops >= NumPrefetchK, "Check failed!");
// only prefetch two k tiles to save vgprs consumption
statically_indexed_array<k_tile_type, NumPrefetchK> k_tiles;
static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) {
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
});
@@ -391,14 +396,23 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
{
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[i_k1]));
store_tile(
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1 % NumPrefetchK>{}]));
__builtin_amdgcn_sched_barrier(0x00000001);
// load v_tiles used in current iteration
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
if constexpr(i_k1 < k1_loops - NumPrefetchK)
{
k_tiles[number<i_k1 % NumPrefetchK>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
}
else
{
// load v_tiles used in current iteration
v_tiles[number<i_k1 - (k1_loops - NumPrefetchK)>{}] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -477,6 +491,28 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
});
};
__builtin_amdgcn_sched_barrier(0x00000001);
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[number<0>{}]));
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<NumPrefetchK, k1_loops, 1>{}([&](auto i_k1) {
// load v_tiles used in current iteration
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
});
__builtin_amdgcn_sched_barrier(0x00000001);
auto m_local = block_tile_reduce<CompDataType>(
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
@@ -544,35 +580,53 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// check whether second V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 3 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[number<1>{}]));
__builtin_amdgcn_sched_barrier(0x00000001);
// STAGE 3, Gemm_1 ( O = P@V )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[number<i_k1>{}]));
__builtin_amdgcn_sched_barrier(0x00000001);
// load k_tiles used by next iteration
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
if constexpr(i_k1 < NumPrefetchK)
{
// load k_tiles used by next iteration
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
};
__builtin_amdgcn_sched_barrier(0x00000001);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0x00000001);
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
if constexpr(i_k1 < k1_loops - 2)
{
__builtin_amdgcn_sched_barrier(0x00000001);
store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[number<i_k1 + 2>{}]));
__builtin_amdgcn_sched_barrier(0x00000001);
};
});
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0)
{
__builtin_amdgcn_s_barrier();
};
} while(seqlen_k_curr < seqlen_k_end);
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();