mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Adds bf16 (bhalf8_t / bhalf2_t) overloads to DequantPack8 and DequantPack8WithZp
so consumers of the wmma_cshuffle_v3_b_scale device-op can dispatch bf16
output without link errors.
The bf16 overloads use new helpers `i4_to_bhalf4_scale` and
`i4_to_bhalf4_zp_scale`. Implementation notes:
* The fp16 helpers (`i4_to_half4_scale` / `i4_to_half4_zp_scale`) use a
bit-pattern trick: load nibble bits into the fp16 mantissa via
AND/OR/EX, subtract a magic constant, and apply scale via native half2
arithmetic. RDNA3 has no equivalent native bhalf2 multiply / fma path,
so the bf16 helpers fall back to scalar fp32 conversion (one
`type_convert<float>` per lane, multiply, `type_convert<bhalf_t>` back).
* Lane order matters: the bf16 helpers MUST extract q's nibbles at the
same positions {0, 4, 1, 5} as the fp16 helpers, NOT the {0, 1, 2, 3}
order produced by the existing unscaled `i4_to_bhalf4` helper (which
uses a different __byte_perm pattern). Mismatched orderings appear to
work for constant-data correctness tests (every nibble has the same
value so position doesn't matter) but produce wildly incorrect output
when nibbles vary across a pack — downstream WMMA accumulates each
dequant lane against a specific activation lane, so reordering breaks
the GEMM. The bf16 helpers below explicitly reproduce the fp16
position pattern.
* Cost: bf16 dispatch is measured ~30% slower than fp16 dispatch on
gfx1151 (Strix Halo / RDNA 3.5) for the same kernel config. Acceptable
trade-off for correctness; can be optimized later by adapting the fp16
bit-trick to bf16's wider exponent.
The reference (slow) `\!CK_USE_PK4_LAYOUT_SHUFFLE` paths in
DequantPack8::operator()(bhalf8_t&, ...) and DequantPack8WithZp::
operator()(bhalf8_t&, ...) similarly use scalar fp32 fallback — bf16
arithmetic isn't broadly available outside the pk4 fast path.
Verified against torch reference on gfx1151 via the standalone aiter
op_tests/test_gemm_w4a16.py harness — sym + asym × fp16 + bf16 all
pass within their respective fp tolerances (fp16 max_abs/max_ref =
5.4e-4, bf16 max_abs/max_ref = 4.5e-3 at M=2048 N=19456 K=2560 G=128).
JIRA: AIESW-32176.