Load Q directly from global memory to registers for BlockGemm

This commit is contained in:
Qianfeng Zhang
2025-12-20 12:50:18 +00:00
parent 57abd10b95
commit 3f6d26e9a7
4 changed files with 17 additions and 367 deletions

View File

@@ -96,9 +96,6 @@ struct BlockFmhaPipelineProblem
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
// ToDo: should we define kUseTrLoad and kLoadWholeQTileOnceThrough Lds here ?
static constexpr bool kLoadWholeQTileOnceThroughLds = kUseTrLoad ? true : false;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;

View File

@@ -75,9 +75,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM<Problem>();
static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM;
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
@@ -223,11 +220,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
OaccBlockTileType o_acc;
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
@@ -239,14 +235,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
using q_dram_tile_type = decltype(load_tile(q_dram_window));
statically_indexed_array<q_dram_tile_type, kGemmNumRepM> q_dram_tiles;
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
q_dram_tiles[i_rep] = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kGemmSingleRepM, 0});
});
using k_tile_type = decltype(load_tile(k_dram_window));
// only prefetch two k tiles to save vgprs consumption
@@ -260,24 +248,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
auto q_tile = load_tile(q_dram_window);
__builtin_amdgcn_sched_barrier(0x00000001);
// provide partition_index for LDS tile window with so that warp_id is in vgpr
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(smem_ptr);
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_write_window = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
auto q_lds_read_window =
make_tile_window(q_lds,
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
{0, 0},
Policy::template MakeQRegSingleRepMTileDistribution<Problem>(),
partition_index);
// K tile in LDS
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
@@ -368,47 +345,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
using q_reg_tile_type = decltype(make_static_distributed_tensor<QDataType>(
Policy::template MakeQRegSingleRepMTileDistribution<Problem>()));
statically_indexed_array<q_reg_tile_type, kGemmNumRepM> q_reg_tiles;
using q_tile_type = decltype(make_static_distributed_tensor<QDataType>(
Policy::template MakeQRegTileDistribution<Problem>()));
q_tile_type q_tile;
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
store_tile(q_lds_write_window, q_dram_tiles[i_rep], partition_index);
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
// by each wavefront is read by itself
__builtin_amdgcn_s_waitcnt(0xc07f);
q_reg_tiles[i_rep] = load_tile(q_lds_read_window);
__builtin_amdgcn_s_waitcnt(0xc07f);
// the following codes will not generate actual instructions by the compiler
set_slice_tile(q_tile,
q_reg_tiles[i_rep],
sequence<i_rep * kGemmSingleRepM, 0>{},
sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{});
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice read
// by each wavefront is over-written by itself
});
q_tile = tile_elementwise_in(q_element_func, q_tile);
auto seqlen_k_curr = seqlen_k_start;
__builtin_amdgcn_sched_barrier(0x00000001);
// ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0x00000001);
using v_tile_type = decltype(load_tile(v_dram_window));
statically_indexed_array<v_tile_type, k1_loops> v_tiles;

View File

@@ -43,16 +43,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
return 4;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSingleRepMTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr index_t kBlockGemmM = GetQKBlockGemmSingleRepM<Problem>();
return BlockGemm::
template MakeABlockTileDistribution<kBlockGemmM, Problem::BlockFmhaShape::kQKHeaddim>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{
@@ -103,33 +93,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
return 8;
else
return 4;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
if constexpr(Problem::kLoadWholeQTileOnceThroughLds)
{
return Problem::GetQDramTileAccessMaxVectorSize();
}
else
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return detail::
GetDramTileAccessMaxVectorSize<QDataType, kBlockSize, kMPerBlock, kKPerBlock>();
};
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
}
template <typename Problem>
@@ -257,217 +230,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
GetVSingleSmemElementSpaceSize<Problem>());
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds
? Problem::BlockFmhaShape::kM0
: GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t kKVector = GetAlignmentQ<Problem>();
// for hdim96 and hdim160, use simplest layout
if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim)
{
return make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
make_tuple(number<kKPerBlock>{}, number<1>{}),
number<kKVector>{},
number<1>{});
}
else if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t DataTypeSize = sizeof(QDataType);
// 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes
constexpr auto MLdsLayer =
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock / MLdsLayer>{},
number<kKPerBlock / kKPack * MLdsLayer>{},
number<kKPack>{}),
make_tuple(number<kKPerBlock * MLdsLayer>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor(
q_lds_block_desc_0,
make_tuple(
make_xor_transform(make_tuple(number<kMPerBlock / MLdsLayer>{},
number<kKPerBlock / kKPack * MLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
constexpr auto q_lds_block_desc_k0_mldslayer_m_k1 = transform_tensor_descriptor(
q_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
make_unmerge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<MLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
q_lds_block_desc_k0_mldslayer_m_k1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return q_lds_block_desc;
}
else
{
static_assert(kKVector % kKPack == 0);
constexpr auto q_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kMPerBlock>{},
number<kKPack>{}),
make_tuple(number<kMPerBlock * kKVector + kKPack>{},
number<kMPerBlock * kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
q_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kKPack>{}))),
make_tuple(sequence<2>{}, sequence<0, 1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return q_lds_block_desc;
};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramSingleRepMTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKVector = GetAlignmentQ<Problem>();
constexpr index_t OtherK = kKPerBlock / kKVector;
if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim)
// for kKPerBlock=32,64,128,256
{
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = OtherK;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
else // for kKPerBlock=96,160
{
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
// ToDo: need more considieration for hdim72
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
constexpr index_t KThreads = OtherK / KRepPerThread;
static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!");
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KRepPerThread, KThreads, kKVector>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<1, 0, 2>>{});
};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKVector = GetAlignmentQ<Problem>();
constexpr index_t OtherK = kKPerBlock / kKVector;
if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim)
// for kKPerBlock=32,64,128,256
{
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = OtherK;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
else // for kKPerBlock=96,160
{
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
// ToDo: need more considieration for hdim72
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
constexpr index_t KThreads = OtherK / KRepPerThread;
static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!");
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KRepPerThread, KThreads, kKVector>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<1, 0, 2>>{});
};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
@@ -823,13 +585,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
sequence<1, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetQKBlockGemmSingleRepM()
{
return Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}) *
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
@@ -976,13 +731,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return MakeQLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::QDataType);
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
{
@@ -1001,8 +749,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return max(GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(),
GetSmemSizeQ<Problem>());
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>();
}
};

View File

@@ -65,8 +65,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
static constexpr bool kUseTrLoad = true;
static_assert(Problem::kLoadWholeQTileOnceThroughLds == true, "Check failed!");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
@@ -226,11 +224,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
OaccBlockTileType o_acc;
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
@@ -242,7 +239,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
auto q_dram_tile = load_tile(q_dram_window);
auto q_tile = load_tile(q_dram_window);
using k_tile_type = decltype(load_tile(k_dram_window));
@@ -262,19 +259,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
// provide partition_index for LDS tile window with so that warp_id is in vgpr
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(smem_ptr);
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_write_window = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
auto q_lds_read_window =
make_tile_window(q_lds,
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
{0, 0},
Policy::template MakeQRegTileDistribution<Problem>(),
partition_index);
// K tile in LDS
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
@@ -361,18 +345,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
store_tile(q_lds_write_window, q_dram_tile, partition_index);
clear_tile(o_acc);
__builtin_amdgcn_sched_barrier(0x00000001);
block_sync_lds();
auto q_tile = load_tile(q_lds_read_window);
q_tile = tile_elementwise_in(q_element_func, q_tile);
set_tile(m, -numeric<CompDataType>::infinity());
clear_tile(l);
@@ -380,13 +353,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
auto seqlen_k_curr = seqlen_k_start;
__builtin_amdgcn_sched_barrier(0x00000001);
// ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0x00000001);
using v_tile_type = decltype(load_tile(v_dram_window));
statically_indexed_array<v_tile_type, k1_loops> v_tiles;