diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 70d2efe6e6..4f8c835086 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -112,6 +112,83 @@ i4_to_half4_zp_scale(int q, const ck::half2_t& scale, const ck::half2_t& scaled_ return res.template AsType()[Number<0>{}]; } +// AIESW-32176: bf16 analog of i4_to_half4_scale. +// +// IMPORTANT: this matches the FP16 helper's nibble-position pattern +// (extract q's nibbles at positions {0, 4, 1, 5}), NOT the order produced +// by the existing unscaled `i4_to_bhalf4` (which uses {0, 1, 2, 3} packed +// via __byte_perm). The two orderings produce different lane layouts that +// are NOT interchangeable downstream — the b_scale GEMM kernel +// accumulates dequant values against specific activation lanes via WMMA, +// so the lane order must match what the fp16 path produces in order to +// reuse the same packed-weight buffer for both fp16 and bf16 dispatches. +// +// Implementation: scalar fp32 conversion. No native bhalf2 multiply / +// fma on RDNA3, so the fp16 trick (single-instruction integer-bias-shift +// into fp16 bit pattern) doesn't translate directly. The scalar path +// below is correct + simple — perf cost is real but the bf16 dispatch +// isn't the perf-critical case (fp16 is the production target). +__device__ inline bhalf4_t i4_to_bhalf4_scale(int q, const ck::bhalf2_t& scale) +{ + vector_type scale_v; + scale_v.template AsType()(Number<0>{}) = scale; + const float s0 = type_convert(scale_v.template AsType()[Number<0>{}]); + const float s1 = type_convert(scale_v.template AsType()[Number<1>{}]); + + // Nibble positions match i4_to_half4_scale: {0, 4, 1, 5} of q. + const int n0 = (q >> 0) & 0xf; + const int n4 = (q >> 16) & 0xf; + const int n1 = (q >> 4) & 0xf; + const int n5 = (q >> 20) & 0xf; + + vector_type res; + res.template AsType()(Number<0>{}) = + type_convert(static_cast(n0 - 8) * s0); + res.template AsType()(Number<1>{}) = + type_convert(static_cast(n4 - 8) * s1); + res.template AsType()(Number<2>{}) = + type_convert(static_cast(n1 - 8) * s0); + res.template AsType()(Number<3>{}) = + type_convert(static_cast(n5 - 8) * s1); + + return res.template AsType()[Number<0>{}]; +} + +// AIESW-32176: bf16 analog of i4_to_half4_zp_scale (asymmetric / AWQ). +// Same nibble-position pattern as i4_to_bhalf4_scale plus a per-group +// scaled_zp subtract, matching (nibble - zp) * scale via the +// (nibble - 8) * scale - scaled_zp identity. Caller precomputes +// scaled_zp = (zp - 8) * scale at weight load. +__device__ inline bhalf4_t +i4_to_bhalf4_zp_scale(int q, const ck::bhalf2_t& scale, const ck::bhalf2_t& scaled_zp) +{ + vector_type scale_v; + scale_v.template AsType()(Number<0>{}) = scale; + vector_type zp_v; + zp_v.template AsType()(Number<0>{}) = scaled_zp; + const float s0 = type_convert(scale_v.template AsType()[Number<0>{}]); + const float s1 = type_convert(scale_v.template AsType()[Number<1>{}]); + const float z0 = type_convert(zp_v.template AsType()[Number<0>{}]); + const float z1 = type_convert(zp_v.template AsType()[Number<1>{}]); + + const int n0 = (q >> 0) & 0xf; + const int n4 = (q >> 16) & 0xf; + const int n1 = (q >> 4) & 0xf; + const int n5 = (q >> 20) & 0xf; + + vector_type res; + res.template AsType()(Number<0>{}) = + type_convert(static_cast(n0 - 8) * s0 - z0); + res.template AsType()(Number<1>{}) = + type_convert(static_cast(n4 - 8) * s1 - z1); + res.template AsType()(Number<2>{}) = + type_convert(static_cast(n1 - 8) * s0 - z0); + res.template AsType()(Number<3>{}) = + type_convert(static_cast(n5 - 8) * s1 - z1); + + return res.template AsType()[Number<0>{}]; +} + __device__ inline f8x4_t i4_to_f8x4(int q) { const int LO = 0x000f000f; @@ -333,6 +410,47 @@ struct DequantPack8 #endif } + // AIESW-32176: bf16 overload — mirrors the fp16 path but uses + // i4_to_bhalf4_scale (which goes through fp32 intermediates because RDNA3 + // has no native bhalf2 multiply). + __host__ __device__ constexpr void + operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x, const ck::bhalf2_t& z) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + vector_type result; + + result.template AsType()(Number<0>{}) = i4_to_bhalf4_scale(bit_cast(x), z); + result.template AsType()(Number<1>{}) = + i4_to_bhalf4_scale(bit_cast(x) >> 8, z); + + y = result.template AsType()[Number<0>{}]; +#else + // Reference (slow) path for the !CK_USE_PK4_LAYOUT_SHUFFLE config: + // do the unscaled bf16 dequant via existing type_convert, then apply + // (* scale) per-element via fp32. + vector_type dst; + vector_type src{x}; + vector_type z_v; + z_v.template AsType()(Number<0>{}) = z; + const float s0 = type_convert(z_v.template AsType()[Number<0>{}]); + const float s1 = type_convert(z_v.template AsType()[Number<1>{}]); + + static_for<0, 4, 1>{}([&](auto i) { + auto v = type_convert(src.template AsType()[i]); + vector_type v_v; + v_v.template AsType()(Number<0>{}) = v; + const float v0 = type_convert(v_v.template AsType()[Number<0>{}]); + const float v1 = type_convert(v_v.template AsType()[Number<1>{}]); + vector_type r; + r.template AsType()(Number<0>{}) = type_convert(v0 * s0); + r.template AsType()(Number<1>{}) = type_convert(v1 * s1); + dst.template AsType()(i) = r.template AsType()[Number<0>{}]; + }); + + y = dst.template AsType()[Number<0>{}]; +#endif + } + constexpr const static bool is_pack8_invocable = true; }; @@ -378,6 +496,53 @@ struct DequantPack8WithZp #endif } + // AIESW-32176: bf16 overload — mirrors the fp16 path but uses + // i4_to_bhalf4_zp_scale (fp32 intermediates; no native bhalf2 multiply on + // RDNA3). + __host__ __device__ constexpr void operator()(ck::bhalf8_t& y, + const ck::pk_i4x4_t& x, + const ck::bhalf2_t& s, + const ck::bhalf2_t& scaled_zp) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + vector_type result; + + result.template AsType()(Number<0>{}) = + i4_to_bhalf4_zp_scale(bit_cast(x), s, scaled_zp); + result.template AsType()(Number<1>{}) = + i4_to_bhalf4_zp_scale(bit_cast(x) >> 8, s, scaled_zp); + + y = result.template AsType()[Number<0>{}]; +#else + // Reference (slow) path: unscaled bf16 dequant + per-element + // (* scale) - scaled_zp via fp32. + vector_type dst; + vector_type src{x}; + vector_type s_v; + s_v.template AsType()(Number<0>{}) = s; + vector_type z_v; + z_v.template AsType()(Number<0>{}) = scaled_zp; + const float s0 = type_convert(s_v.template AsType()[Number<0>{}]); + const float s1 = type_convert(s_v.template AsType()[Number<1>{}]); + const float z0 = type_convert(z_v.template AsType()[Number<0>{}]); + const float z1 = type_convert(z_v.template AsType()[Number<1>{}]); + + static_for<0, 4, 1>{}([&](auto i) { + auto v = type_convert(src.template AsType()[i]); + vector_type v_v; + v_v.template AsType()(Number<0>{}) = v; + const float v0 = type_convert(v_v.template AsType()[Number<0>{}]); + const float v1 = type_convert(v_v.template AsType()[Number<1>{}]); + vector_type r; + r.template AsType()(Number<0>{}) = type_convert(v0 * s0 - z0); + r.template AsType()(Number<1>{}) = type_convert(v1 * s1 - z1); + dst.template AsType()(i) = r.template AsType()[Number<0>{}]; + }); + + y = dst.template AsType()[Number<0>{}]; +#endif + } + constexpr const static bool is_pack8_invocable = true; };