mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Update the in pipeline codes
This commit is contained in:
@@ -167,6 +167,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
constexpr auto NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
|
||||
|
||||
static_assert(NumKLdsBuffers >= 2);
|
||||
static_assert(NumPrefetchV >= 2);
|
||||
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
@@ -341,8 +342,6 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
|
||||
static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
|
||||
@@ -350,6 +349,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, siLU
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
@@ -445,73 +446,35 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
if constexpr(NumPrefetchV == 1) // NumVLdsBuffers == 2
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchV)
|
||||
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
|
||||
|
||||
block_sync_lds();
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
|
||||
|
||||
if constexpr(std::is_same_v<VLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp));
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_tiles[I0]));
|
||||
}
|
||||
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]);
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp));
|
||||
}
|
||||
else // NumVLdsBuffers == 3 or 2
|
||||
else
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchV)
|
||||
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
|
||||
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
|
||||
|
||||
if constexpr(std::is_same_v<VLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp,
|
||||
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]);
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp));
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(
|
||||
v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]));
|
||||
}
|
||||
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchV)
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchV)
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user