Files
composable_kernel/include
Marcus Rosen ad2f19fd31 [CK] DequantPack8 + DequantPack8WithZp bf16 overloads (asymmetric int4 / AWQ)
Adds bf16 (bhalf8_t / bhalf2_t) overloads to DequantPack8 and DequantPack8WithZp
so consumers of the wmma_cshuffle_v3_b_scale device-op can dispatch bf16
output without link errors.

The bf16 overloads use new helpers `i4_to_bhalf4_scale` and
`i4_to_bhalf4_zp_scale`. Implementation notes:

* The fp16 helpers (`i4_to_half4_scale` / `i4_to_half4_zp_scale`) use a
  bit-pattern trick: load nibble bits into the fp16 mantissa via
  AND/OR/EX, subtract a magic constant, and apply scale via native half2
  arithmetic. RDNA3 has no equivalent native bhalf2 multiply / fma path,
  so the bf16 helpers fall back to scalar fp32 conversion (one
  `type_convert<float>` per lane, multiply, `type_convert<bhalf_t>` back).

* Lane order matters: the bf16 helpers MUST extract q's nibbles at the
  same positions {0, 4, 1, 5} as the fp16 helpers, NOT the {0, 1, 2, 3}
  order produced by the existing unscaled `i4_to_bhalf4` helper (which
  uses a different __byte_perm pattern). Mismatched orderings appear to
  work for constant-data correctness tests (every nibble has the same
  value so position doesn't matter) but produce wildly incorrect output
  when nibbles vary across a pack — downstream WMMA accumulates each
  dequant lane against a specific activation lane, so reordering breaks
  the GEMM. The bf16 helpers below explicitly reproduce the fp16
  position pattern.

* Cost: bf16 dispatch is measured ~30% slower than fp16 dispatch on
  gfx1151 (Strix Halo / RDNA 3.5) for the same kernel config. Acceptable
  trade-off for correctness; can be optimized later by adapting the fp16
  bit-trick to bf16's wider exponent.

The reference (slow) `\!CK_USE_PK4_LAYOUT_SHUFFLE` paths in
DequantPack8::operator()(bhalf8_t&, ...) and DequantPack8WithZp::
operator()(bhalf8_t&, ...) similarly use scalar fp32 fallback — bf16
arithmetic isn't broadly available outside the pk4 fast path.

Verified against torch reference on gfx1151 via the standalone aiter
op_tests/test_gemm_w4a16.py harness — sym + asym × fp16 + bf16 all
pass within their respective fp tolerances (fp16 max_abs/max_ref =
5.4e-4, bf16 max_abs/max_ref = 4.5e-3 at M=2048 N=19456 K=2560 G=128).

JIRA: AIESW-32176.
2026-05-14 13:48:13 -07:00
..