Files
composable_kernel/include
juuso-oskari 9d7cc3ee9e CK-UA: extend FP8 to the 16x16x32 _m16 decode tier via LDS roundtrip
The 32x32x16 tiers (prefill_d{64,128}, decode_d{64,128}_m{32,64,128}) keep
the cheap in-register `ds_bpermute_b32` cross-lane swap that fixes the
QK-C / PV-A per-thread alias for the union'd `sp_compute` / `p`.

The 16x16x32 m16 tiers (decode_d{64,128}_m16) cannot use the swap -- the
MFMA puts the paired-lane bit at a different position and the
sub=0/sub=1 4-fp8 chunks no longer map onto each other. We add a
layout-agnostic LDS roundtrip as the `else` branch, gated by the same
`PVWarpTile` constexpr:

  - Hoist two distribution-bound windows over the existing `p_lds`
    region (one bound to the QK-C output distribution, one to the PV-A
    input distribution). Done once per kernel invocation.
  - In `fmha_alu1`, after the cvt_pk_fp8_f32 packing chain, view the
    union's bytes as a `static_distributed_tensor<fp8>` in the QK-C
    distribution, `store_tile` it through `p_lds` in canonical (M, N)
    order, `s_barrier`, then `load_tile` back with the PV-A
    distribution and copy into `sp(idx).p`.

A/B'd a uniform LDS-roundtrip (no fast-path) vs the split: pure LDS
regressed decode_m128 by ~1.5x end-to-end (CK FP8 dropped from
~0.39x of Triton FP8 to ~0.16x), driven by the extra block-wide
barrier on the 4-warp decode path. Keeping the swap for 32x32x16
preserves the previously-tuned perf.

Dispatcher (`unified_attention.cpp`) now FP8-enables every UA variant
including decode_d{64,128}_m16. Four new instance .cpp files
(`unified_attention_d{64,128}_fp8_{mask,nmask}_decode_t.cpp`)
instantiate the m16 FP8 kernels.

Pytest (`test_unified_attention_ck_correctness.py`):
  - 245 BF16/FP16: pass (no regression from the pipeline edit).
  - 160 FP8: pass (was 112 before m16 enablement).
  - 80 skipped: block_size<32 or query_len>kv_len -- pre-existing.

Single-shape m16 dispatches verified on gfx950:
  b=128 sq=1 hq=hk=8 d=128 fp8 PASS  (CK 0.109 ms / Triton 0.043 ms)
  b=128 sq=1 hq=hk=8 d=64  fp8 PASS  (CK 0.077 ms / Triton 0.039 ms)

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-15 20:00:35 +00:00
..