From ddf0f1c8edb4d79028a80fd7fdf8ffbfe2d8527c Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 25 Dec 2025 14:30:57 +0000 Subject: [PATCH] Update the NumPrefetchK and NumPrefetchV in the softmax pipeline on mi300 to achieve better interleaving --- ...tu_attention_with_softmax_fwd_pipeline.hpp | 49 +++++++++---------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index f23bffe772..c7d2d76f5e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -160,9 +160,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS constexpr index_t n0_loops = kN0 / kN0Sub; constexpr index_t k1_loops = kN0 / kK1; - static_assert(n0_loops >= k1_loops, "n0_loops >= k1_loops required by this pipeline"); - static_assert(k1_loops >= 2, - "k1_loops >= 2 required due to pre-storing two v_tiles to Lds"); + static_assert(n0_loops >= 2, "n0_loops >= 2 required by this pipeline"); + static_assert(k1_loops >= 2, "k1_loops >= 2 required by this pipeline"); constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); @@ -318,9 +317,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS auto seqlen_k_curr = seqlen_k_start; + constexpr index_t NumPrefetchV = 2; + + static_assert(NumPrefetchV >= NumPrefetchK); + using v_tile_type = decltype(load_tile(v_dram_window)); - statically_indexed_array v_tiles; + statically_indexed_array v_tiles; do { @@ -339,7 +342,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS } else { - if constexpr(i_n0 - (n0_loops - NumPrefetchK) < k1_loops) + // We assume NumPrefetchV >= NumPrefetchK + if constexpr(i_n0 - (n0_loops - NumPrefetchK) < NumPrefetchK) { // load v_tiles used in current iteration v_tiles[number{}] = @@ -433,7 +437,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS shuffle_tile(v_shuffled_tile, v_tiles[number<0>{}]); // check whether first V-LdsBufer overlap with last K-LdsBuffer, - // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + // this does not occur when n0_loops == 2/4 and NumKVLdsBuffers == 4 if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) { __builtin_amdgcn_s_barrier(); @@ -444,7 +448,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS __builtin_amdgcn_sched_barrier(0x00000001); - static_for{}([&](auto i_k1) { + static_for{}([&](auto i_k1) { // load v_tiles used in current iteration v_tiles[i_k1] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); @@ -519,27 +523,20 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); - // k1_loops >= 2 required - shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]); - - // check whether second V-LdsBufer overlap with last K-LdsBuffer, - // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 - if constexpr((n0_loops - 1) % NumKVLdsBuffers == 3 % NumKVLdsBuffers) - { - __builtin_amdgcn_s_barrier(); - }; - - store_tile( - v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index); - __builtin_amdgcn_sched_barrier(0x00000001); // STAGE 3, Gemm_1 ( O = P@V ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { - if constexpr(i_k1 < NumPrefetchK) + if constexpr(i_k1 < k1_loops - NumPrefetchV) { - // load k_tiles used by next iteration - k_tiles[i_k1] = load_tile(k_dram_window); + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; + + if constexpr((i_k1 >= k1_loops - NumPrefetchV) && + (i_k1 - (k1_loops - NumPrefetchV) < NumPrefetchK)) + { + k_tiles[number{}] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kN0Sub, 0}); }; @@ -552,12 +549,12 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); - if constexpr(i_k1 < k1_loops - 2) + if constexpr(i_k1 < k1_loops - 1) { __builtin_amdgcn_sched_barrier(0x00000001); - shuffle_tile(v_shuffled_tile, v_tiles[number{}]); - store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}], + shuffle_tile(v_shuffled_tile, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]); + store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);