From 54cd431f16d49a00d8348d38d9f99e0c38c9d6e0 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Wed, 5 Nov 2025 16:23:05 +0000 Subject: [PATCH] Improve the softmax+trload pipeline by using kN0=64 and prefetch only two k tiles --- .../hstu_attention_fwd_setting.hpp | 2 +- ...ntion_with_softmax_fwd_trload_pipeline.hpp | 92 +++++++++++++++---- 2 files changed, 74 insertions(+), 20 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp index 9c9ac286b4..096bb69a8c 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -243,7 +243,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<64> template <> struct HstuAttentionWithSoftmaxFwdBlockTile<128> { - using type = ck_tile::sequence<128, 32, 128, 16, 128>; + using type = ck_tile::sequence<128, 64, 128, 16, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 7f0cc215c3..64da3fe9d2 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -219,9 +219,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad using k_tile_type = decltype(load_tile(k_dram_window)); - statically_indexed_array k_tiles; + constexpr index_t NumPrefetchK = 2; - static_for<0, k1_loops, 1>{}([&](auto i_k1) { + static_assert(k1_loops >= NumPrefetchK, "Check failed!"); + + // only prefetch two k tiles to save vgprs consumption + statically_indexed_array k_tiles; + + static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) { k_tiles[i_k1] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); }); @@ -391,14 +396,23 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad { // STAGE 1, Gemm_0 ( S = Q@K ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_write_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[i_k1])); + store_tile( + k_lds_write_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}])); __builtin_amdgcn_sched_barrier(0x00000001); - // load v_tiles used in current iteration - v_tiles[i_k1] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {kK1, 0}); + if constexpr(i_k1 < k1_loops - NumPrefetchK) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + } + else + { + // load v_tiles used in current iteration + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; __builtin_amdgcn_sched_barrier(0x00000001); @@ -477,6 +491,28 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad }); }; + __builtin_amdgcn_sched_barrier(0x00000001); + + // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + { + __builtin_amdgcn_s_barrier(); + }; + + store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_tiles[number<0>{}])); + + __builtin_amdgcn_sched_barrier(0x00000001); + + static_for{}([&](auto i_k1) { + // load v_tiles used in current iteration + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }); + + __builtin_amdgcn_sched_barrier(0x00000001); + auto m_local = block_tile_reduce( pcomp_tile, sequence<1>{}, f_max, -numeric::infinity()); block_tile_reduce_sync(m_local, f_max, bool_constant{}); @@ -544,35 +580,53 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); - // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // check whether second V-LdsBufer overlap with last K-LdsBuffer, // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 - if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + if constexpr((k1_loops - 1) % NumKVLdsBuffers == 3 % NumKVLdsBuffers) { __builtin_amdgcn_s_barrier(); }; + store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_tiles[number<1>{}])); + + __builtin_amdgcn_sched_barrier(0x00000001); + // STAGE 3, Gemm_1 ( O = P@V ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_tiles[number{}])); - - __builtin_amdgcn_sched_barrier(0x00000001); - - // load k_tiles used by next iteration - k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + if constexpr(i_k1 < NumPrefetchK) + { + // load k_tiles used by next iteration + k_tiles[i_k1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + }; __builtin_amdgcn_sched_barrier(0x00000001); block_sync_lds(); - __builtin_amdgcn_sched_barrier(0x00000001); - gemm_1( o_acc, get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); + + if constexpr(i_k1 < k1_loops - 2) + { + __builtin_amdgcn_sched_barrier(0x00000001); + + store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_tiles[number{}])); + + __builtin_amdgcn_sched_barrier(0x00000001); + }; }); + + // check whether last V-LdsBuffer overlap with first K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0) + { + __builtin_amdgcn_s_barrier(); + }; } while(seqlen_k_curr < seqlen_k_end); constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();