CK-UA: freeze docs + comment cleanup (+ gated decode-ring scaffolding)

- Add an architecture README for the unified_attention kernel folder: file
  map, per-CTA work assignment, online-softmax math + scale fusion, FA4 vs
  serial-decode regimes, paged-KV tiers, split-KV, and a tuning-knobs/failed-
  experiments table. Intended as the reference for the FlyDSL port.
- pipeline.hpp: condense the ~230-line experiment-macro header into terse,
  README-backed one-liners (all 13 macro definitions preserved bit-for-bit).
- kernel.hpp: merge the duplicated/contradictory "Step D" SWA-clip comment.
- Gated multi-stage decode async-ring scaffolding (UA_DECODE_STAGES, default
  2 = bit-identical; deeper depth measured perf-neutral, decode is BW-bound).

Full matrix 263/263 PASS, 0 fail; comment-only kernel edits are
behavior-neutral (target fp8 decode shape unchanged at ~88us).

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-06-17 09:14:44 +00:00
parent be398c224f
commit 382bb198eb
4 changed files with 389 additions and 171 deletions

View File

@@ -0,0 +1,168 @@
# CK Tile — Unified Attention kernel
FlashAttention-style fused attention forward for **paged and contiguous KV**,
covering **prefill and decode** in one operator, for **fp8 / bf16 / fp16** on
gfx950 (MI350/MI355). This is the kernel behind aiter's `unified_attention` op;
the test/perf/analysis tooling lives in `aiter/ua-test-scripts/` (see its README).
This document is the architecture reference for reading the kernel and for
porting it (e.g. to FlyDSL). The functional contract mirrors Triton's
`unified_attention` reference.
---
## File map
| file | role |
|---|---|
| `kernel/unified_attention_kernel.hpp` | Host `MakeKargs` / grid helpers + the device `operator()`: per-CTA index math (grid → kv_head/seq/q_block/split), Q/K/V/O DRAM window setup, mask construction, split-KV partitioning, the call into the pipeline, and the epilogue / split-KV workspace writes. |
| `pipeline/unified_attention_pipeline.hpp` | The core per-CTA attention pipeline (the hot loop). Two regimes share one file: the **FA4** 2-warp-group matrix‖softmax overlap (prefill) and the **single-warp-group serial deferred-PV** pipeline (decode). |
| `pipeline/unified_attention_pipeline_default_policy.hpp` | Compile-time policy: tile distributions, LDS descriptors, warp-gemm selection, load-width/alignment selection, smem sizing, async-ring depth, and tuning constants. |
| `pipeline/unified_attention_pipeline_problem.hpp` | Bundles dtypes + shape + traits + mask into the `Problem` type. |
| `pipeline/tile_unified_attention_shape.hpp` | Block/warp tile dims: `kBlockM`, `kBlockQ`, `kPageBlockSize`, `kHeadDim`. |
| `pipeline/tile_unified_attention_traits.hpp` | Padding + occupancy (`kBlockPerCu`) traits. |
| `pipeline/unified_attention_core_loop_scheduler.hpp` | Per-phase `sched_group_barrier` instruction-mix hints that enforce the FA4 phase overlap (kept in lock-step with the macros in the pipeline header). |
| `block/block_masking.hpp` | Causal / sliding-window (FA-style left/right) mask. |
Concrete shape/dtype **instances** (the JIT-compiled translation units) live in
`composable_kernel/example/ck_tile/42_unified_attention/instances/`, dispatched by
`unified_attention.cpp` there.
---
## Per-CTA work assignment (`kernel` `operator()`)
One CTA computes one `(kv_head, q_block, split)` tuple:
- **Decode grid** `dim3(num_kv_heads, num_seqs, num_splits)` (`gridDim.y > 1`):
direct mapping, no binary search, no padding CTAs.
- **Prefill grid** `dim3(num_kv_heads * total_num_q_blocks, 1, num_splits)`:
`blockIdx.x` is folded; a binary search over `query_start_len_ptr` recovers the
sequence, and out-of-range q-blocks early-return.
`num_queries_per_kv` (GQA ratio) is a **runtime** value: `kBlockQ_dyn = kBlockM /
num_queries_per_kv`, so one compiled binary serves MHA and any GQA-N that divides
`kBlockM`. The `kBlockM`-row MFMA tile packs `num_queries_per_kv` consecutive
query rows per KV token; when `kBlockM % num_queries_per_kv != 0` the last 12
rows spill into the next q-tile's first token (co-owned), which drives several
correctness invariants in the index math — see the comments around
`last_tile_row_q_off` and the split-KV partition.
---
## Math (online softmax)
Standard streaming softmax with running max `m`, running sum `l`, output
accumulator `o_acc`, over KV tiles of `kPageBlockSize` tokens:
- Scores `S = scale_s · (Q·Kᵀ)`, masked, then `P = exp2(S m)` (base-2).
- `scale_s` is pre-fused on the host (`MakeKargs`): `sm_scale · q_descale ·
k_descale · log2(e)`. The `log2(e)` lets the device use full/þrate `exp2`
instead of `exp`; the fp8 Q/K per-tensor descales fold in so the inner loop
carries a single scalar (matches Triton's `qk_scale = sm_scale·q_scale·k_scale`).
- The fp8 V per-tensor descale `v_descale` is **deferred** to the post-loop
`o_acc · v_descale / l` step — exact, since V is a linear factor on the
unnormalised output. Non-fp8 dtypes pass `1.0f` (free no-op).
- Output is `o_acc / l`; `lse = m + log(l)` (natural-log domain) is returned for
split-KV combine.
**PV is deferred one tile** (double-buffered on parity): the sequence per tile is
`alu1/pack(prev) → PV(prev) → QK(cur) → alu0/rowmax(cur) → D_upd/rescale`. This
is the known-correct ordering shared by both regimes.
---
## Two pipeline regimes
### FA4 (prefill, 2 warp groups, `NumWarpGroups == 2`)
FlashAttention-4-style overlap. The deferred-PV sequence is cut into two phases:
- **MATRIX** phase: `PV(k-1) + QK(k)` — matrix pipe only.
- **SOFTMAX** phase: `alu1/exp + alu0/rowmax + D_upd/rescale` — VALU/MUFU only.
The two warp groups are primed one phase apart (WG0 in MATRIX while WG1 in
SOFTMAX), so on each SIMD the matrix work of one wave hides the transcendental
work of its co-resident partner. K/V are prefetched a tile ahead into a shared
double buffer at the per-phase block barrier (issued cooperatively by all 8
warps). The `core_loop_scheduler` hints reserve the per-phase instruction mix.
For fp8 the QK-C and PV-A per-thread layouts diverge (PV is forced to
`WGAttrNumAccess::Single`), so after packing, `P` is round-tripped through an LDS
window in canonical (M,N) order and reloaded in the PV-A distribution.
### Serial (decode, single warp group, `kFA4 == false`)
The same deferred-PV pipeline run serially by one 4-warp group, with K/V
double-buffered in LDS. Decode is **HBM-bandwidth bound** at long context; see
`aiter/ua-test-scripts/decode_pipeline_*.md`.
---
## Paged KV
When `kIsPaged`, KV tokens are resolved through `block_tables` (per-sequence page
lists). Performance hinges on keeping page-index resolution off the critical
path, via tiers selected at compile time:
- **Constexpr page size** (`kPageSize_ > 0`, the `ps16/ps32/ps64/ps128`
instances): strength-reduces every `/ % * page_size` to shifts and enables the
exact tier gates below. `kPageSize_ == 0` is the runtime-page-size catch-all.
- **Scalar-promote / single-page SRD rebase**: when a wave's load lands within one
page, fold the page base into the buffer SRD once per wave and drop the per-lane
block-table path.
- **Tier-2 LDS-resident page-table cache** (`kPageTableLdsEntries = 4096`, 16 KiB):
resolves the multi-page fallback's page indices from LDS instead of per-lane
global reads. Coverage `≤ 4096 · page_size` tokens; beyond that the kernel
traps (a runtime fallback was measured 30% from register pressure).
When `!kIsPaged` (contiguous/THD) the logical KV token index *is* its physical
row (per-sequence base folded into the K/V pointer), so all paging math compiles
out.
## Split-KV
`num_splits > 1` partitions the KV range across CTAs in `blockIdx.z`; each writes
fp32 `o_acc`/`lse` workspaces that a separate combine kernel merges. The
partition is computed over the **causal-independent full-sequence** block count
(not the per-tile causal horizon) so a token co-owned by adjacent q-tiles maps to
the same split in both — otherwise non-dividing-GQA + causal races on the shared
token. `num_splits == 1` skips this path entirely.
---
## Tuning knobs & failed experiments
Active defaults (do not change without re-validating correctness + perf):
| macro / policy | default | effect |
|---|---|---|
| `CONDITIONAL_RESCALE` | `1` | FA4-only: carry accumulators in a committed-max frame and skip the online rescale while shifted scores stay ≤ `τ` (`CONDITIONAL_RESCALE_TAU = 8`). Mathematically exact. |
| `UA_FA4_PREFETCH_IN_SOFTMAX` | `1` | bf16/fp16: issue next-tile K/V async prefetch from the SOFTMAX phase (keeps MATRIX pure-matrix). |
| `GetKVAlignmentBytes` | dwordx4 where it tiles | Widen fp8 decode K/V async loads to 16 B/lane (the narrow 4 B/lane default was the main fp8-slower-than-bf16 decode regression). |
| `kKFallbackLds` (`UA_K_FALLBACK_LDS`) | `1` | Resolve the multi-page K fallback (ps16/ps32) page indices via the LDS cache instead of per-lane global reads. |
| `UA_DECODE_STAGES` | `2` | Decode async-ring depth (deeper buffering was measured perf-neutral; decode is BW-bound). |
Experiments kept **OFF** (measured losers — retained as one-line gates so the
rationale isn't relitigated): `UA_FA4_PACKED_SHIFT`, `UA_FA4_PACKED_ALU1_RESCALE`
(together ~3% on canonical fp8 prefill — softmax is hidden under the overlap),
`UA_FA4_PACKED_ROWSUM` (13%, serial chain beats the log-depth tree),
`UA_DYNAMIC_SETPRIO`, `MOVE_FMHA_MASK_TO_COMPUTE` (fp8 +8.8% regression),
`MOVE_FMHA_MASK_TO_GEMM1`, `UA_FA4_PIN_PACK_IN_SOFTMAX`. `UA_FA4_EXP2_APPROX`
(Schraudolph 2^x) is an *approximation* — off by default, validate accuracy first.
---
## Building & testing
The kernel is consumed via aiter's JIT module `module_unified_attention`. After
editing any file here you **must** force a fresh build — the `.so` is not rebuilt
automatically:
```bash
# from the aiter repo root
rm -rf aiter/jit/build/module_unified_attention aiter/jit/module_unified_attention.so
AITER_REBUILD=1 HIP_VISIBLE_DEVICES=2 python3 op_tests/test_unified_attention_ck.py --full
```
See `aiter/ua-test-scripts/README.md` for correctness/perf, and
`aiter/ua-test-scripts/analysis/README.md` for ISA/VGPR/overlap-trace tooling
(including the JIT-free standalone driver, which stamps every build so a stale
binary can never be measured).

View File

@@ -660,27 +660,16 @@ struct UnifiedAttentionKernel
return FmhaMask{cur_batch_query_len, seq_len};
}();
// Step D: Sliding-Window-Attention tile-range clip.
// The per-pixel mask check inside the pipeline already returns the
// correct (zeroed) score for tokens outside the SWA window, so
// skipping this block is correctness-preserving. The point of the
// clip is to skip entire KV tiles that fall completely outside the
// window — for long-context decode that's the difference between
// O(seq_k / kPageBlockSize) and O(window / kPageBlockSize)
// iterations. The intersection with the current split's
// [num_blocks_start, num_blocks) is taken so split-KV stays correct.
// Step D: Sliding-Window-Attention KV-tile clip.
//
// This is REQUIRED for correctness, not just an optimisation. The
// online-softmax pipeline interleaves `m` / `l` updates with prefetch
// and warp-group barriers; an all-(-inf) tile (one wholly outside the
// SWA window) feeds NaN/garbage into the `m` accumulator at the
// barrier boundary, corrupting subsequent tiles. Skipping these
// tiles entirely keeps every iterated tile either fully-inside the
// window or a true edge tile that the per-pixel mask can clean up.
//
// The intersection with the current split's [num_blocks_start,
// num_blocks) is taken so split-KV stays correct.
// Step D: Sliding-Window-Attention KV-tile clip — skip entire KV tiles
// that fall completely outside the window. REQUIRED for correctness, not
// just an optimisation: the online-softmax pipeline interleaves m/l
// updates with prefetch and warp-group barriers, so an all-(-inf) tile
// (wholly outside the window) would feed NaN/garbage into the m
// accumulator at a barrier boundary and corrupt later tiles. For
// long-context decode it is also the difference between
// O(seq_k/kPageBlockSize) and O(window/kPageBlockSize) iterations. The
// intersection with the current split's [num_blocks_start, num_blocks)
// keeps split-KV correct.
if constexpr(FmhaMask::IsMasking && FmhaMask::IsLocal)
{
const auto sw_range = mask.GetTileRangeAlongX(

View File

@@ -3,180 +3,94 @@
#pragma once
// EXPERIMENT: move the FA4 V LDS transpose-read out of the SOFTMAX phase and
// into the MATRIX phase (right before its PV consumer). The default ("Stage B")
// issues the V ds_read in the *preceding* softmax to hide its latency under the
// softmax VALU — but the ATT trace shows that read stalls ~69% and sits on the
// (longer/critical) softmax phase. Moving it to MATRIX, which has barrier slack,
// takes it off the critical path. =1 to enable.
// ===========================================================================
// Tuning knobs & experiment toggles.
// See the "Tuning knobs & failed experiments" table in this folder's README.md
// for the rationale and measured results behind each default. These MUST be set
// BEFORE including unified_attention_core_loop_scheduler.hpp: that header's
// per-phase __builtin_amdgcn_sched_group_barrier hints are gated on the same
// macros and must stay in lock-step with the code motion in this file.
// ===========================================================================
// OFF (measured loser): pin the fp32->fp8 P-pack tail of fmha_alu1 to the
// SOFTMAX phase instead of letting it sink into the next MATRIX slot.
#ifndef UA_FA4_PIN_PACK_IN_SOFTMAX
// Experiment (option 3): fence the fp32->fp8 P-pack (cvt_pk_fp8 tail of
// fmha_alu1) so it retires inside the SOFTMAX phase instead of sinking past the
// matrix barrier into the next MATRIX slot. In the MATRIX slot the pack (a VALU
// op) contends with the co-resident warp group's softmax VALU (the v_max3
// rowmax tree) on the shared SIMD issue port; pinning it to SOFTMAX trades that
// cross-wave contention for in-phase exposure on the (longer) softmax phase.
#define UA_FA4_PIN_PACK_IN_SOFTMAX 0
#endif
// UA_FA4_PREFETCH_IN_SOFTMAX: issue the next-tile K/V async DRAM prefetch from
// the SOFTMAX phase instead of the MATRIX phase. Gated to the 2-byte (bf16/fp16)
// path in the loop (fp8 keeps the matrix-phase prefetch it was tuned for). The
// bf16 prefetch is double the VMEM bytes and its buffer_load-to-LDS issue was
// landing in the (lgkm-stalled) MATRIX phase; moving the *issue* to the VALU-
// bound SOFTMAX phase keeps MATRIX pure-matrix. Residency is unchanged: the next
// MATRIX still drains the load via s_waitcnt_vmcnt<0> + block barrier before any
// K/V LDS read, so this only moves WHERE the async load is kicked off.
// ON (bf16/fp16 only): issue the next-tile K/V async prefetch from the SOFTMAX
// phase so the MATRIX phase stays pure-matrix. Residency is unchanged (the next
// MATRIX still drains the load before any K/V LDS read); this only moves WHERE
// the load is kicked off. fp8 keeps the matrix-phase prefetch it was tuned for.
#ifndef UA_FA4_PREFETCH_IN_SOFTMAX
#define UA_FA4_PREFETCH_IN_SOFTMAX 1
#endif
// FMHA_MASK PLACEMENT: pick exactly one of:
// - both 0 → baseline (mask in K-side memory phase, W0-3 phase 1
// / W4-7 phase 2, right after `cl_load(memK)`).
// - MOVE_FMHA_MASK_TO_COMPUTE=1: hoist mask onto the compute phase
// (W0-3 phase 0 / W4-7 phase 1), right after `fmha_alu1`.
// Experiment 1.5 finding: bf16 0.33%, **fp8 +8.8% regression**
// because the FP8 cvt+bperm cluster inside `fmha_alu1` makes the
// compute phase already-saturated; adding T_mask oversubscribes
// it and the empirical cost is ~2× the bare instruction count.
// - MOVE_FMHA_MASK_TO_GEMM1=1: place mask at the START of the
// gemm1 phase (W0-3 phase 2 / W4-7 phase 3), right before
// `cl_calc(xdl_SP_p23_reg_idx, gemm1)`. This is the latest legal
// placement: `cl_calc(p23, gemm1)` ends with `fmha_alu0(p01_idx)`
// which READS `sp[p01_idx].sp_compute` to compute `m_latest`, so
// mask MUST run before that. Phase 3 (V-mem on W0-3, gemm1 on
// W4-7) is too late and silently corrupts the row-max.
//
// For W4-7 the `++i_total_loops` also defers from end of phase 2
// to start of phase 3 (after mask, before cl_calc) so mask sees
// the same i_total_loops value as gemm0 of this iter.
//
// Per-barrier algebra (mask added to gemm1 phase = T_D on both
// warp groups, removed from K-mem = T_K on both):
// - B1 wait = |T_C (T_D + T_mask)|. With baseline T_C > T_D
// on FP8, the gap closes — DECREASES by T_mask.
// - B2 wait = |(T_K T_mask) T_C| — DECREASES by T_mask.
// - B3 wait = |(T_D + T_mask) (T_K T_mask)|
// = |T_D T_K + 2·T_mask| — DECREASES by 2·T_mask.
// - Net: 4·T_mask total wait (vs 2·T_mask for compute), and
// gemm1 phase has no FP8 cvt+bperm so it should absorb the
// mask without the FP8 oversubscription that hit compute.
//
// Must be defined BEFORE including unified_attention_core_loop_scheduler.hpp
// — that header's `__builtin_amdgcn_sched_group_barrier` per-phase
// hints are gated on these macros and need to stay in lockstep with
// the code motion in this file.
// OFF (measured losers): alternative placements of the FMHA mask within the FA4
// phases. Default (both 0) = mask in the K-side memory phase. _TO_COMPUTE
// oversubscribes the fp8 compute phase (+8.8% regression); _TO_GEMM1 is the
// latest legal placement (must precede the fmha_alu0 that reads sp_compute) but
// did not win. At most one may be 1.
#define MOVE_FMHA_MASK_TO_COMPUTE 0
#define MOVE_FMHA_MASK_TO_GEMM1 0
#if MOVE_FMHA_MASK_TO_COMPUTE && MOVE_FMHA_MASK_TO_GEMM1
#error "MOVE_FMHA_MASK_TO_COMPUTE and MOVE_FMHA_MASK_TO_GEMM1 are mutually exclusive"
#endif
// UA_DYNAMIC_SETPRIO (warp-group-balance plan A2, HipKittens-style)
// 0 (default): static per-warp-group priority, set once at loop entry
// (W0-3 → s_setprio(0), W4-7 → s_setprio(1)). Baseline, bit-identical.
// 1: dynamic priority around the gemm MFMA cluster. `cl_calc` raises
// s_setprio(1) for the duration of the gemm (QK/PV MFMAs + trailing
// fmha_alu0) and drops back to s_setprio(0) after. The two warp groups
// are offset by two phases and co-resident (one wave of each group per
// SIMD), so the group currently in the compute cluster outbids the
// group currently issuing memory for the shared VALU/MFMA issue port —
// targeting the ARBITER_NOT_WIN stall that gates the compute side
// (W0-3: 37.8% of its stalls). Under the macro the static W4-7=1 entry
// is neutralised to 0 so the non-compute baseline is uniformly prio 0.
// OFF (measured loser): dynamic s_setprio around the gemm MFMA cluster (vs the
// default static per-warp-group priority set once at loop entry).
#ifndef UA_DYNAMIC_SETPRIO
#define UA_DYNAMIC_SETPRIO 0
#endif
// UA_FA4_PACKED_SHIFT: emit the softmax score-shift (sp_delta = sp_compute *
// scale_s - scale_s * rowmax) as packed v_pk_fma_f32 (2 f32/instr) instead of 64
// scalar v_fma_f32. Bit-identical: each thread holds one rowmax
// (m.thread_buf_.size()==1) so the FMA addend is uniform across the thread's
// score elements and is broadcast into both packed lanes. Mirrors the hand-tuned
// ASM softmax (v_pk_fma_f32 for the rebase). Halves the shift instruction count.
//
// MEASURED REGRESSION, default OFF. Together with UA_FA4_PACKED_ALU1_RESCALE this
// costs ~3% on the canonical fp8 prefill shape (GPU2, same-session 3-run median:
// packed 1825 TF/s vs scalar 1877 TF/s). The softmax score-shift is hidden under
// the ping-pong overlap, so collapsing the FMAs does not shorten the critical
// path; it only perturbs the scheduler and loses. (An earlier "+4.5%" reading was
// a confounded GPU0 measurement.) Kept gated off for documentation; do not enable.
// OFF (measured ~-3% together with UA_FA4_PACKED_ALU1_RESCALE): emit the softmax
// score-shift as packed v_pk_fma_f32 instead of scalar v_fma_f32. Bit-identical,
// but the shift is hidden under the ping-pong overlap so collapsing it only
// perturbs the scheduler and loses.
#ifndef UA_FA4_PACKED_SHIFT
#define UA_FA4_PACKED_SHIFT 0
#endif
// UA_FA4_EXP2_APPROX: replace the per-element softmax exp (quarter-rate
// v_exp_f32) with the Schraudolph 2^x bit-trick (full-rate). The score-shift FMA
// in fmha_alu0 absorbs the 2^23 scale and the Schraudolph bias, so fmha_alu1 only
// needs a single v_cvt_u32_f32 per element instead of v_exp_f32. The per-row
// max-delta rescale keeps the exact v_exp_f32 (only 1/row, not on the hot path).
// This is an APPROXIMATION (~1e-3 rel error per element) -- it mirrors the ASM
// softmax fast SKU and is only applied on the non-masked, no-softcap path
// (compile-time gate below). Numerics-changing => default OFF; validate accuracy
// before enabling.
// OFF (approximation, ~1e-3 rel error): replace the quarter-rate v_exp_f32
// softmax with the full-rate Schraudolph 2^x bit-trick. Numerics-changing ->
// validate accuracy before enabling; only legal on the non-masked, no-softcap
// path (compile-time gate below).
#ifndef UA_FA4_EXP2_APPROX
#define UA_FA4_EXP2_APPROX 0
#endif
// Schraudolph 2^x bit-trick constants: bits = round(2^23 * x + (127*2^23 - 486411))
// reinterpreted as f32 ~= 2^x. Min-error offset matches the hand-tuned ASM softmax.
// Schraudolph 2^x constants: bits = round(2^23*x + (127*2^23 - 486411)) as f32.
#define UA_EXP2_SCHRAUDOLPH_SCALE 8388608.0f // 2^23
#define UA_EXP2_SCHRAUDOLPH_BIAS 1064866805.0f // 127*2^23 - 486411
// UA_FA4_PACKED_ROWSUM: reduce a thread's row-sum of probabilities with packed
// v_pk_add_f32 into a 2-wide partial, then a single scalar combine, instead of the
// scalar v_add_f32 dependency chain that block_tile_reduce emits. Halves the
// in-thread adds and shortens the latency chain feeding the cross-lane permlane.
// Reassociates the sum (rounding differs at the ULP level) -- safe within the fp8
// /bf16 attention tolerances.
//
// MEASURED LOSER (-13%): a serial v_pk_add_f32 accumulation is a 32-deep latency
// chain that is WORSE than block_tile_reduce's log-depth tree, and the dead scalar
// reduce is not DCE'd. Kept gated off for documentation; do not enable.
// OFF (measured -13%): reduce a thread's probability row-sum with packed
// v_pk_add_f32. The serial packed chain is deeper-latency than block_tile_reduce's
// log-depth tree, and the dead scalar reduce is not DCE'd.
#ifndef UA_FA4_PACKED_ROWSUM
#define UA_FA4_PACKED_ROWSUM 0
#endif
// UA_FA4_PACKED_ALU1_RESCALE: pack the 6-register o_acc partial rescale in
// fmha_alu1 (elementwise *= o_acc_scale) with v_pk_mul_f32, matching the packed
// rescale in fmha_alu_D_upd. Independent elementwise scale (no dependency chain),
// and it halves the number of asm-volatile scheduling boundaries (6 scalar
// v_mul_f32 -> 3 v_pk_mul_f32). Bit-identical.
//
// MEASURED REGRESSION, default OFF -- see the note on UA_FA4_PACKED_SHIFT above:
// the two together cost ~3% (1877 -> 1825 TF/s, GPU2) because the rescale is
// hidden under the ping-pong overlap. (An earlier "+4%" reading was a confounded
// GPU0 measurement.) Kept gated off for documentation; do not enable.
// OFF (measured ~-3% together with UA_FA4_PACKED_SHIFT): pack the o_acc partial
// rescale in fmha_alu1 with v_pk_mul_f32. Bit-identical, but the rescale is hidden
// under the overlap so it does not shorten the critical path.
#ifndef UA_FA4_PACKED_ALU1_RESCALE
#define UA_FA4_PACKED_ALU1_RESCALE 0
#endif
// CONDITIONAL_RESCALE (PLAN_conditional_rescale Part 2)
// 0 (default): always-rescale online softmax — the o_acc/l accumulators are
// renormalised to the true running max `m` every KV tile (the expensive
// 128-VGPR `v_pk_mul_f32` rescale tail in fmha_alu_D_upd + the 6-reg
// partial in fmha_alu1). Bit-identical to the pre-Part-2 kernel.
// 1: FA4-style conditional (skipped) rescale. Carry the accumulators in the
// frame of a *committed* max `m_commit` that only advances (with a rescale)
// when the true max pulls ahead by more than τ = log2 of the safe exp2
// bound. Between commits the shifted scores stay ≤ τ (exp2 ≤ 2^τ, fp32-
// safe) so o_acc/l just accumulate — the rescale multiplies are skipped.
// The decision is made wave-uniformly (ballot: rescale if ANY lane needs
// it) so the guard is a scalar branch with no divergence. Mathematically
// exact (the m_commit frame cancels in o_acc/l; LSE uses m_commit), so no
// end-of-loop correction is needed. Only applied on the 2-warp-group
// prefill path (see kCondRescale); decode keeps always-rescale. Part-1's
// --headroom instrument predicts ~85% (prefill) of rescales are skippable.
// Defined BEFORE the includes so unified_attention_core_loop_scheduler.hpp can
// gate its per-phase sched_group_barrier VALU hints on it (the gemm1+D_upd
// phase reserves ~36 VALU slots for the rescale tail that this skips).
// ON (FA4 prefill only): conditional (skipped) online-softmax rescale. Carry the
// o_acc/l accumulators in the frame of a committed max `m_commit` that advances
// (with a rescale) only when the true max pulls ahead by more than τ; between
// commits the shifted scores stay <= τ (exp2 <= 2^τ, fp32-safe) so the rescale
// multiplies are skipped. The decision is wave-uniform (ballot) so it is a scalar
// branch with no divergence. Mathematically exact (m_commit cancels in o_acc/l;
// LSE uses m_commit). Decode keeps always-rescale (see kCondRescale). Defined
// before the includes so the core-loop scheduler can drop the rescale's VALU slot
// reservation in lock-step.
#if !defined(CONDITIONAL_RESCALE)
#define CONDITIONAL_RESCALE 1
#endif
// τ in scaled-logit (log2) units. exp2(τ) bounds the un-rescaled scores; 8 =>
// max intermediate exp2 == 256, comfortably inside fp32 range even summed over
// thousands of keys. FA4 uses the same log2(256)=8.
// τ in scaled-logit (log2) units; exp2(τ) bounds the un-rescaled scores. 8 =>
// max intermediate exp2 == 256, fp32-safe even summed over thousands of keys.
#if !defined(CONDITIONAL_RESCALE_TAU)
#define CONDITIONAL_RESCALE_TAU 8.0f
#endif
@@ -666,6 +580,12 @@ struct UnifiedAttentionPipeline
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// Design A deep async ring depth (decode: kDecodeStages; FA4: 2). The K
// buffers occupy LDS slots [0, kRingStages); the V buffers follow at
// [kRingStages, 2*kRingStages) -- the V store descriptor's K-buffer base
// (KBufCount template arg) must therefore equal kRingStages too.
constexpr index_t kRingStages = Policy::template GetRingStages<Problem>();
constexpr index_t KStoreWarpShift = Policy::template GetKStoreWarpShift<Problem>();
auto k_lds_window_store = generate_tuple(
[&](auto i_buf) {
@@ -675,15 +595,17 @@ struct UnifiedAttentionPipeline
KLoadNumWarps,
KStoreWarpShift>(i_buf));
},
number<2>{});
number<kRingStages>{});
auto v_lds_window_store = generate_tuple(
[&](auto i_buf) {
return make_lds_tile_window<KDataType>(
smem_ptr,
Policy::template MakeVLdsStoreBlockDescriptor<Problem, VLoadNumWarps>(i_buf));
Policy::template MakeVLdsStoreBlockDescriptor<Problem,
VLoadNumWarps,
kRingStages>(i_buf));
},
number<2>{});
number<kRingStages>{});
statically_indexed_array<
decltype(make_tile_window(
@@ -691,7 +613,7 @@ struct UnifiedAttentionPipeline
nullptr,
Policy::template MakeKLdsLoadBlockDescriptor<Problem, KLoadNumWarps>()),
Policy::template MakeKRegTileDistribution<Problem>())),
2>
kRingStages>
k_lds_window_load;
statically_indexed_array<
@@ -700,7 +622,7 @@ struct UnifiedAttentionPipeline
nullptr,
Policy::template MakeVLdsLoadBlockDescriptor<Problem, VLoadNumWarps>()),
Policy::template MakeVRegTileDistribution<Problem>())),
2>
kRingStages>
v_lds_window_load;
decltype(make_static_distributed_tensor<QDataType>(
@@ -783,8 +705,10 @@ struct UnifiedAttentionPipeline
bool need_rescale = true;
#endif
// initialize k_lds_window and v_lds_window
static_for<0, 2, 1>{}([&](auto idx) {
// initialize k_lds_window and v_lds_window. K buffers occupy LDS slots
// [0, kRingStages); V buffers follow at [kRingStages, 2*kRingStages),
// matching the V store descriptor's KBufCount==kRingStages base.
static_for<0, kRingStages, 1>{}([&](auto idx) {
k_lds_window_load(idx) = make_tile_window(
make_lds_tile_window<KDataType>(
static_cast<char*>(smem_ptr) + (idx)*Policy::template GetSmemSizeKV<Problem>(),
@@ -792,11 +716,12 @@ struct UnifiedAttentionPipeline
Policy::template MakeKRegTileDistribution<Problem>());
});
static_for<0, 2, 1>{}([&](auto idx) {
static_for<0, kRingStages, 1>{}([&](auto idx) {
v_lds_window_load(idx) =
make_tile_window(make_lds_tile_window<VDataType>(
static_cast<char*>(smem_ptr) +
(idx + 2) * Policy::template GetSmemSizeKV<Problem>(),
(idx + kRingStages) *
Policy::template GetSmemSizeKV<Problem>(),
Policy::template MakeVLdsLoadBlockDescriptor<Problem,
VLoadNumWarps>()),
Policy::template MakeVRegTileDistribution<Problem>());
@@ -2967,6 +2892,12 @@ struct UnifiedAttentionPipeline
gemm(xdl_SP_p23_reg_idx, /*gemm_idx=*/number<1>{});
};
// Set by the N-deep decode ring (kRingStages>2) when it finalizes the
// last tile's deferred PV *inline* (it must, because that tile's V lives
// in LDS ring slot r%kRingStages, not the parity slot fmha_post_process
// assumes). Suppresses the shared post-process below in that case.
bool decode_ring_final_pv_done = false;
// pre-stage
{
ASM_MARKER("before pre-stage");
@@ -3020,12 +2951,19 @@ struct UnifiedAttentionPipeline
goto label_main_loops_exit;
}
if(2 < num_total_loop)
// K2 prefetch into buf0 (the freed K0 slot). Only the 2-buffer
// pipeline does this here; the N-deep ring (kRingStages>2) fills
// K2..K(N-1) into their own distinct slots inside the decode
// block below, so it must NOT pre-issue K2 into buf0.
if constexpr(kRingStages == 2)
{
K_mem_load(number<0>{}); // mem_K2
if(2 < num_total_loop)
{
K_mem_load(number<0>{}); // mem_K2
s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
__builtin_amdgcn_s_barrier();
s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
__builtin_amdgcn_s_barrier();
}
}
ASM_MARKER("end pre-stage");
@@ -3043,6 +2981,8 @@ struct UnifiedAttentionPipeline
// K1 in LDS (K buf 1) if num_total_loop >= 2
// K2 loading to LDS (K buf 0) if num_total_loop >= 3
if constexpr(kRingStages == 2)
{
// Step 1: consume V0, K1 -> produce PV(0), QK(1)
s_waitcnt_vmcnt<0>();
__builtin_amdgcn_s_barrier();
@@ -3119,6 +3059,93 @@ struct UnifiedAttentionPipeline
fmha_alu_D_upd();
i_total_loops++;
}
} // end if constexpr(kRingStages == 2)
else
{
// ===== Design A: N-deep async KV ring (decode) =====
// The pre-stage left: K_r0 consumed, K_r1 -> slot 1,
// V_r0 -> slot 0, and sp(0)=QK(r0) (alu1 pending). Fill the
// rest of the ring so K_r1..K_r(N-1) and V_r0..V_r(N-2) are
// all in flight (N-1 of each); slot(tile t) = t % N. Loads
// issued past the split end harmlessly reload the last tile
// into slots that are never consumed (K/V_mem_load stop
// advancing their DRAM window at the end), so no bound
// guards are needed and the in-flight count -- hence the
// vmcnt threshold below -- stays constant for every stage.
static_for<2, kRingStages, 1>{}(
[&](auto m) { K_mem_load(number<m % kRingStages>{}); });
static_for<1, kRingStages - 1, 1>{}(
[&](auto m) { V_mem_load(number<m % kRingStages>{}); });
// With the ring full, only the two oldest in-flight tiles
// (the K_r / V_(r-1) consumed this stage) must be resident;
// the other N-2 of each stay in flight -> deeper HBM overlap
// than the 2-buffer path's full vmcnt<0> drain. (N==2 would
// reduce to vmcnt<0>, i.e. the original; that case is handled
// by the branch above.)
constexpr index_t kRingVmcnt =
(kRingStages - 2) * (K_mem_su_ld_insts + V_mem_su_ld_insts);
// Unroll by N so every LDS ring slot is a compile-time index.
// Stage s processes relative tile r = r_base + s, r_base ≡ 1
// (mod N): read K_r from slot (s+1)%N and V_(r-1) from slot s;
// refill the two slots freed last stage (K_(r-1) slot s,
// V_(r-2) slot (s-1)%N) at the top -- different LDS buffers
// than the ones read below, so the async prefetch has no WAR
// and overlaps this stage's PV+QK. sp keeps its 2-way
// deferred-PV parity (N is even).
while(i_total_loops < num_total_loop)
{
static_for<0, kRingStages, 1>{}([&](auto s) {
constexpr index_t k_rd = (s + 1) % kRingStages;
constexpr index_t v_rd = s % kRingStages;
constexpr index_t k_pf = s % kRingStages;
constexpr index_t v_pf = (s + kRingStages - 1) % kRingStages;
constexpr index_t sp_pv = s % 2;
constexpr index_t sp_qk = (s + 1) % 2;
if(i_total_loops < num_total_loop)
{
s_waitcnt_vmcnt<kRingVmcnt>();
__builtin_amdgcn_s_barrier();
slide_page_table();
K_mem_load(number<k_pf>{}); // K_(r+N-1) -> freed K slot
V_mem_load(number<v_pf>{}); // V_(r+N-2) -> freed V slot
V_lds_load(number<v_rd>{}); // V_(r-1)
s_waitcnt_lgkmcnt<0>();
fmha_alu1(number<sp_pv>{}); // finalize sp(r-1) -> P
gemm(number<sp_pv>{}, /*gemm_idx=*/number<1>{}); // PV
K_lds_load(number<k_rd>{}); // K_r
s_waitcnt_lgkmcnt<0>();
gemm(number<sp_qk>{}, /*gemm_idx=*/number<0>{}); // QK -> sp(r)
fmha_mask(number<sp_qk>{});
fmha_alu0(number<sp_qk>{});
fmha_alu_D_upd();
i_total_loops++;
// Last tile: its PV would be deferred to a stage
// that never runs. Finalize it now -- V_r lives in
// the compile-time slot (s+1)%N (== k_rd). Drain
// first so V_r (left in flight by the staged wait)
// is resident; this also tells the shared epilogue
// not to re-run the (parity-indexed) post-process.
if(i_total_loops >= num_total_loop)
{
s_waitcnt_vmcnt<0>();
__builtin_amdgcn_s_barrier();
V_lds_load(number<k_rd>{}); // V_r
s_waitcnt_lgkmcnt<0>();
fmha_alu1(number<sp_qk>{});
gemm(number<sp_qk>{}, /*gemm_idx=*/number<1>{}); // PV
decode_ring_final_pv_done = true;
}
}
});
}
}
}
else
{
@@ -3203,8 +3230,11 @@ struct UnifiedAttentionPipeline
if(!(num_iters % 2))
fa4_epi(number<1>{});
}
else
else if(!decode_ring_final_pv_done)
{
// (The N-deep decode ring already finalized the last tile's PV
// inline -- see decode_ring_final_pv_done -- because its V sits in
// an LDS ring slot, not the parity slot this post-process assumes.)
if(num_iters % 2)
{
fmha_post_process(number<1>{});

View File

@@ -687,6 +687,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
// alone tile the full V buffer. Default = the shape's NumWarps (cooperative).
template <typename Problem,
ck_tile::index_t NumWarpsOverride = Problem::UnifiedAttentionShape::NumWarps,
ck_tile::index_t KBufCount = 2,
ck_tile::index_t IBuf = 0>
CK_TILE_DEVICE static constexpr auto
MakeVLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
@@ -728,7 +729,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<(IBuf + 2) * GetSingleSmemElementSpaceSize<Problem>()>{},
number<(IBuf + KBufCount) * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KVector>{},
number<1>{});
@@ -855,6 +856,33 @@ struct UnifiedAttentionPipelineDefaultPolicy
#endif
static constexpr bool kFA4WG1LoadsK = UA_FA4_WG1_LOADS_K;
// Design A (decode deep async ring). Number of K/V LDS landing buffers for
// the single-warp-group (decode) path: raising it from 2 keeps N-1 KV-tile
// DRAM loads in flight so the loop can stage `vmcnt` partial waits instead of
// a full per-tile drain (the memory-bound long-context decode regime). Must
// be EVEN (the deferred-PV score double-buffer keeps 2-way parity, which is
// only compile-time resolvable across the N-unroll when N is even) and >= 2.
// Default 2 == the original 2-buffer serial pipeline, bit-identical. The
// 2-warp-group FA4 (prefill) path always uses 2 regardless. LDS cost is
// 2*N*GetSmemSizeKV, so larger N may cost occupancy on the LDS-bound decode
// tiers -- sweep per tier.
#ifndef UA_DECODE_STAGES
#define UA_DECODE_STAGES 2
#endif
static constexpr ck_tile::index_t kDecodeStages = UA_DECODE_STAGES;
static_assert(kDecodeStages >= 2 && (kDecodeStages % 2 == 0),
"UA_DECODE_STAGES must be an even integer >= 2");
// Ring depth actually used by a given kernel instance: the deep ring is a
// decode-only (single-warp-group) lever; the FA4 prefill path keeps 2.
template <typename Problem>
CK_TILE_DEVICE static constexpr ck_tile::index_t GetRingStages()
{
constexpr ck_tile::index_t NumWarpGroups =
Problem::kBlockSize / NumThreadPerWarpGroup;
return (NumWarpGroups == 1) ? kDecodeStages : 2;
}
// Number of waves that cooperate on a V DRAM->LDS load. For the 2-warp-group
// FA4 path with kFA4WG0LoadsV, this is one warp group's waves (so WG0 alone
// fills the tile); otherwise it's the full block (original cooperative load).
@@ -898,7 +926,10 @@ struct UnifiedAttentionPipelineDefaultPolicy
template <typename Problem>
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return 4 * GetSmemSizeKV<Problem>();
// kRingStages K buffers + kRingStages V buffers. Decode uses
// kDecodeStages (>=2), FA4 prefill stays 2 -> the default (N==2)
// reproduces the original 4*GetSmemSizeKV budget exactly.
return 2 * GetRingStages<Problem>() * GetSmemSizeKV<Problem>();
}
};