Files
composable_kernel/include
juuso-oskari 63c75277a0 CK-UA: enable FP8 (e4m3) for prefill/m128 and the 32x32x16 small-tile decode variants
Full pipeline support for FP8 (e4m3fn on gfx950 / e4m3fnuz on gfx942)
in the unified-attention kernel, gated to the 32x32x16 MFMA tiers in
both d=64 and d=128 ladders: prefill_d{64,128}, decode_d{64,128}_m128,
decode_d128_m32, and decode_d64_m64. The 16x16x32 _m16 tiers stay
BF16/FP16-only -- the QK-C and PV-A per-thread layouts there differ
by an M<->N swap that the current slot-swap fixup cannot express; a
full per-thread transpose (most likely via LDS) is needed.

Pipeline (unified_attention_pipeline.hpp):
* `fmha_alu1` now performs a cross-lane P-tile fixup right after the
  FP8 packing of softmax(P). It's a `ds_bpermute_b32` between paired
  lanes `lane ^ 32`, swapping sub=0 slot[k_base+4..k_base+7] with
  sub=1 slot[k_base..k_base+3] for every 8-fp8 chunk. This realigns
  the FP8 packed P operand with PV-A's `Single` AttrNumAccess
  per-thread layout, which is necessary because the QK-C output and
  PV-A input alias byte-for-byte via the sp_compute/p union -- and
  for FP8 the two warp-gemm layouts no longer agree (BF16/FP16 keep
  Double AttrNumAccess in the PV gemm, which matches QK-C natively).
  Gated on `Gemm1WarpTile == 32x32x16`; FP8-only (BF16/FP16 paths take
  the existing cvt_pk path unchanged).

Default policy (unified_attention_pipeline_default_policy.hpp):
* PV warp gemm now selects `WGAttrNumAccessEnum::Single` when V is
  fp8/bf8 and `Double` otherwise. Forced by load_tile_transpose's
  SubMinDim = 64-bit / sizeof(V) constraint: for FP8 SubMinDim=8 and
  kABKPerLane=8 only Single satisfies the validation static_asserts.
* GetAlignmentK / GetAlignmentV on gfx950 drop to 4 B/lane for fp8/
  bf8. The natural 16 B/lane async-load that BF16/FP16 use leaves
  NumIssues = 0 for the FP8 tile shapes we compile, and 8 B/lane
  fails the dword / dwordx3 / dwordx4 constraint in
  amd_buffer_addressing_builtins. 4 B/lane gives NumIssues >= 1 on
  every targeted variant and is the same alignment the gfx942
  fallback already used. BF16/FP16 keep the full 16 B/lane path so
  existing perf is unchanged.
* GetSmemSizeKV adds a `VLoadDescSize` lower bound. The
  MakeVLdsLoadBlockDescriptor's element span dominates the banked
  SingleVSize only for FP8 (small per-lane KVector + fixed
  kVLdsPadInBytes = 64), so without it FP8 hits the GetSmemSizeKV
  static_asserts. BF16/FP16 are unaffected.

Warp-gemm headers + dispatcher:
* New `WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed_T<AttrNumAccess>`
  template alias in warp_gemm.hpp (mirrors the existing BF16 32x32x16
  CTransposed template), used by the PV gemm to thread the FP8
  Single AttrNumAccess through.
* New Dispatcher specialization for
  <fp8_t, fp8_t, float, 32, 32, 16, true, false, false, EDouble>
  in warp_gemm_dispatcher.hpp routing to the new template.

ABI / dispatcher (unified_attention.{cpp,hpp}, unified_attention_impl.hpp):
* New `fp8` value in `unified_attention_args::data_type_enum` (selects
  e4m3fn on gfx950 via CK_TILE_USE_OCP_FP8, e4m3fnuz elsewhere).
* New `unified_attention_problem_traits<...::fp8>` alias:
  qkvp_dtype = ck_tile::fp8_t, acc_dtype = float, o_dtype = bf16_t
  (matches the Triton reference), lse_dtype = float.
* Per-tensor `q_descale` / `k_descale` / `v_descale` floats on
  `unified_attention_args` (default 1.0f so non-FP8 round-trips
  cleanly). The pipeline folds q_descale*k_descale into the softmax
  scale and applies v_descale once to o_acc after the 1/l norm --
  same semantics as Triton's q_scale/k_scale/v_scale.
* `dispatch_variant<>` enables FP8 on prefill_d{64,128},
  decode_d{64,128}_m128, decode_d128_m32, decode_d64_m64. The
  16x16x32 _m16 tiers return (false, -1.f) for now (see top comment).

Instances:
* 12 new FP8 .cpp files under example/.../42_unified_attention/
  instances/ covering the 6 enabled variants x {mask, nmask}.

Validation: 112 / 0 / 128 in the FP8 pytest sweep (passed / failed /
m16-skipped); 245 / 245 in the BF16/FP16 sweep (no regression).
Functional correctness is within the FP8 quant-noise tolerance the
Triton FP8 suite uses (atol/rtol = 1.5e-1). Perf still trails Triton
across the enabled tiers (CK FP8 / Triton FP8 = 0.39-0.69x on the
shapes we benchmarked); that's a separate workstream.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-15 17:34:50 +00:00
..