mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Add IsFirstVLdsBufferOverlapLastKLdsBuffer() check to reduce call of s_barrier()
This commit is contained in:
@@ -342,6 +342,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
|
||||
static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
|
||||
@@ -349,8 +351,6 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, siLU
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
@@ -416,8 +416,6 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
randval_lds_ptr, seqlen_k_start + i_loop * kN0, s, null_randval_window);
|
||||
}
|
||||
|
||||
// ensure gemm_0 has finished access of k-Lds for all warps
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0x7f);
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
@@ -426,12 +424,20 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
|
||||
|
||||
// ensure gemm_0 has finished access of k-Lds for all warps
|
||||
if constexpr(Policy::template IsFirstVLdsBufferOverlapLastKLdsBuffer<Problem>())
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
store_tile(
|
||||
v_lds_windows[I0],
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
// ensure gemm_0 has finished access of k-Lds for all warps
|
||||
if constexpr(Policy::template IsFirstVLdsBufferOverlapLastKLdsBuffer<Problem>())
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
store_tile(v_lds_windows[I0],
|
||||
tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch
|
||||
}
|
||||
|
||||
@@ -321,6 +321,35 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
return integer_least_multiple(single_k_lds_buffer_size - single_v_lds_buffer_size, 64);
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstVLdsBufferOverlapLastKLdsBuffer()
|
||||
{
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
constexpr index_t k0_loops = BlockFmhaShape::kQKHeaddim / BlockFmhaShape::kK0;
|
||||
constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers<Problem>();
|
||||
constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers<Problem>();
|
||||
|
||||
constexpr index_t last_k_lds_buffer_offset =
|
||||
MakeKLdsBlockDescriptor<Problem>().get_element_space_size() / num_k_lds_buffers *
|
||||
((k0_loops - 1) % num_k_lds_buffers) * sizeof(typename Problem::KDataType);
|
||||
|
||||
constexpr index_t last_k_lds_buffer_end =
|
||||
last_k_lds_buffer_offset + MakeKLdsBlockDescriptor<Problem>().get_element_space_size() /
|
||||
num_k_lds_buffers * sizeof(typename Problem::KDataType);
|
||||
|
||||
constexpr index_t first_v_lds_buffer_size =
|
||||
MakeVLdsBlockDescriptor<Problem>().get_element_space_size() / num_v_lds_buffers *
|
||||
sizeof(typename Problem::VDataType);
|
||||
|
||||
constexpr index_t first_v_lds_buffer_offset = GetExclusiveKLdsBytes<Problem>();
|
||||
constexpr index_t first_v_lds_buffer_end =
|
||||
first_v_lds_buffer_offset + first_v_lds_buffer_size;
|
||||
|
||||
return !((first_v_lds_buffer_offset >= last_k_lds_buffer_end) ||
|
||||
(first_v_lds_buffer_end <= last_k_lds_buffer_offset));
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user