Use LDS to in-directly load Q-tile to enable dwordx4 loading and avoid cachelines wasting

This commit is contained in:
Qianfeng Zhang
2025-05-21 16:44:39 +00:00
parent a1346aaf3e
commit 81f7b139e0
2 changed files with 207 additions and 7 deletions

View File

@@ -68,6 +68,9 @@ struct HstuAttentionFwdPipelineQRKSVS
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM<Problem>();
static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM;
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::Traits::kBlockPerCu != -1)
return Problem::Traits::kBlockPerCu;
@@ -173,10 +176,11 @@ struct HstuAttentionFwdPipelineQRKSVS
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
OaccBlockTileType o_acc;
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
@@ -188,15 +192,31 @@ struct HstuAttentionFwdPipelineQRKSVS
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
auto q_tile = load_tile(q_dram_window);
using q_dram_tile_type = decltype(load_tile(q_dram_window));
statically_indexed_array<q_dram_tile_type, kGemmNumRepM> q_dram_tiles;
clear_tile(o_acc);
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
q_dram_tiles[i_rep] = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kGemmSingleRepM, 0});
});
auto k_tile = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
__builtin_amdgcn_sched_barrier(0);
// K tile in LDS
QKVDataType* q_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_write_window = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
auto q_lds_read_window =
make_tile_window(q_lds,
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0},
Policy::template MakeQRegSingleRepMTileDistribution<Problem>());
// K tile in LDS
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
@@ -284,12 +304,54 @@ struct HstuAttentionFwdPipelineQRKSVS
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
using q_reg_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
Policy::template MakeQRegSingleRepMTileDistribution<Problem>()));
statically_indexed_array<q_reg_tile_type, kGemmNumRepM> q_reg_tiles;
using q_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
Policy::template MakeQRegTileDistribution<Problem>()));
q_tile_type q_tile;
{
clear_tile(o_acc);
constexpr index_t complete_tile_thread_buf_size = q_tile_type::get_thread_buffer_size();
constexpr index_t splitted_tile_thread_buf_size =
q_reg_tile_type::get_thread_buffer_size();
static_assert(complete_tile_thread_buf_size ==
kGemmNumRepM * splitted_tile_thread_buf_size,
"Check failed!");
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
store_tile(q_lds_write_window, q_dram_tiles[i_rep]);
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
// by each wavefront is read by itself
__builtin_amdgcn_s_waitcnt(0xc07f);
q_reg_tiles[i_rep] = load_tile(q_lds_read_window);
static_for<0, splitted_tile_thread_buf_size, 1>{}([&](auto i_buf) {
q_tile.get_thread_buffer()[i_rep * splitted_tile_thread_buf_size + i_buf] =
q_reg_tiles[i_rep].get_thread_buffer()[i_buf];
});
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice read
// by each wavefront is over-written by itself
});
};
q_tile = tile_elementwise_in(q_element_func, q_tile);
auto seqlen_k_curr = seqlen_k_start;
index_t i_loop = 0;
// ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile
__builtin_amdgcn_s_barrier();
while(i_loop < num_loops)
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {

View File

@@ -23,6 +23,16 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
return 3;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSingleRepMTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr index_t kBlockGemmM = GetQKBlockGemmSingleRepM<Problem>();
return BlockGemm::
template MakeABlockTileDistribution<kBlockGemmM, Problem::BlockFmhaShape::kQKHeaddim>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{
@@ -43,6 +53,30 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
return WG::WarpGemmAttribute::kKPerThread;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
return 8;
else
return 4;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
return min(MaxVectorSize, ElemPerThread);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
{
@@ -114,6 +148,95 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
GetVSingleSmemElementSpaceSize<Problem>());
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t kKVector = GetAlignmentQ<Problem>();
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<kMPerBlock * kKPack + kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
q_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return q_lds_block_desc;
}
else
{
static_assert(kKVector % kKPack == 0);
constexpr auto q_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kMPerBlock>{},
number<kKPack>{}),
make_tuple(number<kMPerBlock * kKVector + kKPack>{},
number<kMPerBlock * kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
q_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kKPack>{}))),
make_tuple(sequence<2>{}, sequence<0, 1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return q_lds_block_desc;
};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramSingleRepMTileDistribution()
{
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
// for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MThreadPerWarp, MPerThread>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
@@ -345,6 +468,13 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetQKBlockGemmSingleRepM()
{
return Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}) *
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
@@ -488,6 +618,13 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
return (k1_loops - 1 + 1) % num_kv_lds_buffers == 0;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return MakeQLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::QKVDataType);
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
{
@@ -500,7 +637,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(0);
return max(GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(0),
GetSmemSizeQ<Problem>());
}
};