[CK] DequantPack8 + DequantPack8WithZp bf16 overloads (asymmetric int4 / AWQ)

Adds bf16 (bhalf8_t / bhalf2_t) overloads to DequantPack8 and DequantPack8WithZp
so consumers of the wmma_cshuffle_v3_b_scale device-op can dispatch bf16
output without link errors.

The bf16 overloads use new helpers `i4_to_bhalf4_scale` and
`i4_to_bhalf4_zp_scale`. Implementation notes:

* The fp16 helpers (`i4_to_half4_scale` / `i4_to_half4_zp_scale`) use a
  bit-pattern trick: load nibble bits into the fp16 mantissa via
  AND/OR/EX, subtract a magic constant, and apply scale via native half2
  arithmetic. RDNA3 has no equivalent native bhalf2 multiply / fma path,
  so the bf16 helpers fall back to scalar fp32 conversion (one
  `type_convert<float>` per lane, multiply, `type_convert<bhalf_t>` back).

* Lane order matters: the bf16 helpers MUST extract q's nibbles at the
  same positions {0, 4, 1, 5} as the fp16 helpers, NOT the {0, 1, 2, 3}
  order produced by the existing unscaled `i4_to_bhalf4` helper (which
  uses a different __byte_perm pattern). Mismatched orderings appear to
  work for constant-data correctness tests (every nibble has the same
  value so position doesn't matter) but produce wildly incorrect output
  when nibbles vary across a pack — downstream WMMA accumulates each
  dequant lane against a specific activation lane, so reordering breaks
  the GEMM. The bf16 helpers below explicitly reproduce the fp16
  position pattern.

* Cost: bf16 dispatch is measured ~30% slower than fp16 dispatch on
  gfx1151 (Strix Halo / RDNA 3.5) for the same kernel config. Acceptable
  trade-off for correctness; can be optimized later by adapting the fp16
  bit-trick to bf16's wider exponent.

The reference (slow) `\!CK_USE_PK4_LAYOUT_SHUFFLE` paths in
DequantPack8::operator()(bhalf8_t&, ...) and DequantPack8WithZp::
operator()(bhalf8_t&, ...) similarly use scalar fp32 fallback — bf16
arithmetic isn't broadly available outside the pk4 fast path.

Verified against torch reference on gfx1151 via the standalone aiter
op_tests/test_gemm_w4a16.py harness — sym + asym × fp16 + bf16 all
pass within their respective fp tolerances (fp16 max_abs/max_ref =
5.4e-4, bf16 max_abs/max_ref = 4.5e-3 at M=2048 N=19456 K=2560 G=128).

JIRA: AIESW-32176.
This commit is contained in:
Marcus Rosen
2026-05-14 07:08:20 -07:00
parent fe31dbe57e
commit ad2f19fd31

View File

@@ -112,6 +112,83 @@ i4_to_half4_zp_scale(int q, const ck::half2_t& scale, const ck::half2_t& scaled_
return res.template AsType<half4_t>()[Number<0>{}];
}
// AIESW-32176: bf16 analog of i4_to_half4_scale.
//
// IMPORTANT: this matches the FP16 helper's nibble-position pattern
// (extract q's nibbles at positions {0, 4, 1, 5}), NOT the order produced
// by the existing unscaled `i4_to_bhalf4` (which uses {0, 1, 2, 3} packed
// via __byte_perm). The two orderings produce different lane layouts that
// are NOT interchangeable downstream — the b_scale GEMM kernel
// accumulates dequant values against specific activation lanes via WMMA,
// so the lane order must match what the fp16 path produces in order to
// reuse the same packed-weight buffer for both fp16 and bf16 dispatches.
//
// Implementation: scalar fp32 conversion. No native bhalf2 multiply /
// fma on RDNA3, so the fp16 trick (single-instruction integer-bias-shift
// into fp16 bit pattern) doesn't translate directly. The scalar path
// below is correct + simple — perf cost is real but the bf16 dispatch
// isn't the perf-critical case (fp16 is the production target).
__device__ inline bhalf4_t i4_to_bhalf4_scale(int q, const ck::bhalf2_t& scale)
{
vector_type<bhalf_t, 2> scale_v;
scale_v.template AsType<bhalf2_t>()(Number<0>{}) = scale;
const float s0 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<0>{}]);
const float s1 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<1>{}]);
// Nibble positions match i4_to_half4_scale: {0, 4, 1, 5} of q.
const int n0 = (q >> 0) & 0xf;
const int n4 = (q >> 16) & 0xf;
const int n1 = (q >> 4) & 0xf;
const int n5 = (q >> 20) & 0xf;
vector_type<bhalf_t, 4> res;
res.template AsType<bhalf_t>()(Number<0>{}) =
type_convert<bhalf_t>(static_cast<float>(n0 - 8) * s0);
res.template AsType<bhalf_t>()(Number<1>{}) =
type_convert<bhalf_t>(static_cast<float>(n4 - 8) * s1);
res.template AsType<bhalf_t>()(Number<2>{}) =
type_convert<bhalf_t>(static_cast<float>(n1 - 8) * s0);
res.template AsType<bhalf_t>()(Number<3>{}) =
type_convert<bhalf_t>(static_cast<float>(n5 - 8) * s1);
return res.template AsType<bhalf4_t>()[Number<0>{}];
}
// AIESW-32176: bf16 analog of i4_to_half4_zp_scale (asymmetric / AWQ).
// Same nibble-position pattern as i4_to_bhalf4_scale plus a per-group
// scaled_zp subtract, matching (nibble - zp) * scale via the
// (nibble - 8) * scale - scaled_zp identity. Caller precomputes
// scaled_zp = (zp - 8) * scale at weight load.
__device__ inline bhalf4_t
i4_to_bhalf4_zp_scale(int q, const ck::bhalf2_t& scale, const ck::bhalf2_t& scaled_zp)
{
vector_type<bhalf_t, 2> scale_v;
scale_v.template AsType<bhalf2_t>()(Number<0>{}) = scale;
vector_type<bhalf_t, 2> zp_v;
zp_v.template AsType<bhalf2_t>()(Number<0>{}) = scaled_zp;
const float s0 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<0>{}]);
const float s1 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<1>{}]);
const float z0 = type_convert<float>(zp_v.template AsType<bhalf_t>()[Number<0>{}]);
const float z1 = type_convert<float>(zp_v.template AsType<bhalf_t>()[Number<1>{}]);
const int n0 = (q >> 0) & 0xf;
const int n4 = (q >> 16) & 0xf;
const int n1 = (q >> 4) & 0xf;
const int n5 = (q >> 20) & 0xf;
vector_type<bhalf_t, 4> res;
res.template AsType<bhalf_t>()(Number<0>{}) =
type_convert<bhalf_t>(static_cast<float>(n0 - 8) * s0 - z0);
res.template AsType<bhalf_t>()(Number<1>{}) =
type_convert<bhalf_t>(static_cast<float>(n4 - 8) * s1 - z1);
res.template AsType<bhalf_t>()(Number<2>{}) =
type_convert<bhalf_t>(static_cast<float>(n1 - 8) * s0 - z0);
res.template AsType<bhalf_t>()(Number<3>{}) =
type_convert<bhalf_t>(static_cast<float>(n5 - 8) * s1 - z1);
return res.template AsType<bhalf4_t>()[Number<0>{}];
}
__device__ inline f8x4_t i4_to_f8x4(int q)
{
const int LO = 0x000f000f;
@@ -333,6 +410,47 @@ struct DequantPack8
#endif
}
// AIESW-32176: bf16 overload — mirrors the fp16 path but uses
// i4_to_bhalf4_scale (which goes through fp32 intermediates because RDNA3
// has no native bhalf2 multiply).
__host__ __device__ constexpr void
operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x, const ck::bhalf2_t& z) const
{
#if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<bhalf_t, 8> result;
result.template AsType<bhalf4_t>()(Number<0>{}) = i4_to_bhalf4_scale(bit_cast<int>(x), z);
result.template AsType<bhalf4_t>()(Number<1>{}) =
i4_to_bhalf4_scale(bit_cast<int>(x) >> 8, z);
y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else
// Reference (slow) path for the !CK_USE_PK4_LAYOUT_SHUFFLE config:
// do the unscaled bf16 dequant via existing type_convert, then apply
// (* scale) per-element via fp32.
vector_type<bhalf_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};
vector_type<bhalf_t, 2> z_v;
z_v.template AsType<bhalf2_t>()(Number<0>{}) = z;
const float s0 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<0>{}]);
const float s1 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<1>{}]);
static_for<0, 4, 1>{}([&](auto i) {
auto v = type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[i]);
vector_type<bhalf_t, 2> v_v;
v_v.template AsType<bhalf2_t>()(Number<0>{}) = v;
const float v0 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<0>{}]);
const float v1 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<1>{}]);
vector_type<bhalf_t, 2> r;
r.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(v0 * s0);
r.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(v1 * s1);
dst.template AsType<bhalf2_t>()(i) = r.template AsType<bhalf2_t>()[Number<0>{}];
});
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
#endif
}
constexpr const static bool is_pack8_invocable = true;
};
@@ -378,6 +496,53 @@ struct DequantPack8WithZp
#endif
}
// AIESW-32176: bf16 overload — mirrors the fp16 path but uses
// i4_to_bhalf4_zp_scale (fp32 intermediates; no native bhalf2 multiply on
// RDNA3).
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y,
const ck::pk_i4x4_t& x,
const ck::bhalf2_t& s,
const ck::bhalf2_t& scaled_zp) const
{
#if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<bhalf_t, 8> result;
result.template AsType<bhalf4_t>()(Number<0>{}) =
i4_to_bhalf4_zp_scale(bit_cast<int>(x), s, scaled_zp);
result.template AsType<bhalf4_t>()(Number<1>{}) =
i4_to_bhalf4_zp_scale(bit_cast<int>(x) >> 8, s, scaled_zp);
y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else
// Reference (slow) path: unscaled bf16 dequant + per-element
// (* scale) - scaled_zp via fp32.
vector_type<bhalf_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};
vector_type<bhalf_t, 2> s_v;
s_v.template AsType<bhalf2_t>()(Number<0>{}) = s;
vector_type<bhalf_t, 2> z_v;
z_v.template AsType<bhalf2_t>()(Number<0>{}) = scaled_zp;
const float s0 = type_convert<float>(s_v.template AsType<bhalf_t>()[Number<0>{}]);
const float s1 = type_convert<float>(s_v.template AsType<bhalf_t>()[Number<1>{}]);
const float z0 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<0>{}]);
const float z1 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<1>{}]);
static_for<0, 4, 1>{}([&](auto i) {
auto v = type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[i]);
vector_type<bhalf_t, 2> v_v;
v_v.template AsType<bhalf2_t>()(Number<0>{}) = v;
const float v0 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<0>{}]);
const float v1 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<1>{}]);
vector_type<bhalf_t, 2> r;
r.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(v0 * s0 - z0);
r.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(v1 * s1 - z1);
dst.template AsType<bhalf2_t>()(i) = r.template AsType<bhalf2_t>()[Number<0>{}];
});
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
#endif
}
constexpr const static bool is_pack8_invocable = true;
};