Update the NumPrefetchK and NumPrefetchV in the softmax pipeline on mi300 to achieve better interleaving

This commit is contained in:
Qianfeng Zhang
2025-12-25 14:30:57 +00:00
parent 02cae85af5
commit ddf0f1c8ed

View File

@@ -160,9 +160,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
constexpr index_t n0_loops = kN0 / kN0Sub;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(n0_loops >= k1_loops, "n0_loops >= k1_loops required by this pipeline");
static_assert(k1_loops >= 2,
"k1_loops >= 2 required due to pre-storing two v_tiles to Lds");
static_assert(n0_loops >= 2, "n0_loops >= 2 required by this pipeline");
static_assert(k1_loops >= 2, "k1_loops >= 2 required by this pipeline");
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
@@ -318,9 +317,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
auto seqlen_k_curr = seqlen_k_start;
constexpr index_t NumPrefetchV = 2;
static_assert(NumPrefetchV >= NumPrefetchK);
using v_tile_type = decltype(load_tile(v_dram_window));
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
statically_indexed_array<v_tile_type, NumPrefetchV> v_tiles;
do
{
@@ -339,7 +342,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
}
else
{
if constexpr(i_n0 - (n0_loops - NumPrefetchK) < k1_loops)
// We assume NumPrefetchV >= NumPrefetchK
if constexpr(i_n0 - (n0_loops - NumPrefetchK) < NumPrefetchK)
{
// load v_tiles used in current iteration
v_tiles[number<i_n0 - (n0_loops - NumPrefetchK)>{}] =
@@ -433,7 +437,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
shuffle_tile(v_shuffled_tile, v_tiles[number<0>{}]);
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
// this does not occur when n0_loops == 2/4 and NumKVLdsBuffers == 4
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
@@ -444,7 +448,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<NumPrefetchK, k1_loops, 1>{}([&](auto i_k1) {
static_for<NumPrefetchK, NumPrefetchV, 1>{}([&](auto i_k1) {
// load v_tiles used in current iteration
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
@@ -519,27 +523,20 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
// k1_loops >= 2 required
shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]);
// check whether second V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 3 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
store_tile(
v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
// STAGE 3, Gemm_1 ( O = P@V )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
if constexpr(i_k1 < NumPrefetchK)
if constexpr(i_k1 < k1_loops - NumPrefetchV)
{
// load k_tiles used by next iteration
k_tiles[i_k1] = load_tile(k_dram_window);
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
if constexpr((i_k1 >= k1_loops - NumPrefetchV) &&
(i_k1 - (k1_loops - NumPrefetchV) < NumPrefetchK))
{
k_tiles[number<i_k1 - (k1_loops - NumPrefetchV)>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
@@ -552,12 +549,12 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
if constexpr(i_k1 < k1_loops - 2)
if constexpr(i_k1 < k1_loops - 1)
{
__builtin_amdgcn_sched_barrier(0x00000001);
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 2>{}]);
store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}],
shuffle_tile(v_shuffled_tile, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]);
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
v_shuffled_tile,
partition_index);