Improve both the with_softmax and no_softmax pipelines

This commit is contained in:
Qianfeng Zhang
2025-11-04 15:18:58 +00:00
parent bc22b83b19
commit 99993acca4
3 changed files with 104 additions and 50 deletions

View File

@@ -73,7 +73,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<64>
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<128>
{
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
using type = ck_tile::sequence<128, 64, 128, 16, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};

View File

@@ -448,6 +448,27 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
});
}
__builtin_amdgcn_sched_barrier(0x00000001);
using v_shuffled_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
Policy::template MakeShuffledVRegTileDistribution<Problem>()));
v_shuffled_tile_type v_shuffled_tile;
shuffle_tile(v_shuffled_tile, v_tiles[number<0>{}]);
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffled_tile));
__builtin_amdgcn_sched_barrier(0x00000001);
tile_elementwise_inout(f_silu, pcomp_tile);
tile_elementwise_inout([&](auto& x) { x = x * type_convert<CompDataType>(scale_p); },
@@ -466,28 +487,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
using v_shuffled_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
Policy::template MakeShuffledVRegTileDistribution<Problem>()));
statically_indexed_array<v_shuffled_tile_type, k1_loops> v_shuffled_tiles;
static_for<0, k1_loops, 1>{}(
[&](auto i_k1) { shuffle_tile(v_shuffled_tiles[i_k1], v_tiles[i_k1]); });
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
// STAGE 3, Gemm_1 ( O = P@V )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffled_tiles[i_k1]));
__builtin_amdgcn_sched_barrier(0x00000001);
// load k_tiles used by next iteration
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
@@ -500,6 +501,17 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
o_acc,
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
if constexpr(i_k1 < k1_loops - 1)
{
__builtin_amdgcn_sched_barrier(0x00000001);
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 1>{}]);
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffled_tile));
__builtin_amdgcn_sched_barrier(0x00000001);
};
});
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,

View File

@@ -219,9 +219,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
using k_tile_type = decltype(load_tile(k_dram_window));
statically_indexed_array<k_tile_type, k1_loops> k_tiles;
constexpr index_t NumPrefetchK = 2;
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
static_assert(k1_loops >= NumPrefetchK, "Check failed!");
// only prefetch two k tiles to save vgprs consumption
statically_indexed_array<k_tile_type, NumPrefetchK> k_tiles;
static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) {
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
});
@@ -391,14 +396,23 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
{
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(k_lds_write_windows[i_k1],
tile_elementwise_in(k_element_func, k_tiles[i_k1]));
store_tile(
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1 % NumPrefetchK>{}]));
__builtin_amdgcn_sched_barrier(0x00000001);
// load v_tiles used in current iteration
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
if constexpr(i_k1 < k1_loops - NumPrefetchK)
{
k_tiles[number<i_k1 % NumPrefetchK>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
}
else
{
// load v_tiles used in current iteration
v_tiles[number<i_k1 - (k1_loops - NumPrefetchK)>{}] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
};
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -417,8 +431,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
sequence<kM0, (i_k1 + 1) * kK1>{});
});
__builtin_amdgcn_sched_barrier(0x00000001);
// STAGE 2, scale_s, add bias, mask, siLU
if constexpr(kHasBias)
{
@@ -477,6 +489,35 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
});
};
__builtin_amdgcn_sched_barrier(0x00000001);
using v_shuffled_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
Policy::template MakeShuffledVRegTileDistribution<Problem>()));
v_shuffled_tile_type v_shuffled_tile;
shuffle_tile(v_shuffled_tile, v_tiles[number<0>{}]);
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffled_tile));
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<NumPrefetchK, k1_loops, 1>{}([&](auto i_k1) {
// load v_tiles used in current iteration
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
});
__builtin_amdgcn_sched_barrier(0x00000001);
auto m_local = block_tile_reduce<CompDataType>(
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
@@ -544,31 +585,21 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
using v_shuffled_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
Policy::template MakeShuffledVRegTileDistribution<Problem>()));
shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]);
statically_indexed_array<v_shuffled_tile_type, k1_loops> v_shuffled_tiles;
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffled_tile));
static_for<0, k1_loops, 1>{}(
[&](auto i_k1) { shuffle_tile(v_shuffled_tiles[i_k1], v_tiles[i_k1]); });
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
__builtin_amdgcn_sched_barrier(0x00000001);
// STAGE 3, Gemm_1 ( O = P@V )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffled_tiles[i_k1]));
__builtin_amdgcn_sched_barrier(0x00000001);
// load k_tiles used by next iteration
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
if constexpr(i_k1 < NumPrefetchK)
{
// load k_tiles used by next iteration
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
};
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -578,6 +609,17 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
o_acc,
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
if constexpr(i_k1 < k1_loops - 2)
{
__builtin_amdgcn_sched_barrier(0x00000001);
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 2>{}]);
store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffled_tile));
__builtin_amdgcn_sched_barrier(0x00000001);
};
});
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,