Update the in pipeline codes

This commit is contained in:
Qianfeng Zhang
2025-04-13 09:43:58 +00:00
parent 53e567977e
commit 238e78d82e

View File

@@ -167,6 +167,7 @@ struct HstuAttentionFwdPipelineQRKSVS
constexpr auto NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
static_assert(NumKLdsBuffers >= 2);
static_assert(NumPrefetchV >= 2);
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
@@ -341,8 +342,6 @@ struct HstuAttentionFwdPipelineQRKSVS
sequence<kM0, k0_loops * kK0>{}),
k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
__builtin_amdgcn_sched_barrier(0);
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
@@ -350,6 +349,8 @@ struct HstuAttentionFwdPipelineQRKSVS
move_tile_window(v_dram_window, {0, kK1});
});
__builtin_amdgcn_sched_barrier(0);
// STAGE 2, scale_s, add bias, mask, siLU
if constexpr(kHasBias)
{
@@ -445,73 +446,35 @@ struct HstuAttentionFwdPipelineQRKSVS
__builtin_amdgcn_sched_barrier(0);
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
if constexpr(NumPrefetchV == 1) // NumVLdsBuffers == 2
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 < k1_loops - NumPrefetchV)
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
v_tiles[I0] = load_tile(v_dram_window);
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
if constexpr(std::is_same_v<VLayout,
ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffle_tmp));
}
else
{
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[I0]));
}
move_tile_window(v_dram_window, {0, kK1});
});
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]);
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffle_tmp));
}
else // NumVLdsBuffers == 3 or 2
else
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 < k1_loops - NumPrefetchV)
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
if constexpr(std::is_same_v<VLayout,
ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp,
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]);
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffle_tmp));
}
else
{
store_tile(
v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func,
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]));
}
if constexpr(i_k1 < k1_loops - NumPrefetchV)
move_tile_window(v_dram_window, {0, kK1});
});
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func,
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]));
}
}
if constexpr(i_k1 < k1_loops - NumPrefetchV)
move_tile_window(v_dram_window, {0, kK1});
});
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});