mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Use shared ring Lds buffers for K/V to avoid over-lapping between first-K/last-V or last-K/first-V
This commit is contained in:
@@ -220,8 +220,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<QKVDataType*>(static_cast<char*>(smem_ptr) +
|
||||
Policy::template GetExclusiveKLdsBytes<Problem>()),
|
||||
reinterpret_cast<QKVDataType*>(smem_ptr),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
@@ -396,7 +395,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>();
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_start + i_loop * kN0, s, null_randval_window);
|
||||
@@ -411,6 +410,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
|
||||
|
||||
// ensure gemm_0 has finished access of k-Lds for all warps
|
||||
// the over-lap only occurs when k0_loops is 3/5/7, NumKLdsBuffers is 2
|
||||
if constexpr(Policy::template IsFirstVLdsBufferOverlapLastKLdsBuffer<Problem>())
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
@@ -479,6 +479,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]);
|
||||
|
||||
// the over-lap only occurs when k1_loops is 3/5/7, NumVLdsBuffers is 2
|
||||
if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer<Problem>())
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -58,6 +58,44 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
return 8 / sizeof(QKVDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
static_assert(kKVector % kKPack == 0);
|
||||
|
||||
return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize()
|
||||
{
|
||||
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QKVDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
constexpr index_t NPerRow = PixelsPerRow / kKPack;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
static_assert(kNPerBlock % NPerRow == 0);
|
||||
static_assert(kKPerBlock % kKPack == 0);
|
||||
|
||||
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
|
||||
{
|
||||
return max(GetKSingleSmemElementSpaceSize<Problem>(),
|
||||
GetVSingleSmemElementSpaceSize<Problem>());
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
@@ -69,19 +107,26 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
|
||||
static_assert(kKVector % kKPack == 0);
|
||||
|
||||
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumKLdsBuffers>{},
|
||||
number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kNPerBlock>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector>{},
|
||||
number<kNPerBlock * kKVector + kKPack>{},
|
||||
number<kNPerBlock * kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
constexpr index_t KSingleSmemElementSpaceSize =
|
||||
kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
|
||||
|
||||
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
constexpr auto k_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
|
||||
number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kNPerBlock>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<kNPerBlock * kKVector + kKPack>{},
|
||||
number<kNPerBlock * kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
@@ -147,13 +192,17 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t VSingleSmemElementSpaceSize =
|
||||
(kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
|
||||
|
||||
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumVLdsBuffers>{},
|
||||
number<kKPerBlock / kKPack>{},
|
||||
number<kNPerBlock / NPerRow>{},
|
||||
number<NPerRow>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<VSingleSmemElementSpaceSize>{},
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
|
||||
number<PixelsPerRow + kKPack>{},
|
||||
number<kKPack>{},
|
||||
@@ -305,22 +354,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
// leave some exclusive space so that the second v_lds buffer will never overlap with the first
|
||||
// k_lds bufffer
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes()
|
||||
{
|
||||
constexpr index_t single_k_lds_buffer_size =
|
||||
GetSmemSizeK<Problem>() / GetNumKLdsBuffers<Problem>();
|
||||
constexpr index_t single_v_lds_buffer_size =
|
||||
GetSmemSizeV<Problem>() / GetNumVLdsBuffers<Problem>();
|
||||
|
||||
if constexpr(single_k_lds_buffer_size <= single_v_lds_buffer_size)
|
||||
return 0;
|
||||
else
|
||||
return integer_least_multiple(single_k_lds_buffer_size - single_v_lds_buffer_size, 64);
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstVLdsBufferOverlapLastKLdsBuffer()
|
||||
{
|
||||
@@ -328,26 +361,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
|
||||
constexpr index_t k0_loops = BlockFmhaShape::kQKHeaddim / BlockFmhaShape::kK0;
|
||||
constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers<Problem>();
|
||||
constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers<Problem>();
|
||||
|
||||
constexpr index_t last_k_lds_buffer_offset =
|
||||
MakeKLdsBlockDescriptor<Problem>().get_element_space_size() / num_k_lds_buffers *
|
||||
((k0_loops - 1) % num_k_lds_buffers) * sizeof(typename Problem::KDataType);
|
||||
|
||||
constexpr index_t last_k_lds_buffer_end =
|
||||
last_k_lds_buffer_offset + MakeKLdsBlockDescriptor<Problem>().get_element_space_size() /
|
||||
num_k_lds_buffers * sizeof(typename Problem::KDataType);
|
||||
|
||||
constexpr index_t first_v_lds_buffer_size =
|
||||
MakeVLdsBlockDescriptor<Problem>().get_element_space_size() / num_v_lds_buffers *
|
||||
sizeof(typename Problem::VDataType);
|
||||
|
||||
constexpr index_t first_v_lds_buffer_offset = GetExclusiveKLdsBytes<Problem>();
|
||||
constexpr index_t first_v_lds_buffer_end =
|
||||
first_v_lds_buffer_offset + first_v_lds_buffer_size;
|
||||
|
||||
return !((first_v_lds_buffer_offset >= last_k_lds_buffer_end) ||
|
||||
(first_v_lds_buffer_end <= last_k_lds_buffer_offset));
|
||||
return (k0_loops - 1) % num_k_lds_buffers == 0;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -356,43 +371,25 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1;
|
||||
constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers<Problem>();
|
||||
constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers<Problem>();
|
||||
|
||||
constexpr index_t last_v_lds_buffer_offset =
|
||||
MakeVLdsBlockDescriptor<Problem>().get_element_space_size() / num_v_lds_buffers *
|
||||
((k1_loops - 1) % num_v_lds_buffers) * sizeof(typename Problem::VDataType);
|
||||
|
||||
constexpr index_t first_k_lds_buffer_size =
|
||||
MakeKLdsBlockDescriptor<Problem>().get_element_space_size() / num_k_lds_buffers *
|
||||
sizeof(typename Problem::QKVDataType);
|
||||
|
||||
return GetExclusiveKLdsBytes<Problem>() + last_v_lds_buffer_offset <
|
||||
first_k_lds_buffer_size;
|
||||
return (k1_loops - 1) % num_v_lds_buffers == 0;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
|
||||
{
|
||||
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
|
||||
sizeof(typename Problem::QKVDataType);
|
||||
}
|
||||
constexpr index_t num_kv_lds_buffers =
|
||||
max(GetNumKLdsBuffers<Problem>(), GetNumVLdsBuffers<Problem>());
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
|
||||
{
|
||||
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
|
||||
return num_kv_lds_buffers * GetSingleSmemElementSpaceSize<Problem>() *
|
||||
sizeof(typename Problem::QKVDataType);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
// assume V can reuse the other shared memory by K except the first
|
||||
// assume Dropout can reuse the shared memory by V
|
||||
return GetExclusiveKLdsBytes<Problem>() +
|
||||
max(GetSmemSizeK<Problem>() - GetExclusiveKLdsBytes<Problem>(),
|
||||
max(GetSmemSizeV<Problem>(), GetSmemSizeDropout<Problem>(0)));
|
||||
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(0);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user