[CK] AIESW-32282: bake bf16 truncate-to-bf16 conversion as the only behavior; drop *Truncate variants and DequantPolicyFor trait

i4_to_bhalf4_scale and i4_to_bhalf4_zp_scale now use fp32_to_bhalf_truncate
(bit-cast >>16) for the trailing fp32->bf16 step. The IEEE round-to-
nearest-even path's per-nibble v_add3_u32 + v_cmp_o_f32 + v_cndmask_b16
chain is gone — saves ~3 RDNA3.5 VALU instructions per nibble. Worst-case
0.5 ULP of bf16 error (~4e-3 relative, inside the 5e-3 op-test tolerance);
lm_eval on Orion-zhen/Qwen3-1.7B-AWQ shows truncate statistically
indistinguishable from Triton (gsm8k 5-shot n=500, McNemar p=1.000).

Changes:
- DequantPack8Truncate / DequantPack8WithZpTruncate structs removed —
  DequantPack8 / DequantPack8WithZp now ARE the truncate variants on the
  bf16 path. fp16 path unchanged (no rounding chain to skip; the fp16
  bit-trick is already optimal).
- DequantPolicyFor<DequantPack8WithZpTruncate> specialization removed.
- The generic DequantPolicyFor<> + DequantPack8WithZp specialization stay
  in place so the device-op's BDequantOp template hook continues to work
  as a generic plug-in point for callers that want a custom dequant op
  (kept upstream-friendly: no caller side has to change).
- Comments in threadwise_tensor_slice_transfer.hpp and
  blockwise_gemm_pipeline_wmmaops_v1.hpp updated to drop the Truncate
  examples that no longer exist.
This commit is contained in:
Matthias Gehre
2026-05-21 04:22:38 -06:00
parent d43c474532
commit 15b4a580dc
3 changed files with 60 additions and 259 deletions

View File

@@ -19,9 +19,7 @@ namespace ck {
// the matching arity. `void` means "do not override" — the resolved type
// falls back to the threadwise default and the generated code is
// bit-identical to the pre-AIESW-32282 build. Non-void values let callers
// override the dequant element-op per device-op instantiation (e.g.
// {DequantPack8Truncate, DequantPack8WithZpTruncate} for bf16 truncate
// rounding).
// plug in a custom dequant element-op per device-op instantiation.
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t BlockSize,
typename ADataType,
@@ -171,10 +169,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
// template parameters are `void` (the defaults), we route to the threadwise
// transfer's own defaults (DequantPack8 / DequantPack8WithZp) so the
// generated code matches the pre-AIESW-32282 build bit-for-bit. Non-void
// values let the device-op override the bf16 rounding policy (e.g.
// {DequantPack8Truncate, DequantPack8WithZpTruncate}). Two slots because
// the sym (3-arg) and asym (4-arg) branches both compile in the same
// pipeline instantiation — they need element-ops with matching arity.
// values let the device-op plug in a custom dequant element-op. Two slots
// because the sym (3-arg) and asym (4-arg) branches both compile in the
// same pipeline instantiation — they need element-ops with matching arity.
using BElementOpSymResolved =
std::conditional_t<std::is_same_v<BElementOpSym, void>,
ck::tensor_operation::element_wise::DequantPack8,

View File

@@ -112,7 +112,29 @@ i4_to_half4_zp_scale(int q, const ck::half2_t& scale, const ck::half2_t& scaled_
return res.template AsType<half4_t>()[Number<0>{}];
}
// AIESW-32176: bf16 analog of i4_to_half4_scale.
// AIESW-32282: truncate-to-bf16 fp32 reinterpret.
//
// bf16 IS the upper 16 bits of fp32, so we just drop the low 16 bits. The
// IEEE round-to-nearest-even path the compiler emits for
// `type_convert<bhalf_t>(float)` would lower on RDNA3.5 to a chain that's
// ~3 extra VALU instructions per nibble:
// v_add3_u32 +0x7fff (round bias)
// v_cmp_o_f32 (ordered compare for NaN detection)
// v_cndmask_b16 0x7fc0 (splat qNaN if input was NaN)
// For the dequant case the inputs are always finite (no NaN/Inf), so the
// NaN-quietening branch is dead code. The round bias adds at most 0.5 ULP
// of bf16 error vs RTE (~4e-3 relative), inside the W4A16 op-test
// tolerance and statistically indistinguishable from Triton on lm_eval
// (verified gsm8k 5-shot on Orion-zhen/Qwen3-1.7B-AWQ, McNemar p=1.000).
// Locked in as the only bf16 rounding mode for gfx1151. CK analog of
// vLLM PR ROCm/vllm#953's Triton-side fp32->bf16 truncate.
__device__ inline bhalf_t fp32_to_bhalf_truncate(float f)
{
uint32_t bits = __builtin_bit_cast(uint32_t, f);
return __builtin_bit_cast(bhalf_t, static_cast<uint16_t>(bits >> 16));
}
// AIESW-32176/32282: bf16 analog of i4_to_half4_scale.
//
// IMPORTANT: this matches the FP16 helper's nibble-position pattern
// (extract q's nibbles at positions {0, 4, 1, 5}), NOT the order produced
@@ -123,11 +145,12 @@ i4_to_half4_zp_scale(int q, const ck::half2_t& scale, const ck::half2_t& scaled_
// so the lane order must match what the fp16 path produces in order to
// reuse the same packed-weight buffer for both fp16 and bf16 dispatches.
//
// Implementation: scalar fp32 conversion. No native bhalf2 multiply /
// fma on RDNA3, so the fp16 trick (single-instruction integer-bias-shift
// into fp16 bit pattern) doesn't translate directly. The scalar path
// below is correct + simple — perf cost is real but the bf16 dispatch
// isn't the perf-critical case (fp16 is the production target).
// Implementation: scalar fp32 multiply then truncate-to-bf16. RDNA3 has
// no native bhalf2 fma so the fp16 single-instruction bit-trick path
// doesn't translate; the trailing fp32->bf16 step uses
// fp32_to_bhalf_truncate (bit-cast >>16), which is the only bf16
// rounding mode the kernel ships — see fp32_to_bhalf_truncate's comment
// for the rounding-mode decision and lm_eval evidence.
__device__ inline bhalf4_t i4_to_bhalf4_scale(int q, const ck::bhalf2_t& scale)
{
vector_type<bhalf_t, 2> scale_v;
@@ -141,94 +164,6 @@ __device__ inline bhalf4_t i4_to_bhalf4_scale(int q, const ck::bhalf2_t& scale)
const int n1 = (q >> 4) & 0xf;
const int n5 = (q >> 20) & 0xf;
vector_type<bhalf_t, 4> res;
res.template AsType<bhalf_t>()(Number<0>{}) =
type_convert<bhalf_t>(static_cast<float>(n0 - 8) * s0);
res.template AsType<bhalf_t>()(Number<1>{}) =
type_convert<bhalf_t>(static_cast<float>(n4 - 8) * s1);
res.template AsType<bhalf_t>()(Number<2>{}) =
type_convert<bhalf_t>(static_cast<float>(n1 - 8) * s0);
res.template AsType<bhalf_t>()(Number<3>{}) =
type_convert<bhalf_t>(static_cast<float>(n5 - 8) * s1);
return res.template AsType<bhalf4_t>()[Number<0>{}];
}
// AIESW-32176: bf16 analog of i4_to_half4_zp_scale (asymmetric / AWQ).
// Same nibble-position pattern as i4_to_bhalf4_scale plus a per-group
// scaled_zp subtract, matching (nibble - zp) * scale via the
// (nibble - 8) * scale - scaled_zp identity. Caller precomputes
// scaled_zp = (zp - 8) * scale at weight load.
__device__ inline bhalf4_t
i4_to_bhalf4_zp_scale(int q, const ck::bhalf2_t& scale, const ck::bhalf2_t& scaled_zp)
{
vector_type<bhalf_t, 2> scale_v;
scale_v.template AsType<bhalf2_t>()(Number<0>{}) = scale;
vector_type<bhalf_t, 2> zp_v;
zp_v.template AsType<bhalf2_t>()(Number<0>{}) = scaled_zp;
const float s0 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<0>{}]);
const float s1 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<1>{}]);
const float z0 = type_convert<float>(zp_v.template AsType<bhalf_t>()[Number<0>{}]);
const float z1 = type_convert<float>(zp_v.template AsType<bhalf_t>()[Number<1>{}]);
const int n0 = (q >> 0) & 0xf;
const int n4 = (q >> 16) & 0xf;
const int n1 = (q >> 4) & 0xf;
const int n5 = (q >> 20) & 0xf;
vector_type<bhalf_t, 4> res;
res.template AsType<bhalf_t>()(Number<0>{}) =
type_convert<bhalf_t>(static_cast<float>(n0 - 8) * s0 - z0);
res.template AsType<bhalf_t>()(Number<1>{}) =
type_convert<bhalf_t>(static_cast<float>(n4 - 8) * s1 - z1);
res.template AsType<bhalf_t>()(Number<2>{}) =
type_convert<bhalf_t>(static_cast<float>(n1 - 8) * s0 - z0);
res.template AsType<bhalf_t>()(Number<3>{}) =
type_convert<bhalf_t>(static_cast<float>(n5 - 8) * s1 - z1);
return res.template AsType<bhalf4_t>()[Number<0>{}];
}
// AIESW-32282: truncate-to-bf16 fp32 reinterpret.
//
// bf16 IS the upper 16 bits of fp32, so the "correct" round-to-nearest-even
// path the compiler emits for `type_convert<bhalf_t>(float)` lowers to a
// chain that's ~3 RDNA3.5 VALU instructions per nibble:
// v_add3_u32 +0x7fff (round bias)
// v_cmp_o_f32 (ordered compare for NaN detection)
// v_cndmask_b16 0x7fc0 (splat qNaN if input was NaN)
// Cumulatively this accounts for ~1150 of the 3988 lines in the bf16 asym
// ISA dump (see vllm4/notes/ck-w4a16-isa/README.md). The truncate variant
// below skips both the rounding bias and the NaN-quietening branch.
//
// Worst-case error vs RTE: 0.5 ULP of bf16, i.e. relative ~4e-3 — well
// inside the existing W4A16 op-test tolerance (TOL_REL = 5e-3). Dequanted
// nibbles are always finite (no NaN/Inf inputs), so the NaN-quietening
// branch is dead code anyway in this kernel.
//
// CK analog of vLLM PR ROCm/vllm#953's Triton-side fp32->bf16 truncate.
__device__ inline bhalf_t fp32_to_bhalf_truncate(float f)
{
uint32_t bits = __builtin_bit_cast(uint32_t, f);
return __builtin_bit_cast(bhalf_t, static_cast<uint16_t>(bits >> 16));
}
// AIESW-32282: truncate-bf16 variant of i4_to_bhalf4_scale. Same arithmetic
// up to the final fp32->bf16 step; uses fp32_to_bhalf_truncate to skip the
// IEEE round chain.
__device__ inline bhalf4_t i4_to_bhalf4_scale_truncate(int q, const ck::bhalf2_t& scale)
{
vector_type<bhalf_t, 2> scale_v;
scale_v.template AsType<bhalf2_t>()(Number<0>{}) = scale;
const float s0 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<0>{}]);
const float s1 = type_convert<float>(scale_v.template AsType<bhalf_t>()[Number<1>{}]);
// Nibble positions match i4_to_half4_scale: {0, 4, 1, 5} of q.
const int n0 = (q >> 0) & 0xf;
const int n4 = (q >> 16) & 0xf;
const int n1 = (q >> 4) & 0xf;
const int n5 = (q >> 20) & 0xf;
vector_type<bhalf_t, 4> res;
res.template AsType<bhalf_t>()(Number<0>{}) =
fp32_to_bhalf_truncate(static_cast<float>(n0 - 8) * s0);
@@ -242,10 +177,14 @@ __device__ inline bhalf4_t i4_to_bhalf4_scale_truncate(int q, const ck::bhalf2_t
return res.template AsType<bhalf4_t>()[Number<0>{}];
}
// AIESW-32282: truncate-bf16 variant of i4_to_bhalf4_zp_scale.
__device__ inline bhalf4_t i4_to_bhalf4_zp_scale_truncate(int q,
const ck::bhalf2_t& scale,
const ck::bhalf2_t& scaled_zp)
// AIESW-32176/32282: bf16 analog of i4_to_half4_zp_scale (asymmetric / AWQ).
// Same nibble-position pattern as i4_to_bhalf4_scale plus a per-group
// scaled_zp subtract, matching (nibble - zp) * scale via the
// (nibble - 8) * scale - scaled_zp identity. Caller precomputes
// scaled_zp = (zp - 8) * scale at weight load. Trailing fp32->bf16 uses
// the same truncate-to-bf16 mode as the symmetric helper above.
__device__ inline bhalf4_t
i4_to_bhalf4_zp_scale(int q, const ck::bhalf2_t& scale, const ck::bhalf2_t& scaled_zp)
{
vector_type<bhalf_t, 2> scale_v;
scale_v.template AsType<bhalf2_t>()(Number<0>{}) = scale;
@@ -495,12 +434,12 @@ struct DequantPack8
#endif
}
// AIESW-32176: bf16 overload — mirrors the fp16 path but uses
// i4_to_bhalf4_scale (which goes through fp32 intermediates because RDNA3
// has no native bhalf2 multiply). The truncate-rounding variant is a
// sibling element-op (DequantPack8Truncate) so the choice is
// template-instantiation-time, not preprocessor-time. See
// AIESW-32282 commentary on DequantPolicyFor<> in this file.
// AIESW-32176/32282: bf16 overload — uses i4_to_bhalf4_scale which goes
// through fp32 intermediates (RDNA3 has no native bhalf2 multiply) and
// applies a bit-cast truncate for the trailing fp32->bf16 step. This is
// the only bf16 rounding mode the kernel ships — see the
// fp32_to_bhalf_truncate comment for the rounding-mode decision and
// lm_eval evidence.
__host__ __device__ constexpr void
operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x, const ck::bhalf2_t& z) const
{
@@ -515,7 +454,7 @@ struct DequantPack8
#else
// Reference (slow) path for the !CK_USE_PK4_LAYOUT_SHUFFLE config:
// do the unscaled bf16 dequant via existing type_convert, then apply
// (* scale) per-element via fp32.
// (* scale) per-element via fp32 with truncate-to-bf16.
vector_type<bhalf_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};
vector_type<bhalf_t, 2> z_v;
@@ -530,8 +469,8 @@ struct DequantPack8
const float v0 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<0>{}]);
const float v1 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<1>{}]);
vector_type<bhalf_t, 2> r;
r.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(v0 * s0);
r.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(v1 * s1);
r.template AsType<bhalf_t>()(Number<0>{}) = fp32_to_bhalf_truncate(v0 * s0);
r.template AsType<bhalf_t>()(Number<1>{}) = fp32_to_bhalf_truncate(v1 * s1);
dst.template AsType<bhalf2_t>()(i) = r.template AsType<bhalf2_t>()[Number<0>{}];
});
@@ -584,11 +523,10 @@ struct DequantPack8WithZp
#endif
}
// AIESW-32176: bf16 overload — mirrors the fp16 path but uses
// i4_to_bhalf4_zp_scale (fp32 intermediates; no native bhalf2 multiply on
// RDNA3). The truncate-rounding variant is a sibling element-op
// (DequantPack8WithZpTruncate) — see AIESW-32282 commentary on
// DequantPolicyFor<> in this file.
// AIESW-32176/32282: bf16 overload — uses i4_to_bhalf4_zp_scale (fp32
// intermediates because RDNA3 has no native bhalf2 multiply) with a
// bit-cast truncate for the trailing fp32->bf16 step. Locked in as the
// only bf16 rounding mode — see the fp32_to_bhalf_truncate comment.
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y,
const ck::pk_i4x4_t& x,
const ck::bhalf2_t& s,
@@ -605,134 +543,7 @@ struct DequantPack8WithZp
y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else
// Reference (slow) path: unscaled bf16 dequant + per-element
// (* scale) - scaled_zp via fp32.
vector_type<bhalf_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};
vector_type<bhalf_t, 2> s_v;
s_v.template AsType<bhalf2_t>()(Number<0>{}) = s;
vector_type<bhalf_t, 2> z_v;
z_v.template AsType<bhalf2_t>()(Number<0>{}) = scaled_zp;
const float s0 = type_convert<float>(s_v.template AsType<bhalf_t>()[Number<0>{}]);
const float s1 = type_convert<float>(s_v.template AsType<bhalf_t>()[Number<1>{}]);
const float z0 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<0>{}]);
const float z1 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<1>{}]);
static_for<0, 4, 1>{}([&](auto i) {
auto v = type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[i]);
vector_type<bhalf_t, 2> v_v;
v_v.template AsType<bhalf2_t>()(Number<0>{}) = v;
const float v0 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<0>{}]);
const float v1 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<1>{}]);
vector_type<bhalf_t, 2> r;
r.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(v0 * s0 - z0);
r.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(v1 * s1 - z1);
dst.template AsType<bhalf2_t>()(i) = r.template AsType<bhalf2_t>()[Number<0>{}];
});
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
#endif
}
constexpr const static bool is_pack8_invocable = true;
};
// AIESW-32282: truncate-bf16 variant of DequantPack8.
// fp16 path: identical to DequantPack8 (the fp16 bit-trick is already
// optimal and has no rounding chain to skip — i4_to_half4_scale).
// bf16 path: routes through i4_to_bhalf4_scale_truncate, which replaces
// the trailing IEEE round-to-nearest-even with a bit-cast >>16 truncate
// (saves v_add3_u32 + v_cmp_o_f32 + v_cndmask_b16 per nibble). Worst-case
// error 0.5 ULP of bf16, inside the 5e-3 op-test tolerance.
struct DequantPack8Truncate
{
static constexpr const char* name = "DequantPack8Truncate";
template <typename Y, typename X, typename Z>
__host__ __device__ void operator()(Y& y, const X& x, const Z& z) const;
// fp16: delegate to DequantPack8 (no rounding chain to skip).
__host__ __device__ constexpr void
operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const
{
DequantPack8{}(y, x, z);
}
__host__ __device__ constexpr void
operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x, const ck::bhalf2_t& z) const
{
#if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<bhalf_t, 8> result;
result.template AsType<bhalf4_t>()(Number<0>{}) =
i4_to_bhalf4_scale_truncate(bit_cast<int>(x), z);
result.template AsType<bhalf4_t>()(Number<1>{}) =
i4_to_bhalf4_scale_truncate(bit_cast<int>(x) >> 8, z);
y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else
// Reference (slow) path: do unscaled bf16 dequant then apply scale
// via fp32 with truncate-to-bf16 instead of RTE.
vector_type<bhalf_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};
vector_type<bhalf_t, 2> z_v;
z_v.template AsType<bhalf2_t>()(Number<0>{}) = z;
const float s0 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<0>{}]);
const float s1 = type_convert<float>(z_v.template AsType<bhalf_t>()[Number<1>{}]);
static_for<0, 4, 1>{}([&](auto i) {
auto v = type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[i]);
vector_type<bhalf_t, 2> v_v;
v_v.template AsType<bhalf2_t>()(Number<0>{}) = v;
const float v0 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<0>{}]);
const float v1 = type_convert<float>(v_v.template AsType<bhalf_t>()[Number<1>{}]);
vector_type<bhalf_t, 2> r;
r.template AsType<bhalf_t>()(Number<0>{}) = fp32_to_bhalf_truncate(v0 * s0);
r.template AsType<bhalf_t>()(Number<1>{}) = fp32_to_bhalf_truncate(v1 * s1);
dst.template AsType<bhalf2_t>()(i) = r.template AsType<bhalf2_t>()[Number<0>{}];
});
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
#endif
}
constexpr const static bool is_pack8_invocable = true;
};
// AIESW-32282: truncate-bf16 variant of DequantPack8WithZp (asymmetric AWQ).
// Mirrors DequantPack8Truncate: fp16 delegates to DequantPack8WithZp,
// bf16 routes through i4_to_bhalf4_zp_scale_truncate.
struct DequantPack8WithZpTruncate
{
static constexpr const char* name = "DequantPack8WithZpTruncate";
template <typename Y, typename X, typename S, typename ZP>
__host__ __device__ void operator()(Y& y, const X& x, const S& s, const ZP& zp) const;
// fp16: delegate to DequantPack8WithZp.
__host__ __device__ constexpr void operator()(ck::half8_t& y,
const ck::pk_i4x4_t& x,
const ck::half2_t& s,
const ck::half2_t& scaled_zp) const
{
DequantPack8WithZp{}(y, x, s, scaled_zp);
}
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y,
const ck::pk_i4x4_t& x,
const ck::bhalf2_t& s,
const ck::bhalf2_t& scaled_zp) const
{
#if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<bhalf_t, 8> result;
result.template AsType<bhalf4_t>()(Number<0>{}) =
i4_to_bhalf4_zp_scale_truncate(bit_cast<int>(x), s, scaled_zp);
result.template AsType<bhalf4_t>()(Number<1>{}) =
i4_to_bhalf4_zp_scale_truncate(bit_cast<int>(x) >> 8, s, scaled_zp);
y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else
// Reference (slow) path with truncate-to-bf16.
// (* scale) - scaled_zp via fp32 with truncate-to-bf16.
vector_type<bhalf_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};
vector_type<bhalf_t, 2> s_v;
@@ -787,12 +598,6 @@ struct DequantPolicyFor<DequantPack8WithZp>
using sym_type = DequantPack8;
using asym_type = DequantPack8WithZp;
};
template <>
struct DequantPolicyFor<DequantPack8WithZpTruncate>
{
using sym_type = DequantPack8Truncate;
using asym_type = DequantPack8WithZpTruncate;
};
struct PassThroughPack2
{

View File

@@ -1484,9 +1484,9 @@ struct ThreadwiseTensorSliceTransfer_v4
// AIESW-32282: BElementOp (defaulted to ck::tensor_operation::element_wise::DequantPack8)
// selects which i4 -> {fp16,bf16} dequant element-op the pk_i4 path
// instantiates per-call. The default preserves the original hardcoded
// behavior, so existing callers compile bit-identically. Callers that want
// a different rounding policy (e.g. DequantPack8Truncate for bf16) pass it
// explicitly via the template parameter on the call site.
// behavior, so existing callers compile bit-identically. Callers that
// need a different dequant element-op pass it explicitly via the
// template parameter on the call site.
template <typename BElementOp = ck::tensor_operation::element_wise::DequantPack8,
typename SrcRefToOriginDisplacement,
typename DstOriginIdx,
@@ -1699,9 +1699,8 @@ struct ThreadwiseTensorSliceTransfer_v4
//
// AIESW-32282: BElementOp (defaulted to
// ck::tensor_operation::element_wise::DequantPack8WithZp) lets callers
// override the asymmetric dequant element-op per call (e.g. pass
// DequantPack8WithZpTruncate to skip the bf16 round-to-nearest-even
// chain). Default preserves prior hardcoded behavior bit-identically.
// override the asymmetric dequant element-op per call. Default preserves
// prior hardcoded behavior bit-identically.
template <typename BElementOp = ck::tensor_operation::element_wise::DequantPack8WithZp,
typename SrcRefToOriginDisplacement,
typename DstOriginIdx,