mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
CK-UA: widen FP8 K/V async loads to dwordx4 (per-WG load-thread count)
The K/V async-load width selector (GetKVAlignmentBytes via GetAlignmentK/V) computed its dwordx4 budget from the full block (kBlockSize=512 threads), so the 4 KB FP8 prefill tile never tiled cleanly and fell back to dword. With the FA4 per-warp-group decoupling a single 4-warp group (256 thr) fills the tile by itself -> 4 KB / 256 = exactly 16 B/thr = dwordx4. Thread the load-thread count into GetAlignmentK/V as a NumWarps template param (default = shape NumWarps, so all sizing/paged/decode instantiations are byte-identical). Only the load-path callers (MakeK/VDramTileDistribution, the LDS store/load descriptors, and the kAlignmentK/V DRAM-view vector) pass the decoupled GetK/VLoadNumWarps count to unlock the wide load. Effect: global_load_lds_dword 36->9 dwordx4 and buffer_load_dword 36->9 dwordx4 (both runtime branches); VGPR 181->173; LDS/SGPR unchanged. Accuracy PASS 0% (non-causal + causal). Latency-neutral on sq8192 (kernel is memory-latency bound, not load-issue bound) but a strictly-better instruction/VGPR footprint. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -224,10 +224,18 @@ struct UnifiedAttentionPipeline
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr ck_tile::index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
// The DRAM-view vector length must match the K/V load distribution's
|
||||
// KVector, which the FA4 decoupling widens to the per-warp-group load count
|
||||
// (GetK/VLoadNumWarps). Passing the same warp count here keeps the global
|
||||
// buffer_load width in lock-step with the async-copy descriptors.
|
||||
static constexpr ck_tile::index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
kPadHeadDimQ ? 1
|
||||
: Policy::template GetAlignmentK<Problem,
|
||||
Policy::template GetKLoadNumWarps<Problem>()>();
|
||||
static constexpr ck_tile::index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
kPadHeadDimV ? 1
|
||||
: Policy::template GetAlignmentV<Problem,
|
||||
Policy::template GetVLoadNumWarps<Problem>()>();
|
||||
|
||||
static constexpr ck_tile::index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
|
||||
@@ -63,7 +63,9 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
//
|
||||
// The selector below picks dwordx4 whenever it tiles cleanly and falls
|
||||
// back to dword (matches the historical FP8 path) on the prefill tier.
|
||||
template <typename Problem, index_t ElementSizeInBytes>
|
||||
template <typename Problem,
|
||||
index_t ElementSizeInBytes,
|
||||
index_t NumLoadThreads = Problem::kBlockSize>
|
||||
CK_TILE_DEVICE static constexpr index_t GetKVAlignmentBytes()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
@@ -71,7 +73,13 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t tile_elems =
|
||||
Problem::UnifiedAttentionShape::kPageBlockSize *
|
||||
Problem::UnifiedAttentionShape::kHeadDim;
|
||||
constexpr index_t block_size = Problem::kBlockSize;
|
||||
// Threads that actually cooperate on this tile's load. Default is the
|
||||
// whole block (kBlockSize), but the FA4 per-warp-group decoupling has a
|
||||
// single 4-warp group fill the tile by itself, so the width budget is
|
||||
// that group's thread count -- which lets the small FP8 prefill tile
|
||||
// (4 KB / 256 thr = 16 B/thr) finally tile cleanly at dwordx4 instead of
|
||||
// falling back to 4x as many dword loads.
|
||||
constexpr index_t block_size = NumLoadThreads;
|
||||
// KVector_elems for 16 B/lane = 16 / ElementSizeInBytes.
|
||||
// NumIssues * KVector_bytes * kBlockSize == tile_bytes,
|
||||
// so the divisibility check is tile_elems * ElementSizeInBytes
|
||||
@@ -91,23 +99,30 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
// NumWarps = the waves that cooperate on the K/V load (default = the full
|
||||
// block). The FA4 decoupling loads K/V with a single 4-warp group, so the
|
||||
// load-path callers pass that group's warp count to widen the load.
|
||||
template <typename Problem,
|
||||
ck_tile::index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps>
|
||||
CK_TILE_DEVICE static constexpr auto GetAlignmentK()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
constexpr index_t NumLoadThreads = NumWarps * get_warp_size();
|
||||
constexpr index_t MaxReadSizeInBytes =
|
||||
GetKVAlignmentBytes<Problem, sizeof(KDataType)>();
|
||||
GetKVAlignmentBytes<Problem, sizeof(KDataType), NumLoadThreads>();
|
||||
return MaxReadSizeInBytes / sizeof(KDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem,
|
||||
ck_tile::index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps>
|
||||
CK_TILE_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t NumLoadThreads = NumWarps * get_warp_size();
|
||||
constexpr index_t MaxReadSizeInBytes =
|
||||
GetKVAlignmentBytes<Problem, sizeof(VDataType)>();
|
||||
GetKVAlignmentBytes<Problem, sizeof(VDataType), NumLoadThreads>();
|
||||
return MaxReadSizeInBytes / sizeof(VDataType);
|
||||
}
|
||||
|
||||
@@ -156,7 +171,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t NumWarps = NumWarpsOverride;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t KVector = GetAlignmentK<Problem, NumWarpsOverride>(); // this is for global load
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
@@ -197,7 +212,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t NumWarps = NumWarpsOverride;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size(); // 64
|
||||
|
||||
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
|
||||
constexpr index_t KVector = GetAlignmentV<Problem, NumWarpsOverride>(); // this is for global load
|
||||
// 4
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
@@ -418,7 +433,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
[[maybe_unused]] constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t KVector = GetAlignmentK<Problem, NumWarpsOverride>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kKLdsPadInBytes /
|
||||
sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps.
|
||||
@@ -478,7 +493,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t KVector = GetAlignmentK<Problem, NumWarpsOverride>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kKLdsPadInBytes /
|
||||
sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps
|
||||
@@ -602,7 +617,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
[[maybe_unused]] constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
|
||||
constexpr index_t KVector = GetAlignmentV<Problem, NumWarpsOverride>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kVLdsPadInBytes /
|
||||
sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps.
|
||||
@@ -661,7 +676,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t KVector = GetAlignmentK<Problem, NumWarpsOverride>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kVLdsPadInBytes /
|
||||
sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps
|
||||
|
||||
Reference in New Issue
Block a user