From c2e6ab8516f44d3f101db8b9763a46255bf21382 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sun, 13 Apr 2025 10:58:32 +0000 Subject: [PATCH] Add IsFirstVLdsBufferOverlapLastKLdsBuffer() check to reduce call of s_barrier() --- .../hstu_attention_fwd_pipeline.hpp | 14 ++++++--- ..._attention_fwd_pipeline_default_policy.hpp | 29 +++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 9d3348f730..f97bde7126 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -342,6 +342,8 @@ struct HstuAttentionFwdPipelineQRKSVS sequence{}), k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); + __builtin_amdgcn_sched_barrier(0); + const auto bias_tile = load_tile(bias_dram_window); // load bias tile static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) { @@ -349,8 +351,6 @@ struct HstuAttentionFwdPipelineQRKSVS move_tile_window(v_dram_window, {0, kK1}); }); - __builtin_amdgcn_sched_barrier(0); - // STAGE 2, scale_s, add bias, mask, siLU if constexpr(kHasBias) { @@ -416,8 +416,6 @@ struct HstuAttentionFwdPipelineQRKSVS randval_lds_ptr, seqlen_k_start + i_loop * kN0, s, null_randval_window); } - // ensure gemm_0 has finished access of k-Lds for all warps - __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0x7f); if constexpr(std::is_same_v) @@ -426,12 +424,20 @@ struct HstuAttentionFwdPipelineQRKSVS Policy::template MakeShuffledVRegBlockDescriptor()); shuffle_tile(v_shuffle_tmp, v_tiles[I0]); + // ensure gemm_0 has finished access of k-Lds for all warps + if constexpr(Policy::template IsFirstVLdsBufferOverlapLastKLdsBuffer()) + __builtin_amdgcn_s_barrier(); + store_tile( v_lds_windows[I0], tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch } else { + // ensure gemm_0 has finished access of k-Lds for all warps + if constexpr(Policy::template IsFirstVLdsBufferOverlapLastKLdsBuffer()) + __builtin_amdgcn_s_barrier(); + store_tile(v_lds_windows[I0], tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch } diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index ad165e7c00..0b40e353f0 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -321,6 +321,35 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return integer_least_multiple(single_k_lds_buffer_size - single_v_lds_buffer_size, 64); }; + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstVLdsBufferOverlapLastKLdsBuffer() + { + using BlockFmhaShape = remove_cvref_t; + + constexpr index_t k0_loops = BlockFmhaShape::kQKHeaddim / BlockFmhaShape::kK0; + constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers(); + constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers(); + + constexpr index_t last_k_lds_buffer_offset = + MakeKLdsBlockDescriptor().get_element_space_size() / num_k_lds_buffers * + ((k0_loops - 1) % num_k_lds_buffers) * sizeof(typename Problem::KDataType); + + constexpr index_t last_k_lds_buffer_end = + last_k_lds_buffer_offset + MakeKLdsBlockDescriptor().get_element_space_size() / + num_k_lds_buffers * sizeof(typename Problem::KDataType); + + constexpr index_t first_v_lds_buffer_size = + MakeVLdsBlockDescriptor().get_element_space_size() / num_v_lds_buffers * + sizeof(typename Problem::VDataType); + + constexpr index_t first_v_lds_buffer_offset = GetExclusiveKLdsBytes(); + constexpr index_t first_v_lds_buffer_end = + first_v_lds_buffer_offset + first_v_lds_buffer_size; + + return !((first_v_lds_buffer_offset >= last_k_lds_buffer_end) || + (first_v_lds_buffer_end <= last_k_lds_buffer_offset)); + } + template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer() {