mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
[CK] AIESW-32282: bake bf16 truncate-to-bf16 conversion as the only behavior; drop *Truncate variants and DequantPolicyFor trait
i4_to_bhalf4_scale and i4_to_bhalf4_zp_scale now use fp32_to_bhalf_truncate (bit-cast >>16) for the trailing fp32->bf16 step. The IEEE round-to- nearest-even path's per-nibble v_add3_u32 + v_cmp_o_f32 + v_cndmask_b16 chain is gone — saves ~3 RDNA3.5 VALU instructions per nibble. Worst-case 0.5 ULP of bf16 error (~4e-3 relative, inside the 5e-3 op-test tolerance); lm_eval on Orion-zhen/Qwen3-1.7B-AWQ shows truncate statistically indistinguishable from Triton (gsm8k 5-shot n=500, McNemar p=1.000). Changes: - DequantPack8Truncate / DequantPack8WithZpTruncate structs removed — DequantPack8 / DequantPack8WithZp now ARE the truncate variants on the bf16 path. fp16 path unchanged (no rounding chain to skip; the fp16 bit-trick is already optimal). - DequantPolicyFor<DequantPack8WithZpTruncate> specialization removed. - The generic DequantPolicyFor<> + DequantPack8WithZp specialization stay in place so the device-op's BDequantOp template hook continues to work as a generic plug-in point for callers that want a custom dequant op (kept upstream-friendly: no caller side has to change). - Comments in threadwise_tensor_slice_transfer.hpp and blockwise_gemm_pipeline_wmmaops_v1.hpp updated to drop the Truncate examples that no longer exist.
This commit is contained in:
@@ -19,9 +19,7 @@ namespace ck {
|
||||
// the matching arity. `void` means "do not override" — the resolved type
|
||||
// falls back to the threadwise default and the generated code is
|
||||
// bit-identical to the pre-AIESW-32282 build. Non-void values let callers
|
||||
// override the dequant element-op per device-op instantiation (e.g.
|
||||
// {DequantPack8Truncate, DequantPack8WithZpTruncate} for bf16 truncate
|
||||
// rounding).
|
||||
// plug in a custom dequant element-op per device-op instantiation.
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
@@ -171,10 +169,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
// template parameters are `void` (the defaults), we route to the threadwise
|
||||
// transfer's own defaults (DequantPack8 / DequantPack8WithZp) so the
|
||||
// generated code matches the pre-AIESW-32282 build bit-for-bit. Non-void
|
||||
// values let the device-op override the bf16 rounding policy (e.g.
|
||||
// {DequantPack8Truncate, DequantPack8WithZpTruncate}). Two slots because
|
||||
// the sym (3-arg) and asym (4-arg) branches both compile in the same
|
||||
// pipeline instantiation — they need element-ops with matching arity.
|
||||
// values let the device-op plug in a custom dequant element-op. Two slots
|
||||
// because the sym (3-arg) and asym (4-arg) branches both compile in the
|
||||
// same pipeline instantiation — they need element-ops with matching arity.
|
||||
using BElementOpSymResolved =
|
||||
std::conditional_t<std::is_same_v<BElementOpSym, void>,
|
||||
ck::tensor_operation::element_wise::DequantPack8,
|
||||
|
||||
@@ -112,7 +112,29 @@ i4_to_half4_zp_scale(int q, const ck::half2_t& scale, const ck::half2_t& scaled_
|
||||
return res.template AsType<half4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
// AIESW-32176: bf16 analog of i4_to_half4_scale.
|
||||
// AIESW-32282: truncate-to-bf16 fp32 reinterpret.
|
||||
//
|
||||
// bf16 IS the upper 16 bits of fp32, so we just drop the low 16 bits. The
|
||||
// IEEE round-to-nearest-even path the compiler emits for
|
||||
// `type_convert<bhalf_t>(float)` would lower on RDNA3.5 to a chain that's
|
||||
// ~3 extra VALU instructions per nibble:
|
||||
// v_add3_u32 +0x7fff (round bias)
|
||||
// v_cmp_o_f32 (ordered compare for NaN detection)
|
||||
// v_cndmask_b16 0x7fc0 (splat qNaN if input was NaN)
|
||||
// For the dequant case the inputs are always finite (no NaN/Inf), so the
|
||||
// NaN-quietening branch is dead code. The round bias adds at most 0.5 ULP
|
||||
// of bf16 error vs RTE (~4e-3 relative), inside the W4A16 op-test
|
||||
// tolerance and statistically indistinguishable from Triton on lm_eval
|
||||
// (verified gsm8k 5-shot on Orion-zhen/Qwen3-1.7B-AWQ, McNemar p=1.000).
|
||||
// Locked in as the only bf16 rounding mode for gfx1151. CK analog of
|
||||
// vLLM PR ROCm/vllm#953's Triton-side fp32->bf16 truncate.
|
||||
__device__ inline bhalf_t fp32_to_bhalf_truncate(float f)
|
||||
{
|
||||
uint32_t bits = __builtin_bit_cast(uint32_t, f);
|
||||
return __builtin_bit_cast(bhalf_t, static_cast<uint16_t>(bits >> 16));
|
||||
}
|
||||
|
||||
// AIESW-32176/32282: bf16 analog of i4_to_half4_scale.
|
||||
//
|
||||
// IMPORTANT: this matches the FP16 helper's nibble-position pattern
|
||||
// (extract q's nibbles at positions {0, 4, 1, 5}), NOT the order produced
|
||||
@@ -123,11 +145,12 @@ i4_to_half4_zp_scale(int q, const ck::half2_t& scale, const ck::half2_t& scaled_
|
||||
// so the lane order must match what the fp16 path produces in order to
|
||||
// reuse the same packed-weight buffer for both fp16 and bf16 dispatches.
|
||||
//
|
||||
// Implementation: scalar fp32 conversion. No native bhalf2 multiply /
|
||||
// fma on RDNA3, so the fp16 trick (single-instruction integer-bias-shift
|
||||
// into fp16 bit pattern) doesn't translate directly. The scalar path
|
||||
// below is correct + simple — perf cost is real but the bf16 dispatch
|
||||
// isn't the perf-critical case (fp16 is the production target).
|
||||
// Implementation: scalar fp32 multiply then truncate-to-bf16. RDNA3 has
|
||||
// no native bhalf2 fma so the fp16 single-instruction bit-trick path
|
||||
// doesn't translate; the trailing fp32->bf16 step uses
|
||||
// fp32_to_bhalf_truncate (bit-cast >>16), which is the only bf16
|
||||
// rounding mode the kernel ships — see fp32_to_bhalf_truncate's comment
|
||||
// for the rounding-mode decision and lm_eval evidence.
|
||||
__device__ inline bhalf4_t i4_to_bhalf4_scale(int q, const ck::bhalf2_t& scale)
|
||||
{
|
||||
vector_type<bhalf_t, 2> scale_v;
|
||||
@@ -141,94 +164,6 @@ __device__ inline bhalf4_t i4_to_bhalf4_scale(int q, const ck::bhalf2_t& scale)
|
||||
const int n1 = (q >> 4) & 0xf;
|
||||
const int n5 = (q >> 20) & 0xf;
|
||||
|
||||
vector_type<bhalf_t, 4> res;
|
||||
res.template AsType<bhalf_t>()(Number<0>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n0 - 8) * s0);
|
||||
res.template AsType<bhalf_t>()(Number<1>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n4 - 8) * s1);
|
||||
res.template AsType<bhalf_t>()(Number<2>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n1 - 8) * s0);
|
||||
res.template AsType<bhalf_t>()(Number<3>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n5 - 8) * s1);
|
||||
|
||||
return res.template AsType<bhalf4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
// AIESW-32176: bf16 analog of i4_to_half4_zp_scale (asymmetric / AWQ).
|
||||
// Same nibble-position pattern as i4_to_bhalf4_scale plus a per-group
|
||||
// scaled_zp subtract, matching (nibble - zp) * scale via the
|
||||
// (nibble - 8) * scale - scaled_zp identity. Caller precomputes
|
||||
// scaled_zp = (zp - 8) * scale at weight load.
|
||||
__device__ inline bhalf4_t
|
||||
i4_to_bhalf4_zp_scale(int q, const ck::bhalf2_t& scale, const ck::bhalf2_t& scaled_zp)
|
||||
{
|
||||
vector_type<bhalf_t, 2> scale_v;
|
||||
scale_v.template AsType<bhalf2_t>()(Number<0>{}) = scale;
|
||||
vector_type<bhalf_t, 2> zp_v;
|
||||
zp_v.template AsType<bhalf2_t>()(Number<0>{}) = scaled_zp;
|
||||
const float s0 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float s1 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
const float z0 = type_convert<float>(zp_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float z1 = type_convert<float>(zp_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
|
||||
const int n0 = (q >> 0) & 0xf;
|
||||
const int n4 = (q >> 16) & 0xf;
|
||||
const int n1 = (q >> 4) & 0xf;
|
||||
const int n5 = (q >> 20) & 0xf;
|
||||
|
||||
vector_type<bhalf_t, 4> res;
|
||||
res.template AsType<bhalf_t>()(Number<0>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n0 - 8) * s0 - z0);
|
||||
res.template AsType<bhalf_t>()(Number<1>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n4 - 8) * s1 - z1);
|
||||
res.template AsType<bhalf_t>()(Number<2>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n1 - 8) * s0 - z0);
|
||||
res.template AsType<bhalf_t>()(Number<3>{}) =
|
||||
type_convert<bhalf_t>(static_cast<float>(n5 - 8) * s1 - z1);
|
||||
|
||||
return res.template AsType<bhalf4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
// AIESW-32282: truncate-to-bf16 fp32 reinterpret.
|
||||
//
|
||||
// bf16 IS the upper 16 bits of fp32, so the "correct" round-to-nearest-even
|
||||
// path the compiler emits for `type_convert<bhalf_t>(float)` lowers to a
|
||||
// chain that's ~3 RDNA3.5 VALU instructions per nibble:
|
||||
// v_add3_u32 +0x7fff (round bias)
|
||||
// v_cmp_o_f32 (ordered compare for NaN detection)
|
||||
// v_cndmask_b16 0x7fc0 (splat qNaN if input was NaN)
|
||||
// Cumulatively this accounts for ~1150 of the 3988 lines in the bf16 asym
|
||||
// ISA dump (see vllm4/notes/ck-w4a16-isa/README.md). The truncate variant
|
||||
// below skips both the rounding bias and the NaN-quietening branch.
|
||||
//
|
||||
// Worst-case error vs RTE: 0.5 ULP of bf16, i.e. relative ~4e-3 — well
|
||||
// inside the existing W4A16 op-test tolerance (TOL_REL = 5e-3). Dequanted
|
||||
// nibbles are always finite (no NaN/Inf inputs), so the NaN-quietening
|
||||
// branch is dead code anyway in this kernel.
|
||||
//
|
||||
// CK analog of vLLM PR ROCm/vllm#953's Triton-side fp32->bf16 truncate.
|
||||
__device__ inline bhalf_t fp32_to_bhalf_truncate(float f)
|
||||
{
|
||||
uint32_t bits = __builtin_bit_cast(uint32_t, f);
|
||||
return __builtin_bit_cast(bhalf_t, static_cast<uint16_t>(bits >> 16));
|
||||
}
|
||||
|
||||
// AIESW-32282: truncate-bf16 variant of i4_to_bhalf4_scale. Same arithmetic
|
||||
// up to the final fp32->bf16 step; uses fp32_to_bhalf_truncate to skip the
|
||||
// IEEE round chain.
|
||||
__device__ inline bhalf4_t i4_to_bhalf4_scale_truncate(int q, const ck::bhalf2_t& scale)
|
||||
{
|
||||
vector_type<bhalf_t, 2> scale_v;
|
||||
scale_v.template AsType<bhalf2_t>()(Number<0>{}) = scale;
|
||||
const float s0 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float s1 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
|
||||
// Nibble positions match i4_to_half4_scale: {0, 4, 1, 5} of q.
|
||||
const int n0 = (q >> 0) & 0xf;
|
||||
const int n4 = (q >> 16) & 0xf;
|
||||
const int n1 = (q >> 4) & 0xf;
|
||||
const int n5 = (q >> 20) & 0xf;
|
||||
|
||||
vector_type<bhalf_t, 4> res;
|
||||
res.template AsType<bhalf_t>()(Number<0>{}) =
|
||||
fp32_to_bhalf_truncate(static_cast<float>(n0 - 8) * s0);
|
||||
@@ -242,10 +177,14 @@ __device__ inline bhalf4_t i4_to_bhalf4_scale_truncate(int q, const ck::bhalf2_t
|
||||
return res.template AsType<bhalf4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
// AIESW-32282: truncate-bf16 variant of i4_to_bhalf4_zp_scale.
|
||||
__device__ inline bhalf4_t i4_to_bhalf4_zp_scale_truncate(int q,
|
||||
const ck::bhalf2_t& scale,
|
||||
const ck::bhalf2_t& scaled_zp)
|
||||
// AIESW-32176/32282: bf16 analog of i4_to_half4_zp_scale (asymmetric / AWQ).
|
||||
// Same nibble-position pattern as i4_to_bhalf4_scale plus a per-group
|
||||
// scaled_zp subtract, matching (nibble - zp) * scale via the
|
||||
// (nibble - 8) * scale - scaled_zp identity. Caller precomputes
|
||||
// scaled_zp = (zp - 8) * scale at weight load. Trailing fp32->bf16 uses
|
||||
// the same truncate-to-bf16 mode as the symmetric helper above.
|
||||
__device__ inline bhalf4_t
|
||||
i4_to_bhalf4_zp_scale(int q, const ck::bhalf2_t& scale, const ck::bhalf2_t& scaled_zp)
|
||||
{
|
||||
vector_type<bhalf_t, 2> scale_v;
|
||||
scale_v.template AsType<bhalf2_t>()(Number<0>{}) = scale;
|
||||
@@ -495,12 +434,12 @@ struct DequantPack8
|
||||
#endif
|
||||
}
|
||||
|
||||
// AIESW-32176: bf16 overload — mirrors the fp16 path but uses
|
||||
// i4_to_bhalf4_scale (which goes through fp32 intermediates because RDNA3
|
||||
// has no native bhalf2 multiply). The truncate-rounding variant is a
|
||||
// sibling element-op (DequantPack8Truncate) so the choice is
|
||||
// template-instantiation-time, not preprocessor-time. See
|
||||
// AIESW-32282 commentary on DequantPolicyFor<> in this file.
|
||||
// AIESW-32176/32282: bf16 overload — uses i4_to_bhalf4_scale which goes
|
||||
// through fp32 intermediates (RDNA3 has no native bhalf2 multiply) and
|
||||
// applies a bit-cast truncate for the trailing fp32->bf16 step. This is
|
||||
// the only bf16 rounding mode the kernel ships — see the
|
||||
// fp32_to_bhalf_truncate comment for the rounding-mode decision and
|
||||
// lm_eval evidence.
|
||||
__host__ __device__ constexpr void
|
||||
operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x, const ck::bhalf2_t& z) const
|
||||
{
|
||||
@@ -515,7 +454,7 @@ struct DequantPack8
|
||||
#else
|
||||
// Reference (slow) path for the !CK_USE_PK4_LAYOUT_SHUFFLE config:
|
||||
// do the unscaled bf16 dequant via existing type_convert, then apply
|
||||
// (* scale) per-element via fp32.
|
||||
// (* scale) per-element via fp32 with truncate-to-bf16.
|
||||
vector_type<bhalf_t, 8> dst;
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
vector_type<bhalf_t, 2> z_v;
|
||||
@@ -530,8 +469,8 @@ struct DequantPack8
|
||||
const float v0 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float v1 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
vector_type<bhalf_t, 2> r;
|
||||
r.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(v0 * s0);
|
||||
r.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(v1 * s1);
|
||||
r.template AsType<bhalf_t>()(Number<0>{}) = fp32_to_bhalf_truncate(v0 * s0);
|
||||
r.template AsType<bhalf_t>()(Number<1>{}) = fp32_to_bhalf_truncate(v1 * s1);
|
||||
dst.template AsType<bhalf2_t>()(i) = r.template AsType<bhalf2_t>()[Number<0>{}];
|
||||
});
|
||||
|
||||
@@ -584,11 +523,10 @@ struct DequantPack8WithZp
|
||||
#endif
|
||||
}
|
||||
|
||||
// AIESW-32176: bf16 overload — mirrors the fp16 path but uses
|
||||
// i4_to_bhalf4_zp_scale (fp32 intermediates; no native bhalf2 multiply on
|
||||
// RDNA3). The truncate-rounding variant is a sibling element-op
|
||||
// (DequantPack8WithZpTruncate) — see AIESW-32282 commentary on
|
||||
// DequantPolicyFor<> in this file.
|
||||
// AIESW-32176/32282: bf16 overload — uses i4_to_bhalf4_zp_scale (fp32
|
||||
// intermediates because RDNA3 has no native bhalf2 multiply) with a
|
||||
// bit-cast truncate for the trailing fp32->bf16 step. Locked in as the
|
||||
// only bf16 rounding mode — see the fp32_to_bhalf_truncate comment.
|
||||
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y,
|
||||
const ck::pk_i4x4_t& x,
|
||||
const ck::bhalf2_t& s,
|
||||
@@ -605,134 +543,7 @@ struct DequantPack8WithZp
|
||||
y = result.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#else
|
||||
// Reference (slow) path: unscaled bf16 dequant + per-element
|
||||
// (* scale) - scaled_zp via fp32.
|
||||
vector_type<bhalf_t, 8> dst;
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
vector_type<bhalf_t, 2> s_v;
|
||||
s_v.template AsType<bhalf2_t>()(Number<0>{}) = s;
|
||||
vector_type<bhalf_t, 2> z_v;
|
||||
z_v.template AsType<bhalf2_t>()(Number<0>{}) = scaled_zp;
|
||||
const float s0 = type_convert<float>(s_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float s1 = type_convert<float>(s_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
const float z0 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float z1 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
auto v = type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[i]);
|
||||
vector_type<bhalf_t, 2> v_v;
|
||||
v_v.template AsType<bhalf2_t>()(Number<0>{}) = v;
|
||||
const float v0 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float v1 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
vector_type<bhalf_t, 2> r;
|
||||
r.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(v0 * s0 - z0);
|
||||
r.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(v1 * s1 - z1);
|
||||
dst.template AsType<bhalf2_t>()(i) = r.template AsType<bhalf2_t>()[Number<0>{}];
|
||||
});
|
||||
|
||||
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
// AIESW-32282: truncate-bf16 variant of DequantPack8.
|
||||
// fp16 path: identical to DequantPack8 (the fp16 bit-trick is already
|
||||
// optimal and has no rounding chain to skip — i4_to_half4_scale).
|
||||
// bf16 path: routes through i4_to_bhalf4_scale_truncate, which replaces
|
||||
// the trailing IEEE round-to-nearest-even with a bit-cast >>16 truncate
|
||||
// (saves v_add3_u32 + v_cmp_o_f32 + v_cndmask_b16 per nibble). Worst-case
|
||||
// error 0.5 ULP of bf16, inside the 5e-3 op-test tolerance.
|
||||
struct DequantPack8Truncate
|
||||
{
|
||||
static constexpr const char* name = "DequantPack8Truncate";
|
||||
|
||||
template <typename Y, typename X, typename Z>
|
||||
__host__ __device__ void operator()(Y& y, const X& x, const Z& z) const;
|
||||
|
||||
// fp16: delegate to DequantPack8 (no rounding chain to skip).
|
||||
__host__ __device__ constexpr void
|
||||
operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const
|
||||
{
|
||||
DequantPack8{}(y, x, z);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x, const ck::bhalf2_t& z) const
|
||||
{
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
vector_type<bhalf_t, 8> result;
|
||||
|
||||
result.template AsType<bhalf4_t>()(Number<0>{}) =
|
||||
i4_to_bhalf4_scale_truncate(bit_cast<int>(x), z);
|
||||
result.template AsType<bhalf4_t>()(Number<1>{}) =
|
||||
i4_to_bhalf4_scale_truncate(bit_cast<int>(x) >> 8, z);
|
||||
|
||||
y = result.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#else
|
||||
// Reference (slow) path: do unscaled bf16 dequant then apply scale
|
||||
// via fp32 with truncate-to-bf16 instead of RTE.
|
||||
vector_type<bhalf_t, 8> dst;
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
vector_type<bhalf_t, 2> z_v;
|
||||
z_v.template AsType<bhalf2_t>()(Number<0>{}) = z;
|
||||
const float s0 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float s1 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
auto v = type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[i]);
|
||||
vector_type<bhalf_t, 2> v_v;
|
||||
v_v.template AsType<bhalf2_t>()(Number<0>{}) = v;
|
||||
const float v0 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<0>{}]);
|
||||
const float v1 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<1>{}]);
|
||||
vector_type<bhalf_t, 2> r;
|
||||
r.template AsType<bhalf_t>()(Number<0>{}) = fp32_to_bhalf_truncate(v0 * s0);
|
||||
r.template AsType<bhalf_t>()(Number<1>{}) = fp32_to_bhalf_truncate(v1 * s1);
|
||||
dst.template AsType<bhalf2_t>()(i) = r.template AsType<bhalf2_t>()[Number<0>{}];
|
||||
});
|
||||
|
||||
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
// AIESW-32282: truncate-bf16 variant of DequantPack8WithZp (asymmetric AWQ).
|
||||
// Mirrors DequantPack8Truncate: fp16 delegates to DequantPack8WithZp,
|
||||
// bf16 routes through i4_to_bhalf4_zp_scale_truncate.
|
||||
struct DequantPack8WithZpTruncate
|
||||
{
|
||||
static constexpr const char* name = "DequantPack8WithZpTruncate";
|
||||
|
||||
template <typename Y, typename X, typename S, typename ZP>
|
||||
__host__ __device__ void operator()(Y& y, const X& x, const S& s, const ZP& zp) const;
|
||||
|
||||
// fp16: delegate to DequantPack8WithZp.
|
||||
__host__ __device__ constexpr void operator()(ck::half8_t& y,
|
||||
const ck::pk_i4x4_t& x,
|
||||
const ck::half2_t& s,
|
||||
const ck::half2_t& scaled_zp) const
|
||||
{
|
||||
DequantPack8WithZp{}(y, x, s, scaled_zp);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y,
|
||||
const ck::pk_i4x4_t& x,
|
||||
const ck::bhalf2_t& s,
|
||||
const ck::bhalf2_t& scaled_zp) const
|
||||
{
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
vector_type<bhalf_t, 8> result;
|
||||
|
||||
result.template AsType<bhalf4_t>()(Number<0>{}) =
|
||||
i4_to_bhalf4_zp_scale_truncate(bit_cast<int>(x), s, scaled_zp);
|
||||
result.template AsType<bhalf4_t>()(Number<1>{}) =
|
||||
i4_to_bhalf4_zp_scale_truncate(bit_cast<int>(x) >> 8, s, scaled_zp);
|
||||
|
||||
y = result.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#else
|
||||
// Reference (slow) path with truncate-to-bf16.
|
||||
// (* scale) - scaled_zp via fp32 with truncate-to-bf16.
|
||||
vector_type<bhalf_t, 8> dst;
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
vector_type<bhalf_t, 2> s_v;
|
||||
@@ -787,12 +598,6 @@ struct DequantPolicyFor<DequantPack8WithZp>
|
||||
using sym_type = DequantPack8;
|
||||
using asym_type = DequantPack8WithZp;
|
||||
};
|
||||
template <>
|
||||
struct DequantPolicyFor<DequantPack8WithZpTruncate>
|
||||
{
|
||||
using sym_type = DequantPack8Truncate;
|
||||
using asym_type = DequantPack8WithZpTruncate;
|
||||
};
|
||||
|
||||
struct PassThroughPack2
|
||||
{
|
||||
|
||||
@@ -1484,9 +1484,9 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
// AIESW-32282: BElementOp (defaulted to ck::tensor_operation::element_wise::DequantPack8)
|
||||
// selects which i4 -> {fp16,bf16} dequant element-op the pk_i4 path
|
||||
// instantiates per-call. The default preserves the original hardcoded
|
||||
// behavior, so existing callers compile bit-identically. Callers that want
|
||||
// a different rounding policy (e.g. DequantPack8Truncate for bf16) pass it
|
||||
// explicitly via the template parameter on the call site.
|
||||
// behavior, so existing callers compile bit-identically. Callers that
|
||||
// need a different dequant element-op pass it explicitly via the
|
||||
// template parameter on the call site.
|
||||
template <typename BElementOp = ck::tensor_operation::element_wise::DequantPack8,
|
||||
typename SrcRefToOriginDisplacement,
|
||||
typename DstOriginIdx,
|
||||
@@ -1699,9 +1699,8 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
//
|
||||
// AIESW-32282: BElementOp (defaulted to
|
||||
// ck::tensor_operation::element_wise::DequantPack8WithZp) lets callers
|
||||
// override the asymmetric dequant element-op per call (e.g. pass
|
||||
// DequantPack8WithZpTruncate to skip the bf16 round-to-nearest-even
|
||||
// chain). Default preserves prior hardcoded behavior bit-identically.
|
||||
// override the asymmetric dequant element-op per call. Default preserves
|
||||
// prior hardcoded behavior bit-identically.
|
||||
template <typename BElementOp = ck::tensor_operation::element_wise::DequantPack8WithZp,
|
||||
typename SrcRefToOriginDisplacement,
|
||||
typename DstOriginIdx,
|
||||
|
||||
Reference in New Issue
Block a user