From 15b4a580dcc375c31fa021db2c3c18adfaa92003 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Thu, 21 May 2026 04:22:38 -0600 Subject: [PATCH] [CK] AIESW-32282: bake bf16 truncate-to-bf16 conversion as the only behavior; drop *Truncate variants and DequantPolicyFor trait MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit i4_to_bhalf4_scale and i4_to_bhalf4_zp_scale now use fp32_to_bhalf_truncate (bit-cast >>16) for the trailing fp32->bf16 step. The IEEE round-to- nearest-even path's per-nibble v_add3_u32 + v_cmp_o_f32 + v_cndmask_b16 chain is gone — saves ~3 RDNA3.5 VALU instructions per nibble. Worst-case 0.5 ULP of bf16 error (~4e-3 relative, inside the 5e-3 op-test tolerance); lm_eval on Orion-zhen/Qwen3-1.7B-AWQ shows truncate statistically indistinguishable from Triton (gsm8k 5-shot n=500, McNemar p=1.000). Changes: - DequantPack8Truncate / DequantPack8WithZpTruncate structs removed — DequantPack8 / DequantPack8WithZp now ARE the truncate variants on the bf16 path. fp16 path unchanged (no rounding chain to skip; the fp16 bit-trick is already optimal). - DequantPolicyFor specialization removed. - The generic DequantPolicyFor<> + DequantPack8WithZp specialization stay in place so the device-op's BDequantOp template hook continues to work as a generic plug-in point for callers that want a custom dequant op (kept upstream-friendly: no caller side has to change). - Comments in threadwise_tensor_slice_transfer.hpp and blockwise_gemm_pipeline_wmmaops_v1.hpp updated to drop the Truncate examples that no longer exist. --- .../blockwise_gemm_pipeline_wmmaops_v1.hpp | 11 +- .../element/unary_element_wise_operation.hpp | 297 +++--------------- .../threadwise_tensor_slice_transfer.hpp | 11 +- 3 files changed, 60 insertions(+), 259 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index cd3560ecb8..511c579bb9 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -19,9 +19,7 @@ namespace ck { // the matching arity. `void` means "do not override" — the resolved type // falls back to the threadwise default and the generated code is // bit-identical to the pre-AIESW-32282 build. Non-void values let callers -// override the dequant element-op per device-op instantiation (e.g. -// {DequantPack8Truncate, DequantPack8WithZpTruncate} for bf16 truncate -// rounding). +// plug in a custom dequant element-op per device-op instantiation. template , ck::tensor_operation::element_wise::DequantPack8, diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index f85c17dbaa..199465e4d0 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -112,7 +112,29 @@ i4_to_half4_zp_scale(int q, const ck::half2_t& scale, const ck::half2_t& scaled_ return res.template AsType()[Number<0>{}]; } -// AIESW-32176: bf16 analog of i4_to_half4_scale. +// AIESW-32282: truncate-to-bf16 fp32 reinterpret. +// +// bf16 IS the upper 16 bits of fp32, so we just drop the low 16 bits. The +// IEEE round-to-nearest-even path the compiler emits for +// `type_convert(float)` would lower on RDNA3.5 to a chain that's +// ~3 extra VALU instructions per nibble: +// v_add3_u32 +0x7fff (round bias) +// v_cmp_o_f32 (ordered compare for NaN detection) +// v_cndmask_b16 0x7fc0 (splat qNaN if input was NaN) +// For the dequant case the inputs are always finite (no NaN/Inf), so the +// NaN-quietening branch is dead code. The round bias adds at most 0.5 ULP +// of bf16 error vs RTE (~4e-3 relative), inside the W4A16 op-test +// tolerance and statistically indistinguishable from Triton on lm_eval +// (verified gsm8k 5-shot on Orion-zhen/Qwen3-1.7B-AWQ, McNemar p=1.000). +// Locked in as the only bf16 rounding mode for gfx1151. CK analog of +// vLLM PR ROCm/vllm#953's Triton-side fp32->bf16 truncate. +__device__ inline bhalf_t fp32_to_bhalf_truncate(float f) +{ + uint32_t bits = __builtin_bit_cast(uint32_t, f); + return __builtin_bit_cast(bhalf_t, static_cast(bits >> 16)); +} + +// AIESW-32176/32282: bf16 analog of i4_to_half4_scale. // // IMPORTANT: this matches the FP16 helper's nibble-position pattern // (extract q's nibbles at positions {0, 4, 1, 5}), NOT the order produced @@ -123,11 +145,12 @@ i4_to_half4_zp_scale(int q, const ck::half2_t& scale, const ck::half2_t& scaled_ // so the lane order must match what the fp16 path produces in order to // reuse the same packed-weight buffer for both fp16 and bf16 dispatches. // -// Implementation: scalar fp32 conversion. No native bhalf2 multiply / -// fma on RDNA3, so the fp16 trick (single-instruction integer-bias-shift -// into fp16 bit pattern) doesn't translate directly. The scalar path -// below is correct + simple — perf cost is real but the bf16 dispatch -// isn't the perf-critical case (fp16 is the production target). +// Implementation: scalar fp32 multiply then truncate-to-bf16. RDNA3 has +// no native bhalf2 fma so the fp16 single-instruction bit-trick path +// doesn't translate; the trailing fp32->bf16 step uses +// fp32_to_bhalf_truncate (bit-cast >>16), which is the only bf16 +// rounding mode the kernel ships — see fp32_to_bhalf_truncate's comment +// for the rounding-mode decision and lm_eval evidence. __device__ inline bhalf4_t i4_to_bhalf4_scale(int q, const ck::bhalf2_t& scale) { vector_type scale_v; @@ -141,94 +164,6 @@ __device__ inline bhalf4_t i4_to_bhalf4_scale(int q, const ck::bhalf2_t& scale) const int n1 = (q >> 4) & 0xf; const int n5 = (q >> 20) & 0xf; - vector_type res; - res.template AsType()(Number<0>{}) = - type_convert(static_cast(n0 - 8) * s0); - res.template AsType()(Number<1>{}) = - type_convert(static_cast(n4 - 8) * s1); - res.template AsType()(Number<2>{}) = - type_convert(static_cast(n1 - 8) * s0); - res.template AsType()(Number<3>{}) = - type_convert(static_cast(n5 - 8) * s1); - - return res.template AsType()[Number<0>{}]; -} - -// AIESW-32176: bf16 analog of i4_to_half4_zp_scale (asymmetric / AWQ). -// Same nibble-position pattern as i4_to_bhalf4_scale plus a per-group -// scaled_zp subtract, matching (nibble - zp) * scale via the -// (nibble - 8) * scale - scaled_zp identity. Caller precomputes -// scaled_zp = (zp - 8) * scale at weight load. -__device__ inline bhalf4_t -i4_to_bhalf4_zp_scale(int q, const ck::bhalf2_t& scale, const ck::bhalf2_t& scaled_zp) -{ - vector_type scale_v; - scale_v.template AsType()(Number<0>{}) = scale; - vector_type zp_v; - zp_v.template AsType()(Number<0>{}) = scaled_zp; - const float s0 = type_convert(scale_v.template AsType()[Number<0>{}]); - const float s1 = type_convert(scale_v.template AsType()[Number<1>{}]); - const float z0 = type_convert(zp_v.template AsType()[Number<0>{}]); - const float z1 = type_convert(zp_v.template AsType()[Number<1>{}]); - - const int n0 = (q >> 0) & 0xf; - const int n4 = (q >> 16) & 0xf; - const int n1 = (q >> 4) & 0xf; - const int n5 = (q >> 20) & 0xf; - - vector_type res; - res.template AsType()(Number<0>{}) = - type_convert(static_cast(n0 - 8) * s0 - z0); - res.template AsType()(Number<1>{}) = - type_convert(static_cast(n4 - 8) * s1 - z1); - res.template AsType()(Number<2>{}) = - type_convert(static_cast(n1 - 8) * s0 - z0); - res.template AsType()(Number<3>{}) = - type_convert(static_cast(n5 - 8) * s1 - z1); - - return res.template AsType()[Number<0>{}]; -} - -// AIESW-32282: truncate-to-bf16 fp32 reinterpret. -// -// bf16 IS the upper 16 bits of fp32, so the "correct" round-to-nearest-even -// path the compiler emits for `type_convert(float)` lowers to a -// chain that's ~3 RDNA3.5 VALU instructions per nibble: -// v_add3_u32 +0x7fff (round bias) -// v_cmp_o_f32 (ordered compare for NaN detection) -// v_cndmask_b16 0x7fc0 (splat qNaN if input was NaN) -// Cumulatively this accounts for ~1150 of the 3988 lines in the bf16 asym -// ISA dump (see vllm4/notes/ck-w4a16-isa/README.md). The truncate variant -// below skips both the rounding bias and the NaN-quietening branch. -// -// Worst-case error vs RTE: 0.5 ULP of bf16, i.e. relative ~4e-3 — well -// inside the existing W4A16 op-test tolerance (TOL_REL = 5e-3). Dequanted -// nibbles are always finite (no NaN/Inf inputs), so the NaN-quietening -// branch is dead code anyway in this kernel. -// -// CK analog of vLLM PR ROCm/vllm#953's Triton-side fp32->bf16 truncate. -__device__ inline bhalf_t fp32_to_bhalf_truncate(float f) -{ - uint32_t bits = __builtin_bit_cast(uint32_t, f); - return __builtin_bit_cast(bhalf_t, static_cast(bits >> 16)); -} - -// AIESW-32282: truncate-bf16 variant of i4_to_bhalf4_scale. Same arithmetic -// up to the final fp32->bf16 step; uses fp32_to_bhalf_truncate to skip the -// IEEE round chain. -__device__ inline bhalf4_t i4_to_bhalf4_scale_truncate(int q, const ck::bhalf2_t& scale) -{ - vector_type scale_v; - scale_v.template AsType()(Number<0>{}) = scale; - const float s0 = type_convert(scale_v.template AsType()[Number<0>{}]); - const float s1 = type_convert(scale_v.template AsType()[Number<1>{}]); - - // Nibble positions match i4_to_half4_scale: {0, 4, 1, 5} of q. - const int n0 = (q >> 0) & 0xf; - const int n4 = (q >> 16) & 0xf; - const int n1 = (q >> 4) & 0xf; - const int n5 = (q >> 20) & 0xf; - vector_type res; res.template AsType()(Number<0>{}) = fp32_to_bhalf_truncate(static_cast(n0 - 8) * s0); @@ -242,10 +177,14 @@ __device__ inline bhalf4_t i4_to_bhalf4_scale_truncate(int q, const ck::bhalf2_t return res.template AsType()[Number<0>{}]; } -// AIESW-32282: truncate-bf16 variant of i4_to_bhalf4_zp_scale. -__device__ inline bhalf4_t i4_to_bhalf4_zp_scale_truncate(int q, - const ck::bhalf2_t& scale, - const ck::bhalf2_t& scaled_zp) +// AIESW-32176/32282: bf16 analog of i4_to_half4_zp_scale (asymmetric / AWQ). +// Same nibble-position pattern as i4_to_bhalf4_scale plus a per-group +// scaled_zp subtract, matching (nibble - zp) * scale via the +// (nibble - 8) * scale - scaled_zp identity. Caller precomputes +// scaled_zp = (zp - 8) * scale at weight load. Trailing fp32->bf16 uses +// the same truncate-to-bf16 mode as the symmetric helper above. +__device__ inline bhalf4_t +i4_to_bhalf4_zp_scale(int q, const ck::bhalf2_t& scale, const ck::bhalf2_t& scaled_zp) { vector_type scale_v; scale_v.template AsType()(Number<0>{}) = scale; @@ -495,12 +434,12 @@ struct DequantPack8 #endif } - // AIESW-32176: bf16 overload — mirrors the fp16 path but uses - // i4_to_bhalf4_scale (which goes through fp32 intermediates because RDNA3 - // has no native bhalf2 multiply). The truncate-rounding variant is a - // sibling element-op (DequantPack8Truncate) so the choice is - // template-instantiation-time, not preprocessor-time. See - // AIESW-32282 commentary on DequantPolicyFor<> in this file. + // AIESW-32176/32282: bf16 overload — uses i4_to_bhalf4_scale which goes + // through fp32 intermediates (RDNA3 has no native bhalf2 multiply) and + // applies a bit-cast truncate for the trailing fp32->bf16 step. This is + // the only bf16 rounding mode the kernel ships — see the + // fp32_to_bhalf_truncate comment for the rounding-mode decision and + // lm_eval evidence. __host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x, const ck::bhalf2_t& z) const { @@ -515,7 +454,7 @@ struct DequantPack8 #else // Reference (slow) path for the !CK_USE_PK4_LAYOUT_SHUFFLE config: // do the unscaled bf16 dequant via existing type_convert, then apply - // (* scale) per-element via fp32. + // (* scale) per-element via fp32 with truncate-to-bf16. vector_type dst; vector_type src{x}; vector_type z_v; @@ -530,8 +469,8 @@ struct DequantPack8 const float v0 = type_convert(v_v.template AsType()[Number<0>{}]); const float v1 = type_convert(v_v.template AsType()[Number<1>{}]); vector_type r; - r.template AsType()(Number<0>{}) = type_convert(v0 * s0); - r.template AsType()(Number<1>{}) = type_convert(v1 * s1); + r.template AsType()(Number<0>{}) = fp32_to_bhalf_truncate(v0 * s0); + r.template AsType()(Number<1>{}) = fp32_to_bhalf_truncate(v1 * s1); dst.template AsType()(i) = r.template AsType()[Number<0>{}]; }); @@ -584,11 +523,10 @@ struct DequantPack8WithZp #endif } - // AIESW-32176: bf16 overload — mirrors the fp16 path but uses - // i4_to_bhalf4_zp_scale (fp32 intermediates; no native bhalf2 multiply on - // RDNA3). The truncate-rounding variant is a sibling element-op - // (DequantPack8WithZpTruncate) — see AIESW-32282 commentary on - // DequantPolicyFor<> in this file. + // AIESW-32176/32282: bf16 overload — uses i4_to_bhalf4_zp_scale (fp32 + // intermediates because RDNA3 has no native bhalf2 multiply) with a + // bit-cast truncate for the trailing fp32->bf16 step. Locked in as the + // only bf16 rounding mode — see the fp32_to_bhalf_truncate comment. __host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x, const ck::bhalf2_t& s, @@ -605,134 +543,7 @@ struct DequantPack8WithZp y = result.template AsType()[Number<0>{}]; #else // Reference (slow) path: unscaled bf16 dequant + per-element - // (* scale) - scaled_zp via fp32. - vector_type dst; - vector_type src{x}; - vector_type s_v; - s_v.template AsType()(Number<0>{}) = s; - vector_type z_v; - z_v.template AsType()(Number<0>{}) = scaled_zp; - const float s0 = type_convert(s_v.template AsType()[Number<0>{}]); - const float s1 = type_convert(s_v.template AsType()[Number<1>{}]); - const float z0 = type_convert(z_v.template AsType()[Number<0>{}]); - const float z1 = type_convert(z_v.template AsType()[Number<1>{}]); - - static_for<0, 4, 1>{}([&](auto i) { - auto v = type_convert(src.template AsType()[i]); - vector_type v_v; - v_v.template AsType()(Number<0>{}) = v; - const float v0 = type_convert(v_v.template AsType()[Number<0>{}]); - const float v1 = type_convert(v_v.template AsType()[Number<1>{}]); - vector_type r; - r.template AsType()(Number<0>{}) = type_convert(v0 * s0 - z0); - r.template AsType()(Number<1>{}) = type_convert(v1 * s1 - z1); - dst.template AsType()(i) = r.template AsType()[Number<0>{}]; - }); - - y = dst.template AsType()[Number<0>{}]; -#endif - } - - constexpr const static bool is_pack8_invocable = true; -}; - -// AIESW-32282: truncate-bf16 variant of DequantPack8. -// fp16 path: identical to DequantPack8 (the fp16 bit-trick is already -// optimal and has no rounding chain to skip — i4_to_half4_scale). -// bf16 path: routes through i4_to_bhalf4_scale_truncate, which replaces -// the trailing IEEE round-to-nearest-even with a bit-cast >>16 truncate -// (saves v_add3_u32 + v_cmp_o_f32 + v_cndmask_b16 per nibble). Worst-case -// error 0.5 ULP of bf16, inside the 5e-3 op-test tolerance. -struct DequantPack8Truncate -{ - static constexpr const char* name = "DequantPack8Truncate"; - - template - __host__ __device__ void operator()(Y& y, const X& x, const Z& z) const; - - // fp16: delegate to DequantPack8 (no rounding chain to skip). - __host__ __device__ constexpr void - operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const - { - DequantPack8{}(y, x, z); - } - - __host__ __device__ constexpr void - operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x, const ck::bhalf2_t& z) const - { -#if CK_USE_PK4_LAYOUT_SHUFFLE - vector_type result; - - result.template AsType()(Number<0>{}) = - i4_to_bhalf4_scale_truncate(bit_cast(x), z); - result.template AsType()(Number<1>{}) = - i4_to_bhalf4_scale_truncate(bit_cast(x) >> 8, z); - - y = result.template AsType()[Number<0>{}]; -#else - // Reference (slow) path: do unscaled bf16 dequant then apply scale - // via fp32 with truncate-to-bf16 instead of RTE. - vector_type dst; - vector_type src{x}; - vector_type z_v; - z_v.template AsType()(Number<0>{}) = z; - const float s0 = type_convert(z_v.template AsType()[Number<0>{}]); - const float s1 = type_convert(z_v.template AsType()[Number<1>{}]); - - static_for<0, 4, 1>{}([&](auto i) { - auto v = type_convert(src.template AsType()[i]); - vector_type v_v; - v_v.template AsType()(Number<0>{}) = v; - const float v0 = type_convert(v_v.template AsType()[Number<0>{}]); - const float v1 = type_convert(v_v.template AsType()[Number<1>{}]); - vector_type r; - r.template AsType()(Number<0>{}) = fp32_to_bhalf_truncate(v0 * s0); - r.template AsType()(Number<1>{}) = fp32_to_bhalf_truncate(v1 * s1); - dst.template AsType()(i) = r.template AsType()[Number<0>{}]; - }); - - y = dst.template AsType()[Number<0>{}]; -#endif - } - - constexpr const static bool is_pack8_invocable = true; -}; - -// AIESW-32282: truncate-bf16 variant of DequantPack8WithZp (asymmetric AWQ). -// Mirrors DequantPack8Truncate: fp16 delegates to DequantPack8WithZp, -// bf16 routes through i4_to_bhalf4_zp_scale_truncate. -struct DequantPack8WithZpTruncate -{ - static constexpr const char* name = "DequantPack8WithZpTruncate"; - - template - __host__ __device__ void operator()(Y& y, const X& x, const S& s, const ZP& zp) const; - - // fp16: delegate to DequantPack8WithZp. - __host__ __device__ constexpr void operator()(ck::half8_t& y, - const ck::pk_i4x4_t& x, - const ck::half2_t& s, - const ck::half2_t& scaled_zp) const - { - DequantPack8WithZp{}(y, x, s, scaled_zp); - } - - __host__ __device__ constexpr void operator()(ck::bhalf8_t& y, - const ck::pk_i4x4_t& x, - const ck::bhalf2_t& s, - const ck::bhalf2_t& scaled_zp) const - { -#if CK_USE_PK4_LAYOUT_SHUFFLE - vector_type result; - - result.template AsType()(Number<0>{}) = - i4_to_bhalf4_zp_scale_truncate(bit_cast(x), s, scaled_zp); - result.template AsType()(Number<1>{}) = - i4_to_bhalf4_zp_scale_truncate(bit_cast(x) >> 8, s, scaled_zp); - - y = result.template AsType()[Number<0>{}]; -#else - // Reference (slow) path with truncate-to-bf16. + // (* scale) - scaled_zp via fp32 with truncate-to-bf16. vector_type dst; vector_type src{x}; vector_type s_v; @@ -787,12 +598,6 @@ struct DequantPolicyFor using sym_type = DequantPack8; using asym_type = DequantPack8WithZp; }; -template <> -struct DequantPolicyFor -{ - using sym_type = DequantPack8Truncate; - using asym_type = DequantPack8WithZpTruncate; -}; struct PassThroughPack2 { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index a73c55d279..e098922aa4 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1484,9 +1484,9 @@ struct ThreadwiseTensorSliceTransfer_v4 // AIESW-32282: BElementOp (defaulted to ck::tensor_operation::element_wise::DequantPack8) // selects which i4 -> {fp16,bf16} dequant element-op the pk_i4 path // instantiates per-call. The default preserves the original hardcoded - // behavior, so existing callers compile bit-identically. Callers that want - // a different rounding policy (e.g. DequantPack8Truncate for bf16) pass it - // explicitly via the template parameter on the call site. + // behavior, so existing callers compile bit-identically. Callers that + // need a different dequant element-op pass it explicitly via the + // template parameter on the call site. template