diff --git a/include/ck_tile/ops/unified_attention/README.md b/include/ck_tile/ops/unified_attention/README.md new file mode 100644 index 0000000000..576cefda76 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/README.md @@ -0,0 +1,168 @@ +# CK Tile — Unified Attention kernel + +FlashAttention-style fused attention forward for **paged and contiguous KV**, +covering **prefill and decode** in one operator, for **fp8 / bf16 / fp16** on +gfx950 (MI350/MI355). This is the kernel behind aiter's `unified_attention` op; +the test/perf/analysis tooling lives in `aiter/ua-test-scripts/` (see its README). + +This document is the architecture reference for reading the kernel and for +porting it (e.g. to FlyDSL). The functional contract mirrors Triton's +`unified_attention` reference. + +--- + +## File map + +| file | role | +|---|---| +| `kernel/unified_attention_kernel.hpp` | Host `MakeKargs` / grid helpers + the device `operator()`: per-CTA index math (grid → kv_head/seq/q_block/split), Q/K/V/O DRAM window setup, mask construction, split-KV partitioning, the call into the pipeline, and the epilogue / split-KV workspace writes. | +| `pipeline/unified_attention_pipeline.hpp` | The core per-CTA attention pipeline (the hot loop). Two regimes share one file: the **FA4** 2-warp-group matrix‖softmax overlap (prefill) and the **single-warp-group serial deferred-PV** pipeline (decode). | +| `pipeline/unified_attention_pipeline_default_policy.hpp` | Compile-time policy: tile distributions, LDS descriptors, warp-gemm selection, load-width/alignment selection, smem sizing, async-ring depth, and tuning constants. | +| `pipeline/unified_attention_pipeline_problem.hpp` | Bundles dtypes + shape + traits + mask into the `Problem` type. | +| `pipeline/tile_unified_attention_shape.hpp` | Block/warp tile dims: `kBlockM`, `kBlockQ`, `kPageBlockSize`, `kHeadDim`. | +| `pipeline/tile_unified_attention_traits.hpp` | Padding + occupancy (`kBlockPerCu`) traits. | +| `pipeline/unified_attention_core_loop_scheduler.hpp` | Per-phase `sched_group_barrier` instruction-mix hints that enforce the FA4 phase overlap (kept in lock-step with the macros in the pipeline header). | +| `block/block_masking.hpp` | Causal / sliding-window (FA-style left/right) mask. | + +Concrete shape/dtype **instances** (the JIT-compiled translation units) live in +`composable_kernel/example/ck_tile/42_unified_attention/instances/`, dispatched by +`unified_attention.cpp` there. + +--- + +## Per-CTA work assignment (`kernel` `operator()`) + +One CTA computes one `(kv_head, q_block, split)` tuple: + +- **Decode grid** `dim3(num_kv_heads, num_seqs, num_splits)` (`gridDim.y > 1`): + direct mapping, no binary search, no padding CTAs. +- **Prefill grid** `dim3(num_kv_heads * total_num_q_blocks, 1, num_splits)`: + `blockIdx.x` is folded; a binary search over `query_start_len_ptr` recovers the + sequence, and out-of-range q-blocks early-return. + +`num_queries_per_kv` (GQA ratio) is a **runtime** value: `kBlockQ_dyn = kBlockM / +num_queries_per_kv`, so one compiled binary serves MHA and any GQA-N that divides +`kBlockM`. The `kBlockM`-row MFMA tile packs `num_queries_per_kv` consecutive +query rows per KV token; when `kBlockM % num_queries_per_kv != 0` the last 1–2 +rows spill into the next q-tile's first token (co-owned), which drives several +correctness invariants in the index math — see the comments around +`last_tile_row_q_off` and the split-KV partition. + +--- + +## Math (online softmax) + +Standard streaming softmax with running max `m`, running sum `l`, output +accumulator `o_acc`, over KV tiles of `kPageBlockSize` tokens: + +- Scores `S = scale_s · (Q·Kᵀ)`, masked, then `P = exp2(S − m)` (base-2). +- `scale_s` is pre-fused on the host (`MakeKargs`): `sm_scale · q_descale · + k_descale · log2(e)`. The `log2(e)` lets the device use full/þrate `exp2` + instead of `exp`; the fp8 Q/K per-tensor descales fold in so the inner loop + carries a single scalar (matches Triton's `qk_scale = sm_scale·q_scale·k_scale`). +- The fp8 V per-tensor descale `v_descale` is **deferred** to the post-loop + `o_acc · v_descale / l` step — exact, since V is a linear factor on the + unnormalised output. Non-fp8 dtypes pass `1.0f` (free no-op). +- Output is `o_acc / l`; `lse = m + log(l)` (natural-log domain) is returned for + split-KV combine. + +**PV is deferred one tile** (double-buffered on parity): the sequence per tile is +`alu1/pack(prev) → PV(prev) → QK(cur) → alu0/rowmax(cur) → D_upd/rescale`. This +is the known-correct ordering shared by both regimes. + +--- + +## Two pipeline regimes + +### FA4 (prefill, 2 warp groups, `NumWarpGroups == 2`) +FlashAttention-4-style overlap. The deferred-PV sequence is cut into two phases: +- **MATRIX** phase: `PV(k-1) + QK(k)` — matrix pipe only. +- **SOFTMAX** phase: `alu1/exp + alu0/rowmax + D_upd/rescale` — VALU/MUFU only. + +The two warp groups are primed one phase apart (WG0 in MATRIX while WG1 in +SOFTMAX), so on each SIMD the matrix work of one wave hides the transcendental +work of its co-resident partner. K/V are prefetched a tile ahead into a shared +double buffer at the per-phase block barrier (issued cooperatively by all 8 +warps). The `core_loop_scheduler` hints reserve the per-phase instruction mix. + +For fp8 the QK-C and PV-A per-thread layouts diverge (PV is forced to +`WGAttrNumAccess::Single`), so after packing, `P` is round-tripped through an LDS +window in canonical (M,N) order and reloaded in the PV-A distribution. + +### Serial (decode, single warp group, `kFA4 == false`) +The same deferred-PV pipeline run serially by one 4-warp group, with K/V +double-buffered in LDS. Decode is **HBM-bandwidth bound** at long context; see +`aiter/ua-test-scripts/decode_pipeline_*.md`. + +--- + +## Paged KV + +When `kIsPaged`, KV tokens are resolved through `block_tables` (per-sequence page +lists). Performance hinges on keeping page-index resolution off the critical +path, via tiers selected at compile time: + +- **Constexpr page size** (`kPageSize_ > 0`, the `ps16/ps32/ps64/ps128` + instances): strength-reduces every `/ % * page_size` to shifts and enables the + exact tier gates below. `kPageSize_ == 0` is the runtime-page-size catch-all. +- **Scalar-promote / single-page SRD rebase**: when a wave's load lands within one + page, fold the page base into the buffer SRD once per wave and drop the per-lane + block-table path. +- **Tier-2 LDS-resident page-table cache** (`kPageTableLdsEntries = 4096`, 16 KiB): + resolves the multi-page fallback's page indices from LDS instead of per-lane + global reads. Coverage `≤ 4096 · page_size` tokens; beyond that the kernel + traps (a runtime fallback was measured −30% from register pressure). + +When `!kIsPaged` (contiguous/THD) the logical KV token index *is* its physical +row (per-sequence base folded into the K/V pointer), so all paging math compiles +out. + +## Split-KV + +`num_splits > 1` partitions the KV range across CTAs in `blockIdx.z`; each writes +fp32 `o_acc`/`lse` workspaces that a separate combine kernel merges. The +partition is computed over the **causal-independent full-sequence** block count +(not the per-tile causal horizon) so a token co-owned by adjacent q-tiles maps to +the same split in both — otherwise non-dividing-GQA + causal races on the shared +token. `num_splits == 1` skips this path entirely. + +--- + +## Tuning knobs & failed experiments + +Active defaults (do not change without re-validating correctness + perf): + +| macro / policy | default | effect | +|---|---|---| +| `CONDITIONAL_RESCALE` | `1` | FA4-only: carry accumulators in a committed-max frame and skip the online rescale while shifted scores stay ≤ `τ` (`CONDITIONAL_RESCALE_TAU = 8`). Mathematically exact. | +| `UA_FA4_PREFETCH_IN_SOFTMAX` | `1` | bf16/fp16: issue next-tile K/V async prefetch from the SOFTMAX phase (keeps MATRIX pure-matrix). | +| `GetKVAlignmentBytes` | dwordx4 where it tiles | Widen fp8 decode K/V async loads to 16 B/lane (the narrow 4 B/lane default was the main fp8-slower-than-bf16 decode regression). | +| `kKFallbackLds` (`UA_K_FALLBACK_LDS`) | `1` | Resolve the multi-page K fallback (ps16/ps32) page indices via the LDS cache instead of per-lane global reads. | +| `UA_DECODE_STAGES` | `2` | Decode async-ring depth (deeper buffering was measured perf-neutral; decode is BW-bound). | + +Experiments kept **OFF** (measured losers — retained as one-line gates so the +rationale isn't relitigated): `UA_FA4_PACKED_SHIFT`, `UA_FA4_PACKED_ALU1_RESCALE` +(together ~−3% on canonical fp8 prefill — softmax is hidden under the overlap), +`UA_FA4_PACKED_ROWSUM` (−13%, serial chain beats the log-depth tree), +`UA_DYNAMIC_SETPRIO`, `MOVE_FMHA_MASK_TO_COMPUTE` (fp8 +8.8% regression), +`MOVE_FMHA_MASK_TO_GEMM1`, `UA_FA4_PIN_PACK_IN_SOFTMAX`. `UA_FA4_EXP2_APPROX` +(Schraudolph 2^x) is an *approximation* — off by default, validate accuracy first. + +--- + +## Building & testing + +The kernel is consumed via aiter's JIT module `module_unified_attention`. After +editing any file here you **must** force a fresh build — the `.so` is not rebuilt +automatically: + +```bash +# from the aiter repo root +rm -rf aiter/jit/build/module_unified_attention aiter/jit/module_unified_attention.so +AITER_REBUILD=1 HIP_VISIBLE_DEVICES=2 python3 op_tests/test_unified_attention_ck.py --full +``` + +See `aiter/ua-test-scripts/README.md` for correctness/perf, and +`aiter/ua-test-scripts/analysis/README.md` for ISA/VGPR/overlap-trace tooling +(including the JIT-free standalone driver, which stamps every build so a stale +binary can never be measured). diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 714779e1d5..17e9b62aaa 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -660,27 +660,16 @@ struct UnifiedAttentionKernel return FmhaMask{cur_batch_query_len, seq_len}; }(); - // Step D: Sliding-Window-Attention tile-range clip. - // The per-pixel mask check inside the pipeline already returns the - // correct (zeroed) score for tokens outside the SWA window, so - // skipping this block is correctness-preserving. The point of the - // clip is to skip entire KV tiles that fall completely outside the - // window — for long-context decode that's the difference between - // O(seq_k / kPageBlockSize) and O(window / kPageBlockSize) - // iterations. The intersection with the current split's - // [num_blocks_start, num_blocks) is taken so split-KV stays correct. - // Step D: Sliding-Window-Attention KV-tile clip. - // - // This is REQUIRED for correctness, not just an optimisation. The - // online-softmax pipeline interleaves `m` / `l` updates with prefetch - // and warp-group barriers; an all-(-inf) tile (one wholly outside the - // SWA window) feeds NaN/garbage into the `m` accumulator at the - // barrier boundary, corrupting subsequent tiles. Skipping these - // tiles entirely keeps every iterated tile either fully-inside the - // window or a true edge tile that the per-pixel mask can clean up. - // - // The intersection with the current split's [num_blocks_start, - // num_blocks) is taken so split-KV stays correct. + // Step D: Sliding-Window-Attention KV-tile clip — skip entire KV tiles + // that fall completely outside the window. REQUIRED for correctness, not + // just an optimisation: the online-softmax pipeline interleaves m/l + // updates with prefetch and warp-group barriers, so an all-(-inf) tile + // (wholly outside the window) would feed NaN/garbage into the m + // accumulator at a barrier boundary and corrupt later tiles. For + // long-context decode it is also the difference between + // O(seq_k/kPageBlockSize) and O(window/kPageBlockSize) iterations. The + // intersection with the current split's [num_blocks_start, num_blocks) + // keeps split-KV correct. if constexpr(FmhaMask::IsMasking && FmhaMask::IsLocal) { const auto sw_range = mask.GetTileRangeAlongX( diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index c7d94edd42..1ffcdd28b9 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -3,180 +3,94 @@ #pragma once -// EXPERIMENT: move the FA4 V LDS transpose-read out of the SOFTMAX phase and -// into the MATRIX phase (right before its PV consumer). The default ("Stage B") -// issues the V ds_read in the *preceding* softmax to hide its latency under the -// softmax VALU — but the ATT trace shows that read stalls ~69% and sits on the -// (longer/critical) softmax phase. Moving it to MATRIX, which has barrier slack, -// takes it off the critical path. =1 to enable. +// =========================================================================== +// Tuning knobs & experiment toggles. +// See the "Tuning knobs & failed experiments" table in this folder's README.md +// for the rationale and measured results behind each default. These MUST be set +// BEFORE including unified_attention_core_loop_scheduler.hpp: that header's +// per-phase __builtin_amdgcn_sched_group_barrier hints are gated on the same +// macros and must stay in lock-step with the code motion in this file. +// =========================================================================== + +// OFF (measured loser): pin the fp32->fp8 P-pack tail of fmha_alu1 to the +// SOFTMAX phase instead of letting it sink into the next MATRIX slot. #ifndef UA_FA4_PIN_PACK_IN_SOFTMAX -// Experiment (option 3): fence the fp32->fp8 P-pack (cvt_pk_fp8 tail of -// fmha_alu1) so it retires inside the SOFTMAX phase instead of sinking past the -// matrix barrier into the next MATRIX slot. In the MATRIX slot the pack (a VALU -// op) contends with the co-resident warp group's softmax VALU (the v_max3 -// rowmax tree) on the shared SIMD issue port; pinning it to SOFTMAX trades that -// cross-wave contention for in-phase exposure on the (longer) softmax phase. #define UA_FA4_PIN_PACK_IN_SOFTMAX 0 #endif -// UA_FA4_PREFETCH_IN_SOFTMAX: issue the next-tile K/V async DRAM prefetch from -// the SOFTMAX phase instead of the MATRIX phase. Gated to the 2-byte (bf16/fp16) -// path in the loop (fp8 keeps the matrix-phase prefetch it was tuned for). The -// bf16 prefetch is double the VMEM bytes and its buffer_load-to-LDS issue was -// landing in the (lgkm-stalled) MATRIX phase; moving the *issue* to the VALU- -// bound SOFTMAX phase keeps MATRIX pure-matrix. Residency is unchanged: the next -// MATRIX still drains the load via s_waitcnt_vmcnt<0> + block barrier before any -// K/V LDS read, so this only moves WHERE the async load is kicked off. +// ON (bf16/fp16 only): issue the next-tile K/V async prefetch from the SOFTMAX +// phase so the MATRIX phase stays pure-matrix. Residency is unchanged (the next +// MATRIX still drains the load before any K/V LDS read); this only moves WHERE +// the load is kicked off. fp8 keeps the matrix-phase prefetch it was tuned for. #ifndef UA_FA4_PREFETCH_IN_SOFTMAX #define UA_FA4_PREFETCH_IN_SOFTMAX 1 #endif -// FMHA_MASK PLACEMENT: pick exactly one of: -// - both 0 → baseline (mask in K-side memory phase, W0-3 phase 1 -// / W4-7 phase 2, right after `cl_load(memK)`). -// - MOVE_FMHA_MASK_TO_COMPUTE=1: hoist mask onto the compute phase -// (W0-3 phase 0 / W4-7 phase 1), right after `fmha_alu1`. -// Experiment 1.5 finding: bf16 −0.33%, **fp8 +8.8% regression** -// because the FP8 cvt+bperm cluster inside `fmha_alu1` makes the -// compute phase already-saturated; adding T_mask oversubscribes -// it and the empirical cost is ~2× the bare instruction count. -// - MOVE_FMHA_MASK_TO_GEMM1=1: place mask at the START of the -// gemm1 phase (W0-3 phase 2 / W4-7 phase 3), right before -// `cl_calc(xdl_SP_p23_reg_idx, gemm1)`. This is the latest legal -// placement: `cl_calc(p23, gemm1)` ends with `fmha_alu0(p01_idx)` -// which READS `sp[p01_idx].sp_compute` to compute `m_latest`, so -// mask MUST run before that. Phase 3 (V-mem on W0-3, gemm1 on -// W4-7) is too late and silently corrupts the row-max. -// -// For W4-7 the `++i_total_loops` also defers from end of phase 2 -// to start of phase 3 (after mask, before cl_calc) so mask sees -// the same i_total_loops value as gemm0 of this iter. -// -// Per-barrier algebra (mask added to gemm1 phase = T_D on both -// warp groups, removed from K-mem = T_K on both): -// - B1 wait = |T_C − (T_D + T_mask)|. With baseline T_C > T_D -// on FP8, the gap closes — DECREASES by T_mask. -// - B2 wait = |(T_K − T_mask) − T_C| — DECREASES by T_mask. -// - B3 wait = |(T_D + T_mask) − (T_K − T_mask)| -// = |T_D − T_K + 2·T_mask| — DECREASES by 2·T_mask. -// - Net: −4·T_mask total wait (vs −2·T_mask for compute), and -// gemm1 phase has no FP8 cvt+bperm so it should absorb the -// mask without the FP8 oversubscription that hit compute. -// -// Must be defined BEFORE including unified_attention_core_loop_scheduler.hpp -// — that header's `__builtin_amdgcn_sched_group_barrier` per-phase -// hints are gated on these macros and need to stay in lockstep with -// the code motion in this file. +// OFF (measured losers): alternative placements of the FMHA mask within the FA4 +// phases. Default (both 0) = mask in the K-side memory phase. _TO_COMPUTE +// oversubscribes the fp8 compute phase (+8.8% regression); _TO_GEMM1 is the +// latest legal placement (must precede the fmha_alu0 that reads sp_compute) but +// did not win. At most one may be 1. #define MOVE_FMHA_MASK_TO_COMPUTE 0 #define MOVE_FMHA_MASK_TO_GEMM1 0 #if MOVE_FMHA_MASK_TO_COMPUTE && MOVE_FMHA_MASK_TO_GEMM1 #error "MOVE_FMHA_MASK_TO_COMPUTE and MOVE_FMHA_MASK_TO_GEMM1 are mutually exclusive" #endif -// UA_DYNAMIC_SETPRIO (warp-group-balance plan A2, HipKittens-style) -// 0 (default): static per-warp-group priority, set once at loop entry -// (W0-3 → s_setprio(0), W4-7 → s_setprio(1)). Baseline, bit-identical. -// 1: dynamic priority around the gemm MFMA cluster. `cl_calc` raises -// s_setprio(1) for the duration of the gemm (QK/PV MFMAs + trailing -// fmha_alu0) and drops back to s_setprio(0) after. The two warp groups -// are offset by two phases and co-resident (one wave of each group per -// SIMD), so the group currently in the compute cluster outbids the -// group currently issuing memory for the shared VALU/MFMA issue port — -// targeting the ARBITER_NOT_WIN stall that gates the compute side -// (W0-3: 37.8% of its stalls). Under the macro the static W4-7=1 entry -// is neutralised to 0 so the non-compute baseline is uniformly prio 0. +// OFF (measured loser): dynamic s_setprio around the gemm MFMA cluster (vs the +// default static per-warp-group priority set once at loop entry). #ifndef UA_DYNAMIC_SETPRIO #define UA_DYNAMIC_SETPRIO 0 #endif -// UA_FA4_PACKED_SHIFT: emit the softmax score-shift (sp_delta = sp_compute * -// scale_s - scale_s * rowmax) as packed v_pk_fma_f32 (2 f32/instr) instead of 64 -// scalar v_fma_f32. Bit-identical: each thread holds one rowmax -// (m.thread_buf_.size()==1) so the FMA addend is uniform across the thread's -// score elements and is broadcast into both packed lanes. Mirrors the hand-tuned -// ASM softmax (v_pk_fma_f32 for the rebase). Halves the shift instruction count. -// -// MEASURED REGRESSION, default OFF. Together with UA_FA4_PACKED_ALU1_RESCALE this -// costs ~3% on the canonical fp8 prefill shape (GPU2, same-session 3-run median: -// packed 1825 TF/s vs scalar 1877 TF/s). The softmax score-shift is hidden under -// the ping-pong overlap, so collapsing the FMAs does not shorten the critical -// path; it only perturbs the scheduler and loses. (An earlier "+4.5%" reading was -// a confounded GPU0 measurement.) Kept gated off for documentation; do not enable. +// OFF (measured ~-3% together with UA_FA4_PACKED_ALU1_RESCALE): emit the softmax +// score-shift as packed v_pk_fma_f32 instead of scalar v_fma_f32. Bit-identical, +// but the shift is hidden under the ping-pong overlap so collapsing it only +// perturbs the scheduler and loses. #ifndef UA_FA4_PACKED_SHIFT #define UA_FA4_PACKED_SHIFT 0 #endif -// UA_FA4_EXP2_APPROX: replace the per-element softmax exp (quarter-rate -// v_exp_f32) with the Schraudolph 2^x bit-trick (full-rate). The score-shift FMA -// in fmha_alu0 absorbs the 2^23 scale and the Schraudolph bias, so fmha_alu1 only -// needs a single v_cvt_u32_f32 per element instead of v_exp_f32. The per-row -// max-delta rescale keeps the exact v_exp_f32 (only 1/row, not on the hot path). -// This is an APPROXIMATION (~1e-3 rel error per element) -- it mirrors the ASM -// softmax fast SKU and is only applied on the non-masked, no-softcap path -// (compile-time gate below). Numerics-changing => default OFF; validate accuracy -// before enabling. +// OFF (approximation, ~1e-3 rel error): replace the quarter-rate v_exp_f32 +// softmax with the full-rate Schraudolph 2^x bit-trick. Numerics-changing -> +// validate accuracy before enabling; only legal on the non-masked, no-softcap +// path (compile-time gate below). #ifndef UA_FA4_EXP2_APPROX #define UA_FA4_EXP2_APPROX 0 #endif -// Schraudolph 2^x bit-trick constants: bits = round(2^23 * x + (127*2^23 - 486411)) -// reinterpreted as f32 ~= 2^x. Min-error offset matches the hand-tuned ASM softmax. +// Schraudolph 2^x constants: bits = round(2^23*x + (127*2^23 - 486411)) as f32. #define UA_EXP2_SCHRAUDOLPH_SCALE 8388608.0f // 2^23 #define UA_EXP2_SCHRAUDOLPH_BIAS 1064866805.0f // 127*2^23 - 486411 -// UA_FA4_PACKED_ROWSUM: reduce a thread's row-sum of probabilities with packed -// v_pk_add_f32 into a 2-wide partial, then a single scalar combine, instead of the -// scalar v_add_f32 dependency chain that block_tile_reduce emits. Halves the -// in-thread adds and shortens the latency chain feeding the cross-lane permlane. -// Reassociates the sum (rounding differs at the ULP level) -- safe within the fp8 -// /bf16 attention tolerances. -// -// MEASURED LOSER (-13%): a serial v_pk_add_f32 accumulation is a 32-deep latency -// chain that is WORSE than block_tile_reduce's log-depth tree, and the dead scalar -// reduce is not DCE'd. Kept gated off for documentation; do not enable. +// OFF (measured -13%): reduce a thread's probability row-sum with packed +// v_pk_add_f32. The serial packed chain is deeper-latency than block_tile_reduce's +// log-depth tree, and the dead scalar reduce is not DCE'd. #ifndef UA_FA4_PACKED_ROWSUM #define UA_FA4_PACKED_ROWSUM 0 #endif -// UA_FA4_PACKED_ALU1_RESCALE: pack the 6-register o_acc partial rescale in -// fmha_alu1 (elementwise *= o_acc_scale) with v_pk_mul_f32, matching the packed -// rescale in fmha_alu_D_upd. Independent elementwise scale (no dependency chain), -// and it halves the number of asm-volatile scheduling boundaries (6 scalar -// v_mul_f32 -> 3 v_pk_mul_f32). Bit-identical. -// -// MEASURED REGRESSION, default OFF -- see the note on UA_FA4_PACKED_SHIFT above: -// the two together cost ~3% (1877 -> 1825 TF/s, GPU2) because the rescale is -// hidden under the ping-pong overlap. (An earlier "+4%" reading was a confounded -// GPU0 measurement.) Kept gated off for documentation; do not enable. +// OFF (measured ~-3% together with UA_FA4_PACKED_SHIFT): pack the o_acc partial +// rescale in fmha_alu1 with v_pk_mul_f32. Bit-identical, but the rescale is hidden +// under the overlap so it does not shorten the critical path. #ifndef UA_FA4_PACKED_ALU1_RESCALE #define UA_FA4_PACKED_ALU1_RESCALE 0 #endif -// CONDITIONAL_RESCALE (PLAN_conditional_rescale Part 2) -// 0 (default): always-rescale online softmax — the o_acc/l accumulators are -// renormalised to the true running max `m` every KV tile (the expensive -// 128-VGPR `v_pk_mul_f32` rescale tail in fmha_alu_D_upd + the 6-reg -// partial in fmha_alu1). Bit-identical to the pre-Part-2 kernel. -// 1: FA4-style conditional (skipped) rescale. Carry the accumulators in the -// frame of a *committed* max `m_commit` that only advances (with a rescale) -// when the true max pulls ahead by more than τ = log2 of the safe exp2 -// bound. Between commits the shifted scores stay ≤ τ (exp2 ≤ 2^τ, fp32- -// safe) so o_acc/l just accumulate — the rescale multiplies are skipped. -// The decision is made wave-uniformly (ballot: rescale if ANY lane needs -// it) so the guard is a scalar branch with no divergence. Mathematically -// exact (the m_commit frame cancels in o_acc/l; LSE uses m_commit), so no -// end-of-loop correction is needed. Only applied on the 2-warp-group -// prefill path (see kCondRescale); decode keeps always-rescale. Part-1's -// --headroom instrument predicts ~85% (prefill) of rescales are skippable. -// Defined BEFORE the includes so unified_attention_core_loop_scheduler.hpp can -// gate its per-phase sched_group_barrier VALU hints on it (the gemm1+D_upd -// phase reserves ~36 VALU slots for the rescale tail that this skips). +// ON (FA4 prefill only): conditional (skipped) online-softmax rescale. Carry the +// o_acc/l accumulators in the frame of a committed max `m_commit` that advances +// (with a rescale) only when the true max pulls ahead by more than τ; between +// commits the shifted scores stay <= τ (exp2 <= 2^τ, fp32-safe) so the rescale +// multiplies are skipped. The decision is wave-uniform (ballot) so it is a scalar +// branch with no divergence. Mathematically exact (m_commit cancels in o_acc/l; +// LSE uses m_commit). Decode keeps always-rescale (see kCondRescale). Defined +// before the includes so the core-loop scheduler can drop the rescale's VALU slot +// reservation in lock-step. #if !defined(CONDITIONAL_RESCALE) #define CONDITIONAL_RESCALE 1 #endif -// τ in scaled-logit (log2) units. exp2(τ) bounds the un-rescaled scores; 8 => -// max intermediate exp2 == 256, comfortably inside fp32 range even summed over -// thousands of keys. FA4 uses the same log2(256)=8. +// τ in scaled-logit (log2) units; exp2(τ) bounds the un-rescaled scores. 8 => +// max intermediate exp2 == 256, fp32-safe even summed over thousands of keys. #if !defined(CONDITIONAL_RESCALE_TAU) #define CONDITIONAL_RESCALE_TAU 8.0f #endif @@ -666,6 +580,12 @@ struct UnifiedAttentionPipeline const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + // Design A deep async ring depth (decode: kDecodeStages; FA4: 2). The K + // buffers occupy LDS slots [0, kRingStages); the V buffers follow at + // [kRingStages, 2*kRingStages) -- the V store descriptor's K-buffer base + // (KBufCount template arg) must therefore equal kRingStages too. + constexpr index_t kRingStages = Policy::template GetRingStages(); + constexpr index_t KStoreWarpShift = Policy::template GetKStoreWarpShift(); auto k_lds_window_store = generate_tuple( [&](auto i_buf) { @@ -675,15 +595,17 @@ struct UnifiedAttentionPipeline KLoadNumWarps, KStoreWarpShift>(i_buf)); }, - number<2>{}); + number{}); auto v_lds_window_store = generate_tuple( [&](auto i_buf) { return make_lds_tile_window( smem_ptr, - Policy::template MakeVLdsStoreBlockDescriptor(i_buf)); + Policy::template MakeVLdsStoreBlockDescriptor(i_buf)); }, - number<2>{}); + number{}); statically_indexed_array< decltype(make_tile_window( @@ -691,7 +613,7 @@ struct UnifiedAttentionPipeline nullptr, Policy::template MakeKLdsLoadBlockDescriptor()), Policy::template MakeKRegTileDistribution())), - 2> + kRingStages> k_lds_window_load; statically_indexed_array< @@ -700,7 +622,7 @@ struct UnifiedAttentionPipeline nullptr, Policy::template MakeVLdsLoadBlockDescriptor()), Policy::template MakeVRegTileDistribution())), - 2> + kRingStages> v_lds_window_load; decltype(make_static_distributed_tensor( @@ -783,8 +705,10 @@ struct UnifiedAttentionPipeline bool need_rescale = true; #endif - // initialize k_lds_window and v_lds_window - static_for<0, 2, 1>{}([&](auto idx) { + // initialize k_lds_window and v_lds_window. K buffers occupy LDS slots + // [0, kRingStages); V buffers follow at [kRingStages, 2*kRingStages), + // matching the V store descriptor's KBufCount==kRingStages base. + static_for<0, kRingStages, 1>{}([&](auto idx) { k_lds_window_load(idx) = make_tile_window( make_lds_tile_window( static_cast(smem_ptr) + (idx)*Policy::template GetSmemSizeKV(), @@ -792,11 +716,12 @@ struct UnifiedAttentionPipeline Policy::template MakeKRegTileDistribution()); }); - static_for<0, 2, 1>{}([&](auto idx) { + static_for<0, kRingStages, 1>{}([&](auto idx) { v_lds_window_load(idx) = make_tile_window(make_lds_tile_window( static_cast(smem_ptr) + - (idx + 2) * Policy::template GetSmemSizeKV(), + (idx + kRingStages) * + Policy::template GetSmemSizeKV(), Policy::template MakeVLdsLoadBlockDescriptor()), Policy::template MakeVRegTileDistribution()); @@ -2967,6 +2892,12 @@ struct UnifiedAttentionPipeline gemm(xdl_SP_p23_reg_idx, /*gemm_idx=*/number<1>{}); }; + // Set by the N-deep decode ring (kRingStages>2) when it finalizes the + // last tile's deferred PV *inline* (it must, because that tile's V lives + // in LDS ring slot r%kRingStages, not the parity slot fmha_post_process + // assumes). Suppresses the shared post-process below in that case. + bool decode_ring_final_pv_done = false; + // pre-stage { ASM_MARKER("before pre-stage"); @@ -3020,12 +2951,19 @@ struct UnifiedAttentionPipeline goto label_main_loops_exit; } - if(2 < num_total_loop) + // K2 prefetch into buf0 (the freed K0 slot). Only the 2-buffer + // pipeline does this here; the N-deep ring (kRingStages>2) fills + // K2..K(N-1) into their own distinct slots inside the decode + // block below, so it must NOT pre-issue K2 into buf0. + if constexpr(kRingStages == 2) { - K_mem_load(number<0>{}); // mem_K2 + if(2 < num_total_loop) + { + K_mem_load(number<0>{}); // mem_K2 - s_waitcnt_vmcnt(); - __builtin_amdgcn_s_barrier(); + s_waitcnt_vmcnt(); + __builtin_amdgcn_s_barrier(); + } } ASM_MARKER("end pre-stage"); @@ -3043,6 +2981,8 @@ struct UnifiedAttentionPipeline // K1 in LDS (K buf 1) if num_total_loop >= 2 // K2 loading to LDS (K buf 0) if num_total_loop >= 3 + if constexpr(kRingStages == 2) + { // Step 1: consume V0, K1 -> produce PV(0), QK(1) s_waitcnt_vmcnt<0>(); __builtin_amdgcn_s_barrier(); @@ -3119,6 +3059,93 @@ struct UnifiedAttentionPipeline fmha_alu_D_upd(); i_total_loops++; } + } // end if constexpr(kRingStages == 2) + else + { + // ===== Design A: N-deep async KV ring (decode) ===== + // The pre-stage left: K_r0 consumed, K_r1 -> slot 1, + // V_r0 -> slot 0, and sp(0)=QK(r0) (alu1 pending). Fill the + // rest of the ring so K_r1..K_r(N-1) and V_r0..V_r(N-2) are + // all in flight (N-1 of each); slot(tile t) = t % N. Loads + // issued past the split end harmlessly reload the last tile + // into slots that are never consumed (K/V_mem_load stop + // advancing their DRAM window at the end), so no bound + // guards are needed and the in-flight count -- hence the + // vmcnt threshold below -- stays constant for every stage. + static_for<2, kRingStages, 1>{}( + [&](auto m) { K_mem_load(number{}); }); + static_for<1, kRingStages - 1, 1>{}( + [&](auto m) { V_mem_load(number{}); }); + + // With the ring full, only the two oldest in-flight tiles + // (the K_r / V_(r-1) consumed this stage) must be resident; + // the other N-2 of each stay in flight -> deeper HBM overlap + // than the 2-buffer path's full vmcnt<0> drain. (N==2 would + // reduce to vmcnt<0>, i.e. the original; that case is handled + // by the branch above.) + constexpr index_t kRingVmcnt = + (kRingStages - 2) * (K_mem_su_ld_insts + V_mem_su_ld_insts); + + // Unroll by N so every LDS ring slot is a compile-time index. + // Stage s processes relative tile r = r_base + s, r_base ≡ 1 + // (mod N): read K_r from slot (s+1)%N and V_(r-1) from slot s; + // refill the two slots freed last stage (K_(r-1) slot s, + // V_(r-2) slot (s-1)%N) at the top -- different LDS buffers + // than the ones read below, so the async prefetch has no WAR + // and overlaps this stage's PV+QK. sp keeps its 2-way + // deferred-PV parity (N is even). + while(i_total_loops < num_total_loop) + { + static_for<0, kRingStages, 1>{}([&](auto s) { + constexpr index_t k_rd = (s + 1) % kRingStages; + constexpr index_t v_rd = s % kRingStages; + constexpr index_t k_pf = s % kRingStages; + constexpr index_t v_pf = (s + kRingStages - 1) % kRingStages; + constexpr index_t sp_pv = s % 2; + constexpr index_t sp_qk = (s + 1) % 2; + + if(i_total_loops < num_total_loop) + { + s_waitcnt_vmcnt(); + __builtin_amdgcn_s_barrier(); + + slide_page_table(); + K_mem_load(number{}); // K_(r+N-1) -> freed K slot + V_mem_load(number{}); // V_(r+N-2) -> freed V slot + + V_lds_load(number{}); // V_(r-1) + s_waitcnt_lgkmcnt<0>(); + fmha_alu1(number{}); // finalize sp(r-1) -> P + gemm(number{}, /*gemm_idx=*/number<1>{}); // PV + + K_lds_load(number{}); // K_r + s_waitcnt_lgkmcnt<0>(); + gemm(number{}, /*gemm_idx=*/number<0>{}); // QK -> sp(r) + fmha_mask(number{}); + fmha_alu0(number{}); + fmha_alu_D_upd(); + i_total_loops++; + + // Last tile: its PV would be deferred to a stage + // that never runs. Finalize it now -- V_r lives in + // the compile-time slot (s+1)%N (== k_rd). Drain + // first so V_r (left in flight by the staged wait) + // is resident; this also tells the shared epilogue + // not to re-run the (parity-indexed) post-process. + if(i_total_loops >= num_total_loop) + { + s_waitcnt_vmcnt<0>(); + __builtin_amdgcn_s_barrier(); + V_lds_load(number{}); // V_r + s_waitcnt_lgkmcnt<0>(); + fmha_alu1(number{}); + gemm(number{}, /*gemm_idx=*/number<1>{}); // PV + decode_ring_final_pv_done = true; + } + } + }); + } + } } else { @@ -3203,8 +3230,11 @@ struct UnifiedAttentionPipeline if(!(num_iters % 2)) fa4_epi(number<1>{}); } - else + else if(!decode_ring_final_pv_done) { + // (The N-deep decode ring already finalized the last tile's PV + // inline -- see decode_ring_final_pv_done -- because its V sits in + // an LDS ring slot, not the parity slot this post-process assumes.) if(num_iters % 2) { fmha_post_process(number<1>{}); diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp index f691b4dcb0..b61961ba3f 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp @@ -687,6 +687,7 @@ struct UnifiedAttentionPipelineDefaultPolicy // alone tile the full V buffer. Default = the shape's NumWarps (cooperative). template CK_TILE_DEVICE static constexpr auto MakeVLdsStoreBlockDescriptor(ck_tile::number = ck_tile::number<0>{}) @@ -728,7 +729,7 @@ struct UnifiedAttentionPipelineDefaultPolicy number{}, number{}, number<1>{}), - number<(IBuf + 2) * GetSingleSmemElementSpaceSize()>{}, + number<(IBuf + KBufCount) * GetSingleSmemElementSpaceSize()>{}, number{}, number<1>{}); @@ -855,6 +856,33 @@ struct UnifiedAttentionPipelineDefaultPolicy #endif static constexpr bool kFA4WG1LoadsK = UA_FA4_WG1_LOADS_K; + // Design A (decode deep async ring). Number of K/V LDS landing buffers for + // the single-warp-group (decode) path: raising it from 2 keeps N-1 KV-tile + // DRAM loads in flight so the loop can stage `vmcnt` partial waits instead of + // a full per-tile drain (the memory-bound long-context decode regime). Must + // be EVEN (the deferred-PV score double-buffer keeps 2-way parity, which is + // only compile-time resolvable across the N-unroll when N is even) and >= 2. + // Default 2 == the original 2-buffer serial pipeline, bit-identical. The + // 2-warp-group FA4 (prefill) path always uses 2 regardless. LDS cost is + // 2*N*GetSmemSizeKV, so larger N may cost occupancy on the LDS-bound decode + // tiers -- sweep per tier. +#ifndef UA_DECODE_STAGES +#define UA_DECODE_STAGES 2 +#endif + static constexpr ck_tile::index_t kDecodeStages = UA_DECODE_STAGES; + static_assert(kDecodeStages >= 2 && (kDecodeStages % 2 == 0), + "UA_DECODE_STAGES must be an even integer >= 2"); + + // Ring depth actually used by a given kernel instance: the deep ring is a + // decode-only (single-warp-group) lever; the FA4 prefill path keeps 2. + template + CK_TILE_DEVICE static constexpr ck_tile::index_t GetRingStages() + { + constexpr ck_tile::index_t NumWarpGroups = + Problem::kBlockSize / NumThreadPerWarpGroup; + return (NumWarpGroups == 1) ? kDecodeStages : 2; + } + // Number of waves that cooperate on a V DRAM->LDS load. For the 2-warp-group // FA4 path with kFA4WG0LoadsV, this is one warp group's waves (so WG0 alone // fills the tile); otherwise it's the full block (original cooperative load). @@ -898,7 +926,10 @@ struct UnifiedAttentionPipelineDefaultPolicy template CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - return 4 * GetSmemSizeKV(); + // kRingStages K buffers + kRingStages V buffers. Decode uses + // kDecodeStages (>=2), FA4 prefill stays 2 -> the default (N==2) + // reproduces the original 4*GetSmemSizeKV budget exactly. + return 2 * GetRingStages() * GetSmemSizeKV(); } };