From ad2f19fd3128e7d149983dff0438df92ba81b646 Mon Sep 17 00:00:00 2001 From: Marcus Rosen Date: Thu, 14 May 2026 07:08:20 -0700 Subject: [PATCH] [CK] DequantPack8 + DequantPack8WithZp bf16 overloads (asymmetric int4 / AWQ) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds bf16 (bhalf8_t / bhalf2_t) overloads to DequantPack8 and DequantPack8WithZp so consumers of the wmma_cshuffle_v3_b_scale device-op can dispatch bf16 output without link errors. The bf16 overloads use new helpers `i4_to_bhalf4_scale` and `i4_to_bhalf4_zp_scale`. Implementation notes: * The fp16 helpers (`i4_to_half4_scale` / `i4_to_half4_zp_scale`) use a bit-pattern trick: load nibble bits into the fp16 mantissa via AND/OR/EX, subtract a magic constant, and apply scale via native half2 arithmetic. RDNA3 has no equivalent native bhalf2 multiply / fma path, so the bf16 helpers fall back to scalar fp32 conversion (one `type_convert` per lane, multiply, `type_convert` back). * Lane order matters: the bf16 helpers MUST extract q's nibbles at the same positions {0, 4, 1, 5} as the fp16 helpers, NOT the {0, 1, 2, 3} order produced by the existing unscaled `i4_to_bhalf4` helper (which uses a different __byte_perm pattern). Mismatched orderings appear to work for constant-data correctness tests (every nibble has the same value so position doesn't matter) but produce wildly incorrect output when nibbles vary across a pack — downstream WMMA accumulates each dequant lane against a specific activation lane, so reordering breaks the GEMM. The bf16 helpers below explicitly reproduce the fp16 position pattern. * Cost: bf16 dispatch is measured ~30% slower than fp16 dispatch on gfx1151 (Strix Halo / RDNA 3.5) for the same kernel config. Acceptable trade-off for correctness; can be optimized later by adapting the fp16 bit-trick to bf16's wider exponent. The reference (slow) `\!CK_USE_PK4_LAYOUT_SHUFFLE` paths in DequantPack8::operator()(bhalf8_t&, ...) and DequantPack8WithZp:: operator()(bhalf8_t&, ...) similarly use scalar fp32 fallback — bf16 arithmetic isn't broadly available outside the pk4 fast path. Verified against torch reference on gfx1151 via the standalone aiter op_tests/test_gemm_w4a16.py harness — sym + asym × fp16 + bf16 all pass within their respective fp tolerances (fp16 max_abs/max_ref = 5.4e-4, bf16 max_abs/max_ref = 4.5e-3 at M=2048 N=19456 K=2560 G=128). JIRA: AIESW-32176. --- .../element/unary_element_wise_operation.hpp | 165 ++++++++++++++++++ 1 file changed, 165 insertions(+) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 70d2efe6e6..4f8c835086 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -112,6 +112,83 @@ i4_to_half4_zp_scale(int q, const ck::half2_t& scale, const ck::half2_t& scaled_ return res.template AsType()[Number<0>{}]; } +// AIESW-32176: bf16 analog of i4_to_half4_scale. +// +// IMPORTANT: this matches the FP16 helper's nibble-position pattern +// (extract q's nibbles at positions {0, 4, 1, 5}), NOT the order produced +// by the existing unscaled `i4_to_bhalf4` (which uses {0, 1, 2, 3} packed +// via __byte_perm). The two orderings produce different lane layouts that +// are NOT interchangeable downstream — the b_scale GEMM kernel +// accumulates dequant values against specific activation lanes via WMMA, +// so the lane order must match what the fp16 path produces in order to +// reuse the same packed-weight buffer for both fp16 and bf16 dispatches. +// +// Implementation: scalar fp32 conversion. No native bhalf2 multiply / +// fma on RDNA3, so the fp16 trick (single-instruction integer-bias-shift +// into fp16 bit pattern) doesn't translate directly. The scalar path +// below is correct + simple — perf cost is real but the bf16 dispatch +// isn't the perf-critical case (fp16 is the production target). +__device__ inline bhalf4_t i4_to_bhalf4_scale(int q, const ck::bhalf2_t& scale) +{ + vector_type scale_v; + scale_v.template AsType()(Number<0>{}) = scale; + const float s0 = type_convert(scale_v.template AsType()[Number<0>{}]); + const float s1 = type_convert(scale_v.template AsType()[Number<1>{}]); + + // Nibble positions match i4_to_half4_scale: {0, 4, 1, 5} of q. + const int n0 = (q >> 0) & 0xf; + const int n4 = (q >> 16) & 0xf; + const int n1 = (q >> 4) & 0xf; + const int n5 = (q >> 20) & 0xf; + + vector_type res; + res.template AsType()(Number<0>{}) = + type_convert(static_cast(n0 - 8) * s0); + res.template AsType()(Number<1>{}) = + type_convert(static_cast(n4 - 8) * s1); + res.template AsType()(Number<2>{}) = + type_convert(static_cast(n1 - 8) * s0); + res.template AsType()(Number<3>{}) = + type_convert(static_cast(n5 - 8) * s1); + + return res.template AsType()[Number<0>{}]; +} + +// AIESW-32176: bf16 analog of i4_to_half4_zp_scale (asymmetric / AWQ). +// Same nibble-position pattern as i4_to_bhalf4_scale plus a per-group +// scaled_zp subtract, matching (nibble - zp) * scale via the +// (nibble - 8) * scale - scaled_zp identity. Caller precomputes +// scaled_zp = (zp - 8) * scale at weight load. +__device__ inline bhalf4_t +i4_to_bhalf4_zp_scale(int q, const ck::bhalf2_t& scale, const ck::bhalf2_t& scaled_zp) +{ + vector_type scale_v; + scale_v.template AsType()(Number<0>{}) = scale; + vector_type zp_v; + zp_v.template AsType()(Number<0>{}) = scaled_zp; + const float s0 = type_convert(scale_v.template AsType()[Number<0>{}]); + const float s1 = type_convert(scale_v.template AsType()[Number<1>{}]); + const float z0 = type_convert(zp_v.template AsType()[Number<0>{}]); + const float z1 = type_convert(zp_v.template AsType()[Number<1>{}]); + + const int n0 = (q >> 0) & 0xf; + const int n4 = (q >> 16) & 0xf; + const int n1 = (q >> 4) & 0xf; + const int n5 = (q >> 20) & 0xf; + + vector_type res; + res.template AsType()(Number<0>{}) = + type_convert(static_cast(n0 - 8) * s0 - z0); + res.template AsType()(Number<1>{}) = + type_convert(static_cast(n4 - 8) * s1 - z1); + res.template AsType()(Number<2>{}) = + type_convert(static_cast(n1 - 8) * s0 - z0); + res.template AsType()(Number<3>{}) = + type_convert(static_cast(n5 - 8) * s1 - z1); + + return res.template AsType()[Number<0>{}]; +} + __device__ inline f8x4_t i4_to_f8x4(int q) { const int LO = 0x000f000f; @@ -333,6 +410,47 @@ struct DequantPack8 #endif } + // AIESW-32176: bf16 overload — mirrors the fp16 path but uses + // i4_to_bhalf4_scale (which goes through fp32 intermediates because RDNA3 + // has no native bhalf2 multiply). + __host__ __device__ constexpr void + operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x, const ck::bhalf2_t& z) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + vector_type result; + + result.template AsType()(Number<0>{}) = i4_to_bhalf4_scale(bit_cast(x), z); + result.template AsType()(Number<1>{}) = + i4_to_bhalf4_scale(bit_cast(x) >> 8, z); + + y = result.template AsType()[Number<0>{}]; +#else + // Reference (slow) path for the !CK_USE_PK4_LAYOUT_SHUFFLE config: + // do the unscaled bf16 dequant via existing type_convert, then apply + // (* scale) per-element via fp32. + vector_type dst; + vector_type src{x}; + vector_type z_v; + z_v.template AsType()(Number<0>{}) = z; + const float s0 = type_convert(z_v.template AsType()[Number<0>{}]); + const float s1 = type_convert(z_v.template AsType()[Number<1>{}]); + + static_for<0, 4, 1>{}([&](auto i) { + auto v = type_convert(src.template AsType()[i]); + vector_type v_v; + v_v.template AsType()(Number<0>{}) = v; + const float v0 = type_convert(v_v.template AsType()[Number<0>{}]); + const float v1 = type_convert(v_v.template AsType()[Number<1>{}]); + vector_type r; + r.template AsType()(Number<0>{}) = type_convert(v0 * s0); + r.template AsType()(Number<1>{}) = type_convert(v1 * s1); + dst.template AsType()(i) = r.template AsType()[Number<0>{}]; + }); + + y = dst.template AsType()[Number<0>{}]; +#endif + } + constexpr const static bool is_pack8_invocable = true; }; @@ -378,6 +496,53 @@ struct DequantPack8WithZp #endif } + // AIESW-32176: bf16 overload — mirrors the fp16 path but uses + // i4_to_bhalf4_zp_scale (fp32 intermediates; no native bhalf2 multiply on + // RDNA3). + __host__ __device__ constexpr void operator()(ck::bhalf8_t& y, + const ck::pk_i4x4_t& x, + const ck::bhalf2_t& s, + const ck::bhalf2_t& scaled_zp) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + vector_type result; + + result.template AsType()(Number<0>{}) = + i4_to_bhalf4_zp_scale(bit_cast(x), s, scaled_zp); + result.template AsType()(Number<1>{}) = + i4_to_bhalf4_zp_scale(bit_cast(x) >> 8, s, scaled_zp); + + y = result.template AsType()[Number<0>{}]; +#else + // Reference (slow) path: unscaled bf16 dequant + per-element + // (* scale) - scaled_zp via fp32. + vector_type dst; + vector_type src{x}; + vector_type s_v; + s_v.template AsType()(Number<0>{}) = s; + vector_type z_v; + z_v.template AsType()(Number<0>{}) = scaled_zp; + const float s0 = type_convert(s_v.template AsType()[Number<0>{}]); + const float s1 = type_convert(s_v.template AsType()[Number<1>{}]); + const float z0 = type_convert(z_v.template AsType()[Number<0>{}]); + const float z1 = type_convert(z_v.template AsType()[Number<1>{}]); + + static_for<0, 4, 1>{}([&](auto i) { + auto v = type_convert(src.template AsType()[i]); + vector_type v_v; + v_v.template AsType()(Number<0>{}) = v; + const float v0 = type_convert(v_v.template AsType()[Number<0>{}]); + const float v1 = type_convert(v_v.template AsType()[Number<1>{}]); + vector_type r; + r.template AsType()(Number<0>{}) = type_convert(v0 * s0 - z0); + r.template AsType()(Number<1>{}) = type_convert(v1 * s1 - z1); + dst.template AsType()(i) = r.template AsType()[Number<0>{}]; + }); + + y = dst.template AsType()[Number<0>{}]; +#endif + } + constexpr const static bool is_pack8_invocable = true; };