From 2d53d67b6d650b87938a79c297b21d27b22387f8 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 25 Dec 2025 14:32:44 +0000 Subject: [PATCH] Update the NumPrefetchK and NumPrefetchV in the softmax pipeline on mi350 to achieve better interleaving --- ...ntion_with_softmax_fwd_trload_pipeline.hpp | 53 +++++++++---------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 69161209ec..ec409d7d0c 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -160,9 +160,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad constexpr index_t n0_loops = kN0 / kN0Sub; constexpr index_t k1_loops = kN0 / kK1; - static_assert(n0_loops >= k1_loops, "n0_loops >= k1_loops required by this pipeline"); - static_assert(k1_loops >= 2, - "k1_loops >= 2 required due to pre-storing two v_tiles to Lds"); + static_assert(n0_loops >= 2, "n0_loops >= 2 required by this pipeline"); + static_assert(k1_loops >= 2, "k1_loops >= 2 required by this pipeline"); constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); @@ -211,13 +210,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad using k_tile_type = decltype(load_tile(k_dram_window)); - constexpr index_t NumPrefetchK = (n0_loops <= 3) ? 1 : 2; + constexpr index_t NumPrefetchK = 1; static_assert(n0_loops >= NumPrefetchK, "Check failed!"); - static_assert(k1_loops >= 2, - "k1_loops >= 2 required due to pre-storing two v_tiles to Lds"); - // only prefetch two k tiles to save vgprs consumption statically_indexed_array k_tiles; @@ -321,9 +317,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad auto seqlen_k_curr = seqlen_k_start; + constexpr index_t NumPrefetchV = 2; + + static_assert(NumPrefetchV >= NumPrefetchK); + using v_tile_type = decltype(load_tile(v_dram_window)); - statically_indexed_array v_tiles; + statically_indexed_array v_tiles; do { @@ -342,7 +342,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad } else { - if constexpr(i_n0 - (n0_loops - NumPrefetchK) < k1_loops) + // We assume NumPrefetchV >= NumPrefetchK + if constexpr(i_n0 - (n0_loops - NumPrefetchK) < NumPrefetchK) { // load v_tiles used in current iteration v_tiles[number{}] = @@ -443,7 +444,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad __builtin_amdgcn_sched_barrier(0x00000001); - static_for{}([&](auto i_k1) { + static_for{}([&](auto i_k1) { // load v_tiles used in current iteration v_tiles[i_k1] = load_tile(v_dram_window); move_tile_window(v_dram_window, {kK1, 0}); @@ -460,6 +461,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad tile_elementwise_inout( [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + __builtin_amdgcn_sched_barrier(0x00000001); + constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); @@ -518,26 +521,20 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); - // check whether second V-LdsBufer overlap with last K-LdsBuffer, - // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 - if constexpr((k1_loops - 1) % NumKVLdsBuffers == 3 % NumKVLdsBuffers) - { - __builtin_amdgcn_s_barrier(); - }; - - // k1_loops >= 2 required - store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], - v_tiles[number<1>{}], - partition_index); - __builtin_amdgcn_sched_barrier(0x00000001); // STAGE 3, Gemm_1 ( O = P@V ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { - if constexpr(i_k1 < NumPrefetchK) + if constexpr(i_k1 < k1_loops - NumPrefetchV) { - // load k_tiles used by next iteration - k_tiles[i_k1] = load_tile(k_dram_window); + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + if constexpr((i_k1 >= k1_loops - NumPrefetchV) && + (i_k1 - (k1_loops - NumPrefetchV) < NumPrefetchK)) + { + k_tiles[number{}] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kN0Sub, 0}); }; @@ -550,12 +547,12 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); - if constexpr(i_k1 < k1_loops - 2) + if constexpr(i_k1 < k1_loops - 1) { __builtin_amdgcn_sched_barrier(0x00000001); - store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}], - v_tiles[number{}], + store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], + v_tiles[number{}], partition_index); __builtin_amdgcn_sched_barrier(0x00000001);