Load Q through Lds

This commit is contained in:
Qianfeng Zhang
2025-12-14 15:46:37 +00:00
parent 12c88731c6
commit c3d3487ca4
2 changed files with 230 additions and 23 deletions

View File

@@ -74,6 +74,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM<Problem>();
static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM;
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
@@ -220,10 +223,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
OaccBlockTileType o_acc;
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kSubQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kGemmSingleRepM>{}, number<kSubQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
@@ -235,6 +239,14 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
using q_dram_tile_type = decltype(load_tile(q_dram_window));
statically_indexed_array<q_dram_tile_type, kGemmNumRepM> q_dram_tiles;
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
q_dram_tiles[i_rep] = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kGemmSingleRepM, 0});
});
using k_tile_type = decltype(load_tile(k_dram_window));
// only prefetch two k tiles to save vgprs consumption
@@ -250,9 +262,22 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
__builtin_amdgcn_sched_barrier(0x00000001);
auto q_tile = load_tile(q_dram_window);
// provide partition_index for LDS tile window with so that warp_id is in vgpr
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
__builtin_amdgcn_sched_barrier(0x00000001);
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(smem_ptr);
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_write_window = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto q_lds_read_window =
make_tile_window(q_lds,
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
{0, 0},
Policy::template MakeQRegSingleRepMTileDistribution<Problem>(),
partition_index);
// K tile in LDS
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
@@ -354,19 +379,58 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
using q_reg_tile_type = decltype(make_static_distributed_tensor<QDataType>(
Policy::template MakeQRegSingleRepMTileDistribution<Problem>()));
statically_indexed_array<q_reg_tile_type, kGemmNumRepM> q_reg_tiles;
using q_tile_type = decltype(make_static_distributed_tensor<QDataType>(
Policy::template MakeQRegTileDistribution<Problem>()));
q_tile_type q_tile;
{
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
store_tile(q_lds_write_window, q_dram_tiles[i_rep], partition_index);
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
// by each wavefront is read by itself
__builtin_amdgcn_s_waitcnt(0xc07f);
q_reg_tiles[i_rep] = load_tile(q_lds_read_window);
__builtin_amdgcn_s_waitcnt(0xc07f);
// the following codes will not generate actual instructions by the compiler
set_slice_tile(q_tile,
q_reg_tiles[i_rep],
sequence<i_rep * kGemmSingleRepM, 0>{},
sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{});
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice read
// by each wavefront is over-written by itself
});
clear_tile(o_acc);
set_tile(m, -numeric<CompDataType>::infinity());
clear_tile(l);
};
q_tile = tile_elementwise_in(q_element_func, q_tile);
auto seqlen_k_curr = seqlen_k_start;
__builtin_amdgcn_sched_barrier(0x00000001);
// ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0x00000001);
using v_tile_type = decltype(load_tile(v_dram_window));
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
// provide partition_index for LDS tile window with so that warp_id is in vgpr
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
do
{
// STAGE 1, Gemm_0 ( S = Q@K )

View File

@@ -40,6 +40,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
return 4;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSingleRepMTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr index_t kBlockGemmM = GetQKBlockGemmSingleRepM<Problem>();
return BlockGemm::
template MakeABlockTileDistribution<kBlockGemmM, Problem::BlockFmhaShape::kQKHeaddim>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{
@@ -90,16 +100,28 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
return 8;
else
return 4;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
return min(MaxVectorSize, ElemPerThread);
}
template <typename Problem>
@@ -209,6 +231,118 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
GetVSingleSmemElementSpaceSize<Problem>());
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t kKVector = GetAlignmentQ<Problem>();
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t DataTypeSize = sizeof(QDataType);
// 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes
constexpr auto MLdsLayer =
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock / MLdsLayer>{},
number<kKPerBlock / kKPack * MLdsLayer>{},
number<kKPack>{}),
make_tuple(number<kKPerBlock * MLdsLayer>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor(
q_lds_block_desc_0,
make_tuple(
make_xor_transform(make_tuple(number<kMPerBlock / MLdsLayer>{},
number<kKPerBlock / kKPack * MLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
constexpr auto q_lds_block_desc_k0_mldslayer_m_k1 = transform_tensor_descriptor(
q_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
make_unmerge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<MLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
q_lds_block_desc_k0_mldslayer_m_k1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return q_lds_block_desc;
}
else
{
static_assert(kKVector % kKPack == 0);
constexpr auto q_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kMPerBlock>{},
number<kKPack>{}),
make_tuple(number<kMPerBlock * kKVector + kKPack>{},
number<kMPerBlock * kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
q_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kKPack>{}))),
make_tuple(sequence<2>{}, sequence<0, 1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return q_lds_block_desc;
};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramSingleRepMTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t kKVector = GetAlignmentQ<Problem>();
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
// for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MThreadPerWarp, MPerThread>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
@@ -317,19 +451,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
constexpr index_t kKVector = GetAlignmentK<Problem>();
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
@@ -446,6 +574,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
sequence<1, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetQKBlockGemmSingleRepM()
{
return Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}) *
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
@@ -587,6 +722,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return MakeQLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::QDataType);
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
{
@@ -605,7 +747,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>();
return max(GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(),
GetSmemSizeQ<Problem>());
}
};