mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
[CK] DequantPack8 + DequantPack8WithZp bf16 overloads (asymmetric int4 / AWQ)
Adds bf16 (bhalf8_t / bhalf2_t) overloads to DequantPack8 and DequantPack8WithZp
so consumers of the wmma_cshuffle_v3_b_scale device-op can dispatch bf16
output without link errors.
The bf16 overloads use new helpers `i4_to_bhalf4_scale` and
`i4_to_bhalf4_zp_scale`. Implementation notes:
* The fp16 helpers (`i4_to_half4_scale` / `i4_to_half4_zp_scale`) use a
bit-pattern trick: load nibble bits into the fp16 mantissa via
AND/OR/EX, subtract a magic constant, and apply scale via native half2
arithmetic. RDNA3 has no equivalent native bhalf2 multiply / fma path,
so the bf16 helpers fall back to scalar fp32 conversion (one
`type_convert<float>` per lane, multiply, `type_convert<bhalf_t>` back).
* Lane order matters: the bf16 helpers MUST extract q's nibbles at the
same positions {0, 4, 1, 5} as the fp16 helpers, NOT the {0, 1, 2, 3}
order produced by the existing unscaled `i4_to_bhalf4` helper (which
uses a different __byte_perm pattern). Mismatched orderings appear to
work for constant-data correctness tests (every nibble has the same
value so position doesn't matter) but produce wildly incorrect output
when nibbles vary across a pack — downstream WMMA accumulates each
dequant lane against a specific activation lane, so reordering breaks
the GEMM. The bf16 helpers below explicitly reproduce the fp16
position pattern.
* Cost: bf16 dispatch is measured ~30% slower than fp16 dispatch on
gfx1151 (Strix Halo / RDNA 3.5) for the same kernel config. Acceptable
trade-off for correctness; can be optimized later by adapting the fp16
bit-trick to bf16's wider exponent.
The reference (slow) `\!CK_USE_PK4_LAYOUT_SHUFFLE` paths in
DequantPack8::operator()(bhalf8_t&, ...) and DequantPack8WithZp::
operator()(bhalf8_t&, ...) similarly use scalar fp32 fallback — bf16
arithmetic isn't broadly available outside the pk4 fast path.
Verified against torch reference on gfx1151 via the standalone aiter
op_tests/test_gemm_w4a16.py harness — sym + asym × fp16 + bf16 all
pass within their respective fp tolerances (fp16 max_abs/max_ref =
5.4e-4, bf16 max_abs/max_ref = 4.5e-3 at M=2048 N=19456 K=2560 G=128).
JIRA: AIESW-32176.
This commit is contained in:
@@ -112,6 +112,83 @@ i4_to_half4_zp_scale(int q, const ck::half2_t& scale, const ck::half2_t& scaled_
|
||||
return res.template AsType<half4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
// AIESW-32176: bf16 analog of i4_to_half4_scale.
|
||||
//
|
||||
// IMPORTANT: this matches the FP16 helper's nibble-position pattern
|
||||
// (extract q's nibbles at positions {0, 4, 1, 5}), NOT the order produced
|
||||
// by the existing unscaled `i4_to_bhalf4` (which uses {0, 1, 2, 3} packed
|
||||
// via __byte_perm). The two orderings produce different lane layouts that
|
||||
// are NOT interchangeable downstream — the b_scale GEMM kernel
|
||||
// accumulates dequant values against specific activation lanes via WMMA,
|
||||
// so the lane order must match what the fp16 path produces in order to
|
||||
// reuse the same packed-weight buffer for both fp16 and bf16 dispatches.
|
||||
//
|
||||
// Implementation: scalar fp32 conversion. No native bhalf2 multiply /
|
||||
// fma on RDNA3, so the fp16 trick (single-instruction integer-bias-shift
|
||||
// into fp16 bit pattern) doesn't translate directly. The scalar path
|
||||
// below is correct + simple — perf cost is real but the bf16 dispatch
|
||||
// isn't the perf-critical case (fp16 is the production target).
|
||||
__device__ inline bhalf4_t i4_to_bhalf4_scale(int q, const ck::bhalf2_t& scale)
|
||||
{
|
||||
vector_type<bhalf_t, 2> scale_v;
|
||||
scale_v.template AsType<bhalf2_t>()(Number<0>{}) = scale;
|
||||
const float s0 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float s1 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
|
||||
// Nibble positions match i4_to_half4_scale: {0, 4, 1, 5} of q.
|
||||
const int n0 = (q >> 0) & 0xf;
|
||||
const int n4 = (q >> 16) & 0xf;
|
||||
const int n1 = (q >> 4) & 0xf;
|
||||
const int n5 = (q >> 20) & 0xf;
|
||||
|
||||
vector_type<bhalf_t, 4> res;
|
||||
res.template AsType<bhalf_t>()(Number<0>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n0 - 8) * s0);
|
||||
res.template AsType<bhalf_t>()(Number<1>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n4 - 8) * s1);
|
||||
res.template AsType<bhalf_t>()(Number<2>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n1 - 8) * s0);
|
||||
res.template AsType<bhalf_t>()(Number<3>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n5 - 8) * s1);
|
||||
|
||||
return res.template AsType<bhalf4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
// AIESW-32176: bf16 analog of i4_to_half4_zp_scale (asymmetric / AWQ).
|
||||
// Same nibble-position pattern as i4_to_bhalf4_scale plus a per-group
|
||||
// scaled_zp subtract, matching (nibble - zp) * scale via the
|
||||
// (nibble - 8) * scale - scaled_zp identity. Caller precomputes
|
||||
// scaled_zp = (zp - 8) * scale at weight load.
|
||||
__device__ inline bhalf4_t
|
||||
i4_to_bhalf4_zp_scale(int q, const ck::bhalf2_t& scale, const ck::bhalf2_t& scaled_zp)
|
||||
{
|
||||
vector_type<bhalf_t, 2> scale_v;
|
||||
scale_v.template AsType<bhalf2_t>()(Number<0>{}) = scale;
|
||||
vector_type<bhalf_t, 2> zp_v;
|
||||
zp_v.template AsType<bhalf2_t>()(Number<0>{}) = scaled_zp;
|
||||
const float s0 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float s1 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
const float z0 = type_convert<float>(zp_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float z1 = type_convert<float>(zp_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
|
||||
const int n0 = (q >> 0) & 0xf;
|
||||
const int n4 = (q >> 16) & 0xf;
|
||||
const int n1 = (q >> 4) & 0xf;
|
||||
const int n5 = (q >> 20) & 0xf;
|
||||
|
||||
vector_type<bhalf_t, 4> res;
|
||||
res.template AsType<bhalf_t>()(Number<0>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n0 - 8) * s0 - z0);
|
||||
res.template AsType<bhalf_t>()(Number<1>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n4 - 8) * s1 - z1);
|
||||
res.template AsType<bhalf_t>()(Number<2>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n1 - 8) * s0 - z0);
|
||||
res.template AsType<bhalf_t>()(Number<3>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n5 - 8) * s1 - z1);
|
||||
|
||||
return res.template AsType<bhalf4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
__device__ inline f8x4_t i4_to_f8x4(int q)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
@@ -333,6 +410,47 @@ struct DequantPack8
|
||||
#endif
|
||||
}
|
||||
|
||||
// AIESW-32176: bf16 overload — mirrors the fp16 path but uses
|
||||
// i4_to_bhalf4_scale (which goes through fp32 intermediates because RDNA3
|
||||
// has no native bhalf2 multiply).
|
||||
__host__ __device__ constexpr void
|
||||
operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x, const ck::bhalf2_t& z) const
|
||||
{
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
vector_type<bhalf_t, 8> result;
|
||||
|
||||
result.template AsType<bhalf4_t>()(Number<0>{}) = i4_to_bhalf4_scale(bit_cast<int>(x), z);
|
||||
result.template AsType<bhalf4_t>()(Number<1>{}) =
|
||||
i4_to_bhalf4_scale(bit_cast<int>(x) >> 8, z);
|
||||
|
||||
y = result.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#else
|
||||
// Reference (slow) path for the !CK_USE_PK4_LAYOUT_SHUFFLE config:
|
||||
// do the unscaled bf16 dequant via existing type_convert, then apply
|
||||
// (* scale) per-element via fp32.
|
||||
vector_type<bhalf_t, 8> dst;
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
vector_type<bhalf_t, 2> z_v;
|
||||
z_v.template AsType<bhalf2_t>()(Number<0>{}) = z;
|
||||
const float s0 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float s1 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
auto v = type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[i]);
|
||||
vector_type<bhalf_t, 2> v_v;
|
||||
v_v.template AsType<bhalf2_t>()(Number<0>{}) = v;
|
||||
const float v0 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float v1 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
vector_type<bhalf_t, 2> r;
|
||||
r.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(v0 * s0);
|
||||
r.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(v1 * s1);
|
||||
dst.template AsType<bhalf2_t>()(i) = r.template AsType<bhalf2_t>()[Number<0>{}];
|
||||
});
|
||||
|
||||
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
@@ -378,6 +496,53 @@ struct DequantPack8WithZp
|
||||
#endif
|
||||
}
|
||||
|
||||
// AIESW-32176: bf16 overload — mirrors the fp16 path but uses
|
||||
// i4_to_bhalf4_zp_scale (fp32 intermediates; no native bhalf2 multiply on
|
||||
// RDNA3).
|
||||
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y,
|
||||
const ck::pk_i4x4_t& x,
|
||||
const ck::bhalf2_t& s,
|
||||
const ck::bhalf2_t& scaled_zp) const
|
||||
{
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
vector_type<bhalf_t, 8> result;
|
||||
|
||||
result.template AsType<bhalf4_t>()(Number<0>{}) =
|
||||
i4_to_bhalf4_zp_scale(bit_cast<int>(x), s, scaled_zp);
|
||||
result.template AsType<bhalf4_t>()(Number<1>{}) =
|
||||
i4_to_bhalf4_zp_scale(bit_cast<int>(x) >> 8, s, scaled_zp);
|
||||
|
||||
y = result.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#else
|
||||
// Reference (slow) path: unscaled bf16 dequant + per-element
|
||||
// (* scale) - scaled_zp via fp32.
|
||||
vector_type<bhalf_t, 8> dst;
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
vector_type<bhalf_t, 2> s_v;
|
||||
s_v.template AsType<bhalf2_t>()(Number<0>{}) = s;
|
||||
vector_type<bhalf_t, 2> z_v;
|
||||
z_v.template AsType<bhalf2_t>()(Number<0>{}) = scaled_zp;
|
||||
const float s0 = type_convert<float>(s_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float s1 = type_convert<float>(s_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
const float z0 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float z1 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
auto v = type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[i]);
|
||||
vector_type<bhalf_t, 2> v_v;
|
||||
v_v.template AsType<bhalf2_t>()(Number<0>{}) = v;
|
||||
const float v0 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float v1 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
vector_type<bhalf_t, 2> r;
|
||||
r.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(v0 * s0 - z0);
|
||||
r.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(v1 * s1 - z1);
|
||||
dst.template AsType<bhalf2_t>()(i) = r.template AsType<bhalf2_t>()[Number<0>{}];
|
||||
});
|
||||
|
||||
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user