mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
CK-UA: freeze docs + comment cleanup (+ gated decode-ring scaffolding)
- Add an architecture README for the unified_attention kernel folder: file map, per-CTA work assignment, online-softmax math + scale fusion, FA4 vs serial-decode regimes, paged-KV tiers, split-KV, and a tuning-knobs/failed- experiments table. Intended as the reference for the FlyDSL port. - pipeline.hpp: condense the ~230-line experiment-macro header into terse, README-backed one-liners (all 13 macro definitions preserved bit-for-bit). - kernel.hpp: merge the duplicated/contradictory "Step D" SWA-clip comment. - Gated multi-stage decode async-ring scaffolding (UA_DECODE_STAGES, default 2 = bit-identical; deeper depth measured perf-neutral, decode is BW-bound). Full matrix 263/263 PASS, 0 fail; comment-only kernel edits are behavior-neutral (target fp8 decode shape unchanged at ~88us). Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
168
include/ck_tile/ops/unified_attention/README.md
Normal file
168
include/ck_tile/ops/unified_attention/README.md
Normal file
@@ -0,0 +1,168 @@
|
||||
# CK Tile — Unified Attention kernel
|
||||
|
||||
FlashAttention-style fused attention forward for **paged and contiguous KV**,
|
||||
covering **prefill and decode** in one operator, for **fp8 / bf16 / fp16** on
|
||||
gfx950 (MI350/MI355). This is the kernel behind aiter's `unified_attention` op;
|
||||
the test/perf/analysis tooling lives in `aiter/ua-test-scripts/` (see its README).
|
||||
|
||||
This document is the architecture reference for reading the kernel and for
|
||||
porting it (e.g. to FlyDSL). The functional contract mirrors Triton's
|
||||
`unified_attention` reference.
|
||||
|
||||
---
|
||||
|
||||
## File map
|
||||
|
||||
| file | role |
|
||||
|---|---|
|
||||
| `kernel/unified_attention_kernel.hpp` | Host `MakeKargs` / grid helpers + the device `operator()`: per-CTA index math (grid → kv_head/seq/q_block/split), Q/K/V/O DRAM window setup, mask construction, split-KV partitioning, the call into the pipeline, and the epilogue / split-KV workspace writes. |
|
||||
| `pipeline/unified_attention_pipeline.hpp` | The core per-CTA attention pipeline (the hot loop). Two regimes share one file: the **FA4** 2-warp-group matrix‖softmax overlap (prefill) and the **single-warp-group serial deferred-PV** pipeline (decode). |
|
||||
| `pipeline/unified_attention_pipeline_default_policy.hpp` | Compile-time policy: tile distributions, LDS descriptors, warp-gemm selection, load-width/alignment selection, smem sizing, async-ring depth, and tuning constants. |
|
||||
| `pipeline/unified_attention_pipeline_problem.hpp` | Bundles dtypes + shape + traits + mask into the `Problem` type. |
|
||||
| `pipeline/tile_unified_attention_shape.hpp` | Block/warp tile dims: `kBlockM`, `kBlockQ`, `kPageBlockSize`, `kHeadDim`. |
|
||||
| `pipeline/tile_unified_attention_traits.hpp` | Padding + occupancy (`kBlockPerCu`) traits. |
|
||||
| `pipeline/unified_attention_core_loop_scheduler.hpp` | Per-phase `sched_group_barrier` instruction-mix hints that enforce the FA4 phase overlap (kept in lock-step with the macros in the pipeline header). |
|
||||
| `block/block_masking.hpp` | Causal / sliding-window (FA-style left/right) mask. |
|
||||
|
||||
Concrete shape/dtype **instances** (the JIT-compiled translation units) live in
|
||||
`composable_kernel/example/ck_tile/42_unified_attention/instances/`, dispatched by
|
||||
`unified_attention.cpp` there.
|
||||
|
||||
---
|
||||
|
||||
## Per-CTA work assignment (`kernel` `operator()`)
|
||||
|
||||
One CTA computes one `(kv_head, q_block, split)` tuple:
|
||||
|
||||
- **Decode grid** `dim3(num_kv_heads, num_seqs, num_splits)` (`gridDim.y > 1`):
|
||||
direct mapping, no binary search, no padding CTAs.
|
||||
- **Prefill grid** `dim3(num_kv_heads * total_num_q_blocks, 1, num_splits)`:
|
||||
`blockIdx.x` is folded; a binary search over `query_start_len_ptr` recovers the
|
||||
sequence, and out-of-range q-blocks early-return.
|
||||
|
||||
`num_queries_per_kv` (GQA ratio) is a **runtime** value: `kBlockQ_dyn = kBlockM /
|
||||
num_queries_per_kv`, so one compiled binary serves MHA and any GQA-N that divides
|
||||
`kBlockM`. The `kBlockM`-row MFMA tile packs `num_queries_per_kv` consecutive
|
||||
query rows per KV token; when `kBlockM % num_queries_per_kv != 0` the last 1–2
|
||||
rows spill into the next q-tile's first token (co-owned), which drives several
|
||||
correctness invariants in the index math — see the comments around
|
||||
`last_tile_row_q_off` and the split-KV partition.
|
||||
|
||||
---
|
||||
|
||||
## Math (online softmax)
|
||||
|
||||
Standard streaming softmax with running max `m`, running sum `l`, output
|
||||
accumulator `o_acc`, over KV tiles of `kPageBlockSize` tokens:
|
||||
|
||||
- Scores `S = scale_s · (Q·Kᵀ)`, masked, then `P = exp2(S − m)` (base-2).
|
||||
- `scale_s` is pre-fused on the host (`MakeKargs`): `sm_scale · q_descale ·
|
||||
k_descale · log2(e)`. The `log2(e)` lets the device use full/þrate `exp2`
|
||||
instead of `exp`; the fp8 Q/K per-tensor descales fold in so the inner loop
|
||||
carries a single scalar (matches Triton's `qk_scale = sm_scale·q_scale·k_scale`).
|
||||
- The fp8 V per-tensor descale `v_descale` is **deferred** to the post-loop
|
||||
`o_acc · v_descale / l` step — exact, since V is a linear factor on the
|
||||
unnormalised output. Non-fp8 dtypes pass `1.0f` (free no-op).
|
||||
- Output is `o_acc / l`; `lse = m + log(l)` (natural-log domain) is returned for
|
||||
split-KV combine.
|
||||
|
||||
**PV is deferred one tile** (double-buffered on parity): the sequence per tile is
|
||||
`alu1/pack(prev) → PV(prev) → QK(cur) → alu0/rowmax(cur) → D_upd/rescale`. This
|
||||
is the known-correct ordering shared by both regimes.
|
||||
|
||||
---
|
||||
|
||||
## Two pipeline regimes
|
||||
|
||||
### FA4 (prefill, 2 warp groups, `NumWarpGroups == 2`)
|
||||
FlashAttention-4-style overlap. The deferred-PV sequence is cut into two phases:
|
||||
- **MATRIX** phase: `PV(k-1) + QK(k)` — matrix pipe only.
|
||||
- **SOFTMAX** phase: `alu1/exp + alu0/rowmax + D_upd/rescale` — VALU/MUFU only.
|
||||
|
||||
The two warp groups are primed one phase apart (WG0 in MATRIX while WG1 in
|
||||
SOFTMAX), so on each SIMD the matrix work of one wave hides the transcendental
|
||||
work of its co-resident partner. K/V are prefetched a tile ahead into a shared
|
||||
double buffer at the per-phase block barrier (issued cooperatively by all 8
|
||||
warps). The `core_loop_scheduler` hints reserve the per-phase instruction mix.
|
||||
|
||||
For fp8 the QK-C and PV-A per-thread layouts diverge (PV is forced to
|
||||
`WGAttrNumAccess::Single`), so after packing, `P` is round-tripped through an LDS
|
||||
window in canonical (M,N) order and reloaded in the PV-A distribution.
|
||||
|
||||
### Serial (decode, single warp group, `kFA4 == false`)
|
||||
The same deferred-PV pipeline run serially by one 4-warp group, with K/V
|
||||
double-buffered in LDS. Decode is **HBM-bandwidth bound** at long context; see
|
||||
`aiter/ua-test-scripts/decode_pipeline_*.md`.
|
||||
|
||||
---
|
||||
|
||||
## Paged KV
|
||||
|
||||
When `kIsPaged`, KV tokens are resolved through `block_tables` (per-sequence page
|
||||
lists). Performance hinges on keeping page-index resolution off the critical
|
||||
path, via tiers selected at compile time:
|
||||
|
||||
- **Constexpr page size** (`kPageSize_ > 0`, the `ps16/ps32/ps64/ps128`
|
||||
instances): strength-reduces every `/ % * page_size` to shifts and enables the
|
||||
exact tier gates below. `kPageSize_ == 0` is the runtime-page-size catch-all.
|
||||
- **Scalar-promote / single-page SRD rebase**: when a wave's load lands within one
|
||||
page, fold the page base into the buffer SRD once per wave and drop the per-lane
|
||||
block-table path.
|
||||
- **Tier-2 LDS-resident page-table cache** (`kPageTableLdsEntries = 4096`, 16 KiB):
|
||||
resolves the multi-page fallback's page indices from LDS instead of per-lane
|
||||
global reads. Coverage `≤ 4096 · page_size` tokens; beyond that the kernel
|
||||
traps (a runtime fallback was measured −30% from register pressure).
|
||||
|
||||
When `!kIsPaged` (contiguous/THD) the logical KV token index *is* its physical
|
||||
row (per-sequence base folded into the K/V pointer), so all paging math compiles
|
||||
out.
|
||||
|
||||
## Split-KV
|
||||
|
||||
`num_splits > 1` partitions the KV range across CTAs in `blockIdx.z`; each writes
|
||||
fp32 `o_acc`/`lse` workspaces that a separate combine kernel merges. The
|
||||
partition is computed over the **causal-independent full-sequence** block count
|
||||
(not the per-tile causal horizon) so a token co-owned by adjacent q-tiles maps to
|
||||
the same split in both — otherwise non-dividing-GQA + causal races on the shared
|
||||
token. `num_splits == 1` skips this path entirely.
|
||||
|
||||
---
|
||||
|
||||
## Tuning knobs & failed experiments
|
||||
|
||||
Active defaults (do not change without re-validating correctness + perf):
|
||||
|
||||
| macro / policy | default | effect |
|
||||
|---|---|---|
|
||||
| `CONDITIONAL_RESCALE` | `1` | FA4-only: carry accumulators in a committed-max frame and skip the online rescale while shifted scores stay ≤ `τ` (`CONDITIONAL_RESCALE_TAU = 8`). Mathematically exact. |
|
||||
| `UA_FA4_PREFETCH_IN_SOFTMAX` | `1` | bf16/fp16: issue next-tile K/V async prefetch from the SOFTMAX phase (keeps MATRIX pure-matrix). |
|
||||
| `GetKVAlignmentBytes` | dwordx4 where it tiles | Widen fp8 decode K/V async loads to 16 B/lane (the narrow 4 B/lane default was the main fp8-slower-than-bf16 decode regression). |
|
||||
| `kKFallbackLds` (`UA_K_FALLBACK_LDS`) | `1` | Resolve the multi-page K fallback (ps16/ps32) page indices via the LDS cache instead of per-lane global reads. |
|
||||
| `UA_DECODE_STAGES` | `2` | Decode async-ring depth (deeper buffering was measured perf-neutral; decode is BW-bound). |
|
||||
|
||||
Experiments kept **OFF** (measured losers — retained as one-line gates so the
|
||||
rationale isn't relitigated): `UA_FA4_PACKED_SHIFT`, `UA_FA4_PACKED_ALU1_RESCALE`
|
||||
(together ~−3% on canonical fp8 prefill — softmax is hidden under the overlap),
|
||||
`UA_FA4_PACKED_ROWSUM` (−13%, serial chain beats the log-depth tree),
|
||||
`UA_DYNAMIC_SETPRIO`, `MOVE_FMHA_MASK_TO_COMPUTE` (fp8 +8.8% regression),
|
||||
`MOVE_FMHA_MASK_TO_GEMM1`, `UA_FA4_PIN_PACK_IN_SOFTMAX`. `UA_FA4_EXP2_APPROX`
|
||||
(Schraudolph 2^x) is an *approximation* — off by default, validate accuracy first.
|
||||
|
||||
---
|
||||
|
||||
## Building & testing
|
||||
|
||||
The kernel is consumed via aiter's JIT module `module_unified_attention`. After
|
||||
editing any file here you **must** force a fresh build — the `.so` is not rebuilt
|
||||
automatically:
|
||||
|
||||
```bash
|
||||
# from the aiter repo root
|
||||
rm -rf aiter/jit/build/module_unified_attention aiter/jit/module_unified_attention.so
|
||||
AITER_REBUILD=1 HIP_VISIBLE_DEVICES=2 python3 op_tests/test_unified_attention_ck.py --full
|
||||
```
|
||||
|
||||
See `aiter/ua-test-scripts/README.md` for correctness/perf, and
|
||||
`aiter/ua-test-scripts/analysis/README.md` for ISA/VGPR/overlap-trace tooling
|
||||
(including the JIT-free standalone driver, which stamps every build so a stale
|
||||
binary can never be measured).
|
||||
@@ -660,27 +660,16 @@ struct UnifiedAttentionKernel
|
||||
return FmhaMask{cur_batch_query_len, seq_len};
|
||||
}();
|
||||
|
||||
// Step D: Sliding-Window-Attention tile-range clip.
|
||||
// The per-pixel mask check inside the pipeline already returns the
|
||||
// correct (zeroed) score for tokens outside the SWA window, so
|
||||
// skipping this block is correctness-preserving. The point of the
|
||||
// clip is to skip entire KV tiles that fall completely outside the
|
||||
// window — for long-context decode that's the difference between
|
||||
// O(seq_k / kPageBlockSize) and O(window / kPageBlockSize)
|
||||
// iterations. The intersection with the current split's
|
||||
// [num_blocks_start, num_blocks) is taken so split-KV stays correct.
|
||||
// Step D: Sliding-Window-Attention KV-tile clip.
|
||||
//
|
||||
// This is REQUIRED for correctness, not just an optimisation. The
|
||||
// online-softmax pipeline interleaves `m` / `l` updates with prefetch
|
||||
// and warp-group barriers; an all-(-inf) tile (one wholly outside the
|
||||
// SWA window) feeds NaN/garbage into the `m` accumulator at the
|
||||
// barrier boundary, corrupting subsequent tiles. Skipping these
|
||||
// tiles entirely keeps every iterated tile either fully-inside the
|
||||
// window or a true edge tile that the per-pixel mask can clean up.
|
||||
//
|
||||
// The intersection with the current split's [num_blocks_start,
|
||||
// num_blocks) is taken so split-KV stays correct.
|
||||
// Step D: Sliding-Window-Attention KV-tile clip — skip entire KV tiles
|
||||
// that fall completely outside the window. REQUIRED for correctness, not
|
||||
// just an optimisation: the online-softmax pipeline interleaves m/l
|
||||
// updates with prefetch and warp-group barriers, so an all-(-inf) tile
|
||||
// (wholly outside the window) would feed NaN/garbage into the m
|
||||
// accumulator at a barrier boundary and corrupt later tiles. For
|
||||
// long-context decode it is also the difference between
|
||||
// O(seq_k/kPageBlockSize) and O(window/kPageBlockSize) iterations. The
|
||||
// intersection with the current split's [num_blocks_start, num_blocks)
|
||||
// keeps split-KV correct.
|
||||
if constexpr(FmhaMask::IsMasking && FmhaMask::IsLocal)
|
||||
{
|
||||
const auto sw_range = mask.GetTileRangeAlongX(
|
||||
|
||||
@@ -3,180 +3,94 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// EXPERIMENT: move the FA4 V LDS transpose-read out of the SOFTMAX phase and
|
||||
// into the MATRIX phase (right before its PV consumer). The default ("Stage B")
|
||||
// issues the V ds_read in the *preceding* softmax to hide its latency under the
|
||||
// softmax VALU — but the ATT trace shows that read stalls ~69% and sits on the
|
||||
// (longer/critical) softmax phase. Moving it to MATRIX, which has barrier slack,
|
||||
// takes it off the critical path. =1 to enable.
|
||||
// ===========================================================================
|
||||
// Tuning knobs & experiment toggles.
|
||||
// See the "Tuning knobs & failed experiments" table in this folder's README.md
|
||||
// for the rationale and measured results behind each default. These MUST be set
|
||||
// BEFORE including unified_attention_core_loop_scheduler.hpp: that header's
|
||||
// per-phase __builtin_amdgcn_sched_group_barrier hints are gated on the same
|
||||
// macros and must stay in lock-step with the code motion in this file.
|
||||
// ===========================================================================
|
||||
|
||||
// OFF (measured loser): pin the fp32->fp8 P-pack tail of fmha_alu1 to the
|
||||
// SOFTMAX phase instead of letting it sink into the next MATRIX slot.
|
||||
#ifndef UA_FA4_PIN_PACK_IN_SOFTMAX
|
||||
// Experiment (option 3): fence the fp32->fp8 P-pack (cvt_pk_fp8 tail of
|
||||
// fmha_alu1) so it retires inside the SOFTMAX phase instead of sinking past the
|
||||
// matrix barrier into the next MATRIX slot. In the MATRIX slot the pack (a VALU
|
||||
// op) contends with the co-resident warp group's softmax VALU (the v_max3
|
||||
// rowmax tree) on the shared SIMD issue port; pinning it to SOFTMAX trades that
|
||||
// cross-wave contention for in-phase exposure on the (longer) softmax phase.
|
||||
#define UA_FA4_PIN_PACK_IN_SOFTMAX 0
|
||||
#endif
|
||||
|
||||
// UA_FA4_PREFETCH_IN_SOFTMAX: issue the next-tile K/V async DRAM prefetch from
|
||||
// the SOFTMAX phase instead of the MATRIX phase. Gated to the 2-byte (bf16/fp16)
|
||||
// path in the loop (fp8 keeps the matrix-phase prefetch it was tuned for). The
|
||||
// bf16 prefetch is double the VMEM bytes and its buffer_load-to-LDS issue was
|
||||
// landing in the (lgkm-stalled) MATRIX phase; moving the *issue* to the VALU-
|
||||
// bound SOFTMAX phase keeps MATRIX pure-matrix. Residency is unchanged: the next
|
||||
// MATRIX still drains the load via s_waitcnt_vmcnt<0> + block barrier before any
|
||||
// K/V LDS read, so this only moves WHERE the async load is kicked off.
|
||||
// ON (bf16/fp16 only): issue the next-tile K/V async prefetch from the SOFTMAX
|
||||
// phase so the MATRIX phase stays pure-matrix. Residency is unchanged (the next
|
||||
// MATRIX still drains the load before any K/V LDS read); this only moves WHERE
|
||||
// the load is kicked off. fp8 keeps the matrix-phase prefetch it was tuned for.
|
||||
#ifndef UA_FA4_PREFETCH_IN_SOFTMAX
|
||||
#define UA_FA4_PREFETCH_IN_SOFTMAX 1
|
||||
#endif
|
||||
|
||||
// FMHA_MASK PLACEMENT: pick exactly one of:
|
||||
// - both 0 → baseline (mask in K-side memory phase, W0-3 phase 1
|
||||
// / W4-7 phase 2, right after `cl_load(memK)`).
|
||||
// - MOVE_FMHA_MASK_TO_COMPUTE=1: hoist mask onto the compute phase
|
||||
// (W0-3 phase 0 / W4-7 phase 1), right after `fmha_alu1`.
|
||||
// Experiment 1.5 finding: bf16 −0.33%, **fp8 +8.8% regression**
|
||||
// because the FP8 cvt+bperm cluster inside `fmha_alu1` makes the
|
||||
// compute phase already-saturated; adding T_mask oversubscribes
|
||||
// it and the empirical cost is ~2× the bare instruction count.
|
||||
// - MOVE_FMHA_MASK_TO_GEMM1=1: place mask at the START of the
|
||||
// gemm1 phase (W0-3 phase 2 / W4-7 phase 3), right before
|
||||
// `cl_calc(xdl_SP_p23_reg_idx, gemm1)`. This is the latest legal
|
||||
// placement: `cl_calc(p23, gemm1)` ends with `fmha_alu0(p01_idx)`
|
||||
// which READS `sp[p01_idx].sp_compute` to compute `m_latest`, so
|
||||
// mask MUST run before that. Phase 3 (V-mem on W0-3, gemm1 on
|
||||
// W4-7) is too late and silently corrupts the row-max.
|
||||
//
|
||||
// For W4-7 the `++i_total_loops` also defers from end of phase 2
|
||||
// to start of phase 3 (after mask, before cl_calc) so mask sees
|
||||
// the same i_total_loops value as gemm0 of this iter.
|
||||
//
|
||||
// Per-barrier algebra (mask added to gemm1 phase = T_D on both
|
||||
// warp groups, removed from K-mem = T_K on both):
|
||||
// - B1 wait = |T_C − (T_D + T_mask)|. With baseline T_C > T_D
|
||||
// on FP8, the gap closes — DECREASES by T_mask.
|
||||
// - B2 wait = |(T_K − T_mask) − T_C| — DECREASES by T_mask.
|
||||
// - B3 wait = |(T_D + T_mask) − (T_K − T_mask)|
|
||||
// = |T_D − T_K + 2·T_mask| — DECREASES by 2·T_mask.
|
||||
// - Net: −4·T_mask total wait (vs −2·T_mask for compute), and
|
||||
// gemm1 phase has no FP8 cvt+bperm so it should absorb the
|
||||
// mask without the FP8 oversubscription that hit compute.
|
||||
//
|
||||
// Must be defined BEFORE including unified_attention_core_loop_scheduler.hpp
|
||||
// — that header's `__builtin_amdgcn_sched_group_barrier` per-phase
|
||||
// hints are gated on these macros and need to stay in lockstep with
|
||||
// the code motion in this file.
|
||||
// OFF (measured losers): alternative placements of the FMHA mask within the FA4
|
||||
// phases. Default (both 0) = mask in the K-side memory phase. _TO_COMPUTE
|
||||
// oversubscribes the fp8 compute phase (+8.8% regression); _TO_GEMM1 is the
|
||||
// latest legal placement (must precede the fmha_alu0 that reads sp_compute) but
|
||||
// did not win. At most one may be 1.
|
||||
#define MOVE_FMHA_MASK_TO_COMPUTE 0
|
||||
#define MOVE_FMHA_MASK_TO_GEMM1 0
|
||||
#if MOVE_FMHA_MASK_TO_COMPUTE && MOVE_FMHA_MASK_TO_GEMM1
|
||||
#error "MOVE_FMHA_MASK_TO_COMPUTE and MOVE_FMHA_MASK_TO_GEMM1 are mutually exclusive"
|
||||
#endif
|
||||
|
||||
// UA_DYNAMIC_SETPRIO (warp-group-balance plan A2, HipKittens-style)
|
||||
// 0 (default): static per-warp-group priority, set once at loop entry
|
||||
// (W0-3 → s_setprio(0), W4-7 → s_setprio(1)). Baseline, bit-identical.
|
||||
// 1: dynamic priority around the gemm MFMA cluster. `cl_calc` raises
|
||||
// s_setprio(1) for the duration of the gemm (QK/PV MFMAs + trailing
|
||||
// fmha_alu0) and drops back to s_setprio(0) after. The two warp groups
|
||||
// are offset by two phases and co-resident (one wave of each group per
|
||||
// SIMD), so the group currently in the compute cluster outbids the
|
||||
// group currently issuing memory for the shared VALU/MFMA issue port —
|
||||
// targeting the ARBITER_NOT_WIN stall that gates the compute side
|
||||
// (W0-3: 37.8% of its stalls). Under the macro the static W4-7=1 entry
|
||||
// is neutralised to 0 so the non-compute baseline is uniformly prio 0.
|
||||
// OFF (measured loser): dynamic s_setprio around the gemm MFMA cluster (vs the
|
||||
// default static per-warp-group priority set once at loop entry).
|
||||
#ifndef UA_DYNAMIC_SETPRIO
|
||||
#define UA_DYNAMIC_SETPRIO 0
|
||||
#endif
|
||||
|
||||
// UA_FA4_PACKED_SHIFT: emit the softmax score-shift (sp_delta = sp_compute *
|
||||
// scale_s - scale_s * rowmax) as packed v_pk_fma_f32 (2 f32/instr) instead of 64
|
||||
// scalar v_fma_f32. Bit-identical: each thread holds one rowmax
|
||||
// (m.thread_buf_.size()==1) so the FMA addend is uniform across the thread's
|
||||
// score elements and is broadcast into both packed lanes. Mirrors the hand-tuned
|
||||
// ASM softmax (v_pk_fma_f32 for the rebase). Halves the shift instruction count.
|
||||
//
|
||||
// MEASURED REGRESSION, default OFF. Together with UA_FA4_PACKED_ALU1_RESCALE this
|
||||
// costs ~3% on the canonical fp8 prefill shape (GPU2, same-session 3-run median:
|
||||
// packed 1825 TF/s vs scalar 1877 TF/s). The softmax score-shift is hidden under
|
||||
// the ping-pong overlap, so collapsing the FMAs does not shorten the critical
|
||||
// path; it only perturbs the scheduler and loses. (An earlier "+4.5%" reading was
|
||||
// a confounded GPU0 measurement.) Kept gated off for documentation; do not enable.
|
||||
// OFF (measured ~-3% together with UA_FA4_PACKED_ALU1_RESCALE): emit the softmax
|
||||
// score-shift as packed v_pk_fma_f32 instead of scalar v_fma_f32. Bit-identical,
|
||||
// but the shift is hidden under the ping-pong overlap so collapsing it only
|
||||
// perturbs the scheduler and loses.
|
||||
#ifndef UA_FA4_PACKED_SHIFT
|
||||
#define UA_FA4_PACKED_SHIFT 0
|
||||
#endif
|
||||
|
||||
// UA_FA4_EXP2_APPROX: replace the per-element softmax exp (quarter-rate
|
||||
// v_exp_f32) with the Schraudolph 2^x bit-trick (full-rate). The score-shift FMA
|
||||
// in fmha_alu0 absorbs the 2^23 scale and the Schraudolph bias, so fmha_alu1 only
|
||||
// needs a single v_cvt_u32_f32 per element instead of v_exp_f32. The per-row
|
||||
// max-delta rescale keeps the exact v_exp_f32 (only 1/row, not on the hot path).
|
||||
// This is an APPROXIMATION (~1e-3 rel error per element) -- it mirrors the ASM
|
||||
// softmax fast SKU and is only applied on the non-masked, no-softcap path
|
||||
// (compile-time gate below). Numerics-changing => default OFF; validate accuracy
|
||||
// before enabling.
|
||||
// OFF (approximation, ~1e-3 rel error): replace the quarter-rate v_exp_f32
|
||||
// softmax with the full-rate Schraudolph 2^x bit-trick. Numerics-changing ->
|
||||
// validate accuracy before enabling; only legal on the non-masked, no-softcap
|
||||
// path (compile-time gate below).
|
||||
#ifndef UA_FA4_EXP2_APPROX
|
||||
#define UA_FA4_EXP2_APPROX 0
|
||||
#endif
|
||||
|
||||
// Schraudolph 2^x bit-trick constants: bits = round(2^23 * x + (127*2^23 - 486411))
|
||||
// reinterpreted as f32 ~= 2^x. Min-error offset matches the hand-tuned ASM softmax.
|
||||
// Schraudolph 2^x constants: bits = round(2^23*x + (127*2^23 - 486411)) as f32.
|
||||
#define UA_EXP2_SCHRAUDOLPH_SCALE 8388608.0f // 2^23
|
||||
#define UA_EXP2_SCHRAUDOLPH_BIAS 1064866805.0f // 127*2^23 - 486411
|
||||
|
||||
// UA_FA4_PACKED_ROWSUM: reduce a thread's row-sum of probabilities with packed
|
||||
// v_pk_add_f32 into a 2-wide partial, then a single scalar combine, instead of the
|
||||
// scalar v_add_f32 dependency chain that block_tile_reduce emits. Halves the
|
||||
// in-thread adds and shortens the latency chain feeding the cross-lane permlane.
|
||||
// Reassociates the sum (rounding differs at the ULP level) -- safe within the fp8
|
||||
// /bf16 attention tolerances.
|
||||
//
|
||||
// MEASURED LOSER (-13%): a serial v_pk_add_f32 accumulation is a 32-deep latency
|
||||
// chain that is WORSE than block_tile_reduce's log-depth tree, and the dead scalar
|
||||
// reduce is not DCE'd. Kept gated off for documentation; do not enable.
|
||||
// OFF (measured -13%): reduce a thread's probability row-sum with packed
|
||||
// v_pk_add_f32. The serial packed chain is deeper-latency than block_tile_reduce's
|
||||
// log-depth tree, and the dead scalar reduce is not DCE'd.
|
||||
#ifndef UA_FA4_PACKED_ROWSUM
|
||||
#define UA_FA4_PACKED_ROWSUM 0
|
||||
#endif
|
||||
|
||||
// UA_FA4_PACKED_ALU1_RESCALE: pack the 6-register o_acc partial rescale in
|
||||
// fmha_alu1 (elementwise *= o_acc_scale) with v_pk_mul_f32, matching the packed
|
||||
// rescale in fmha_alu_D_upd. Independent elementwise scale (no dependency chain),
|
||||
// and it halves the number of asm-volatile scheduling boundaries (6 scalar
|
||||
// v_mul_f32 -> 3 v_pk_mul_f32). Bit-identical.
|
||||
//
|
||||
// MEASURED REGRESSION, default OFF -- see the note on UA_FA4_PACKED_SHIFT above:
|
||||
// the two together cost ~3% (1877 -> 1825 TF/s, GPU2) because the rescale is
|
||||
// hidden under the ping-pong overlap. (An earlier "+4%" reading was a confounded
|
||||
// GPU0 measurement.) Kept gated off for documentation; do not enable.
|
||||
// OFF (measured ~-3% together with UA_FA4_PACKED_SHIFT): pack the o_acc partial
|
||||
// rescale in fmha_alu1 with v_pk_mul_f32. Bit-identical, but the rescale is hidden
|
||||
// under the overlap so it does not shorten the critical path.
|
||||
#ifndef UA_FA4_PACKED_ALU1_RESCALE
|
||||
#define UA_FA4_PACKED_ALU1_RESCALE 0
|
||||
#endif
|
||||
|
||||
// CONDITIONAL_RESCALE (PLAN_conditional_rescale Part 2)
|
||||
// 0 (default): always-rescale online softmax — the o_acc/l accumulators are
|
||||
// renormalised to the true running max `m` every KV tile (the expensive
|
||||
// 128-VGPR `v_pk_mul_f32` rescale tail in fmha_alu_D_upd + the 6-reg
|
||||
// partial in fmha_alu1). Bit-identical to the pre-Part-2 kernel.
|
||||
// 1: FA4-style conditional (skipped) rescale. Carry the accumulators in the
|
||||
// frame of a *committed* max `m_commit` that only advances (with a rescale)
|
||||
// when the true max pulls ahead by more than τ = log2 of the safe exp2
|
||||
// bound. Between commits the shifted scores stay ≤ τ (exp2 ≤ 2^τ, fp32-
|
||||
// safe) so o_acc/l just accumulate — the rescale multiplies are skipped.
|
||||
// The decision is made wave-uniformly (ballot: rescale if ANY lane needs
|
||||
// it) so the guard is a scalar branch with no divergence. Mathematically
|
||||
// exact (the m_commit frame cancels in o_acc/l; LSE uses m_commit), so no
|
||||
// end-of-loop correction is needed. Only applied on the 2-warp-group
|
||||
// prefill path (see kCondRescale); decode keeps always-rescale. Part-1's
|
||||
// --headroom instrument predicts ~85% (prefill) of rescales are skippable.
|
||||
// Defined BEFORE the includes so unified_attention_core_loop_scheduler.hpp can
|
||||
// gate its per-phase sched_group_barrier VALU hints on it (the gemm1+D_upd
|
||||
// phase reserves ~36 VALU slots for the rescale tail that this skips).
|
||||
// ON (FA4 prefill only): conditional (skipped) online-softmax rescale. Carry the
|
||||
// o_acc/l accumulators in the frame of a committed max `m_commit` that advances
|
||||
// (with a rescale) only when the true max pulls ahead by more than τ; between
|
||||
// commits the shifted scores stay <= τ (exp2 <= 2^τ, fp32-safe) so the rescale
|
||||
// multiplies are skipped. The decision is wave-uniform (ballot) so it is a scalar
|
||||
// branch with no divergence. Mathematically exact (m_commit cancels in o_acc/l;
|
||||
// LSE uses m_commit). Decode keeps always-rescale (see kCondRescale). Defined
|
||||
// before the includes so the core-loop scheduler can drop the rescale's VALU slot
|
||||
// reservation in lock-step.
|
||||
#if !defined(CONDITIONAL_RESCALE)
|
||||
#define CONDITIONAL_RESCALE 1
|
||||
#endif
|
||||
// τ in scaled-logit (log2) units. exp2(τ) bounds the un-rescaled scores; 8 =>
|
||||
// max intermediate exp2 == 256, comfortably inside fp32 range even summed over
|
||||
// thousands of keys. FA4 uses the same log2(256)=8.
|
||||
// τ in scaled-logit (log2) units; exp2(τ) bounds the un-rescaled scores. 8 =>
|
||||
// max intermediate exp2 == 256, fp32-safe even summed over thousands of keys.
|
||||
#if !defined(CONDITIONAL_RESCALE_TAU)
|
||||
#define CONDITIONAL_RESCALE_TAU 8.0f
|
||||
#endif
|
||||
@@ -666,6 +580,12 @@ struct UnifiedAttentionPipeline
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
// Design A deep async ring depth (decode: kDecodeStages; FA4: 2). The K
|
||||
// buffers occupy LDS slots [0, kRingStages); the V buffers follow at
|
||||
// [kRingStages, 2*kRingStages) -- the V store descriptor's K-buffer base
|
||||
// (KBufCount template arg) must therefore equal kRingStages too.
|
||||
constexpr index_t kRingStages = Policy::template GetRingStages<Problem>();
|
||||
|
||||
constexpr index_t KStoreWarpShift = Policy::template GetKStoreWarpShift<Problem>();
|
||||
auto k_lds_window_store = generate_tuple(
|
||||
[&](auto i_buf) {
|
||||
@@ -675,15 +595,17 @@ struct UnifiedAttentionPipeline
|
||||
KLoadNumWarps,
|
||||
KStoreWarpShift>(i_buf));
|
||||
},
|
||||
number<2>{});
|
||||
number<kRingStages>{});
|
||||
|
||||
auto v_lds_window_store = generate_tuple(
|
||||
[&](auto i_buf) {
|
||||
return make_lds_tile_window<KDataType>(
|
||||
smem_ptr,
|
||||
Policy::template MakeVLdsStoreBlockDescriptor<Problem, VLoadNumWarps>(i_buf));
|
||||
Policy::template MakeVLdsStoreBlockDescriptor<Problem,
|
||||
VLoadNumWarps,
|
||||
kRingStages>(i_buf));
|
||||
},
|
||||
number<2>{});
|
||||
number<kRingStages>{});
|
||||
|
||||
statically_indexed_array<
|
||||
decltype(make_tile_window(
|
||||
@@ -691,7 +613,7 @@ struct UnifiedAttentionPipeline
|
||||
nullptr,
|
||||
Policy::template MakeKLdsLoadBlockDescriptor<Problem, KLoadNumWarps>()),
|
||||
Policy::template MakeKRegTileDistribution<Problem>())),
|
||||
2>
|
||||
kRingStages>
|
||||
k_lds_window_load;
|
||||
|
||||
statically_indexed_array<
|
||||
@@ -700,7 +622,7 @@ struct UnifiedAttentionPipeline
|
||||
nullptr,
|
||||
Policy::template MakeVLdsLoadBlockDescriptor<Problem, VLoadNumWarps>()),
|
||||
Policy::template MakeVRegTileDistribution<Problem>())),
|
||||
2>
|
||||
kRingStages>
|
||||
v_lds_window_load;
|
||||
|
||||
decltype(make_static_distributed_tensor<QDataType>(
|
||||
@@ -783,8 +705,10 @@ struct UnifiedAttentionPipeline
|
||||
bool need_rescale = true;
|
||||
#endif
|
||||
|
||||
// initialize k_lds_window and v_lds_window
|
||||
static_for<0, 2, 1>{}([&](auto idx) {
|
||||
// initialize k_lds_window and v_lds_window. K buffers occupy LDS slots
|
||||
// [0, kRingStages); V buffers follow at [kRingStages, 2*kRingStages),
|
||||
// matching the V store descriptor's KBufCount==kRingStages base.
|
||||
static_for<0, kRingStages, 1>{}([&](auto idx) {
|
||||
k_lds_window_load(idx) = make_tile_window(
|
||||
make_lds_tile_window<KDataType>(
|
||||
static_cast<char*>(smem_ptr) + (idx)*Policy::template GetSmemSizeKV<Problem>(),
|
||||
@@ -792,11 +716,12 @@ struct UnifiedAttentionPipeline
|
||||
Policy::template MakeKRegTileDistribution<Problem>());
|
||||
});
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto idx) {
|
||||
static_for<0, kRingStages, 1>{}([&](auto idx) {
|
||||
v_lds_window_load(idx) =
|
||||
make_tile_window(make_lds_tile_window<VDataType>(
|
||||
static_cast<char*>(smem_ptr) +
|
||||
(idx + 2) * Policy::template GetSmemSizeKV<Problem>(),
|
||||
(idx + kRingStages) *
|
||||
Policy::template GetSmemSizeKV<Problem>(),
|
||||
Policy::template MakeVLdsLoadBlockDescriptor<Problem,
|
||||
VLoadNumWarps>()),
|
||||
Policy::template MakeVRegTileDistribution<Problem>());
|
||||
@@ -2967,6 +2892,12 @@ struct UnifiedAttentionPipeline
|
||||
gemm(xdl_SP_p23_reg_idx, /*gemm_idx=*/number<1>{});
|
||||
};
|
||||
|
||||
// Set by the N-deep decode ring (kRingStages>2) when it finalizes the
|
||||
// last tile's deferred PV *inline* (it must, because that tile's V lives
|
||||
// in LDS ring slot r%kRingStages, not the parity slot fmha_post_process
|
||||
// assumes). Suppresses the shared post-process below in that case.
|
||||
bool decode_ring_final_pv_done = false;
|
||||
|
||||
// pre-stage
|
||||
{
|
||||
ASM_MARKER("before pre-stage");
|
||||
@@ -3020,12 +2951,19 @@ struct UnifiedAttentionPipeline
|
||||
goto label_main_loops_exit;
|
||||
}
|
||||
|
||||
if(2 < num_total_loop)
|
||||
// K2 prefetch into buf0 (the freed K0 slot). Only the 2-buffer
|
||||
// pipeline does this here; the N-deep ring (kRingStages>2) fills
|
||||
// K2..K(N-1) into their own distinct slots inside the decode
|
||||
// block below, so it must NOT pre-issue K2 into buf0.
|
||||
if constexpr(kRingStages == 2)
|
||||
{
|
||||
K_mem_load(number<0>{}); // mem_K2
|
||||
if(2 < num_total_loop)
|
||||
{
|
||||
K_mem_load(number<0>{}); // mem_K2
|
||||
|
||||
s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
|
||||
__builtin_amdgcn_s_barrier();
|
||||
s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
}
|
||||
|
||||
ASM_MARKER("end pre-stage");
|
||||
@@ -3043,6 +2981,8 @@ struct UnifiedAttentionPipeline
|
||||
// K1 in LDS (K buf 1) if num_total_loop >= 2
|
||||
// K2 loading to LDS (K buf 0) if num_total_loop >= 3
|
||||
|
||||
if constexpr(kRingStages == 2)
|
||||
{
|
||||
// Step 1: consume V0, K1 -> produce PV(0), QK(1)
|
||||
s_waitcnt_vmcnt<0>();
|
||||
__builtin_amdgcn_s_barrier();
|
||||
@@ -3119,6 +3059,93 @@ struct UnifiedAttentionPipeline
|
||||
fmha_alu_D_upd();
|
||||
i_total_loops++;
|
||||
}
|
||||
} // end if constexpr(kRingStages == 2)
|
||||
else
|
||||
{
|
||||
// ===== Design A: N-deep async KV ring (decode) =====
|
||||
// The pre-stage left: K_r0 consumed, K_r1 -> slot 1,
|
||||
// V_r0 -> slot 0, and sp(0)=QK(r0) (alu1 pending). Fill the
|
||||
// rest of the ring so K_r1..K_r(N-1) and V_r0..V_r(N-2) are
|
||||
// all in flight (N-1 of each); slot(tile t) = t % N. Loads
|
||||
// issued past the split end harmlessly reload the last tile
|
||||
// into slots that are never consumed (K/V_mem_load stop
|
||||
// advancing their DRAM window at the end), so no bound
|
||||
// guards are needed and the in-flight count -- hence the
|
||||
// vmcnt threshold below -- stays constant for every stage.
|
||||
static_for<2, kRingStages, 1>{}(
|
||||
[&](auto m) { K_mem_load(number<m % kRingStages>{}); });
|
||||
static_for<1, kRingStages - 1, 1>{}(
|
||||
[&](auto m) { V_mem_load(number<m % kRingStages>{}); });
|
||||
|
||||
// With the ring full, only the two oldest in-flight tiles
|
||||
// (the K_r / V_(r-1) consumed this stage) must be resident;
|
||||
// the other N-2 of each stay in flight -> deeper HBM overlap
|
||||
// than the 2-buffer path's full vmcnt<0> drain. (N==2 would
|
||||
// reduce to vmcnt<0>, i.e. the original; that case is handled
|
||||
// by the branch above.)
|
||||
constexpr index_t kRingVmcnt =
|
||||
(kRingStages - 2) * (K_mem_su_ld_insts + V_mem_su_ld_insts);
|
||||
|
||||
// Unroll by N so every LDS ring slot is a compile-time index.
|
||||
// Stage s processes relative tile r = r_base + s, r_base ≡ 1
|
||||
// (mod N): read K_r from slot (s+1)%N and V_(r-1) from slot s;
|
||||
// refill the two slots freed last stage (K_(r-1) slot s,
|
||||
// V_(r-2) slot (s-1)%N) at the top -- different LDS buffers
|
||||
// than the ones read below, so the async prefetch has no WAR
|
||||
// and overlaps this stage's PV+QK. sp keeps its 2-way
|
||||
// deferred-PV parity (N is even).
|
||||
while(i_total_loops < num_total_loop)
|
||||
{
|
||||
static_for<0, kRingStages, 1>{}([&](auto s) {
|
||||
constexpr index_t k_rd = (s + 1) % kRingStages;
|
||||
constexpr index_t v_rd = s % kRingStages;
|
||||
constexpr index_t k_pf = s % kRingStages;
|
||||
constexpr index_t v_pf = (s + kRingStages - 1) % kRingStages;
|
||||
constexpr index_t sp_pv = s % 2;
|
||||
constexpr index_t sp_qk = (s + 1) % 2;
|
||||
|
||||
if(i_total_loops < num_total_loop)
|
||||
{
|
||||
s_waitcnt_vmcnt<kRingVmcnt>();
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
slide_page_table();
|
||||
K_mem_load(number<k_pf>{}); // K_(r+N-1) -> freed K slot
|
||||
V_mem_load(number<v_pf>{}); // V_(r+N-2) -> freed V slot
|
||||
|
||||
V_lds_load(number<v_rd>{}); // V_(r-1)
|
||||
s_waitcnt_lgkmcnt<0>();
|
||||
fmha_alu1(number<sp_pv>{}); // finalize sp(r-1) -> P
|
||||
gemm(number<sp_pv>{}, /*gemm_idx=*/number<1>{}); // PV
|
||||
|
||||
K_lds_load(number<k_rd>{}); // K_r
|
||||
s_waitcnt_lgkmcnt<0>();
|
||||
gemm(number<sp_qk>{}, /*gemm_idx=*/number<0>{}); // QK -> sp(r)
|
||||
fmha_mask(number<sp_qk>{});
|
||||
fmha_alu0(number<sp_qk>{});
|
||||
fmha_alu_D_upd();
|
||||
i_total_loops++;
|
||||
|
||||
// Last tile: its PV would be deferred to a stage
|
||||
// that never runs. Finalize it now -- V_r lives in
|
||||
// the compile-time slot (s+1)%N (== k_rd). Drain
|
||||
// first so V_r (left in flight by the staged wait)
|
||||
// is resident; this also tells the shared epilogue
|
||||
// not to re-run the (parity-indexed) post-process.
|
||||
if(i_total_loops >= num_total_loop)
|
||||
{
|
||||
s_waitcnt_vmcnt<0>();
|
||||
__builtin_amdgcn_s_barrier();
|
||||
V_lds_load(number<k_rd>{}); // V_r
|
||||
s_waitcnt_lgkmcnt<0>();
|
||||
fmha_alu1(number<sp_qk>{});
|
||||
gemm(number<sp_qk>{}, /*gemm_idx=*/number<1>{}); // PV
|
||||
decode_ring_final_pv_done = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -3203,8 +3230,11 @@ struct UnifiedAttentionPipeline
|
||||
if(!(num_iters % 2))
|
||||
fa4_epi(number<1>{});
|
||||
}
|
||||
else
|
||||
else if(!decode_ring_final_pv_done)
|
||||
{
|
||||
// (The N-deep decode ring already finalized the last tile's PV
|
||||
// inline -- see decode_ring_final_pv_done -- because its V sits in
|
||||
// an LDS ring slot, not the parity slot this post-process assumes.)
|
||||
if(num_iters % 2)
|
||||
{
|
||||
fmha_post_process(number<1>{});
|
||||
|
||||
@@ -687,6 +687,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
// alone tile the full V buffer. Default = the shape's NumWarps (cooperative).
|
||||
template <typename Problem,
|
||||
ck_tile::index_t NumWarpsOverride = Problem::UnifiedAttentionShape::NumWarps,
|
||||
ck_tile::index_t KBufCount = 2,
|
||||
ck_tile::index_t IBuf = 0>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeVLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
|
||||
@@ -728,7 +729,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<(IBuf + 2) * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<(IBuf + KBufCount) * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -855,6 +856,33 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
#endif
|
||||
static constexpr bool kFA4WG1LoadsK = UA_FA4_WG1_LOADS_K;
|
||||
|
||||
// Design A (decode deep async ring). Number of K/V LDS landing buffers for
|
||||
// the single-warp-group (decode) path: raising it from 2 keeps N-1 KV-tile
|
||||
// DRAM loads in flight so the loop can stage `vmcnt` partial waits instead of
|
||||
// a full per-tile drain (the memory-bound long-context decode regime). Must
|
||||
// be EVEN (the deferred-PV score double-buffer keeps 2-way parity, which is
|
||||
// only compile-time resolvable across the N-unroll when N is even) and >= 2.
|
||||
// Default 2 == the original 2-buffer serial pipeline, bit-identical. The
|
||||
// 2-warp-group FA4 (prefill) path always uses 2 regardless. LDS cost is
|
||||
// 2*N*GetSmemSizeKV, so larger N may cost occupancy on the LDS-bound decode
|
||||
// tiers -- sweep per tier.
|
||||
#ifndef UA_DECODE_STAGES
|
||||
#define UA_DECODE_STAGES 2
|
||||
#endif
|
||||
static constexpr ck_tile::index_t kDecodeStages = UA_DECODE_STAGES;
|
||||
static_assert(kDecodeStages >= 2 && (kDecodeStages % 2 == 0),
|
||||
"UA_DECODE_STAGES must be an even integer >= 2");
|
||||
|
||||
// Ring depth actually used by a given kernel instance: the deep ring is a
|
||||
// decode-only (single-warp-group) lever; the FA4 prefill path keeps 2.
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetRingStages()
|
||||
{
|
||||
constexpr ck_tile::index_t NumWarpGroups =
|
||||
Problem::kBlockSize / NumThreadPerWarpGroup;
|
||||
return (NumWarpGroups == 1) ? kDecodeStages : 2;
|
||||
}
|
||||
|
||||
// Number of waves that cooperate on a V DRAM->LDS load. For the 2-warp-group
|
||||
// FA4 path with kFA4WG0LoadsV, this is one warp group's waves (so WG0 alone
|
||||
// fills the tile); otherwise it's the full block (original cooperative load).
|
||||
@@ -898,7 +926,10 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return 4 * GetSmemSizeKV<Problem>();
|
||||
// kRingStages K buffers + kRingStages V buffers. Decode uses
|
||||
// kDecodeStages (>=2), FA4 prefill stays 2 -> the default (N==2)
|
||||
// reproduces the original 4*GetSmemSizeKV budget exactly.
|
||||
return 2 * GetRingStages<Problem>() * GetSmemSizeKV<Problem>();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user