Use shared ring Lds buffers for K/V to avoid over-lapping between first-K/last-V or last-K/first-V

This commit is contained in:
Qianfeng Zhang
2025-04-18 09:47:43 +00:00
parent f12a47218f
commit 88e54a8989
2 changed files with 75 additions and 77 deletions

View File

@@ -220,8 +220,7 @@ struct HstuAttentionFwdPipelineQRKSVS
Policy::template MakeVDramTileDistribution<Problem>());
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<QKVDataType*>(static_cast<char*>(smem_ptr) +
Policy::template GetExclusiveKLdsBytes<Problem>()),
reinterpret_cast<QKVDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
@@ -396,7 +395,7 @@ struct HstuAttentionFwdPipelineQRKSVS
if constexpr(kHasDropout)
{
auto randval_lds_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>();
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
randval_lds_ptr, seqlen_k_start + i_loop * kN0, s, null_randval_window);
@@ -411,6 +410,7 @@ struct HstuAttentionFwdPipelineQRKSVS
shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
// ensure gemm_0 has finished access of k-Lds for all warps
// the over-lap only occurs when k0_loops is 3/5/7, NumKLdsBuffers is 2
if constexpr(Policy::template IsFirstVLdsBufferOverlapLastKLdsBuffer<Problem>())
__builtin_amdgcn_s_barrier();
@@ -479,6 +479,7 @@ struct HstuAttentionFwdPipelineQRKSVS
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]);
// the over-lap only occurs when k1_loops is 3/5/7, NumVLdsBuffers is 2
if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer<Problem>())
{
__builtin_amdgcn_sched_barrier(0);

View File

@@ -58,6 +58,44 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
return 8 / sizeof(QKVDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKVector = GetAlignmentK<Problem>();
static_assert(kKVector % kKPack == 0);
return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize()
{
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QKVDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
{
return max(GetKSingleSmemElementSpaceSize<Problem>(),
GetVSingleSmemElementSpaceSize<Problem>());
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
@@ -69,19 +107,26 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
static_assert(kKVector % kKPack == 0);
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumKLdsBuffers>{},
number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kNPerBlock>{},
number<kKPack>{}),
make_tuple(number<kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector>{},
number<kNPerBlock * kKVector + kKPack>{},
number<kNPerBlock * kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr index_t KSingleSmemElementSpaceSize =
kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kNPerBlock>{},
number<kKPack>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<kNPerBlock * kKVector + kKPack>{},
number<kNPerBlock * kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
@@ -147,13 +192,17 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t VSingleSmemElementSpaceSize =
(kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumVLdsBuffers>{},
number<kKPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<VSingleSmemElementSpaceSize>{},
make_tuple(number<SingleSmemElementSpaceSize>{},
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
@@ -305,22 +354,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
// leave some exclusive space so that the second v_lds buffer will never overlap with the first
// k_lds bufffer
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes()
{
constexpr index_t single_k_lds_buffer_size =
GetSmemSizeK<Problem>() / GetNumKLdsBuffers<Problem>();
constexpr index_t single_v_lds_buffer_size =
GetSmemSizeV<Problem>() / GetNumVLdsBuffers<Problem>();
if constexpr(single_k_lds_buffer_size <= single_v_lds_buffer_size)
return 0;
else
return integer_least_multiple(single_k_lds_buffer_size - single_v_lds_buffer_size, 64);
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstVLdsBufferOverlapLastKLdsBuffer()
{
@@ -328,26 +361,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t k0_loops = BlockFmhaShape::kQKHeaddim / BlockFmhaShape::kK0;
constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers<Problem>();
constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers<Problem>();
constexpr index_t last_k_lds_buffer_offset =
MakeKLdsBlockDescriptor<Problem>().get_element_space_size() / num_k_lds_buffers *
((k0_loops - 1) % num_k_lds_buffers) * sizeof(typename Problem::KDataType);
constexpr index_t last_k_lds_buffer_end =
last_k_lds_buffer_offset + MakeKLdsBlockDescriptor<Problem>().get_element_space_size() /
num_k_lds_buffers * sizeof(typename Problem::KDataType);
constexpr index_t first_v_lds_buffer_size =
MakeVLdsBlockDescriptor<Problem>().get_element_space_size() / num_v_lds_buffers *
sizeof(typename Problem::VDataType);
constexpr index_t first_v_lds_buffer_offset = GetExclusiveKLdsBytes<Problem>();
constexpr index_t first_v_lds_buffer_end =
first_v_lds_buffer_offset + first_v_lds_buffer_size;
return !((first_v_lds_buffer_offset >= last_k_lds_buffer_end) ||
(first_v_lds_buffer_end <= last_k_lds_buffer_offset));
return (k0_loops - 1) % num_k_lds_buffers == 0;
}
template <typename Problem>
@@ -356,43 +371,25 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1;
constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers<Problem>();
constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers<Problem>();
constexpr index_t last_v_lds_buffer_offset =
MakeVLdsBlockDescriptor<Problem>().get_element_space_size() / num_v_lds_buffers *
((k1_loops - 1) % num_v_lds_buffers) * sizeof(typename Problem::VDataType);
constexpr index_t first_k_lds_buffer_size =
MakeKLdsBlockDescriptor<Problem>().get_element_space_size() / num_k_lds_buffers *
sizeof(typename Problem::QKVDataType);
return GetExclusiveKLdsBytes<Problem>() + last_v_lds_buffer_offset <
first_k_lds_buffer_size;
return (k1_loops - 1) % num_v_lds_buffers == 0;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
{
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::QKVDataType);
}
constexpr index_t num_kv_lds_buffers =
max(GetNumKLdsBuffers<Problem>(), GetNumVLdsBuffers<Problem>());
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
{
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
return num_kv_lds_buffers * GetSingleSmemElementSpaceSize<Problem>() *
sizeof(typename Problem::QKVDataType);
}
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
// assume V can reuse the other shared memory by K except the first
// assume Dropout can reuse the shared memory by V
return GetExclusiveKLdsBytes<Problem>() +
max(GetSmemSizeK<Problem>() - GetExclusiveKLdsBytes<Problem>(),
max(GetSmemSizeV<Problem>(), GetSmemSizeDropout<Problem>(0)));
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(0);
}
};