Softmax codegen wins for the canonical fp8 prefill shape (b1, sq=sk=75600,
hq=hk=5, d128, non-causal), matching the hand-tuned ASM softmax instruction mix:
* Packed score shift (UA_FA4_PACKED_SHIFT, default on): each thread holds one
rowmax, so the shift addend (-scale_s*max) is uniform across the thread's score
elements. Broadcast it and emit v_pk_fma_f32 (2 f32/instr) instead of 64 scalar
v_fma_f32 -> 237 v_pk_fma in the ISA. Bit-identical. +4.5%.
* Packed alu1 o_acc partial rescale (UA_FA4_PACKED_ALU1_RESCALE, default on):
pack the 6-register *= o_acc_scale with v_pk_mul_f32, halving the asm-volatile
scheduling boundaries. Bit-identical. +4% (barrier-wait drops 2504->1544 cyc).
Combined +8% (1649 -> 1783 TF/s standalone; ~1774 TF/s contiguous prod) over the
prior baseline; full regression matrix (bf16/fp8 prefill/decode/splitkv/long/ps16
+ fixtures) PASS.
Also consolidates the fastest fp8 prefill config as the compile-time default:
kv128 tile (UA_PREFILL_D128_BLOCKSIZE=128) + cooperative K/V load + single-sp +
wide 32x32x64 MMA, all 0-spill.
Gated-off experiments kept with measured verdicts: UA_FA4_EXP2_APPROX (Schraudolph
2^x, null -- exp is hidden by MFMA overlap) and UA_FA4_PACKED_ROWSUM (-13%, serial
chain worse than the tree reduce).
Co-authored-by: Cursor <cursoragent@cursor.com>
- Add strategy C (cvt-only, barrier-free) QK-C->PV-A FP8 relayout for the
K=64 wide v_mfma_f32_32x32x64 tile: QK-C and PV-A per-thread layouts
coincide under the wide MMA, so the relayout is just the fp32->fp8 pack
(matches the ASM kernel's _softmax_pack_P_fp8). Gate kFP8RelayoutWithinWave
for K=64 in addition to K=16; both are FA4-safe (no in-softmax barrier).
- Wire the wide-MMA variant config (example) + relayout default policy.
- Move the FA4 V LDS transpose-read out of the preceding SOFTMAX into the
MATRIX phase, off the longer/critical softmax path (UA_FA4_VLOAD_IN_MATRIX=1).
- Add UA_FA4_PIN_PACK_IN_SOFTMAX experiment toggle (default 0).
Measured: wide MMA closed the structural gap vs the ASM fp8 kernel from
~1.75x to ~1.16x at b1/h5/sq75600/d128 (1711 TF standalone).
Co-authored-by: Cursor <cursoragent@cursor.com>
Per-TU -DUA_STUB_INSTANCE emits a trivial host stub (no device kernel)
so a standalone trace build can compile exactly one UA instance and keep
its code object small -- avoids rocprofv3 ATT disassembling the full
instance set. Real launch vs stub selected via UA_KERNEL_DISPATCH_RESULT.
Co-authored-by: Cursor <cursoragent@cursor.com>
Checkpoint before merging upstream CK. Adds the kIsPaged=false kernel
instances (d64/d128 bf16/fp8, mask/nmask), folds kv_start into base
pointers, and fixes the non-causal KV-range envelope (scan full seq_len
when !FmhaMask::IsMasking instead of the causal horizon).
Co-authored-by: Cursor <cursoragent@cursor.com>
Under split-KV, a KV token co-owned by two query tiles (which happens only
when num_queries_per_kv does not divide kBlockM, e.g. d=128 qpkv=6) was
assigned its split partition from the per-tile causal horizon
(total_num_kv_blocks, which grows with the query tile index). The two owning
tiles then reduced disjoint KV-block ranges for that shared token and the
combine step merged partials over different ranges -> a ~1-row error / NaN on
the tile-boundary token. MHA and ratios that divide kBlockM are immune (no
token is shared across tiles).
Fix: derive blocks_per_split from the causal-INDEPENDENT full-sequence block
count so split s maps to the same blocks in every query tile, then clamp only
the END by the per-tile causal horizon. The duplicate co-owned store becomes
idempotent again. num_splits == 1 is unchanged.
Also adds the d128 bf16 page_size=128 decode instances (mask/nmask x
default/s/t) plus the matching dispatch in unified_attention.cpp and the
fmha_batch_prefill codegen hook.
Co-authored-by: Cursor <cursoragent@cursor.com>
Pipeline cleanup (-fav4):
* Delete the 8-wave compute/memory ping-pong baseline (the ~200-line
monolithic `core_loop` lambda + its 2-warp-group dispatch). It was
reachable only under -DUA_FA4_PIPELINE=0 and never beat FA4 on any
measured prefill shape, so it was dead under the default build.
* Drop the UA_FA4_PIPELINE toggle entirely. kFA4 is now derived purely
from NumWarpGroups==2 + the 32x32x16 within-wave FP8 P-relayout
invariant, with a static_assert pinning that every 2-WG instance is
FA4-capable (fails the build loudly instead of running an empty loop).
* Remove the now-orphaned ADD_SBARRIER_FOR_PHASE0/PHASE2 knobs (they
only gated barriers inside the deleted core_loop). MOVE_FMHA_MASK_*
stay (still consumed by the FA4 core-loop scheduler).
* The non-FA4 pre-stage + fmha_post_process epilogue are retained: they
are shared by the single-warp-group (NumWarpGroups==1) serial decode
path, where kFA4 is false.
Behaviour-preserving for the default build: FA4 prefill perf is bit-for-
bit unchanged (b16 sq=sk=10000 fp8 CK=5.76ms before/after) and the full
decode regression (d{64,128} x {bf16,fp8} x split-KV {2,64}) still PASSes.
Add opt-in prefill fallback knob (unified_attention.cpp):
* AITER_UA_PREFILL_FALLBACK=1 routes prefill-sized shapes to the 4-warp
single-warp-group *serial* decode_*_m128 instances instead of FA4.
Reuses already-compiled instances (no extra binary). OFF by default:
the serial path has no matrix/softmax overlap and measured ~0.66-0.70x
Triton vs FA4's ~0.73-0.80x on gfx950 fp8 GQA-12/2 (i.e. SLOWER than
FA4). Kept as a diagnostic / robustness A-B knob only.
Co-authored-by: Cursor <cursoragent@cursor.com>
Eight files outside the UA scope had drifted onto this branch over time
via earlier commits whose subject lines explicitly carried no "CK-UA:"
prefix — they are independent fmha bug fixes and codegen additions that
do not touch any code path the unified_attention example, kernel or
pipeline actually compiles or includes.
Reverted to the merge-base content of:
include/ck_tile/ops/fmha/block/block_masking.hpp (-39 lines)
added by 6729989b9 "Fix FMHA split-KV for paged-KV with
page_block_size < kN0"
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp (-2 lines)
added by ec2db01e4 "Fix fmha_fwd early-exit bug: seqlen_q <=
min_seqlen_q should be <"
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py
example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/fmha_fwd_runner.hpp
example/ck_tile/01_fmha/mask.hpp
added by 63821af1f / cb6fb2802 / c5600bc8a / e5272603c / 10564b0c4
/ cd7ba6e2e / 07ba03bcb — split-KV decode tiles, codegen tweaks,
and a sliding-window mask fix, all in the 01_fmha example program
(a separate build target; the UA example lives in
42_unified_attention and pulls in zero 01_fmha sources).
These commits are still reachable from the branch's reflog and from
their original commit hashes; they should each be cherry-picked onto
their own branches and sent upstream as standalone fmha bug-fix PRs —
they look like clean fixes that upstream would welcome, but they don't
belong in the UA PR's scope.
Verified empirically: clean JIT rebuild of module_unified_attention
followed by both regression shapes pass at full perf
b=128/sk=16384/d=128/bf16 : 1.5152 ms, 5672 GB/s, PASS
b=1/sk=1M/d=128/bf16 nb=70k : 0.7677 ms, 5594 GB/s, PASS
matching the pre-revert numbers to within run-to-run noise.
Branch's shared-CK touch surface after this revert: tile_scatter_gather.hpp
(+152 from our async_load_raw_long method), load_tile.hpp (+21 from the
sister dispatcher), warp_gemm[.|_dispatcher.]hpp (+13 for the FP8 e4m3
small-tile registration), and the new amd_global_load_lds_raw.hpp file.
Down from 14 shared files to 4.
Co-authored-by: Cursor <cursoragent@cursor.com>
The decode_d128_m16 tier was VGPR-saturated and LDS-bound on bf16/fp16
(probe_decode_d128 showed VGPR=256 + AGPR overflow, ~2x fp8's LDS at
the same kBlockN), capping it at 1 CTA/CU. Halving kBlockN for the
non-fp8 path on the m16 tier sheds enough LDS and VGPR pressure to
fit 3-4 CTAs/CU (LDS-bound). The halved kBlockN forces a smaller-K
MFMA shape on the m16 PV gemm (16x16x32 -> 16x16x16); we also auto-
adjust WarpGemm::K so PVAttrNumAccess picks Single vs Double access
correctly. The PVAttrNumAccess derivation is now generic — driven by
(kABKPerLane, SubMinDim) rather than just (dtype) — so the new
shape compiles without per-variant special-casing.
Variants only affected where cfg::BlockSize/2 >= WarpGemm::N (i.e.
decode_d128_m16); m32/m128/prefill keep their un-halved tiles since
they use 32x32 N-warps.
Co-authored-by: Cursor <cursoragent@cursor.com>
Promote the runtime `page_size` argument to a non-type template parameter
`kPageSize_` on UnifiedAttentionPipeline. Thread it through
unified_attention_kernel_traits and dispatch_variant<V> so the host-side
dispatcher routes on args.page_blk_size ∈ {16, 32, 64} to a constexpr-
pinned prefill instance; values outside that menu (or any decode variant)
fall back to the existing kPageSize_=0 runtime-page-size instance.
Two wins fold together on the prefill tiers:
1. Strength-reduction. Every `/ page_size`, `* page_size`, and `% page_size`
in the per-tile address chain collapses to a literal-folded shift /
multiply-by-magic (`/ 32` → shr 5, etc).
2. Wider Tier-0/Tier-2 gate. The scalar-promote + LDS-cache fast path now
uses the *real* precondition `KY0_step_N <= kPageSize` at compile time
instead of the conservative `KY0_step_N <= 16` hedge — so prefill_d128
bf16/fp16 (KY0_step_N=32), prefill_d64 fp8 (KY0_step_N=32), and
prefill_d64 bf16/fp16 (KY0_step_N=64) also enter the fast path at
their natural page sizes.
Measured impact (sq=sk=75600, MI355, n=30 iters, GQA-8):
variant KY0_step_N ps before after Δ
prefill_d128 fp8 16 32 119.0 111.5 -6.3 %
prefill_d128 bf16 32 32 132.7 130.3 -1.8 %
prefill_d64 fp8 32 32 80.9 68.1 -15.8 %
prefill_d64 bf16 64 64 74.4 73.4 -1.3 %
Decode variants stay on the kPageSize_=0 instances (Tier-0 gate gates them
out anyway — <8 warps — and the binary-size cost isn't justified). All
sweep_fp8.sh shapes + 21 multi-seed multi-sk-length prefill shapes
correctness-PASS. Pre-existing Tier-2 LDS-cache limit (4096 entries)
documented in the pipeline header — same constraint applies to the
kPageSize_=0 fallback so this is not a regression.
36 new prefill instance files: prefill_d{64,128} × {fp16, bf16, fp8} ×
{mask, nmask} × {ps16, ps32, ps64}.
Co-authored-by: Cursor <cursoragent@cursor.com>
The 32x32x16 tiers (prefill_d{64,128}, decode_d{64,128}_m{32,64,128}) keep
the cheap in-register `ds_bpermute_b32` cross-lane swap that fixes the
QK-C / PV-A per-thread alias for the union'd `sp_compute` / `p`.
The 16x16x32 m16 tiers (decode_d{64,128}_m16) cannot use the swap -- the
MFMA puts the paired-lane bit at a different position and the
sub=0/sub=1 4-fp8 chunks no longer map onto each other. We add a
layout-agnostic LDS roundtrip as the `else` branch, gated by the same
`PVWarpTile` constexpr:
- Hoist two distribution-bound windows over the existing `p_lds`
region (one bound to the QK-C output distribution, one to the PV-A
input distribution). Done once per kernel invocation.
- In `fmha_alu1`, after the cvt_pk_fp8_f32 packing chain, view the
union's bytes as a `static_distributed_tensor<fp8>` in the QK-C
distribution, `store_tile` it through `p_lds` in canonical (M, N)
order, `s_barrier`, then `load_tile` back with the PV-A
distribution and copy into `sp(idx).p`.
A/B'd a uniform LDS-roundtrip (no fast-path) vs the split: pure LDS
regressed decode_m128 by ~1.5x end-to-end (CK FP8 dropped from
~0.39x of Triton FP8 to ~0.16x), driven by the extra block-wide
barrier on the 4-warp decode path. Keeping the swap for 32x32x16
preserves the previously-tuned perf.
Dispatcher (`unified_attention.cpp`) now FP8-enables every UA variant
including decode_d{64,128}_m16. Four new instance .cpp files
(`unified_attention_d{64,128}_fp8_{mask,nmask}_decode_t.cpp`)
instantiate the m16 FP8 kernels.
Pytest (`test_unified_attention_ck_correctness.py`):
- 245 BF16/FP16: pass (no regression from the pipeline edit).
- 160 FP8: pass (was 112 before m16 enablement).
- 80 skipped: block_size<32 or query_len>kv_len -- pre-existing.
Single-shape m16 dispatches verified on gfx950:
b=128 sq=1 hq=hk=8 d=128 fp8 PASS (CK 0.109 ms / Triton 0.043 ms)
b=128 sq=1 hq=hk=8 d=64 fp8 PASS (CK 0.077 ms / Triton 0.039 ms)
Co-authored-by: Cursor <cursoragent@cursor.com>
Full pipeline support for FP8 (e4m3fn on gfx950 / e4m3fnuz on gfx942)
in the unified-attention kernel, gated to the 32x32x16 MFMA tiers in
both d=64 and d=128 ladders: prefill_d{64,128}, decode_d{64,128}_m128,
decode_d128_m32, and decode_d64_m64. The 16x16x32 _m16 tiers stay
BF16/FP16-only -- the QK-C and PV-A per-thread layouts there differ
by an M<->N swap that the current slot-swap fixup cannot express; a
full per-thread transpose (most likely via LDS) is needed.
Pipeline (unified_attention_pipeline.hpp):
* `fmha_alu1` now performs a cross-lane P-tile fixup right after the
FP8 packing of softmax(P). It's a `ds_bpermute_b32` between paired
lanes `lane ^ 32`, swapping sub=0 slot[k_base+4..k_base+7] with
sub=1 slot[k_base..k_base+3] for every 8-fp8 chunk. This realigns
the FP8 packed P operand with PV-A's `Single` AttrNumAccess
per-thread layout, which is necessary because the QK-C output and
PV-A input alias byte-for-byte via the sp_compute/p union -- and
for FP8 the two warp-gemm layouts no longer agree (BF16/FP16 keep
Double AttrNumAccess in the PV gemm, which matches QK-C natively).
Gated on `Gemm1WarpTile == 32x32x16`; FP8-only (BF16/FP16 paths take
the existing cvt_pk path unchanged).
Default policy (unified_attention_pipeline_default_policy.hpp):
* PV warp gemm now selects `WGAttrNumAccessEnum::Single` when V is
fp8/bf8 and `Double` otherwise. Forced by load_tile_transpose's
SubMinDim = 64-bit / sizeof(V) constraint: for FP8 SubMinDim=8 and
kABKPerLane=8 only Single satisfies the validation static_asserts.
* GetAlignmentK / GetAlignmentV on gfx950 drop to 4 B/lane for fp8/
bf8. The natural 16 B/lane async-load that BF16/FP16 use leaves
NumIssues = 0 for the FP8 tile shapes we compile, and 8 B/lane
fails the dword / dwordx3 / dwordx4 constraint in
amd_buffer_addressing_builtins. 4 B/lane gives NumIssues >= 1 on
every targeted variant and is the same alignment the gfx942
fallback already used. BF16/FP16 keep the full 16 B/lane path so
existing perf is unchanged.
* GetSmemSizeKV adds a `VLoadDescSize` lower bound. The
MakeVLdsLoadBlockDescriptor's element span dominates the banked
SingleVSize only for FP8 (small per-lane KVector + fixed
kVLdsPadInBytes = 64), so without it FP8 hits the GetSmemSizeKV
static_asserts. BF16/FP16 are unaffected.
Warp-gemm headers + dispatcher:
* New `WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed_T<AttrNumAccess>`
template alias in warp_gemm.hpp (mirrors the existing BF16 32x32x16
CTransposed template), used by the PV gemm to thread the FP8
Single AttrNumAccess through.
* New Dispatcher specialization for
<fp8_t, fp8_t, float, 32, 32, 16, true, false, false, EDouble>
in warp_gemm_dispatcher.hpp routing to the new template.
ABI / dispatcher (unified_attention.{cpp,hpp}, unified_attention_impl.hpp):
* New `fp8` value in `unified_attention_args::data_type_enum` (selects
e4m3fn on gfx950 via CK_TILE_USE_OCP_FP8, e4m3fnuz elsewhere).
* New `unified_attention_problem_traits<...::fp8>` alias:
qkvp_dtype = ck_tile::fp8_t, acc_dtype = float, o_dtype = bf16_t
(matches the Triton reference), lse_dtype = float.
* Per-tensor `q_descale` / `k_descale` / `v_descale` floats on
`unified_attention_args` (default 1.0f so non-FP8 round-trips
cleanly). The pipeline folds q_descale*k_descale into the softmax
scale and applies v_descale once to o_acc after the 1/l norm --
same semantics as Triton's q_scale/k_scale/v_scale.
* `dispatch_variant<>` enables FP8 on prefill_d{64,128},
decode_d{64,128}_m128, decode_d128_m32, decode_d64_m64. The
16x16x32 _m16 tiers return (false, -1.f) for now (see top comment).
Instances:
* 12 new FP8 .cpp files under example/.../42_unified_attention/
instances/ covering the 6 enabled variants x {mask, nmask}.
Validation: 112 / 0 / 128 in the FP8 pytest sweep (passed / failed /
m16-skipped); 245 / 245 in the BF16/FP16 sweep (no regression).
Functional correctness is within the FP8 quant-noise tolerance the
Triton FP8 suite uses (atol/rtol = 1.5e-1). Perf still trails Triton
across the enabled tiers (CK FP8 / Triton FP8 = 0.39-0.69x on the
shapes we benchmarked); that's a separate workstream.
Co-authored-by: Cursor <cursoragent@cursor.com>
The shared-SRD buffer_load_dword_lds path that K_mem_load / V_mem_load use
wraps the per-lane voffset (int32 bytes) once
num_blocks * page_size * row_stride * sizeof(T) > INT32_MAX,
silently returning wrong data on large paged-KV pools (e.g. >4 GB caches).
Add a second path, async_load_tile_raw_long, that issues the same load via
__builtin_amdgcn_global_load_lds with per-lane 64-bit base pointers, lifting
both 4 GB limits (SRD size + voffset). Per-issue LDS pointers are computed
explicitly because the intrinsic sets m0 itself, so the old m0_set / m0_inc
bookkeeping doesn't apply. The path also clamps lane_elem_off to the live
buffer range to mimic the original SRD's hardware OOB behaviour.
Dispatch is a wave-uniform runtime branch on a new
cache_ptr_int32_overflow_possible flag plumbed from
unified_attention_args through MakeKargs into the pipeline operator().
Small caches keep the original buffer_load throughput; only the (rare)
>4 GB cache pays the global_load_lds cost.
k_page_offsets / v_page_offsets are widened to long_index_t. The original
buffer_load path implicitly narrows back to int32 when forwarding through
async_get_vectorized_elements_raw, which is intentional and safe whenever
the overflow flag is false.
For diagnostics, also derive a constexpr KWaveSpanInN =
(LaneGroups - 1) * NumWarps + 1 inside the pipeline; when this exceeds
page_size a single buffer_load spans multiple random pages, so the
per-issue SRD-rebase optimisation (not implemented yet) would not apply
even on a sub-4 GB cache. Informational only today.
Test: ua-test-scripts correctness sweep (245/245 pass), plus
test_single_shape.py -b 32 -sq 8192 -sk 120000 -hq 64 -hk 8 -d 64 \
--num-blocks 1200000 --block-size 16 --test
which previously returned wrong data due to the int32 wrap and now passes
with max abs diff 1.22e-04 vs Triton.
Co-authored-by: Cursor <cursoragent@cursor.com>
After moving kBlockQ to runtime in the previous commit, the static
NumQPerKV in `variant_config<V>` and the runtime-vs-static assert in
the kernel became the only things still tying a compiled binary to a
specific num_queries_per_kv. Drop both and the existing instances now
serve every num_qpkv that divides kBlockM evenly.
Concretely:
* `variant_config<V>` -- remove the NumQPerKV field from every
specialization.
* `unified_attention_kernel_traits` -- remove the `num_queries_per_kv`
/ `kBlockQ = kBlockM / num_qpkv` derivation. The BlockTile's 2nd
entry (the static `kBlockQ` exposed via UnifiedAttentionShape) is
anchored at kBlockM so it describes the "num_qpkv == 1" fallback;
the actual kBlockQ is always the runtime value.
* `unified_attention_kernel_launch` -- recompute kBlockQ at host time
from `args.num_queries_per_kv` for the total_num_q_blocks math.
* `unified_attention_kernel.hpp` -- drop the
`assert(kBlockQ_dyn == kBlockQ)` (it enforced the very coupling we
just removed).
* `unified_attention.cpp::select_config` -- collapse the two
per-num_qpkv code paths into a single (head_dim, avg_rows,
max_rows) ladder, where avg_rows = avg_q * num_qpkv.
Variant renames (8 variants):
prefill_d128_mha -> prefill_d128
decode_d128_mha_m128 -> decode_d128_m128
decode_d128_mha_m32 -> decode_d128_m32
decode_d128_mha_m16 -> decode_d128_m16
prefill_d64_gqa8 -> prefill_d64
decode_d64_gqa8_m128 -> decode_d64_m128
decode_d64_gqa8_m64 -> decode_d64_m64
decode_d64_gqa8_m16 -> decode_d64_m16
The 16 d=64 instance files lose their `_gqa8` infix to match the
d=128 naming (file count unchanged: 16 dtypes x mask combos per
head_dim).
Validation:
* Correctness suite: 241/245 (same 4 pre-existing int32-overflow
failures in the prefill rebased-pointer path).
* d=128 GQA-8 (a NEW combo we never had a binary for) -- runs
correctly on the existing decode_d128_m* binaries with num_qpkv=8
at runtime. max abs diff <= 1e-2 vs the torch reference at ql in
{1, 4, 16}.
* d=64 MHA (also a new combo) -- runs correctly on the existing
decode_d64_m* binaries with num_qpkv=1. Same tolerance.
* Perf sweep (b=4..256, sk=120000, MI300):
d=64 GQA-8: speedups 1.28x..1.84x vs Triton (within 0.6%
of baseline).
d=128 MHA: speedups 0.98x..1.14x vs Triton (within 0.3%
of baseline).
Unlocked: adding new (head_dim, num_qpkv) combos no longer requires
new kernel binaries -- just a host-side heuristic update mapping the
combo to the appropriate (kBlockM, BlockWarps) ladder.
Co-authored-by: Cursor <cursoragent@cursor.com>
Replace 4 near-identical *_kernel_traits classes (~400 lines of repeated
shape/policy plumbing) with one templated `unified_attention_kernel_traits`
parameterized by `KernelVariant V`. The 6 dispatch_<variant> helpers in
unified_attention.cpp collapse into a single `dispatch_variant<V>` function
template that fans out over (dtype, mask).
Per-variant compile-time knobs (BlockM, BlockSize, warp count, MFMA shape,
pipeline policy, decode-grid flag) now live in one variant_config<V>
specialization each. "What's different between variants" is readable
top-to-bottom in a single block of code, and each instance .cpp shrinks to
a one-line `INST_UNIFIED_ATTENTION_DISPATCH(V, DTYPE, IS_MASK)` macro.
No behavior change. Correctness suite: 236/240 (same 4 known
num_blocks=32768 + d=128 MHA int32-overflow failures as baseline).
Co-authored-by: Cursor <cursoragent@cursor.com>
The bs32 variants existed because pre-fix the pipeline required
kBlockN <= page_size, so page_size=32 forced a kBlockN=32 kernel
family. The multi-page-tile fix (commit 473869aba) lifted that
constraint and made kBlockN compile-time-independent of the runtime
page size, so the bs32 family is now redundant: every non-bs32 variant
is correct for any page_size.
This was validated in advance by forcing use_bs32=false in the
dispatcher and running the full correctness suite -- 236/240, identical
to baseline (the 4 remaining failures are the pre-existing int32-
overflow case, orthogonal).
Removes:
* 16 instances/unified_attention_*_bs32_*.cpp files
* unified_attention_decode_bs32_kernel_traits in unified_attention_impl.hpp
* 3 _BS32 dispatch macros in unified_attention.cpp
* 3 _p32 entries from the KernelVariant enum
* 3 dispatch_*_p32 helper functions and their switch cases
* the page_blk_size branch in select_config (now a pure tile-tier ladder)
Net: 12 fewer compile units (build time -6s on JIT), 78 fewer dispatcher
lines, and "which kernel runs?" is now driven purely by Q-tile shape.
Co-authored-by: Cursor <cursoragent@cursor.com>
Until now every d=128 MHA workload took the 8-warp prefill kernel
(kBlockM=256, kBlockQ=256), wasting 255/256 Q rows on pure-decode
shapes where Q is 1. Add a dedicated 4-warp decode variant with
kBlockM=128 (kBlockQ=128) that cuts the Q-tile waste roughly in half.
* Four new instance files at instances/unified_attention_d128_*_decode.cpp,
each instantiating unified_attention_decode_kernel_traits<dt, mask, 128, 128, 1>.
* KernelVariant::decode_d128_mha_m128 wired into select_config: chosen
when both avg_q and max_seqlen_q fit in 128, else fall back to prefill.
Tests: ua-test-scripts/test_unified_attention_ck_correctness.py stays at
236/240 -- the pure-decode seq_lens pattern in head_config=(16,16,128)
now routes to the new variant and matches the torch reference. The 4
remaining failures are the pre-existing int32-overflow case (orthogonal).
Co-authored-by: Cursor <cursoragent@cursor.com>
The previous dispatcher was a 4-deep nested-if cascade that picked one
of seven DISPATCH_* macros based on (hdim, num_queries_per_kv, dtype,
mask, tile_tier, use_bs32). The macro names hid both the traits class
and the dispatch path, so reasoning about "what kernel runs for shape
X" required reading the whole file.
Replace it with two named layers:
1. KernelVariant enum -- a flat list of every compiled instance.
2. select_config(args) -- the only place runtime decisions live;
reads the problem and emits a KernelConfig{variant, ...}.
The final switch over the variant calls into per-variant dispatch
helpers that fan out over (dtype, mask) via the existing DISPATCH_*
macros. Behaviour is unchanged: each old (hdim, nqpkv, tier, p32) tuple
maps 1:1 to a KernelVariant, and the same instance is launched.
Follow-up commits in this series will:
- add a dedicated d=128 MHA decode variant
- delete the _p32 ("bs32") family now that the multi-page-tile fix
in the pipeline makes kBlockN independent of page_size
Test: ua-test-scripts/test_unified_attention_ck_correctness.py
stays at 236/240 (same 4 pre-existing int32-overflow failures).
Co-authored-by: Cursor <cursoragent@cursor.com>
End-to-end split-KV (FlashDecoding-style) for the CK unified attention
kernel. The host launches a single 3D grid with z == num_splits; each
CTA computes its KV-range slice and writes a normalized (o_acc, lse)
partial to FP32 workspaces, which the caller reduces into the final
output.
Pipeline changes:
- operator() returns ck_tile::make_tuple(o_acc, lse) instead of just
o_acc. The masked-empty early-exit returns lse = -inf so downstream
combine weighs the partial as zero.
- LSE is built in the natural-log domain from the pipeline's *unscaled*
rowmax: lse = (scale_s / log2(e)) * m + log(l). Previously we used
m / log2(e) + log(l), which dropped the per-head scale and produced
LSE values ~1/scale too large.
- Fix post-process parity: which SP register is left in the
alu0-done/alu1-pending state at loop exit depends on the parity of
the *iteration count* (= num_total_loop - num_blocks_start), not on
num_total_loop alone. For non-split (num_blocks_start == 0) the two
parities coincide; for splits starting at an odd tile they don't.
- Fix split-KV page-table offset: num_blocks_start is counted in
kPageBlockSize-sized tiles, but block_tables is indexed in
page_size-sized pages — shifting block_table_offset by num_blocks_start
reads the wrong pages whenever kPageBlockSize != page_size. Replaced
with split_token_offset = num_blocks_start * kPageBlockSize added to
logical_token before /page_size, so the page lookup uses the absolute
token position.
Kernel + dispatcher:
- Drop kargs.i_split; each CTA reads i_split = blockIdx.z.
- GridSize{2D,Decode} now take num_splits and add it as the z-dim
(defaults to 1, so non-split callers see dim3(..., 1, 1)).
- New write path: when num_splits > 1, the kernel skips the user
epilogue and instead writes the FP32 (o_acc, lse) tile pair into
workspace tensors at [head, split, batch_start_token, ...] using
Default2DEpilogue (UseRawStore=true) for o_acc and store_tile for
lse. Host strides come from kargs.
Co-authored-by: Cursor <cursoragent@cursor.com>
Decode grid (num_kv_heads, num_seqs) assumes each seq has <= kBlockQ
tokens. For mixed batches (decode + prefill), avg_q is low but some
seqs have hundreds of tokens, causing truncation. Added max_seqlen_q
to args and check it in select_tile_tier to force medium tier (1D
grid with Q tile iteration) for mixed batches.
362/362 no-window shapes now pass.
Made-with: Cursor
mask_info::decode('b:left,right,sink') always created mask_bottom_right
(IsLocal=false) which ignores the left window boundary. For sliding
window attention (left >= 0), use window_generic (IsLocal=true) so the
kernel respects the left boundary.
Fixes: CK split-KV producing identical results with and without sliding
window. Now 724/724 shapes pass correctness vs Triton.
Made-with: Cursor
1. Dual-tile: add both bn0=64 (preferred) and bn0=32 (fallback) for
hdim=64 on gfx9 and gfx12. The dispatch checks page_block_size %
bn0 == 0 at runtime to select the optimal tile. bn0=64 halves KV
iterations when page_block_size >= 64.
2. Tile dict now supports lists per hdim. The codegen loop iterates
over all tile variants, generating separate kernel instances for
each. Combine kernels are unaffected (tile-independent).
3. Enable kMergeNumHeadGroupsSeqLenQ for hdim=64 decode (previously
hdim=128 only). For GQA-8 with max_seqlen_q=1, this packs 8 head
groups into the M dimension. Only activates for no-mask instances
(kernel static_assert requires !kHasMask).
4. Add qr (non-async) pipeline for fwd non-bias group mode as
fallback after qr_async. The async pipeline on this branch has a
kernel-level bug where fmha_fwd launches but writes no output.
Made-with: Cursor
[CK_TILE] Fix Windows build in FMHA head grouping
## Motivation
This is a follow-up fix for [PR
#5018](https://github.com/ROCm/rocm-libraries/pull/5018).
[PR #5018](https://github.com/ROCm/rocm-libraries/pull/5018) added
LLC-aware FMHA head grouping / head-major scheduling on RDNA, but it
also introduced Linux-only code paths, including `<dirent.h>`, which
break Windows builds. This change fixes that by guarding the
Linux-specific LLC probing logic so non-Linux platforms can still build
correctly.
## Technical Details
- Guard `<dirent.h>` with `#ifdef __linux__`
- Guard KFD sysfs traversal logic with `#if defined(__linux__)`
- On non-Linux platforms, return `0` from
`get_kfd_sysfs_llc_cache_bytes()`
- Preserve existing fallback behavior through:
- `CK_TILE_FMHA_LLC_CACHE_MB`
- arch-based default LLC sizes
- no head grouping when no LLC size can be resolved
## Test Plan
<!-- Explain any relevant testing done to verify this PR. -->
## Test Result
<!-- Briefly summarize test outcomes. -->
## Submission Checklist
- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
[CK_TILE] add tf32 support
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## Proposed changes
TF32 is added in CK on gfx942 and gfx950. This PR is to initiate tf32 in
CK_TILE on gfx942 and gfx950.
## Checklist
Please put an into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.
- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run on all changed files
- [ ] Any dependent changes have been merged
## Discussion
CK Tile MX GEMM Packing Improvement
## Motivation
Reduce the scale loading size and also has better utilization of MFMA
scale selection.
## Technical Details
Add up the packing of mx scales.
## Test Plan
Use the existing test cases.
## Test Result
<!-- Briefly summarize test outcomes. -->
## Submission Checklist
- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
[CK_TILE] Add LLC-aware FMHA head grouping and head-major
scheduling on RDNA (#5018)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## Motivation
Long-sequence FMHA can become memory-bound when K/V working sets exceed
Infinity Cache (LLC), causing repeated DRAM traffic across heads.
This PR introduces LLC-aware launch ordering improvements for FMHA
forward, and it is currently enabled only on gfx11 and gfx12. The
approach is inspired by
[`Dao-AILab/flash-attention#2217`](https://github.com/Dao-AILab/flash-attention/pull/2217),
adapted to CK’s kernel/runner structure and layout handling.
In this context, `bshd` is the layout used in Flash-Attention, while
`bhsd` is the default layout used by the CK Tile FMHA example.
## Technical Details
This PR adds two complementary strategies:
- For `bshd` input layout (`i_perm/o_perm=0`), enable explicit LLC-aware
head grouping:
- Estimate LLC size (env override, KFD sysfs, or arch default).
- Compute group size from K/V bytes per head vs LLC target.
- Launch FMHA forward repeatedly per head-group by slicing Q/K/V/O (and
related tensors).
- For `bhsd` input layout (`i_perm/o_perm=1`), apply implicit
launch-order adjustment:
- Keep a single kernel launch.
- Reinterpret block linearization in `GetTileIndex` to make execution
head-major,
improving temporal locality of per-head K/V reuse.
Additional integration updates:
- Propagate `num_head_q_total` and `head_start` through FMHA args/kargs.
- Use global head indexing for dropout RNG stream mapping so grouped
launches keep
deterministic/consistent dropout behavior.
- Keep fallback behavior unchanged when grouping is not beneficial or
disabled.
## Test Plan
- `test_ck_tile_fmha`
- `tile_example_fmha_fwd`
## Test Result
- `test_ck_tile_fmha`: all tests passed.
- `tile_example_fmha_fwd`: tested this on gfx1100, gfx1151, and gfx1201,
and all of them show higher performance compared to the baseline. The
improvement is consistent, and performance is well maintained even at
long sequence lengths.
./build/bin/tile_example_fmha_fwd -prec=bf16 -mode=0 -b=1 -h=24 -d=128
-s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}
- TFLOPs by sequence length target: gfx1100 layout: bhsd
SeqLen | Before | After | Speedup
-- | -- | -- | --
1024 | 56.27 | 61.48 | 1.09x
4096 | 67.10 | 72.27 | 1.08x
8192 | 65.99 | 71.64 | 1.09x
12288 | 61.60 | 76.61 | 1.24x
16384 | 58.99 | 75.74 | 1.28x
20480 | 57.32 | 74.42 | 1.30x
24576 | 56.89 | 74.25 | 1.31x
27280 | 18.93 | 24.48 | 1.29x
- TFLOPs by sequence length target: gfx1201 layout: bshd
SeqLen | Before | After | Speedup
-- | -- | -- | --
1024 | 66.79 | 65.90 | 0.99x
4096 | 85.90 | 86.80 | 1.01x
8192 | 77.06 | 90.29 | 1.17x
12288 | 58.36 | 88.98 | 1.52x
16384 | 52.12 | 88.88 | 1.71x
20480 | 48.11 | 88.42 | 1.84x
24576 | 47.12 | 89.07 | 1.89x
27280 | 49.05 | 50.31 | 1.03x
## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
[CK Tile] Eight Waves pipeline GEMM
## Motivation
Eight waves pipeline was added for ABQuant. The goal of this PR is to
enable it also for GEMM
## Technical Details
Summary:
- Block:
- Create block struct for GEMM using eight warps specific distribution
encodings
- Use this block struct in ABQuant for encodings
- Pipeline:
- Create impl pipeline for eight waves which can be used by GEMM and
ABQuant as base (and for AQuant and BQuant in the future)
- Create eight waves pipeline for GEMM (this can not be easily
integrated in the existing async pipeline)
- Pipeline policy:
- Extract GEMM specific parts in the ABQuant policy to define GEMM
policy (then ABQuant use it as base and add Quant specific methods)
- Minor: naming was inconsistent between warp/wave, everything is now
referred to as eight waves
So overall we have:
- block struct directly used by GEMM -> ABQuant derived struct to
implement operator
- Impl base pipeline with general implementation -> GEMM and ABQuant
pipelines use it to avoid code duplication but still define their own
pipelines
- pipeline policy struct directly used by GEMM -> ABQuant derived policy
struct for Quant specific parts
## Test Plan
Added new tests for GEMM pipeline:
`test_ck_tile_gemm_pipeline_comp_async_eight_waves` (only gfx950
supports it).
Note: K padding test is disabled for this pipeline because it's not
implemented yet
## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
[CK][CK Tile] Grouped Convolution Backward Weight set of
fixes (#5387)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## Motivation
Grouped Convolution Backward Weight split k fixes for CK tile kernels
## Technical Details
- get k batch from kargs to get deduced k batch
- multiply zeroing size by data type size
- disable v6 (producing a incorrect results)
## Test Plan
test_grouped_convnd_bwd_weight_tile
## Test Result
Pass
## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
[CK][Examples] Adding parameters for a couple of CK examples:
-gemm_add_add_mean_meansquare_xdl_fp16 -gemm_dl_quantization_int8
-gemm_xdl_bias_relu_quantization_int8 -gemm_xdl_quantization_int8
Signed-off-by: Michal Kulikowski <Michal.Kulikowski@amd.com>
ck_tile: add gtest unit tests for MX flatmm (gfx950)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## Summary
- Add correctness unit tests for the MX-format flatmm kernel
(`example/ck_tile/18_flatmm/mxgemm`) under `test/ck_tile/flatmm/`
- Tests cover all five dtype combinations: FP4×FP4, FP8×FP8, FP6×FP6,
FP8×FP4, FP4×FP8
- Tests cover all four kernel dispatch paths (the `has_hot_loop` ×
`tail_num` product):
- `has_hot_loop=false, tail=ODD` (K=256, num_loop=1)
- `has_hot_loop=false, tail=EVEN` (K=512, num_loop=2)
- `has_hot_loop=true, tail=ODD` (K=768, num_loop=3)
- `has_hot_loop=true, tail=EVEN` (K=1024, num_loop=4)
- Remove unsupported `-split_k` CLI option from
`tile_example_mx_flatmm`; the pre-shuffled B layout is incompatible with
K-splitting and the option silently produced wrong results
## Changes
**New files (`test/ck_tile/flatmm/`):**
- `CMakeLists.txt` — builds 40 kernel instances as a shared OBJECT
library, links into 5 per-dtype test executables; forwards
`-DCK_TILE_USE_OCP_FP8` when `CK_USE_OCP_FP8` is ON
- `test_mx_flatmm_base.hpp` — base test fixture with
`run_test_with_validation(M, N, K, kbatch=1)`
- `test_mx_flatmm_fixtures.hpp` — concrete `TestMXFlatmm` typed test
class and type aliases
- `test_mx_flatmm_fp{4fp4,8fp8,6fp6,8fp4,4fp8}.cpp` — per-dtype
`TYPED_TEST_SUITE` files
**Modified files:**
- `example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp` — moved
`preShuffleWeight` here (was in `mx_flatmm.cpp`) so it is includeable by
both the example and the tests
- `example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp` / `run_mx_flatmm.inc`
— removed `-split_k` CLI arg, hardcoded `k_batch=1`, fixed `k_split`
formula, updated call sites after `preShuffleWeight` move
- `test/ck_tile/CMakeLists.txt` — added `add_subdirectory(flatmm)`