From 238e78d82ecfee26c0a08d6bbd8426f2176a885c Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sun, 13 Apr 2025 09:43:58 +0000 Subject: [PATCH] Update the in pipeline codes --- .../hstu_attention_fwd_pipeline.hpp | 93 ++++++------------- 1 file changed, 28 insertions(+), 65 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 80b637b051..9d3348f730 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -167,6 +167,7 @@ struct HstuAttentionFwdPipelineQRKSVS constexpr auto NumPrefetchV = Policy::template GetNumPrefetchV(); static_assert(NumKLdsBuffers >= 2); + static_assert(NumPrefetchV >= 2); auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), q_dram_block_window_tmp.get_window_lengths(), @@ -341,8 +342,6 @@ struct HstuAttentionFwdPipelineQRKSVS sequence{}), k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); - __builtin_amdgcn_sched_barrier(0); - const auto bias_tile = load_tile(bias_dram_window); // load bias tile static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) { @@ -350,6 +349,8 @@ struct HstuAttentionFwdPipelineQRKSVS move_tile_window(v_dram_window, {0, kK1}); }); + __builtin_amdgcn_sched_barrier(0); + // STAGE 2, scale_s, add bias, mask, siLU if constexpr(kHasBias) { @@ -445,73 +446,35 @@ struct HstuAttentionFwdPipelineQRKSVS __builtin_amdgcn_sched_barrier(0); - // STAGE 3, KV gemm - if constexpr(k1_loops > 1) - { - if constexpr(NumPrefetchV == 1) // NumVLdsBuffers == 2 + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 < k1_loops - NumPrefetchV) + v_tiles[number{}] = load_tile(v_dram_window); + + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number{}]); + + if constexpr(std::is_same_v) { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - v_tiles[I0] = load_tile(v_dram_window); - - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_windows[number{}]); - - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_tiles[I0]); - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffle_tmp)); - } - else - { - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_tiles[I0])); - } - - move_tile_window(v_dram_window, {0, kK1}); - }); + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]); + store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_shuffle_tmp)); } - else // NumVLdsBuffers == 3 or 2 + else { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - if constexpr(i_k1 < k1_loops - NumPrefetchV) - v_tiles[number{}] = load_tile(v_dram_window); - - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_windows[number{}]); - - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, - v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]); - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffle_tmp)); - } - else - { - store_tile( - v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, - v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); - } - - if constexpr(i_k1 < k1_loops - NumPrefetchV) - move_tile_window(v_dram_window, {0, kK1}); - }); + store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], + tile_elementwise_in(v_element_func, + v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); } - } + + if constexpr(i_k1 < k1_loops - NumPrefetchV) + move_tile_window(v_dram_window, {0, kK1}); + }); + // move K tile windows move_tile_window(k_dram_block_window, {kN0, 0});