CK-UA: widen FP8 K/V async loads to dwordx4 (per-WG load-thread count)

The K/V async-load width selector (GetKVAlignmentBytes via GetAlignmentK/V)
computed its dwordx4 budget from the full block (kBlockSize=512 threads), so
the 4 KB FP8 prefill tile never tiled cleanly and fell back to dword. With the
FA4 per-warp-group decoupling a single 4-warp group (256 thr) fills the tile by
itself -> 4 KB / 256 = exactly 16 B/thr = dwordx4.

Thread the load-thread count into GetAlignmentK/V as a NumWarps template param
(default = shape NumWarps, so all sizing/paged/decode instantiations are
byte-identical). Only the load-path callers (MakeK/VDramTileDistribution, the
LDS store/load descriptors, and the kAlignmentK/V DRAM-view vector) pass the
decoupled GetK/VLoadNumWarps count to unlock the wide load.

Effect: global_load_lds_dword 36->9 dwordx4 and buffer_load_dword 36->9 dwordx4
(both runtime branches); VGPR 181->173; LDS/SGPR unchanged. Accuracy PASS 0%
(non-causal + causal). Latency-neutral on sq8192 (kernel is memory-latency
bound, not load-issue bound) but a strictly-better instruction/VGPR footprint.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-06-09 16:40:21 +00:00
parent c11722bf3e
commit a4d3ff34fb
2 changed files with 37 additions and 14 deletions

View File

@@ -224,10 +224,18 @@ struct UnifiedAttentionPipeline
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr ck_tile::index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
// The DRAM-view vector length must match the K/V load distribution's
// KVector, which the FA4 decoupling widens to the per-warp-group load count
// (GetK/VLoadNumWarps). Passing the same warp count here keeps the global
// buffer_load width in lock-step with the async-copy descriptors.
static constexpr ck_tile::index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
kPadHeadDimQ ? 1
: Policy::template GetAlignmentK<Problem,
Policy::template GetKLoadNumWarps<Problem>()>();
static constexpr ck_tile::index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
kPadHeadDimV ? 1
: Policy::template GetAlignmentV<Problem,
Policy::template GetVLoadNumWarps<Problem>()>();
static constexpr ck_tile::index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();

View File

@@ -63,7 +63,9 @@ struct UnifiedAttentionPipelineDefaultPolicy
//
// The selector below picks dwordx4 whenever it tiles cleanly and falls
// back to dword (matches the historical FP8 path) on the prefill tier.
template <typename Problem, index_t ElementSizeInBytes>
template <typename Problem,
index_t ElementSizeInBytes,
index_t NumLoadThreads = Problem::kBlockSize>
CK_TILE_DEVICE static constexpr index_t GetKVAlignmentBytes()
{
#if defined(__gfx950__)
@@ -71,7 +73,13 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t tile_elems =
Problem::UnifiedAttentionShape::kPageBlockSize *
Problem::UnifiedAttentionShape::kHeadDim;
constexpr index_t block_size = Problem::kBlockSize;
// Threads that actually cooperate on this tile's load. Default is the
// whole block (kBlockSize), but the FA4 per-warp-group decoupling has a
// single 4-warp group fill the tile by itself, so the width budget is
// that group's thread count -- which lets the small FP8 prefill tile
// (4 KB / 256 thr = 16 B/thr) finally tile cleanly at dwordx4 instead of
// falling back to 4x as many dword loads.
constexpr index_t block_size = NumLoadThreads;
// KVector_elems for 16 B/lane = 16 / ElementSizeInBytes.
// NumIssues * KVector_bytes * kBlockSize == tile_bytes,
// so the divisibility check is tile_elems * ElementSizeInBytes
@@ -91,23 +99,30 @@ struct UnifiedAttentionPipelineDefaultPolicy
#endif
}
template <typename Problem>
// NumWarps = the waves that cooperate on the K/V load (default = the full
// block). The FA4 decoupling loads K/V with a single 4-warp group, so the
// load-path callers pass that group's warp count to widen the load.
template <typename Problem,
ck_tile::index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps>
CK_TILE_DEVICE static constexpr auto GetAlignmentK()
{
using namespace ck_tile;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t NumLoadThreads = NumWarps * get_warp_size();
constexpr index_t MaxReadSizeInBytes =
GetKVAlignmentBytes<Problem, sizeof(KDataType)>();
GetKVAlignmentBytes<Problem, sizeof(KDataType), NumLoadThreads>();
return MaxReadSizeInBytes / sizeof(KDataType);
}
template <typename Problem>
template <typename Problem,
ck_tile::index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps>
CK_TILE_DEVICE static constexpr auto GetAlignmentV()
{
using namespace ck_tile;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t NumLoadThreads = NumWarps * get_warp_size();
constexpr index_t MaxReadSizeInBytes =
GetKVAlignmentBytes<Problem, sizeof(VDataType)>();
GetKVAlignmentBytes<Problem, sizeof(VDataType), NumLoadThreads>();
return MaxReadSizeInBytes / sizeof(VDataType);
}
@@ -156,7 +171,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t NumWarps = NumWarpsOverride;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t KVector = GetAlignmentK<Problem, NumWarpsOverride>(); // this is for global load
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
@@ -197,7 +212,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t NumWarps = NumWarpsOverride;
constexpr index_t WarpSize = ck_tile::get_warp_size(); // 64
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
constexpr index_t KVector = GetAlignmentV<Problem, NumWarpsOverride>(); // this is for global load
// 4
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
@@ -418,7 +433,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t WarpSize = ck_tile::get_warp_size();
[[maybe_unused]] constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t KVector = GetAlignmentK<Problem, NumWarpsOverride>(); // this is for global load
constexpr index_t kPad =
kKLdsPadInBytes /
sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps.
@@ -478,7 +493,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t KVector = GetAlignmentK<Problem, NumWarpsOverride>(); // this is for global load
constexpr index_t kPad =
kKLdsPadInBytes /
sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps
@@ -602,7 +617,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t WarpSize = ck_tile::get_warp_size();
[[maybe_unused]] constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
constexpr index_t KVector = GetAlignmentV<Problem, NumWarpsOverride>(); // this is for global load
constexpr index_t kPad =
kVLdsPadInBytes /
sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps.
@@ -661,7 +676,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t KVector = GetAlignmentK<Problem, NumWarpsOverride>(); // this is for global load
constexpr index_t kPad =
kVLdsPadInBytes /
sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps