Commit Graph

1578 Commits

Author SHA1 Message Date
juuso-oskari
8abbd21a01 CK-UA: VGPR-pressure toggles for kv128 probing (all default OFF)
Adds compile-time levers, all guarded and bit-identical to production when
unset, used to characterise why prefill_d128 fp8 fits KV tile 64 but not 128
under the 256-VGPR/wave ceiling (see ua-test-scripts/kv128_vgpr_findings.md):

- UA_PREFILL_D128_BLOCKSIZE (default 64): KV-tile override for probing kv128.
- UA_FA4_INPLACE_DELTA (default 0): drop sp_delta, scale-shift/exp2 in place on
  sp_compute (fmha_alu_D_upd reads only m/l/o_acc/rowsum_p, never raw scores, so
  bit-identical). VGPR-neutral on its own (compiler already reclaims sp_delta).
- UA_FA4_SHARED_SPCOMPUTE (default 0): keep ONE shared fp32 sp_compute + a
  2-slot fp8 P ping-pong instead of a 2-slot union{sp_compute,p}. The deferred
  PV only needs one live fp32 score; this cuts kv128 spills 173 -> 126. (Forces
  in-place delta; slightly regresses kv64 so it is a kv128-only lever.)
- UA_FA4_UNION_KV (default 0): union k_tile/v_tile (ASM-style). VGPR-neutral;
  kept as a documented dead end (compiler already overlaps their live ranges).

P thread-buffer size exposed as a type-derived constexpr (kPThreadBufSize) so
the static_assert/static_for sites work when sp(idx) is the runtime proxy.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 15:29:04 +00:00
juuso-oskari
9aa380e6c2 CK-UA: wide 32x32x64 FP8 MMA with cvt-only P relayout + V-read in MATRIX
- Add strategy C (cvt-only, barrier-free) QK-C->PV-A FP8 relayout for the
  K=64 wide v_mfma_f32_32x32x64 tile: QK-C and PV-A per-thread layouts
  coincide under the wide MMA, so the relayout is just the fp32->fp8 pack
  (matches the ASM kernel's _softmax_pack_P_fp8). Gate kFP8RelayoutWithinWave
  for K=64 in addition to K=16; both are FA4-safe (no in-softmax barrier).
- Wire the wide-MMA variant config (example) + relayout default policy.
- Move the FA4 V LDS transpose-read out of the preceding SOFTMAX into the
  MATRIX phase, off the longer/critical softmax path (UA_FA4_VLOAD_IN_MATRIX=1).
- Add UA_FA4_PIN_PACK_IN_SOFTMAX experiment toggle (default 0).

Measured: wide MMA closed the structural gap vs the ASM fp8 kernel from
~1.75x to ~1.16x at b1/h5/sq75600/d128 (1711 TF standalone).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 14:41:47 +00:00
juuso-oskari
a4d3ff34fb CK-UA: widen FP8 K/V async loads to dwordx4 (per-WG load-thread count)
The K/V async-load width selector (GetKVAlignmentBytes via GetAlignmentK/V)
computed its dwordx4 budget from the full block (kBlockSize=512 threads), so
the 4 KB FP8 prefill tile never tiled cleanly and fell back to dword. With the
FA4 per-warp-group decoupling a single 4-warp group (256 thr) fills the tile by
itself -> 4 KB / 256 = exactly 16 B/thr = dwordx4.

Thread the load-thread count into GetAlignmentK/V as a NumWarps template param
(default = shape NumWarps, so all sizing/paged/decode instantiations are
byte-identical). Only the load-path callers (MakeK/VDramTileDistribution, the
LDS store/load descriptors, and the kAlignmentK/V DRAM-view vector) pass the
decoupled GetK/VLoadNumWarps count to unlock the wide load.

Effect: global_load_lds_dword 36->9 dwordx4 and buffer_load_dword 36->9 dwordx4
(both runtime branches); VGPR 181->173; LDS/SGPR unchanged. Accuracy PASS 0%
(non-causal + causal). Latency-neutral on sq8192 (kernel is memory-latency
bound, not load-issue bound) but a strictly-better instruction/VGPR footprint.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-09 16:40:21 +00:00
juuso-oskari
c11722bf3e CK-UA: decouple K/V DRAM loads per warp-group + early V read (FA4 fp8)
Split the cooperative K/V cache loads across the two FA4 warp groups so
each group owns exactly one tile's DRAM load and address arithmetic:
WG0 loads V, WG1 loads K, and both read from the shared LDS buffers.

- kFA4WG0LoadsV / kFA4WG1LoadsK policy flags + GetVLoadNumWarps /
  GetKLoadNumWarps: the owning group's 4 waves alone fill the tile via
  4-warp descriptors; the partner skips the load and reads from LDS.
- High-warp-group support for the raw async path: the raw store bakes the
  absolute warp id into the LDS M0, so WG1 (waves 4-7) needs a base shift
  (GetKStoreWarpShift / WarpIdShift in MakeKLdsStoreBlockDescriptor) to
  map back to the 4-warp layout, plus WG-relative (warp % NumWarps) page
  offsets so the gather token positions are correct.
- Stage B: move each tile's V LDS read into the PRECEDING softmax phase so
  the read latency hides under softmax VALU. Safe because V is now single-
  group-owned; uses drain-before-barrier (vmcnt<0> then s_barrier) so all
  4 cooperating writer waves' slices are published before the read.
- Gate per-tile offset refresh per warp-group (WG0 refreshes V, WG1 K), so
  each wave fetches a block-table page index for one tile instead of both;
  loop counters stay uniform.

Validated 0% mismatch vs GPU reference, causal + non-causal, sq 256..8192.
Net latency vs the cooperative baseline: causal ~-3-4.6%, non-causal
~-2-4.7% across sq 2048..16384 (d128 fp8).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-09 14:35:00 +00:00
juuso-oskari
26bc49f733 CK-UA: un-union kv_tile so K ds_read overlaps PV MFMA (FA4 fp8)
kv_tile held k_tile/v_tile in a *union* to save VGPRs, but on gfx950 the
union forced a hard serialization: K_lds_load wrote the same registers
the PV MFMA was reading (v_tile), so the K ds_read could not start until
the PV MFMA retired -> full LDS latency exposed at the QK gemm's
s_waitcnt_lgkmcnt<0> (ATT: ~half of all memwait stall).

Make k_tile/v_tile separate registers and pin K_lds_load between the PV
and QK MFMAs with sched_barrier so its ds_read executes on the LSU
*during* the PV MFMA (latency hidden) without being hoisted above PV
(which would race the partner WG's cooperative K load on long contexts).

Occupancy is VGPR-bound here (160KB LDS), and .vgpr_count is unchanged
(172 -> 172), so the change is free. Standalone fp8 d128 sq8192:
~515.7 -> ~497.5 us (-3.7%), memwait 31% -> 19%, accuracy 0% mismatch.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-09 10:46:56 +00:00
juuso-oskari
5a44f5885f WIP(unified_attention): contiguous/non-paged kIsPaged path + non-causal full-range fix
Checkpoint before merging upstream CK. Adds the kIsPaged=false kernel
instances (d64/d128 bf16/fp8, mask/nmask), folds kv_start into base
pointers, and fixes the non-causal KV-range envelope (scan full seq_len
when !FmhaMask::IsMasking instead of the causal horizon).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-08 14:24:05 +00:00
juuso-oskari
0f009a3442 CK-UA: fix split-KV partition for non-dividing GQA + add ps128 decode instances
Under split-KV, a KV token co-owned by two query tiles (which happens only
when num_queries_per_kv does not divide kBlockM, e.g. d=128 qpkv=6) was
assigned its split partition from the per-tile causal horizon
(total_num_kv_blocks, which grows with the query tile index). The two owning
tiles then reduced disjoint KV-block ranges for that shared token and the
combine step merged partials over different ranges -> a ~1-row error / NaN on
the tile-boundary token. MHA and ratios that divide kBlockM are immune (no
token is shared across tiles).

Fix: derive blocks_per_split from the causal-INDEPENDENT full-sequence block
count so split s maps to the same blocks in every query tile, then clamp only
the END by the per-tile causal horizon. The duplicate co-owned store becomes
idempotent again. num_splits == 1 is unchanged.

Also adds the d128 bf16 page_size=128 decode instances (mask/nmask x
default/s/t) plus the matching dispatch in unified_attention.cpp and the
fmha_batch_prefill codegen hook.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-08 08:46:55 +00:00
juuso-oskari
5d1def74a6 CK-UA: remove legacy ping-pong pipeline; FA4 is the only 2-WG path
Pipeline cleanup (-fav4):
  * Delete the 8-wave compute/memory ping-pong baseline (the ~200-line
    monolithic `core_loop` lambda + its 2-warp-group dispatch). It was
    reachable only under -DUA_FA4_PIPELINE=0 and never beat FA4 on any
    measured prefill shape, so it was dead under the default build.
  * Drop the UA_FA4_PIPELINE toggle entirely. kFA4 is now derived purely
    from NumWarpGroups==2 + the 32x32x16 within-wave FP8 P-relayout
    invariant, with a static_assert pinning that every 2-WG instance is
    FA4-capable (fails the build loudly instead of running an empty loop).
  * Remove the now-orphaned ADD_SBARRIER_FOR_PHASE0/PHASE2 knobs (they
    only gated barriers inside the deleted core_loop). MOVE_FMHA_MASK_*
    stay (still consumed by the FA4 core-loop scheduler).
  * The non-FA4 pre-stage + fmha_post_process epilogue are retained: they
    are shared by the single-warp-group (NumWarpGroups==1) serial decode
    path, where kFA4 is false.

Behaviour-preserving for the default build: FA4 prefill perf is bit-for-
bit unchanged (b16 sq=sk=10000 fp8 CK=5.76ms before/after) and the full
decode regression (d{64,128} x {bf16,fp8} x split-KV {2,64}) still PASSes.

Add opt-in prefill fallback knob (unified_attention.cpp):
  * AITER_UA_PREFILL_FALLBACK=1 routes prefill-sized shapes to the 4-warp
    single-warp-group *serial* decode_*_m128 instances instead of FA4.
    Reuses already-compiled instances (no extra binary). OFF by default:
    the serial path has no matrix/softmax overlap and measured ~0.66-0.70x
    Triton vs FA4's ~0.73-0.80x on gfx950 fp8 GQA-12/2 (i.e. SLOWER than
    FA4). Kept as a diagnostic / robustness A-B knob only.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-03 09:15:30 +00:00
juuso-oskari
374536f19a CK-UA: checkpoint FA4 pipeline + int64 Q/O base-offset fix
Working state before the pipeline cleanup/refactor:
  * FA4 matrix-softmax warp-group overlap pipeline (UA_FA4_PIPELINE=1).
  * Widen per-CTA query/output base offsets to long_index_t so large
    total_q (big-batch prefill) can't overflow int32 and fault on the
    output store (cache_ptr_int32_overflow_possible only covers K/V).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-03 08:47:43 +00:00
juuso-oskari
64d3e00077 Merge remote-tracking branch 'origin/dlejeune/ua-swa-v2' into jukorhon/unified-attention-ck 2026-06-01 12:45:43 +00:00
juuso-oskari
9373fab553 CK-UA: replace FP8 repack ds_bpermute with v_permlane32_swap_b32 (gfx950)
The FP8 32x32x16 fmha_alu1 repack trades each lane's "bad" 4-byte pack
with the lane 32 away (QK-C vs PV-A layouts differ by an l^32 swap).
Since 3431615ff this used an LDS-crossbar ds_bpermute plus an is_sub_0
v_cndmask mux.

gfx950 exposes v_permlane32_swap_b32 (__builtin_amdgcn_permlane32_swap,
feature permlane32-swap) which does the l^32 exchange in a single VALU
op with no LDS round-trip. Verified on-device that permlane32_swap(
lo_pack, hi_pack) returns {out_lo, out_hi} for every lane, folding both
the cross-lane swap and the per-lane sub-block muxing into one
instruction. Guarded #if defined(__gfx950__); ds_bpermute kept as the
#else fallback (gfx942 lacks the feature).

ISA (prefill_d128 fp8 instance): 12 ds_bpermute -> 0, replaced by 12
v_permlane32_swap_b32; v_cndmask muxing removed. FP8 prefill + decode
PASS vs torch reference. Clean A/B (median of 3, b=4 FP8 prefill):
sq=sk 2048/5000/10000 -> 1.6% / 1.9% / 2.1% faster, scaling with the
per-iter repack count.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-01 11:41:52 +00:00
juuso-oskari
87658a9518 CK-UA: hoist wave-uniform warp id out of the async-load issue loop
tile_scatter_gather::async_load_raw / async_load_raw_long recompute
get_warp_id() (threadIdx.x/warp_size + a convergent v_readfirstlane) at
every K/V load issue to form the m0 / LDS-wave base. The value is
wave-uniform and constant for the window's lifetime, but LLVM cannot
hoist or CSE it across the load loop: v_readfirstlane is convergent and
the m0 set is an asm-volatile with a memory clobber, which together pin
the recompute to each issue.

Materialize the warp id once at window construction (cached_warp_id_,
set only for the global-memory gather windows that issue these loads)
and read the cached SGPR in both async paths. ISA: the per-issue
s_lshr ÷64 and v_readfirstlane drop out of the loop (warp-id readfirstlane
sites 36 -> 11, the ÷64 shift down to 2 static sites).

Matched sweeps (line-tables-identical codegen, GQA-6 d128 page64):
  bf16 prefill: -7.25% aggregate, 12/12 shapes improved, 0 regressions
                (CK/Triton at sq>=5000 moves ~0.83x -> ~0.90x)
  fp8  all:     -0.4..-1.3% aggregate (fp8 prefill is gated by the
                ds_bpermute repack, which masks the addressing savings)
Correctness vs torch reference: PASS (fp8 + bf16).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-01 11:05:21 +00:00
juuso-oskari
7fc24c8c45 CK-UA: within-tile page-table dedup + UA-owned core-loop scheduler
Two prefill_d128 changes on the unified-attention pingpong (checkpoint):

1. refresh_{k,v}_offsets: dedup the per-issue page-table lookup. With a
   compile-time page_size the issue->page map is a pure compile-time
   function in two provable regimes (page-divides-tile / tile-divides-
   page), so phys_page is resolved once per distinct page instead of
   once per issue -- collapsing to a single ds_read + readfirstlane at
   page_size >= kPageBlockSize. Gated on kHasCePageSize; the runtime-
   page-size scalar-promote and per-lane fallbacks stay byte-identical.
   Measured fp8 prefill (ps=64), amir-shape sweep: +6.8% aggregate
   (5-7%/shape, scaling with seqlen); B2 K-mem barrier straggler
   -21..25%, total mean barrier stall -12%. Correctness verified
   fp8 ps={32,64} and bf16 ps={16,32,64}.

   (A cross-tile phys_page memo was prototyped and reverted: the Tier-2
   LDS read is already cheap/hidden post-dedup, so the runtime guard +
   loop-carried dep it needed was a net ~0.3% regression.)

2. Fork the FMHA CoreLoopScheduler into a UA-owned UAcoreLoopScheduler
   and thread MOVE_FMHA_MASK_TO_COMPUTE through its sched_group_barrier
   hints so the per-phase instruction-mix hint stays in lockstep with
   mask code motion. With the macro at 0 the table is byte-identical to
   the upstream FMHA scheduler (same hints, same codegen).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-01 09:01:21 +00:00
Damien Lejeune
2b4020af0d Merge branch 'jukorhon/unified-attention-ck' into dlejeune/ua-swa-v2 2026-05-28 09:18:57 +00:00
Damien Lejeune
1cc12ab5f3 Step D tile clip + first SWA instances (large prefill tier) 2026-05-27 14:37:53 +00:00
juuso-oskari
a3714e82cf CK-UA: revert unrelated fmha touches not consumed by unified_attention
Eight files outside the UA scope had drifted onto this branch over time
via earlier commits whose subject lines explicitly carried no "CK-UA:"
prefix — they are independent fmha bug fixes and codegen additions that
do not touch any code path the unified_attention example, kernel or
pipeline actually compiles or includes.

Reverted to the merge-base content of:

  include/ck_tile/ops/fmha/block/block_masking.hpp    (-39 lines)
      added by 6729989b9 "Fix FMHA split-KV for paged-KV with
      page_block_size < kN0"

  include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp (-2 lines)
      added by ec2db01e4 "Fix fmha_fwd early-exit bug: seqlen_q <=
      min_seqlen_q should be <"

  example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
  example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py
  example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
  example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
  example/ck_tile/01_fmha/fmha_fwd_runner.hpp
  example/ck_tile/01_fmha/mask.hpp
      added by 63821af1f / cb6fb2802 / c5600bc8a / e5272603c / 10564b0c4
      / cd7ba6e2e / 07ba03bcb — split-KV decode tiles, codegen tweaks,
      and a sliding-window mask fix, all in the 01_fmha example program
      (a separate build target; the UA example lives in
      42_unified_attention and pulls in zero 01_fmha sources).

These commits are still reachable from the branch's reflog and from
their original commit hashes; they should each be cherry-picked onto
their own branches and sent upstream as standalone fmha bug-fix PRs —
they look like clean fixes that upstream would welcome, but they don't
belong in the UA PR's scope.

Verified empirically: clean JIT rebuild of module_unified_attention
followed by both regression shapes pass at full perf
  b=128/sk=16384/d=128/bf16  : 1.5152 ms, 5672 GB/s, PASS
  b=1/sk=1M/d=128/bf16 nb=70k : 0.7677 ms, 5594 GB/s, PASS
matching the pre-revert numbers to within run-to-run noise.

Branch's shared-CK touch surface after this revert: tile_scatter_gather.hpp
(+152 from our async_load_raw_long method), load_tile.hpp (+21 from the
sister dispatcher), warp_gemm[.|_dispatcher.]hpp (+13 for the FP8 e4m3
small-tile registration), and the new amd_global_load_lds_raw.hpp file.
Down from 14 shared files to 4.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-27 13:24:11 +00:00
juuso-oskari
7772504f54 CK-UA: relocate amd_async_global_load_lds_raw to its own header
This helper (added by 1f6942143 to support the unified_attention >2GB-cache
decode path, then extended in 46e622539 with the clang<21 inline-asm
fallback) was inlined into amd_buffer_addressing.hpp and
amd_buffer_addressing_builtins.hpp purely for topical fit — both files
also house the other amd_async_* helpers. Functionally the helper has
exactly one caller (tile_scatter_gather::async_load_raw_long), and it
doesn't exist anywhere upstream.

Move it into its own header, include/ck_tile/core/arch/amd_global_load_lds_raw.hpp,
and revert the two long-standing HW-utility headers to bit-identical-to-
upstream. Net effect:

  amd_buffer_addressing.hpp           — 0 lines diff vs merge-base
  amd_buffer_addressing_builtins.hpp  — 0 lines diff vs merge-base
  amd_global_load_lds_raw.hpp         — new file (157 lines)
  tile_scatter_gather.hpp             — +1 include line

The CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN macro lives with the
helper in the new file, so the toolchain gate also leaves no footprint
in the addressing headers.

Verified zero perf delta on the two key UA shapes vs. the pre-relocation
build (b=128/sk=16384/d=128/bf16 and b=1/sk=1M/d=128/bf16); both PASS at
~1.51 ms / 5674 GB/s and 0.77 ms / 5609 GB/s respectively, matching
prior runs within run-to-run noise.

Motivation: shrink the surface area an eventual upstream PR would have
to defend on long-standing core HW headers. Anyone reviewing now sees
the addition as a single new file rather than a +233-line edit across
two of CK's most central utility headers.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-27 13:15:55 +00:00
Damien Lejeune
086512d842 Add IsLocal argument to trait 2026-05-27 13:02:39 +00:00
Damien Lejeune
cea9adab59 Add SWA parameters end-to-end 2026-05-27 12:53:05 +00:00
Damien Lejeune
6753ddfbd4 Add the IsOutOfSinkBound alias + update mask cmd line argument in the example 2026-05-27 12:45:41 +00:00
juuso-oskari
46e6225397 CK-UA: gate dwordx3/x4 global_load_lds builtin on clang≥21, inline-asm fallback
The size=12 and size=16 ImmArg overloads of __builtin_amdgcn_global_load_lds
for gfx950 only landed in AMD clang ~21 (present in ROCm ≥ 7.11 / clang 22,
absent in ROCm 7.1.1 / clang 20). Building this CK branch on the older
toolchain failed during semantic analysis of amd_buffer_addressing_builtins.hpp:

    error: invalid size value
       __builtin_amdgcn_global_load_lds(gptr, lptr, 16, ...);
    note: size must be 1, 2, or 4

The error is unavoidable as soon as the unified_attention pipeline is built —
its `if (cache_ptr_int32_overflow_possible)` dispatch is a runtime branch,
not `if constexpr`, so the `bytes ∈ {12, 16}` instantiations are compiled
regardless of whether any workload at runtime takes that path.

Fix: introduce CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN, gated on
__clang_major__ >= 21 (overridable). When 0, emit
`global_load_lds_dwordx{1,3,4}` via inline asm, with M0 set explicitly
through `s_mov_b32` from the addrspace(3) `lptr` narrowed to its 32-bit
LDS byte offset and wave-uniformed via `readfirstlane`. The assembler
accepts the mnemonic and emits the same HW instruction the builtin
would lower to (verified zero perf delta vs. the builtin path across
the full decode regression sweep — all 8 (b, d, dtype) configs match
to within ≤ 1.5% run-to-run noise when the fallback is force-on).

Two simpler "issue N× size=4" decompositions were tried and rejected:
INST.OFFSET stepping by 4 reproduces the dwordx4 layout for no shape;
stepping by 256 with `gptr += 4` per issue happens to pass on one
big-cache decode shape (b=1 / sk=1M) but fails on b=128 / sk=16384 /
d=128 / bf16. The native dwordx4's in-LDS sub-issue ordering doesn't
reduce to any combination of dword INST.OFFSET steps we could find that
survives all decode shapes; asking the assembler for the literal
instruction sidesteps the question.

The dormant amd_buffer_addressing.hpp copy (used only when CK_TILE_USE_
BUFFER_ADDRESSING_BUILTIN is forced to 0, which doesn't happen on clang
≥ 20) gets the same treatment so toggling the macro doesn't reintroduce
the bug.

Allows building jukorhon/unified-attention-ck on ROCm 7.1.1 unchanged;
upgrading to a newer ROCm container remains the recommended option.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-27 12:45:18 +00:00
juuso-oskari
2645149bbf CK-UA: shrink Tier-2 page-table LDS cache to per-split window
The decode pipeline's Tier-2 LDS page-table cache bulk-loads
block_tables_ptr_[block_table_offset + i] into block_tables_lds[i] at
kernel entry so subsequent refresh_*_offsets calls can resolve the
phys_page via a one-cycle ds_read_b32 broadcast instead of a scoreboarded
s_load_dword. Before this commit the bulk load covered an *absolute*
prefix [0, num_blocks) of the batch's page table — i.e. every page index
this CTA could ever produce, even the ones earlier splits skipped over.

That conservative load (1) wasted bulk-load bytes on splits 1+ and (2)
made the 4096-entry static cap a *total-sequence* limit rather than a
per-split limit, so the long-context decode

    b=64 sq=1 sk=128000 hq=64 hk=8 d=128 bf16 page=16

(total_kv_pages = 8000, num_splits = 4, last split window = [6000, 8000))
tripped `assert(split_end_page <= kPageTableLdsEntries)` on splits 2/3.
The wrapper's split-KV scheduler would otherwise keep each CTA's working
set well under the cap (2000 pages here vs 4096 cap) — the assert was
firing on data the kernel doesn't actually need.

Make the cache per-split-window: bulk-load only
[split_start_page, split_end_page) where
    split_start_page = ⌊num_blocks_start · kPageBlockSize / page_size⌋,
and shift refresh_{k,v}_offsets' lookup by split_start_page so the LDS
index stays in [0, split_window_pages). split_start_page is a
kernel-entry constant, so the subtract folds into the s_load_dword's
immediate offset on every refresh call — no per-iteration cost.

On the prefill path (num_blocks_start == 0) split_start_page == 0 and
the change collapses to the original absolute-indexed load, so prefill
codegen is bit-identical.

Verified on gfx950 (256 CUs):
* Originally failing shape now passes correctness and runs at
  5713 GB/s (vs Triton 4402 GB/s).
* Default regression (b=128 sq=1 sk=16384 hq=64 hk=8, 3 runs × 4
  d/dtype combos) all PASS at both block_size=32 (baseline path)
  and block_size=16; bandwidth unchanged vs pre-patch baseline.
* Long-context decode shapes that previously couldn't run at
  page_size=16 — sk ∈ {65K, 262K, 524K} — now all PASS at
  ~5350-5690 GB/s.
* Prefill (b=1 sq=4096) PASS, perf unchanged.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-27 09:49:08 +00:00
juuso-oskari
badc807025 CK-UA: enable Tier-2 LDS page-table cache on decode + fix split-KV bulk-load OOB
Three coupled changes to the Tier-2 LDS-resident page-table cache:

1. Drop the `kBlockSize >= 8 * warp_size` gate from both the runtime
   kScalarPromote{K,V}PageIdx predicates and the static
   GetPageTableLdsBytes() allocator. The original conservative gate
   excluded TinyDecode (kBlockSize == 64); the trade-off has since
   flipped now that bf16 m16 doubles the per-tile iter count (and
   thus the per-tile page-table refresh count) via the halved kBlockN
   change. Enabling Tier-2 on TinyDecode eliminates the per-iter
   `s_waitcnt vmcnt(0)` drains that the per-lane block_tables_ptr_
   vector loads were forcing.

2. Fix a silent corruption bug in GetPageTableLdsBytes(): the LDS-
   allocation gate carried the old hedge while operator()'s runtime
   gate had already dropped it, so on TinyDecode the bulk-load
   path wrote into LDS regions belonging to the K/V double buffers
   above it. Both gates now share the same constexpr predicate.

3. Split-KV bulk-load correction. refresh_*_offsets indexes
   block_tables_lds by absolute page index (= block_table_offset-
   relative), so on splits 1+ where the CTA's split_token_offset > 0
   the original bulk load only covered pages
   [0, num_pages_for_split) and read OOB. We now load
   [block_table_offset, block_table_offset + split_end_page) to
   cover every absolute page the CTA can index.

Also: add explicit `s_waitcnt_lgkmcnt<0>()` after the bulk-load. On
multi-warp tiers the s_barrier carries the LDS-write drain
implicitly; on single-warp TinyDecode LLVM elides s_barrier entirely
and the refresh path reads stale LDS without the explicit drain.

Validated: correctness sweep across bf16/fp16/fp8 × {decode, prefill}
× b in {1,32,128,256}, sk up to 128k. Decode perf: ~1.18x geomean vs
Triton on long-context d=128 GQA8 (was 1.5x+ pre-fix).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-26 08:21:10 +00:00
juuso-oskari
310efc556f CK-UA: halve kBlockN for bf16/fp16 m16 decode + generalise PVAttrNumAccess
The decode_d128_m16 tier was VGPR-saturated and LDS-bound on bf16/fp16
(probe_decode_d128 showed VGPR=256 + AGPR overflow, ~2x fp8's LDS at
the same kBlockN), capping it at 1 CTA/CU. Halving kBlockN for the
non-fp8 path on the m16 tier sheds enough LDS and VGPR pressure to
fit 3-4 CTAs/CU (LDS-bound). The halved kBlockN forces a smaller-K
MFMA shape on the m16 PV gemm (16x16x32 -> 16x16x16); we also auto-
adjust WarpGemm::K so PVAttrNumAccess picks Single vs Double access
correctly. The PVAttrNumAccess derivation is now generic — driven by
(kABKPerLane, SubMinDim) rather than just (dtype) — so the new
shape compiles without per-variant special-casing.

Variants only affected where cfg::BlockSize/2 >= WarpGemm::N (i.e.
decode_d128_m16); m32/m128/prefill keep their un-halved tiles since
they use 32x32 N-warps.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-26 08:20:55 +00:00
juuso-oskari
89b54563b6 CK-UA: skip post-load page_offsets refresh on final K/V tile
The K_mem_load / V_mem_load lambdas unconditionally call refresh_*_offsets
after each load to prepare per-element page_offsets for the next tile.
On the *last* tile of the (split-KV per-split) loop the next load is
never issued, but the refresh still reads
block_tables[block_table_offset + (last_relative_tile + 1)] — one past
the seq's last valid logical_page on the final split. When block_tables
happens to be the last allocation in a memory page that read faults.

The PyTorch caching allocator hides the bug for small workloads (the
4-byte OOB lands in adjacent live memory and just returns garbage), but
it reproduces reliably once a workload deep-copies >~30 distinct
block_tables tensors and the allocator scatters them across unmapped
page boundaries. The fault is not split-KV specific — the single-launch
path (num_splits == 1) hits the same OOB on the final tile of the only
"split". Verified on MI355 with a 200-config decode FP8 sweep (b ∈
{1,4,8,16,32}, sk ∈ {512,1024,2048}, d ∈ {64,128}, GQA-{2,8}, bs ∈
{16,32,64}, bf16+fp16, ±FP8): 200/200 pass against the reference; same
configs were "memory access fault by GPU node" at iter ~27 before the
fix.

Note on the gate: k_block_idx / v_block_idx are 0-based *relative to
this split*, while num_total_loop is the absolute end index, so the
correct bound is `num_total_loop - num_blocks_start` (= per-split iter
count). Skipping the refresh leaves k_page_offsets / v_page_offsets
stale on the final iter, which is harmless because no subsequent load
consumes them.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-25 09:57:17 +00:00
juuso-oskari
06e1a70e7a CK-UA: constexpr page_size (Tier 3) — prefill_d64 fp8 -15.8%, prefill_d128 fp8 -6.3%
Promote the runtime `page_size` argument to a non-type template parameter
`kPageSize_` on UnifiedAttentionPipeline. Thread it through
unified_attention_kernel_traits and dispatch_variant<V> so the host-side
dispatcher routes on args.page_blk_size ∈ {16, 32, 64} to a constexpr-
pinned prefill instance; values outside that menu (or any decode variant)
fall back to the existing kPageSize_=0 runtime-page-size instance.

Two wins fold together on the prefill tiers:

1. Strength-reduction. Every `/ page_size`, `* page_size`, and `% page_size`
   in the per-tile address chain collapses to a literal-folded shift /
   multiply-by-magic (`/ 32` → shr 5, etc).

2. Wider Tier-0/Tier-2 gate. The scalar-promote + LDS-cache fast path now
   uses the *real* precondition `KY0_step_N <= kPageSize` at compile time
   instead of the conservative `KY0_step_N <= 16` hedge — so prefill_d128
   bf16/fp16 (KY0_step_N=32), prefill_d64 fp8 (KY0_step_N=32), and
   prefill_d64 bf16/fp16 (KY0_step_N=64) also enter the fast path at
   their natural page sizes.

Measured impact (sq=sk=75600, MI355, n=30 iters, GQA-8):

  variant            KY0_step_N  ps   before   after    Δ
  prefill_d128 fp8   16          32   119.0    111.5    -6.3 %
  prefill_d128 bf16  32          32   132.7    130.3    -1.8 %
  prefill_d64  fp8   32          32    80.9     68.1   -15.8 %
  prefill_d64  bf16  64          64    74.4     73.4    -1.3 %

Decode variants stay on the kPageSize_=0 instances (Tier-0 gate gates them
out anyway — <8 warps — and the binary-size cost isn't justified). All
sweep_fp8.sh shapes + 21 multi-seed multi-sk-length prefill shapes
correctness-PASS. Pre-existing Tier-2 LDS-cache limit (4096 entries)
documented in the pipeline header — same constraint applies to the
kPageSize_=0 fallback so this is not a regression.

36 new prefill instance files: prefill_d{64,128} × {fp16, bf16, fp8} ×
{mask, nmask} × {ps16, ps32, ps64}.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-19 12:46:39 +00:00
juuso-oskari
045b1f57bf CK-UA: widen FP8 K/V async loads to dwordx4 where the tile allows it
GetAlignmentK / GetAlignmentV previously returned a blanket 4 B/lane
(one dword) for every FP8/BF8 tile, citing the gfx950 LDS-direct load
constraint (only dword / dwordx3 / dwordx4 are supported). That cap was
correct for the 8-warp prefill variants (kBlockSize=512, NumIssues drops
to 0.5 at 16 B/lane) but over-applied to every decode tier, where the
1/2/4-warp tile geometry has plenty of headroom.

Refactor the alignment selector into GetKVAlignmentBytes<>, which picks
dwordx4 whenever NumIssues = kPageBlockSize*kHeadDim/(kBlockSize*16)
is an integer >= 1 and falls back to dword otherwise. BF16/FP16 paths
stay at 16 B/lane on every compiled tile, so existing perf is unchanged.
FP8 prefill_d{64,128} also keep the historical dword path because
NumIssues = 0.5 there. FP8 decode_d{64,128}_m{16,32,64,128} now use
dwordx4: same byte volume per K/V tile but 4x fewer async-load issues
(SQ_INSTS_VMEM 131M -> 33M on b=128 sq=1 sk=128000 d=64).

Wall-clock impact on the long-context decode sweep (HIP_VISIBLE_DEVICES=2,
ITERS=20, WARMUP=5, MI355):

  shape                              dtype  before    after    speedup
  decode d=64  sq=1 sk=128000 b=128  fp8     7.17 ms  4.57 ms  1.57x
  decode d=64  sq=1 sk=128000 b=256  fp8    16.24 ms  9.51 ms  1.71x
  decode d=128 sq=1 sk=128000 b=128  fp8    13.11 ms  7.15 ms  1.83x
  decode d=128 sq=1 sk=128000 b=256  fp8    31.37 ms  9.78 ms  3.21x
  decode    d=64  sq=1 sk=128000 b=4 fp8     0.42 ms  0.22 ms  1.92x
  decode    d=128 sq=1 sk=128000 b=4 fp8     0.80 ms  0.42 ms  1.93x
  prefill d=64  sq=75600 sk=75600 b=1 fp8   81.4  ms 81.2 ms   1.00x  (dword fallback)
  prefill d=128 sq=75600 sk=75600 b=1 fp8  143.5  ms 143.6 ms  1.00x  (dword fallback)

Correctness verified across fp8/bf16/fp16, causal/non-causal, and all 7
compiled tile variants. Full PMC + PC-sample analysis is in
ua-test-scripts/rocprof_analysis/BOTTLENECK_ANALYSIS.md section 8.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-19 08:06:29 +00:00
juuso-oskari
7a319d9a4b CK-UA: drop redundant phase-0 s_barrier (-3% fp8 prefill_d128 decode)
`ADD_SBARRIER_FOR_PHASE0=1` added an extra `s_barrier()` at the start of
every `cl_p` half of every KV iteration, on top of the three barriers
that already gate the LDS hand-offs in phases 1/2/3.

rocprofv3 bottleneck analysis (b=4 sq=8 sk=4096 hq=64 hk=8 d=128 fp8):
the prefill_d128 8-warp variant spends ~15% of GUI_ACTIVE cycles at
s_barriers and shows %any_wait ≈ 200%. PC sampling pinpoints the
phase-0 `s_barrier` (right after softmax rescale, before async prefetch)
as a top hotspot.

Examining the data flow shows the phase-0 barrier is redundant:
  - phase1's `s_waitcnt vmcnt(...); s_barrier` guards the K-LDS write
    (from the previous iter's K async load) before any warp reads it.
  - phase2's `s_waitcnt lgkmcnt(0); s_barrier` guards the softmax-P
    LDS write before gemm1 reads it.
  - phase3's `s_waitcnt vmcnt(...); s_barrier` guards the V-LDS write
    before the next iter's V-LDS read.

These three already provide every cross-warp ordering the pipeline
needs. The phase-0 barrier was purely defensive.

Measurement: 0.1945 → 0.1883 ms (n=300 iters × 3 trials, single shape).
Correctness verified against the Triton reference on fp8/bf16/fp16 ×
{b=4/32/128} × {sq=1/4/8} × {causal,non-causal} × d∈{64,128}.

Leaving the macro and the `=1` documented path in place so the previous
behaviour can be restored if a future arch/shape regresses.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-18 19:16:48 +00:00
juuso-oskari
3431615ff0 CK-UA: fuse FP8 cvt + cross-lane swap to hide ds_bpermute latency
Previously the 32x32x16 FP8 P-tile cvt and the QK-C -> PV-A cross-lane
swap ran in two separate static_for loops back-to-back inside fmha_alu1:
the whole tile was cvt'd into p.thread_buf_ first, then a second pass
issued one ds_bpermute_b32 per 8-fp8 K-chunk and read/wrote the same
buffer to swap the "bad" 4-byte halves between paired lanes.

The ds_bpermute has nontrivial LDS-DMA latency that the scheduler has
no way to hide when it lives alone in a tight serial loop with the
gather/scatter packs around it.

Fuse the two into one 8-fp8-per-iter loop:
  1. cvt 8 fp32 -> 2 packed uint32 (lo_pack=slot[0..3], hi_pack=slot[4..7])
     using the chained cvt_pk_fp8_f32 pattern matching cast_tile_pk_fp8_fp32.
  2. Pick own_bad = (sub==0 ? hi_pack : lo_pack) and issue ds_bpermute on it.
  3. Write back all 8 fp8 bytes; the "good" half lands first so its byte
     stores can overlap with the in-flight ds_bpermute, and the next
     iter's cvts can begin while the swap is still pending.

The 16x16x32 LDS-roundtrip branch keeps the original separated cvt
loop (no swap latency to hide there since the relayout goes through
LDS, not ds_bpermute).

Single-shape FP8 perf on gfx950 GPU 2 (CUDA graph, 50 iters):
  decode d=128 b=4 sq=8 sk=4096:  0.2106 -> 0.1951 ms  (-7.4%)
  decode d=64  b=4 sq=8 sk=4096:  0.1464 -> 0.1208 ms  (-17.5%)
  prefill d=128 b=2 sq=512 sk=4k: 0.2558 -> 0.2220 ms  (-13.2%)

BF16 unchanged (0.2046 -> 0.2039 ms, within noise).

Correctness: pytest UA correctness suite 405 passed / 80 skipped
(245 BF16/FP16 + 160 FP8), unchanged from before.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-18 15:48:01 +00:00
juuso-oskari
9d7cc3ee9e CK-UA: extend FP8 to the 16x16x32 _m16 decode tier via LDS roundtrip
The 32x32x16 tiers (prefill_d{64,128}, decode_d{64,128}_m{32,64,128}) keep
the cheap in-register `ds_bpermute_b32` cross-lane swap that fixes the
QK-C / PV-A per-thread alias for the union'd `sp_compute` / `p`.

The 16x16x32 m16 tiers (decode_d{64,128}_m16) cannot use the swap -- the
MFMA puts the paired-lane bit at a different position and the
sub=0/sub=1 4-fp8 chunks no longer map onto each other. We add a
layout-agnostic LDS roundtrip as the `else` branch, gated by the same
`PVWarpTile` constexpr:

  - Hoist two distribution-bound windows over the existing `p_lds`
    region (one bound to the QK-C output distribution, one to the PV-A
    input distribution). Done once per kernel invocation.
  - In `fmha_alu1`, after the cvt_pk_fp8_f32 packing chain, view the
    union's bytes as a `static_distributed_tensor<fp8>` in the QK-C
    distribution, `store_tile` it through `p_lds` in canonical (M, N)
    order, `s_barrier`, then `load_tile` back with the PV-A
    distribution and copy into `sp(idx).p`.

A/B'd a uniform LDS-roundtrip (no fast-path) vs the split: pure LDS
regressed decode_m128 by ~1.5x end-to-end (CK FP8 dropped from
~0.39x of Triton FP8 to ~0.16x), driven by the extra block-wide
barrier on the 4-warp decode path. Keeping the swap for 32x32x16
preserves the previously-tuned perf.

Dispatcher (`unified_attention.cpp`) now FP8-enables every UA variant
including decode_d{64,128}_m16. Four new instance .cpp files
(`unified_attention_d{64,128}_fp8_{mask,nmask}_decode_t.cpp`)
instantiate the m16 FP8 kernels.

Pytest (`test_unified_attention_ck_correctness.py`):
  - 245 BF16/FP16: pass (no regression from the pipeline edit).
  - 160 FP8: pass (was 112 before m16 enablement).
  - 80 skipped: block_size<32 or query_len>kv_len -- pre-existing.

Single-shape m16 dispatches verified on gfx950:
  b=128 sq=1 hq=hk=8 d=128 fp8 PASS  (CK 0.109 ms / Triton 0.043 ms)
  b=128 sq=1 hq=hk=8 d=64  fp8 PASS  (CK 0.077 ms / Triton 0.039 ms)

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-15 20:00:35 +00:00
juuso-oskari
63c75277a0 CK-UA: enable FP8 (e4m3) for prefill/m128 and the 32x32x16 small-tile decode variants
Full pipeline support for FP8 (e4m3fn on gfx950 / e4m3fnuz on gfx942)
in the unified-attention kernel, gated to the 32x32x16 MFMA tiers in
both d=64 and d=128 ladders: prefill_d{64,128}, decode_d{64,128}_m128,
decode_d128_m32, and decode_d64_m64. The 16x16x32 _m16 tiers stay
BF16/FP16-only -- the QK-C and PV-A per-thread layouts there differ
by an M<->N swap that the current slot-swap fixup cannot express; a
full per-thread transpose (most likely via LDS) is needed.

Pipeline (unified_attention_pipeline.hpp):
* `fmha_alu1` now performs a cross-lane P-tile fixup right after the
  FP8 packing of softmax(P). It's a `ds_bpermute_b32` between paired
  lanes `lane ^ 32`, swapping sub=0 slot[k_base+4..k_base+7] with
  sub=1 slot[k_base..k_base+3] for every 8-fp8 chunk. This realigns
  the FP8 packed P operand with PV-A's `Single` AttrNumAccess
  per-thread layout, which is necessary because the QK-C output and
  PV-A input alias byte-for-byte via the sp_compute/p union -- and
  for FP8 the two warp-gemm layouts no longer agree (BF16/FP16 keep
  Double AttrNumAccess in the PV gemm, which matches QK-C natively).
  Gated on `Gemm1WarpTile == 32x32x16`; FP8-only (BF16/FP16 paths take
  the existing cvt_pk path unchanged).

Default policy (unified_attention_pipeline_default_policy.hpp):
* PV warp gemm now selects `WGAttrNumAccessEnum::Single` when V is
  fp8/bf8 and `Double` otherwise. Forced by load_tile_transpose's
  SubMinDim = 64-bit / sizeof(V) constraint: for FP8 SubMinDim=8 and
  kABKPerLane=8 only Single satisfies the validation static_asserts.
* GetAlignmentK / GetAlignmentV on gfx950 drop to 4 B/lane for fp8/
  bf8. The natural 16 B/lane async-load that BF16/FP16 use leaves
  NumIssues = 0 for the FP8 tile shapes we compile, and 8 B/lane
  fails the dword / dwordx3 / dwordx4 constraint in
  amd_buffer_addressing_builtins. 4 B/lane gives NumIssues >= 1 on
  every targeted variant and is the same alignment the gfx942
  fallback already used. BF16/FP16 keep the full 16 B/lane path so
  existing perf is unchanged.
* GetSmemSizeKV adds a `VLoadDescSize` lower bound. The
  MakeVLdsLoadBlockDescriptor's element span dominates the banked
  SingleVSize only for FP8 (small per-lane KVector + fixed
  kVLdsPadInBytes = 64), so without it FP8 hits the GetSmemSizeKV
  static_asserts. BF16/FP16 are unaffected.

Warp-gemm headers + dispatcher:
* New `WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed_T<AttrNumAccess>`
  template alias in warp_gemm.hpp (mirrors the existing BF16 32x32x16
  CTransposed template), used by the PV gemm to thread the FP8
  Single AttrNumAccess through.
* New Dispatcher specialization for
  <fp8_t, fp8_t, float, 32, 32, 16, true, false, false, EDouble>
  in warp_gemm_dispatcher.hpp routing to the new template.

ABI / dispatcher (unified_attention.{cpp,hpp}, unified_attention_impl.hpp):
* New `fp8` value in `unified_attention_args::data_type_enum` (selects
  e4m3fn on gfx950 via CK_TILE_USE_OCP_FP8, e4m3fnuz elsewhere).
* New `unified_attention_problem_traits<...::fp8>` alias:
  qkvp_dtype = ck_tile::fp8_t, acc_dtype = float, o_dtype = bf16_t
  (matches the Triton reference), lse_dtype = float.
* Per-tensor `q_descale` / `k_descale` / `v_descale` floats on
  `unified_attention_args` (default 1.0f so non-FP8 round-trips
  cleanly). The pipeline folds q_descale*k_descale into the softmax
  scale and applies v_descale once to o_acc after the 1/l norm --
  same semantics as Triton's q_scale/k_scale/v_scale.
* `dispatch_variant<>` enables FP8 on prefill_d{64,128},
  decode_d{64,128}_m128, decode_d128_m32, decode_d64_m64. The
  16x16x32 _m16 tiers return (false, -1.f) for now (see top comment).

Instances:
* 12 new FP8 .cpp files under example/.../42_unified_attention/
  instances/ covering the 6 enabled variants x {mask, nmask}.

Validation: 112 / 0 / 128 in the FP8 pytest sweep (passed / failed /
m16-skipped); 245 / 245 in the BF16/FP16 sweep (no regression).
Functional correctness is within the FP8 quant-noise tolerance the
Triton FP8 suite uses (atol/rtol = 1.5e-1). Perf still trails Triton
across the enabled tiers (CK FP8 / Triton FP8 = 0.39-0.69x on the
shapes we benchmarked); that's a separate workstream.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-15 17:34:50 +00:00
juuso-oskari
c0e985d075 CK-UA: document why per-issue SRD-rebase path was tried and dropped
Replace the speculative TODO-style comment next to the K_mem_load /
V_mem_load dispatch with a record of the actual experiment: we
implemented async_load_tile_raw_rebased (buffer_load_dword_lds with a
per-issue SRD whose 48-bit base absorbs the wave-uniform page offset),
verified correctness on multiple big-cache decode shapes, and measured
it against the existing async_load_tile_raw_long path on an isolated
GPU. Rebased was at best tied with long and at worst ~6% slower
(b=1 sk=1M d=64 GQA8: 2.46 ms vs 2.32 ms; b=8 sk=200k d=128 GQA8:
2.12 ms vs 2.02 ms). The workloads are compute / softmax bound, not
K/V load bandwidth bound, so the buffer_load throughput edge never
materialises, while the per-issue SRD construction adds real SGPR
pressure.

No functional change in this commit -- only the explanatory comment is
updated so the next person who eyes the same idea finds the receipts
before re-implementing.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-15 10:18:39 +00:00
juuso-oskari
1f69421434 CK-UA: dispatch K/V async load on cache_ptr_int32_overflow_possible
The shared-SRD buffer_load_dword_lds path that K_mem_load / V_mem_load use
wraps the per-lane voffset (int32 bytes) once
  num_blocks * page_size * row_stride * sizeof(T) > INT32_MAX,
silently returning wrong data on large paged-KV pools (e.g. >4 GB caches).

Add a second path, async_load_tile_raw_long, that issues the same load via
__builtin_amdgcn_global_load_lds with per-lane 64-bit base pointers, lifting
both 4 GB limits (SRD size + voffset). Per-issue LDS pointers are computed
explicitly because the intrinsic sets m0 itself, so the old m0_set / m0_inc
bookkeeping doesn't apply. The path also clamps lane_elem_off to the live
buffer range to mimic the original SRD's hardware OOB behaviour.

Dispatch is a wave-uniform runtime branch on a new
cache_ptr_int32_overflow_possible flag plumbed from
unified_attention_args through MakeKargs into the pipeline operator().
Small caches keep the original buffer_load throughput; only the (rare)
>4 GB cache pays the global_load_lds cost.

k_page_offsets / v_page_offsets are widened to long_index_t. The original
buffer_load path implicitly narrows back to int32 when forwarding through
async_get_vectorized_elements_raw, which is intentional and safe whenever
the overflow flag is false.

For diagnostics, also derive a constexpr KWaveSpanInN =
(LaneGroups - 1) * NumWarps + 1 inside the pipeline; when this exceeds
page_size a single buffer_load spans multiple random pages, so the
per-issue SRD-rebase optimisation (not implemented yet) would not apply
even on a sub-4 GB cache. Informational only today.

Test: ua-test-scripts correctness sweep (245/245 pass), plus
  test_single_shape.py -b 32 -sq 8192 -sk 120000 -hq 64 -hk 8 -d 64 \
      --num-blocks 1200000 --block-size 16 --test
which previously returned wrong data due to the int32 wrap and now passes
with max abs diff 1.22e-04 vs Triton.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-15 09:00:43 +00:00
juuso-oskari
d77f0bea63 CK-UA: collapse MHA/GQA variants -- one binary per (head_dim, kBlockM)
After moving kBlockQ to runtime in the previous commit, the static
NumQPerKV in `variant_config<V>` and the runtime-vs-static assert in
the kernel became the only things still tying a compiled binary to a
specific num_queries_per_kv. Drop both and the existing instances now
serve every num_qpkv that divides kBlockM evenly.

Concretely:
  * `variant_config<V>` -- remove the NumQPerKV field from every
    specialization.
  * `unified_attention_kernel_traits` -- remove the `num_queries_per_kv`
    / `kBlockQ = kBlockM / num_qpkv` derivation. The BlockTile's 2nd
    entry (the static `kBlockQ` exposed via UnifiedAttentionShape) is
    anchored at kBlockM so it describes the "num_qpkv == 1" fallback;
    the actual kBlockQ is always the runtime value.
  * `unified_attention_kernel_launch` -- recompute kBlockQ at host time
    from `args.num_queries_per_kv` for the total_num_q_blocks math.
  * `unified_attention_kernel.hpp` -- drop the
    `assert(kBlockQ_dyn == kBlockQ)` (it enforced the very coupling we
    just removed).
  * `unified_attention.cpp::select_config` -- collapse the two
    per-num_qpkv code paths into a single (head_dim, avg_rows,
    max_rows) ladder, where avg_rows = avg_q * num_qpkv.

Variant renames (8 variants):
  prefill_d128_mha       -> prefill_d128
  decode_d128_mha_m128   -> decode_d128_m128
  decode_d128_mha_m32    -> decode_d128_m32
  decode_d128_mha_m16    -> decode_d128_m16
  prefill_d64_gqa8       -> prefill_d64
  decode_d64_gqa8_m128   -> decode_d64_m128
  decode_d64_gqa8_m64    -> decode_d64_m64
  decode_d64_gqa8_m16    -> decode_d64_m16

The 16 d=64 instance files lose their `_gqa8` infix to match the
d=128 naming (file count unchanged: 16 dtypes x mask combos per
head_dim).

Validation:
  * Correctness suite: 241/245 (same 4 pre-existing int32-overflow
    failures in the prefill rebased-pointer path).
  * d=128 GQA-8 (a NEW combo we never had a binary for) -- runs
    correctly on the existing decode_d128_m* binaries with num_qpkv=8
    at runtime. max abs diff <= 1e-2 vs the torch reference at ql in
    {1, 4, 16}.
  * d=64 MHA (also a new combo) -- runs correctly on the existing
    decode_d64_m* binaries with num_qpkv=1. Same tolerance.
  * Perf sweep (b=4..256, sk=120000, MI300):
      d=64  GQA-8: speedups 1.28x..1.84x vs Triton (within 0.6%
                   of baseline).
      d=128 MHA:   speedups 0.98x..1.14x vs Triton (within 0.3%
                   of baseline).

Unlocked: adding new (head_dim, num_qpkv) combos no longer requires
new kernel binaries -- just a host-side heuristic update mapping the
combo to the appropriate (kBlockM, BlockWarps) ladder.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-12 12:15:55 +00:00
juuso-oskari
614afea7eb CK-UA: derive kBlockQ at runtime, decouple from variant template
kBlockQ (= kBlockM / num_queries_per_kv) was constexpr in
`UnifiedAttentionShape` / the kernel-traits, forcing one kernel
instance per (kBlockM, num_qpkv) pair even though the matmul tile is
fully determined by kBlockM and kHeadDim. Audit confirmed kBlockQ
only feeds:

  * arithmetic in `unified_attention_kernel.hpp` (loop bounds, Q-tile
    indexing, query_len padding),
  * `pad_tensor_view` size tuples for Q/O/LSE DRAM views,
  * one `mask.IsEdgeTile(... number<kBlockQ>{} ...)` call inside the
    pipeline's per-K-tile mask check.

None of these structurally need a compile-time value:

* `pad_tensor_view` already accepts mixed runtime/compile-time tuple
  elements (e.g. it's passed plain `1` next to `kHeadDimPadded`).
* `IsEdgeTile` only does runtime arithmetic on the tile size; adding a
  runtime overload that accepts `index_t` is trivial (the compile-time
  one now forwards to it).

Wiring:
  * `block_masking.hpp` -- add an `IsEdgeTile(..., index_t tile_h,
    index_t tile_w)` overload; the existing `number<>` overload just
    forwards to it.
  * `unified_attention_pipeline.hpp` -- new optional
    `num_queries_per_kv` arg on the pipeline's `operator()` (default 0
    keeps existing call sites unchanged). Computes
    `kBlockQ_dyn = (num_qpkv > 0) ? (kBlockM / num_qpkv) : kBlockQ`
    once at the top, uses it in the IsEdgeTile call.
  * `unified_attention_kernel.hpp` -- compute
    `const index_t kBlockQ_dyn = kBlockM / kargs.num_queries_per_kv`
    once and replace every per-call `kBlockQ` use with `kBlockQ_dyn`.
    Pass `kargs.num_queries_per_kv` through to the pipeline. The
    debug-only assert(`kBlockQ_dyn == kBlockQ`) keeps the static and
    dynamic values in lock-step until we actually collapse variants.

Perf A/B (b=4..256, sk=120000, MI300):

  d=128 MHA (num_qpkv = 1, runtime div is trivial):
    BW within +/-0.2% across all batch sizes (noise).

  d=64 GQA-8 (num_qpkv = 8, runtime division actually happens):
    speedups 1.28x..2.14x vs Triton -- identical to baseline.

Correctness suite stays at 241/245 (same 4 pre-existing int32-overflow
failures in the d=128 prefill rebased-pointer path).

This is a no-op on perf and unlocks a follow-up where we collapse the
two num_qpkv values per (head_dim, kBlockM) -- e.g. the future d=128
GQA-8 variant can reuse the existing decode_d128_mha_* instances by
just passing a different runtime num_queries_per_kv.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-12 12:01:59 +00:00
juuso-oskari
25364aa634 Add KV-segment parallelism to CK unified attention pipeline
End-to-end split-KV (FlashDecoding-style) for the CK unified attention
kernel. The host launches a single 3D grid with z == num_splits; each
CTA computes its KV-range slice and writes a normalized (o_acc, lse)
partial to FP32 workspaces, which the caller reduces into the final
output.

Pipeline changes:
- operator() returns ck_tile::make_tuple(o_acc, lse) instead of just
  o_acc. The masked-empty early-exit returns lse = -inf so downstream
  combine weighs the partial as zero.
- LSE is built in the natural-log domain from the pipeline's *unscaled*
  rowmax: lse = (scale_s / log2(e)) * m + log(l). Previously we used
  m / log2(e) + log(l), which dropped the per-head scale and produced
  LSE values ~1/scale too large.
- Fix post-process parity: which SP register is left in the
  alu0-done/alu1-pending state at loop exit depends on the parity of
  the *iteration count* (= num_total_loop - num_blocks_start), not on
  num_total_loop alone. For non-split (num_blocks_start == 0) the two
  parities coincide; for splits starting at an odd tile they don't.
- Fix split-KV page-table offset: num_blocks_start is counted in
  kPageBlockSize-sized tiles, but block_tables is indexed in
  page_size-sized pages — shifting block_table_offset by num_blocks_start
  reads the wrong pages whenever kPageBlockSize != page_size. Replaced
  with split_token_offset = num_blocks_start * kPageBlockSize added to
  logical_token before /page_size, so the page lookup uses the absolute
  token position.

Kernel + dispatcher:
- Drop kargs.i_split; each CTA reads i_split = blockIdx.z.
- GridSize{2D,Decode} now take num_splits and add it as the z-dim
  (defaults to 1, so non-split callers see dim3(..., 1, 1)).
- New write path: when num_splits > 1, the kernel skips the user
  epilogue and instead writes the FP32 (o_acc, lse) tile pair into
  workspace tensors at [head, split, batch_start_token, ...] using
  Default2DEpilogue (UseRawStore=true) for o_acc and store_tile for
  lse. Host strides come from kargs.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-12 08:42:09 +00:00
juuso-oskari
473869aba5 Lift kPageBlockSize <= page_size constraint in CK-UA pipeline
Refactor the K/V DRAM access in the unified-attention pipeline to use
tile_scatter_gather with a unified per-(thread, Y0-iter) page-offset
formula:

    logical_token = tile_idx * kPageBlockSize + thread_N_pos + i * Y0_step_N
    logical_page  = logical_token / page_size
    within_page   = logical_token % page_size
    phys_page     = block_tables[block_table_offset + logical_page]
    page_offsets[i] = (phys_page * page_size + within_page) * row_stride

The page indirection now lives entirely in page_offsets, refreshed via
update_page_idx() between iters. The per-iter SRD rebase
(set_bottom_tensor_view_data_ptr + init_raw) and the use_ptr_rebase
overflow heuristic are gone.

Effects:
 - The assertion kv_page_size_in_blocks >= 1 (i.e. kPageBlockSize <=
   page_size) in the kernel is dropped. Tiles may now span multiple
   cache pages, as long as Y0_step_N (= N1*N2 from the K/V tile dist)
   divides page_size so that a wave-wide load never straddles a page.
 - Pipeline arg renamed kv_page_size_in_blocks -> page_size (PageSize
   in tokens). Kernel passes kargs.page_size through directly.
 - Validated correctness vs Triton on bf16 / d=64 / decode_s with
   block_size in {16, 32, 64}; max abs diff 1.22e-04 in all cases.
   Perf is on par with the prior pass-1 scaffolding (~3.6 ms on the
   131072-context shape).

TODO(overflow): page_offsets are index_t; caches whose
num_blocks * page_size * row_stride exceeds INT32_MAX will wrap.
A future change should plumb long_index_t through the scatter-gather
load path or compute a per-batch min-page shift in a pre-pass.

TODO(unsupported regime): page_size < Y0_step_N (a wave crosses a page
mid-iter) needs per-lane VGPR SRDs and is not implemented.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-11 10:04:01 +00:00
root
8506db8761 Fix int32 overflow in CK-UA pipeline via pointer rebasing
tensor_coordinate::get_offset() returns index_t (int32), causing overflow
when page_idx * block_size * stride > 2^31 (~131K blocks for d64/GQA-8).

Fix: rebase K/V data pointer for each page using int64 arithmetic instead
of set_window_origin with large offsets. After rebasing p_data_ and
buffer_size_, call init_raw() to refresh the AMD buffer resource descriptor,
then set_window_origin({0,0}) to reset cached coordinates.

Tested: num_blocks up to 2M with nkh=1/8, blk=32/64. All pass.
Made-with: Cursor
2026-04-02 09:39:07 +00:00
root
e8587b86c2 Fix CK-UA pipeline: s_waitcnt_vmcnt<0> in fmha_post_process
The final V tile's async load was not properly waited on before reading
from LDS: s_waitcnt_vmcnt<K_inst> allowed V_inst outstanding loads
(a no-op when K_inst == V_inst). The last loop iteration never prefetches
K, so only V is outstanding. Use s_waitcnt_vmcnt<0> unconditionally.

This partially fixes the BS32 race condition for production workloads
(maxk >= 256). A deeper pipeline race remains for very short KV
sequences (maxk < ~165, 2-5 pages) with block_size=32 at high batch.

Made-with: Cursor
2026-04-01 23:04:07 +00:00
root
87d16738bf WIP: CK-UA KV-segment parallelism - kernel args and split range
Added split-KV fields to UnifiedAttentionVarlenKargs (num_splits,
i_split, lse_acc_ptr, o_acc_ptr + strides). Modified operator() to
compute per-split KV range using blocks_per_split.

INCOMPLETE: The pipeline returns normalized o_acc but the split-KV
combine kernel needs unnormalized o_acc + lse. Need to modify the
pipeline to optionally return m and l values alongside o_acc.

The kernel changes compile but the epilogue needs the split path
(write to float accumulators instead of final output).

Made-with: Cursor
2026-04-01 19:09:59 +00:00
root
cd7ba6e2e8 Add unified attention (42_unified_attention)
Squashed from aghamari/unified-attention-decode-opt branch.

CK tile paged-KV attention kernel optimized for decode with 4-tier
dispatch (tiny/small/medium/large), 16x16 MFMA, 2D decode grid,
head-group merging. Supports hdim=64 GQA-8 and hdim=128 MHA with
block_size=32.

Made-with: Cursor
2026-04-01 16:39:15 +00:00
root
ec2db01e4a Fix fmha_fwd early-exit bug: seqlen_q <= min_seqlen_q should be <
The kSkipMinSeqlenQ optimization incorrectly used <= comparison, causing
the kernel to skip batches where seqlen_q equals min_seqlen_q. This
happens in the common case of no padding (all batches have the same
seqlen_q == min_seqlen_q), producing all-zero output silently.

Changed to strict < so batches with exactly min_seqlen_q tokens are
still processed.

Made-with: Cursor
2026-04-01 16:24:31 +00:00
root
6729989b97 Fix FMHA split-KV for paged-KV with page_block_size < kN0
Cherry-picked from aghamari/unified-attention-decode-opt (fadf0d585).
- block_masking.hpp: 5-param GetTileRangeAlongX for GenericAttentionMask
- fmha_fwd_splitkv.py: bn0=32 for hdim=64

Made-with: Cursor
2026-04-01 16:24:19 +00:00
root
4c5e290378 Add unified attention (42_unified_attention) and topk_softmax_decode
Squashed from aghamari/unified-attention-decode-opt branch.

42_unified_attention: CK tile paged-KV attention kernel optimized for
decode with 4-tier dispatch (tiny/small/medium/large), 16x16 MFMA,
2D decode grid, head-group merging. Supports hdim=64 GQA-8 and
hdim=128 MHA with block_size=32.

topk_softmax_decode: fused topk + softmax kernel for M=1 MoE decode.

Made-with: Cursor
2026-04-01 16:24:04 +00:00
Chinmay Dattanand Kuchinad
2bb69a24ea [rocm-libraries] ROCm/rocm-libraries#5776 (commit ee1bbcb)
[CK] Fix async pivot mismatch in persistent GEMM kernel
 scheduler (#5776)

## Motivation

Fix pivot mismatch in the persistent GEMM kernel's async input scheduler
that causes **GPU hangs** and incorrect results when used with AsyncTP
(Asynchronous Tensor Parallelism) on ROCm.

PyTorch's `_fused_all_gather_matmul_native` uses this persistent GEMM
kernel with chunk signals to overlap communication and computation. The
pivot mechanism ensures each rank starts computing from its own local
shard first (which is already available), then moves to remote chunks as
they arrive over the network.

Because of the pivot mismatch, the kernel frequently waits on signals
for chunks that have not yet arrived, while attempting to read data from
completely different chunks. This synchronization desync reliably
triggers infinite hangs during multi-GPU native AsyncTP execution. This
fix is required to enable functional AsyncTP support on ROCm.

## Technical Details

In the persistent kernel loop (`UniversalGemmKernel::operator()`), the
M-tile coordinate used for data selection (`i_m`) and the M-tile
coordinate used for the chunk-signal wait (`chunk_idx`) were derived
from inconsistent bases:

* `i_m` was computed from the **unpivoted** tile index `iM`.
* `chunk_idx` was computed from the **pivoted** expression `(iM +
tile_idx_pivot)`.

This means the kernel could wait for chunk N's signal but then read from
chunk M's memory, or vice versa. The mismatch scales with GPU count:
with 2 GPUs ~50% of tiles are wrong, with 4 GPUs ~75%, etc.

**The Fix:**
Introduce a single pivoted M-tile index (`iM_eff`) and derive both `i_m`
and `chunk_idx` from it. This guarantees the kernel always waits for the
correct chunk before reading its data.

*(Note: Minor cosmetic `clang-format` changes were also pulled in
alongside the fix).*

## Test Plan

1. Build PyTorch with this CK change.
2. Run the specific multi-GPU AsyncTP native test:
`timeout 180s env HIP_VISIBLE_DEVICES=0,1 pytest
test/distributed/test_symmetric_memory.py -k
test_fused_all_gather_matmul_native -q -s -x`

## Test Result

Tests verify correct overlapping execution without hangs or accuracy
mismatches when running the AsyncTP native path with non-zero pivots.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-01 16:22:08 +00:00
aledudek
119712bd90 [rocm-libraries] ROCm/rocm-libraries#4469 (commit 0844cb0)
[CK_TILE] Add pooling in tile_engine

## Motivation

<!-- Explain the purpose of this PR and the goals it aims to achieve.
-->
Add pooling in ck tile engine

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-01 07:32:36 +00:00
Yi DING
791afc6465 [rocm-libraries] ROCm/rocm-libraries#5991 (commit 8d85e8e)
[CK_TILE] Fix FMHA BWD IGLP incorrect results due to AGPR
 misallocation (#5991)

## Motivation

After PR #5790 removed the `if constexpr(FmhaMask::IsMasking)` guard
around the
`num_total_loop <= 0` early-exit check, the IGLP pipeline
(`BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP`) produces incorrect dK/dV
gradients for
non-masking kernels (even with fix in #5915). Assembly inspection
confirms that the CFG change causes the LLVM
register allocator to reuse AGPR accumulators as scratch destinations in
the dK/dV
reduction loop, breaking the loop-carried accumulation across Q-tile
iterations.

## Technical Details

- Add `[[unlikely]]` to the `num_total_loop <= 0` early-exit in
`BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP`. This attribute is load-bearing:
it
restores the CFG shape that the register allocator needs to correctly
assign
  dedicated AGPRs to each column of the dK/dV accumulator.
- Only the IGLP pipeline is affected; the other two BWD pipelines do not
exhibit
  this issue.

## Test Plan

## Test Result

## Submission Checklist

- [x] Look over the contributing guidelines at

https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-01 05:45:19 +00:00
Estevan Vedovelli
a33b5be1b9 [rocm-libraries] ROCm/rocm-libraries#6022 (commit 54b284a)
[CK] contraction: extend GetTypeString() to include
 layout-differentiating params (#6022)

## Motivation

Consumers that identify kernels by their `GetTypeString()` (such as
hipTensor's actor-critic kernel selection, which hashes the string into
a
stable cross-platform UID) were silently dropping one of two colliding
variants during registry insertion.

`GetTypeString()` in `DeviceContractionMultipleD_Xdl_CShuffle`
previously
printed 13 template parameters, omitting
`ABlockTransferSrcScalarPerVector`,
`BBlockTransferSrcScalarPerVector`, `ABlockLdsExtraM`, and
`BBlockLdsExtraN`.

These four parameters determine the block-transfer access width and LDS
padding strategy, and are precisely what differentiates the `kk`, `kn`,
`mk`, and `mn` layout variants from one another when all other geometry
parameters are equal. Two instantiations with identical 13-parameter
strings
are distinct C++ types that accept different stride layouts and reject
each
other's arguments via `IsSupportedArgument`.

This patch extends the output to 17 parameters so that every distinct
template instantiation of this class produces a unique
`GetTypeString()`.

## Technical Details

`include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp`:
- extend `GetTypeString()` from 13 to 17 parameters including
`ABlockTransferSrcScalarPerVector`,
`BBlockTransferSrcScalarPerVector`, `ABlockLdsExtraM`, and
`BBlockLdsExtraN`.

## Test Plan

Build CK and hipTensor with these changes, and verify hipTensor can
differentiate and select the
correct kernels with layout variations.

## Test Result

CK is building correctly and hipTensor is selecting the kernels
correctly.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-31 15:19:43 +00:00
Bartłomiej Kocot
ef4ff4667d [rocm-libraries] ROCm/rocm-libraries#5842 (commit 04c5690)
[CK][CK Tile] Force padding for atomic_add bf16 C tensor
 (#5842)

## Motivation

Force padding for atomic_add bf16 C tensor to avoid memfaults.

## Technical Details

- add global atomic add for bf16 and enable them
- add padding for atomic add bf16 due to the lack of oob
- remove padding for not continous dims in conv for other cases
- minor bwd data conv fixes

## Test Plan

test_grouped_conv_*_tile

## Test Result

pending

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-31 08:03:41 +00:00
jakpiase
66dc81d530 [rocm-libraries] ROCm/rocm-libraries#5729 (commit 516c974)
[CK_TILE] Changed cshuffle LDS descriptor to naive layout
 (#5729)

## Motivation
This PR changes gemm/convolution cshuffle layout into plain one. to
improve cshuffle operation performance.

## Technical Details
The purpose is that before this change the cshuffle layout was having
some descriptor transformations that were probably aimed at reducing LDS
bank conflicts, but the transformations itself were terribly slow, which
negatively impacted the performance.

## Test Plan
There is no need for additional tests, since current tests cover this
functionality.
2026-03-31 03:40:25 +00:00