CK-UA: decouple K/V DRAM loads per warp-group + early V read (FA4 fp8)

Split the cooperative K/V cache loads across the two FA4 warp groups so
each group owns exactly one tile's DRAM load and address arithmetic:
WG0 loads V, WG1 loads K, and both read from the shared LDS buffers.

- kFA4WG0LoadsV / kFA4WG1LoadsK policy flags + GetVLoadNumWarps /
  GetKLoadNumWarps: the owning group's 4 waves alone fill the tile via
  4-warp descriptors; the partner skips the load and reads from LDS.
- High-warp-group support for the raw async path: the raw store bakes the
  absolute warp id into the LDS M0, so WG1 (waves 4-7) needs a base shift
  (GetKStoreWarpShift / WarpIdShift in MakeKLdsStoreBlockDescriptor) to
  map back to the 4-warp layout, plus WG-relative (warp % NumWarps) page
  offsets so the gather token positions are correct.
- Stage B: move each tile's V LDS read into the PRECEDING softmax phase so
  the read latency hides under softmax VALU. Safe because V is now single-
  group-owned; uses drain-before-barrier (vmcnt<0> then s_barrier) so all
  4 cooperating writer waves' slices are published before the read.
- Gate per-tile offset refresh per warp-group (WG0 refreshes V, WG1 K), so
  each wave fetches a block-table page index for one tile instead of both;
  loop counters stay uniform.

Validated 0% mismatch vs GPU reference, causal + non-causal, sq 256..8192.
Net latency vs the cooperative baseline: causal ~-3-4.6%, non-causal
~-2-4.7% across sq 2048..16384 (d128 fp8).

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-06-09 14:35:00 +00:00
parent 26bc49f733
commit c11722bf3e
2 changed files with 248 additions and 91 deletions

View File

@@ -488,6 +488,19 @@ struct UnifiedAttentionPipeline
const index_t warp_group_id = get_warp_id() / 4;
// FA4 "WG0 loads V": warp group 0's 4 waves load the FULL V tile into
// the shared V LDS buffer (V descriptors use VLoadNumWarps == 4 waves);
// warp group 1 skips the V DRAM load and relies on the inter-phase
// barrier for residency. v_load_active gates the async V load issue.
constexpr index_t VLoadNumWarps = Policy::template GetVLoadNumWarps<Problem>();
constexpr index_t KLoadNumWarps = Policy::template GetKLoadNumWarps<Problem>();
constexpr index_t NumWarpGroups_ = Problem::kBlockSize / Policy::NumThreadPerWarpGroup;
const bool v_load_active =
(!Policy::kFA4WG0LoadsV) || (NumWarpGroups_ != 2) || (warp_group_id == 0);
// Symmetric: warp group 1 alone loads K (WG0 reads from shared LDS).
const bool k_load_active =
(!Policy::kFA4WG1LoadsK) || (NumWarpGroups_ != 2) || (warp_group_id == 1);
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
@@ -531,34 +544,41 @@ struct UnifiedAttentionPipeline
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
constexpr index_t KStoreWarpShift = Policy::template GetKStoreWarpShift<Problem>();
auto k_lds_window_store = generate_tuple(
[&](auto i_buf) {
return make_lds_tile_window<KDataType>(
smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf));
smem_ptr,
Policy::template MakeKLdsStoreBlockDescriptor<Problem,
KLoadNumWarps,
KStoreWarpShift>(i_buf));
},
number<2>{});
auto v_lds_window_store = generate_tuple(
[&](auto i_buf) {
return make_lds_tile_window<KDataType>(
smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor<Problem>(i_buf));
smem_ptr,
Policy::template MakeVLdsStoreBlockDescriptor<Problem, VLoadNumWarps>(i_buf));
},
number<2>{});
statically_indexed_array<decltype(make_tile_window(
make_lds_tile_window<KDataType>(
nullptr,
Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
Policy::template MakeKRegTileDistribution<Problem>())),
2>
statically_indexed_array<
decltype(make_tile_window(
make_lds_tile_window<KDataType>(
nullptr,
Policy::template MakeKLdsLoadBlockDescriptor<Problem, KLoadNumWarps>()),
Policy::template MakeKRegTileDistribution<Problem>())),
2>
k_lds_window_load;
statically_indexed_array<decltype(make_tile_window(
make_lds_tile_window<VDataType>(
nullptr,
Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
Policy::template MakeVRegTileDistribution<Problem>())),
2>
statically_indexed_array<
decltype(make_tile_window(
make_lds_tile_window<VDataType>(
nullptr,
Policy::template MakeVLdsLoadBlockDescriptor<Problem, VLoadNumWarps>()),
Policy::template MakeVRegTileDistribution<Problem>())),
2>
v_lds_window_load;
decltype(make_static_distributed_tensor<QDataType>(
@@ -619,7 +639,7 @@ struct UnifiedAttentionPipeline
k_lds_window_load(idx) = make_tile_window(
make_lds_tile_window<KDataType>(
static_cast<char*>(smem_ptr) + (idx)*Policy::template GetSmemSizeKV<Problem>(),
Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
Policy::template MakeKLdsLoadBlockDescriptor<Problem, KLoadNumWarps>()),
Policy::template MakeKRegTileDistribution<Problem>());
});
@@ -628,7 +648,8 @@ struct UnifiedAttentionPipeline
make_tile_window(make_lds_tile_window<VDataType>(
static_cast<char*>(smem_ptr) +
(idx + 2) * Policy::template GetSmemSizeKV<Problem>(),
Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
Policy::template MakeVLdsLoadBlockDescriptor<Problem,
VLoadNumWarps>()),
Policy::template MakeVRegTileDistribution<Problem>());
});
@@ -714,8 +735,8 @@ struct UnifiedAttentionPipeline
// a large negative relative offset that the HW OOB check clamps to 0).
// A robust fix would either plumb long_index_t through the gather load
// path or compute a per-batch min-page shift in a pre-pass.
const auto k_dist = Policy::template MakeKDramTileDistribution<Problem>();
const auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
const auto k_dist = Policy::template MakeKDramTileDistribution<Problem, KLoadNumWarps>();
const auto v_dist = Policy::template MakeVDramTileDistribution<Problem, VLoadNumWarps>();
using KDstrType = decltype(k_dist);
using VDstrType = decltype(v_dist);
constexpr index_t KNRepeat =
@@ -729,8 +750,19 @@ struct UnifiedAttentionPipeline
VDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<1>{}] *
VDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<2>{}];
const auto k_thread_coord = k_dist.calculate_index();
const auto v_thread_coord = v_dist.calculate_index();
// WG-relative warp index for the gather page-offset computation. When a
// single warp group loads a tile (V by WG0 / K by WG1), only that
// group's waves issue the load, and their absolute warp ids must be
// folded back into [0, NumWarps) so k_thread_n_pos / v_thread_n_pos (the
// per-wave token position baked into page_idx) match the group-relative
// distribution. For the cooperative case NumWarps == full block, so the
// modulo is the identity. The scatter-gather's own get_partition_index
// use is harmless here: the gather (token) dim is zeroed and replaced by
// page_idx, and the remaining (head-dim) coordinate is lane-based.
const auto k_part = ck_tile::array<index_t, 2>{get_warp_id() % KLoadNumWarps, get_lane_id()};
const auto v_part = ck_tile::array<index_t, 2>{get_warp_id() % VLoadNumWarps, get_lane_id()};
const auto k_thread_coord = k_dist.calculate_index(k_part);
const auto v_thread_coord = v_dist.calculate_index(v_part);
const index_t k_thread_n_pos = k_thread_coord[number<0>{}];
const index_t v_thread_n_pos = v_thread_coord[number<0>{}];
@@ -1456,12 +1488,24 @@ struct UnifiedAttentionPipeline
// num_blocks_start.
const index_t num_iters_per_split = num_total_loop - num_blocks_start;
auto K_mem_load = [&](auto k_lds_write_idx) {
if(cache_ptr_int32_overflow_possible)
async_load_tile_raw_long(k_lds_window_store(k_lds_write_idx), k_dram_window);
else
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
// FA4 "WG1 loads K": only warp group 1's waves issue the K async load
// (its KLoadNumWarps==4 layout fills the full shared K tile). WG0
// skips it and reads K from shared LDS (barrier-synchronized).
if(k_load_active)
{
if(cache_ptr_int32_overflow_possible)
async_load_tile_raw_long(k_lds_window_store(k_lds_write_idx), k_dram_window);
else
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
}
k_block_idx++;
if(k_block_idx < num_iters_per_split)
// Only the K-loading warp group needs K offsets refreshed: with
// kFA4WG1LoadsK, WG0 never issues a K load, so computing its page
// offsets (incl. the block-table page-index ds_read) is pure waste.
// k_block_idx itself stays uniform across all waves so loop control
// and buffer parity never diverge. Gating here also means each wave
// fetches a page-table index for exactly ONE tile (K *or* V), not both.
if(k_load_active && k_block_idx < num_iters_per_split)
{
refresh_k_offsets(k_block_idx);
if constexpr(kRebaseKSrd)
@@ -1474,12 +1518,22 @@ struct UnifiedAttentionPipeline
};
auto V_mem_load = [&](auto v_lds_write_idx) {
if(cache_ptr_int32_overflow_possible)
async_load_tile_raw_long(v_lds_window_store(v_lds_write_idx), v_dram_window);
else
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
// FA4 "WG0 loads V": only warp group 0's waves issue the V async
// load (its VLoadNumWarps==4 layout fills the full shared V tile).
// WG1 skips the load; bookkeeping (v_block_idx / offsets) stays
// uniform across all waves so the loop's scalar state never diverges.
if(v_load_active)
{
if(cache_ptr_int32_overflow_possible)
async_load_tile_raw_long(v_lds_window_store(v_lds_write_idx), v_dram_window);
else
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
}
v_block_idx++;
if(v_block_idx < num_iters_per_split)
// Symmetric to K: only the V-loading warp group (WG0) refreshes V
// offsets; WG1 skips it (it never issues a V load). v_block_idx stays
// uniform for loop/buffer bookkeeping.
if(v_load_active && v_block_idx < num_iters_per_split)
{
refresh_v_offsets(v_block_idx);
if constexpr(kRebaseVSrd)
@@ -2165,41 +2219,23 @@ struct UnifiedAttentionPipeline
auto gemm1 = number<1>{};
// MATRIX phase: deferred PV(k-1) then QK(k). Pure matrix pipe.
// Consumes V(1-pi) / K(pi) resident in LDS; the union kv_tile holds
// v_tile for the PV then is overwritten with k_tile for the QK
// (V_lds → gemm1 → K_lds → gemm0 ordering).
// V-read hoist: issue the PV gemm's LDS→register read (v_rd == pi)
// EARLY — before the matrix phase's compute — so its ~LDS-latency
// overlaps the address-calc VALU of prefetch() (WG0) / the barrier
// exit (WG1) instead of being exposed right before the PV MFMA. The
// V buffer pi was populated by a prior prefetch and already waited
// on (vmcnt<0> at phase entry); prefetch writes K-buf[pi]/V-buf[1-pi]
// so there is no aliasing with this read.
// V-read hoist: issue the PV gemm's LDS→register read (v_rd == pi)
// EARLY — before the matrix phase's compute — so its ~LDS-latency
// overlaps the address-calc VALU of prefetch() (WG0) / the barrier
// exit (WG1) instead of being exposed right before the PV MFMA. The
// V buffer pi was populated by a prior prefetch and already waited
// on (vmcnt<0> at phase entry); prefetch writes K-buf[pi]/V-buf[1-pi]
// so there is no aliasing with this read.
// Consumes V(pi) / K(1-pi) resident in LDS; kv_tile holds v_tile for
// the PV and (separately) k_tile for the QK.
//
// NOTE: do NOT also hoist K_lds_load here. The QK gemm reads K-buf
// [1-pi], and in the WG1 softmax-first prologue that buffer is not
// yet guaranteed resident this early (its async load completes a
// phase later) — hoisting K corrupts long-context runs. K stays
// issued between the PV and QK MFMAs (its latency hides under PV).
// V-read hoist: issue the PV gemm's LDS→register read (v_rd == pi)
// EARLY — before the matrix phase's compute — so its ~LDS-latency
// overlaps the address-calc VALU of prefetch() (WG0) / the barrier
// exit (WG1) instead of being exposed right before the PV MFMA.
// V-read into SOFTMAX (Stage B): the PV gemm's V tile (v_rd == pi) is
// now read in the *preceding* SOFTMAX phase rather than at the top of
// this MATRIX phase, so its ~LDS latency overlaps the full softmax
// VALU (exp / rowsum / P-cvt) instead of only the prefetch address
// calc. This is safe now that V is loaded by a single warp group
// (kFA4WG0LoadsV): WG0 reads V it loaded itself, so its own vmcnt<0>
// proves residency (no partner dependency); WG1 reads an already-
// barrier-published V buffer. The pre-read lands in v_tile and this
// MATRIX phase consumes it directly (see fa4_softmax / the WG0 prime).
//
// NOTE: do NOT also hoist K_lds_load. K/V are loaded COOPERATIVELY
// by both warp groups; a wave's vmcnt only drains its OWN async
// loads, not the partner group's half, so a cooperatively-filled
// buffer is reliably resident only deeper into the phase. The PV
// gemm provides exactly that slack for the K read — moving K ahead
// of PV races the partner's load completion and corrupts long
// contexts. K stays issued between the PV and QK MFMAs.
// NOTE: do NOT hoist K_lds_load the same way. The QK gemm reads K-buf
// [1-pi] which the partner group (WG1) loads; WG0 has no own-vmcnt
// proof of its residency this early, only the barrier deeper in the
// phase. K stays issued between the PV and QK MFMAs (latency under PV).
auto fa4_vload = [&](auto pi) { V_lds_load(pi); };
auto fa4_matrix = [&](auto pi) {
@@ -2207,14 +2243,15 @@ struct UnifiedAttentionPipeline
auto qk_sp = number<1>{} - pi; // QK target slot
auto k_rd = number<1>{} - pi;
s_waitcnt_lgkmcnt<0>(); // wait the hoisted fa4_vload(pi)
s_waitcnt_lgkmcnt<0>(); // wait the V pre-read issued in prev SOFTMAX
gemm(pv_sp, gemm1); // o_acc += P(pi) @ V(k-1)
// K read into its OWN registers (kv_tile no longer a union), so
// this ds_read executes on the LSU *during* the PV MFMA above
// instead of waiting for it to retire. The sched_barrier pins it
// here (program order) so it is NOT hoisted above the PV gemm --
// that would race the partner WG's cooperative K load and
// corrupt long contexts (the residency hazard documented above).
// K read into its OWN registers (k_tile no longer aliases v_tile),
// so this ds_read executes on the LSU *during* the PV MFMA above
// rather than waiting for it to retire; the sched_barriers pin it
// here. K is now single-warp-group loaded (kFA4WG1LoadsK) so it is
// resident at the slot-A barrier, but issuing the read AFTER the
// PV gemm call (overlapping the in-flight MFMA) schedules strictly
// better than hoisting it ahead of PV — measured ~3-4% faster.
__builtin_amdgcn_sched_barrier(0);
K_lds_load(k_rd); // overlaps the PV MFMA (latency hidden)
__builtin_amdgcn_sched_barrier(0);
@@ -2256,20 +2293,32 @@ struct UnifiedAttentionPipeline
if constexpr(cl_p == 0)
{
// ---- slot A: MATRIX(pi) ‖ (WG1: SOFTMAX) ----
// V tile (buf pi) was pre-read into v_tile in the previous
// SOFTMAX phase (or the WG0 prime for the first tile).
ASM_MARKER("fa4 MATRIX Wave0-3");
s_waitcnt_vmcnt<0>();
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
fa4_vload(pi); // hoisted V read; latency hidden under prefetch
prefetch();
fa4_matrix(pi);
// ---- slot B: SOFTMAX(pi) ‖ (WG1: MATRIX) ----
// Pre-read the next MATRIX's V tile (buf 1-pi == the buffer
// this iteration's prefetch just filled), overlapping the
// softmax VALU below; v_tile survives into the next slot-A
// MATRIX (PV consumes it via lgkmcnt<0>). The V buffer is
// filled cooperatively by WG0's 4 waves, so a wave reads
// slices written by its peers: drain the load (vmcnt<0>) and
// then cross the phase barrier so all 4 waves' writes are
// published BEFORE the read (drain-before-barrier; reading
// after only an own-vmcnt races the peers' slices).
ASM_MARKER("fa4 SOFTMAX Wave0-3");
s_waitcnt_vmcnt<0>();
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
fa4_vload(number<1>{} - pi);
fa4_softmax(pi);
if(num_total_loop <= ++i_total_loops)
@@ -2281,20 +2330,28 @@ struct UnifiedAttentionPipeline
// WG1 is one phase ahead (primed by the FA4 prologue): it
// softmaxes the tile it QK'd in its previous MATRIX phase
// while WG0 runs the MATRIX of the same tile.
//
// Pre-read this iteration's slot-B MATRIX V tile (buf pi).
// That buffer was filled by WG0 and already drained+published
// by WG0's drain-before-barrier in its prior SOFTMAX slot, so
// the slot-A barrier just crossed guarantees all 4 writer
// waves' slices are visible. The read overlaps the softmax
// VALU below; v_tile survives into the slot-B MATRIX.
ASM_MARKER("fa4 SOFTMAX Wave4-7");
s_waitcnt_vmcnt<0>();
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
fa4_vload(pi);
prefetch();
fa4_softmax(number<1>{} - pi);
// ---- slot B: MATRIX(pi) ‖ (WG0: SOFTMAX) ----
// v_tile holds buf pi from the slot-A pre-read above.
ASM_MARKER("fa4 MATRIX Wave4-7");
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
fa4_vload(pi); // hoisted V read (overlaps barrier exit)
fa4_matrix(pi);
if(num_total_loop <= ++i_total_loops)
@@ -2508,6 +2565,11 @@ struct UnifiedAttentionPipeline
fmha_alu0(number<0>{});
fmha_alu_D_upd();
fmha_alu1(number<0>{}); // sp(0).p = P(0)
// Prime v_tile for the first MATRIX(0): V buf 0 was loaded by
// WG0 in the pre-stage, so its own vmcnt<0> proves residency.
// (Stage B reads each subsequent tile's V in the prior SOFTMAX.)
s_waitcnt_vmcnt<0>();
V_lds_load(number<0>{});
while(core_loop_fa4(number<0>{}))
;
}

View File

@@ -141,7 +141,11 @@ struct UnifiedAttentionPipelineDefaultPolicy
return 16 / sizeof(VDataType);
}
template <typename Problem>
// NumWarpsOverride mirrors MakeVDramTileDistribution: the FA4 "WG1 loads K"
// path passes NumThreadPerWarpGroup/WarpSize so warp group 1's waves alone
// tile the full K buffer (the partner group reads it from shared LDS).
template <typename Problem,
ck_tile::index_t NumWarpsOverride = Problem::UnifiedAttentionShape::NumWarps>
CK_TILE_DEVICE static constexpr auto MakeKDramTileDistribution()
{
using namespace ck_tile;
@@ -149,7 +153,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t NumWarps = NumWarpsOverride;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
@@ -158,7 +162,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector));
static_cast<void>(kBlockSize);
constexpr index_t N0 = NumIssues;
constexpr index_t N1 = LaneGroups;
@@ -175,7 +180,13 @@ struct UnifiedAttentionPipelineDefaultPolicy
sequence<0, 1>>{});
}
template <typename Problem>
// NumWarpsOverride lets the FA4 per-warp-group ("private V") path request a
// distribution where only NumWarps waves cooperate on the load (so each
// warp group loads the FULL V tile by itself, into its own LDS buffer, and
// its own vmcnt proves residency without waiting on the partner group).
// Default = the shape's NumWarps (the original block-cooperative load).
template <typename Problem,
ck_tile::index_t NumWarpsOverride = Problem::UnifiedAttentionShape::NumWarps>
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
{
using namespace ck_tile;
@@ -183,7 +194,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t NumWarps = NumWarpsOverride;
constexpr index_t WarpSize = ck_tile::get_warp_size(); // 64
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
@@ -193,7 +204,10 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
// NumWarps-relative form (NumWarps may be < the full block when the FA4
// per-warp-group path requests a private-V distribution).
static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector));
static_cast<void>(kBlockSize);
constexpr index_t N0 = NumIssues; // 8
constexpr index_t N1 = LaneGroups; // 2
@@ -378,7 +392,19 @@ struct UnifiedAttentionPipelineDefaultPolicy
static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords
static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords
template <typename Problem, ck_tile::index_t IBuf = 0>
// WarpIdShift handles a sub-block load issued by a NON-zero warp group via
// the raw async path. The raw store derives its LDS offset as
// M0 = base + size_per_wave * get_warp_id() (ABSOLUTE warp id 0..7)
// so a NumWarps-wide (e.g. 4-wave) layout only tiles correctly for warp ids
// 0..NumWarps-1. When warp group g (>0) alone fills the tile, its waves have
// absolute ids [g*NumWarps, (g+1)*NumWarps); shifting the descriptor base by
// -WarpIdShift*size_per_wave (WarpIdShift = g*NumWarps) maps them back to
// effective ids 0..NumWarps-1, i.e. the exact physical layout a warp-group-0
// load would produce -- so the (unshifted) read descriptor reads it directly.
template <typename Problem,
ck_tile::index_t NumWarpsOverride = Problem::UnifiedAttentionShape::NumWarps,
ck_tile::index_t WarpIdShift = 0,
ck_tile::index_t IBuf = 0>
CK_TILE_DEVICE static constexpr auto
MakeKLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
{
@@ -388,7 +414,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t NumWarps = NumWarpsOverride;
constexpr index_t WarpSize = ck_tile::get_warp_size();
[[maybe_unused]] constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
@@ -405,7 +431,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
WarpSize /
LanesPerK; // how many groups (within a wave), they may load different N, but same K
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector));
static_cast<void>(kBlockSize);
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
@@ -418,7 +445,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<IBuf * GetSingleSmemElementSpaceSize<Problem>() -
WarpIdShift*(WarpSize * KVector + kPad)>{},
number<KVector>{},
number<1>{});
@@ -436,7 +464,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
return k_lds_block_desc_issues_warps_lanes;
}
template <typename Problem>
template <typename Problem,
ck_tile::index_t NumWarpsOverride = Problem::UnifiedAttentionShape::NumWarps>
CK_TILE_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor()
{
using namespace ck_tile;
@@ -445,7 +474,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t NumWarps = NumWarpsOverride;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
@@ -458,7 +487,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector));
static_cast<void>(kBlockSize);
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
@@ -553,7 +583,12 @@ struct UnifiedAttentionPipelineDefaultPolicy
return max(max(SingleKSize, SingleVSize), VLoadDescSize);
}
template <typename Problem, ck_tile::index_t IBuf = 0>
// NumWarpsOverride mirrors MakeVDramTileDistribution: the FA4 "WG0 loads V"
// path passes NumThreadPerWarpGroup/WarpSize (== 4) so warp group 0's waves
// alone tile the full V buffer. Default = the shape's NumWarps (cooperative).
template <typename Problem,
ck_tile::index_t NumWarpsOverride = Problem::UnifiedAttentionShape::NumWarps,
ck_tile::index_t IBuf = 0>
CK_TILE_DEVICE static constexpr auto
MakeVLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
{
@@ -563,7 +598,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t NumWarps = NumWarpsOverride;
constexpr index_t WarpSize = ck_tile::get_warp_size();
[[maybe_unused]] constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
@@ -580,7 +615,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
WarpSize /
LanesPerK; // how many groups (within a wave), they may load different N, but same K
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector));
static_cast<void>(kBlockSize);
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
@@ -611,7 +647,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
return v_lds_block_desc_issues_warps_lanes;
}
template <typename Problem>
template <typename Problem,
ck_tile::index_t NumWarpsOverride = Problem::UnifiedAttentionShape::NumWarps>
CK_TILE_DEVICE static constexpr auto MakeVLdsLoadBlockDescriptor()
{
using namespace ck_tile;
@@ -620,7 +657,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t NumWarps = NumWarpsOverride;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
@@ -633,7 +670,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector));
static_cast<void>(kBlockSize);
constexpr auto v_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
@@ -688,6 +726,63 @@ struct UnifiedAttentionPipelineDefaultPolicy
return kv_element_space_size_in_bytes;
}
// FA4 "WG0 loads V" prototype: when the block runs as two warp groups, have
// ONLY warp group 0 (waves 0-3) load the full V tile into the shared V LDS
// buffer (V's DRAM dist + LDS descriptors use NumThreadPerWarpGroup/WarpSize
// == 4 waves so WG0 alone fills the tile). WG1 skips the V DRAM load
// entirely. No 2x DRAM, no extra LDS (V stays a shared 2-buffer). This
// decouples V's residency from the partner group's cooperative-load shard
// (WG0's own vmcnt proves the load) so the V LDS read can later move into
// the SOFTMAX phase. K stays block-cooperative across all 8 waves.
// Toggle to false to restore the block-cooperative (8-wave) V load.
static constexpr bool kFA4WG0LoadsV = true;
// Symmetric K decoupling: warp group 1 (waves 4-7) alone loads the full K
// tile into the shared K LDS buffer; warp group 0 reads it from shared LDS.
// Together with kFA4WG0LoadsV this balances DRAM-load work (WG0->V, WG1->K)
// and lets each group issue only one tile's load/address instructions.
static constexpr bool kFA4WG1LoadsK = true;
// Number of waves that cooperate on a V DRAM->LDS load. For the 2-warp-group
// FA4 path with kFA4WG0LoadsV, this is one warp group's waves (so WG0 alone
// fills the tile); otherwise it's the full block (original cooperative load).
template <typename Problem>
CK_TILE_DEVICE static constexpr ck_tile::index_t GetVLoadNumWarps()
{
constexpr ck_tile::index_t NumWarpGroups =
Problem::kBlockSize / NumThreadPerWarpGroup;
if constexpr(kFA4WG0LoadsV && NumWarpGroups == 2)
return NumThreadPerWarpGroup / ck_tile::get_warp_size();
else
return Problem::UnifiedAttentionShape::NumWarps;
}
// K analogue of GetVLoadNumWarps (warp group 1 alone fills the K tile).
template <typename Problem>
CK_TILE_DEVICE static constexpr ck_tile::index_t GetKLoadNumWarps()
{
constexpr ck_tile::index_t NumWarpGroups =
Problem::kBlockSize / NumThreadPerWarpGroup;
if constexpr(kFA4WG1LoadsK && NumWarpGroups == 2)
return NumThreadPerWarpGroup / ck_tile::get_warp_size();
else
return Problem::UnifiedAttentionShape::NumWarps;
}
// Raw-async warp-id shift for the K store (see MakeKLdsStoreBlockDescriptor):
// K is loaded by warp group 1, whose absolute warp ids start at one warp
// group's worth of waves, so the store base must shift by that many waves.
template <typename Problem>
CK_TILE_DEVICE static constexpr ck_tile::index_t GetKStoreWarpShift()
{
constexpr ck_tile::index_t NumWarpGroups =
Problem::kBlockSize / NumThreadPerWarpGroup;
if constexpr(kFA4WG1LoadsK && NumWarpGroups == 2)
return NumThreadPerWarpGroup / ck_tile::get_warp_size(); // WG1's first abs warp id
else
return 0;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{