Add IsFirstVLdsBufferOverlapLastKLdsBuffer() check to reduce call of s_barrier()

This commit is contained in:
Qianfeng Zhang
2025-04-13 10:58:32 +00:00
parent 238e78d82e
commit c2e6ab8516
2 changed files with 39 additions and 4 deletions

View File

@@ -342,6 +342,8 @@ struct HstuAttentionFwdPipelineQRKSVS
sequence<kM0, k0_loops * kK0>{}),
k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
__builtin_amdgcn_sched_barrier(0);
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
@@ -349,8 +351,6 @@ struct HstuAttentionFwdPipelineQRKSVS
move_tile_window(v_dram_window, {0, kK1});
});
__builtin_amdgcn_sched_barrier(0);
// STAGE 2, scale_s, add bias, mask, siLU
if constexpr(kHasBias)
{
@@ -416,8 +416,6 @@ struct HstuAttentionFwdPipelineQRKSVS
randval_lds_ptr, seqlen_k_start + i_loop * kN0, s, null_randval_window);
}
// ensure gemm_0 has finished access of k-Lds for all warps
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0x7f);
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
@@ -426,12 +424,20 @@ struct HstuAttentionFwdPipelineQRKSVS
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
// ensure gemm_0 has finished access of k-Lds for all warps
if constexpr(Policy::template IsFirstVLdsBufferOverlapLastKLdsBuffer<Problem>())
__builtin_amdgcn_s_barrier();
store_tile(
v_lds_windows[I0],
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
// ensure gemm_0 has finished access of k-Lds for all warps
if constexpr(Policy::template IsFirstVLdsBufferOverlapLastKLdsBuffer<Problem>())
__builtin_amdgcn_s_barrier();
store_tile(v_lds_windows[I0],
tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch
}

View File

@@ -321,6 +321,35 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
return integer_least_multiple(single_k_lds_buffer_size - single_v_lds_buffer_size, 64);
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstVLdsBufferOverlapLastKLdsBuffer()
{
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
constexpr index_t k0_loops = BlockFmhaShape::kQKHeaddim / BlockFmhaShape::kK0;
constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers<Problem>();
constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers<Problem>();
constexpr index_t last_k_lds_buffer_offset =
MakeKLdsBlockDescriptor<Problem>().get_element_space_size() / num_k_lds_buffers *
((k0_loops - 1) % num_k_lds_buffers) * sizeof(typename Problem::KDataType);
constexpr index_t last_k_lds_buffer_end =
last_k_lds_buffer_offset + MakeKLdsBlockDescriptor<Problem>().get_element_space_size() /
num_k_lds_buffers * sizeof(typename Problem::KDataType);
constexpr index_t first_v_lds_buffer_size =
MakeVLdsBlockDescriptor<Problem>().get_element_space_size() / num_v_lds_buffers *
sizeof(typename Problem::VDataType);
constexpr index_t first_v_lds_buffer_offset = GetExclusiveKLdsBytes<Problem>();
constexpr index_t first_v_lds_buffer_end =
first_v_lds_buffer_offset + first_v_lds_buffer_size;
return !((first_v_lds_buffer_offset >= last_k_lds_buffer_end) ||
(first_v_lds_buffer_end <= last_k_lds_buffer_offset));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer()
{