Files
composable_kernel/include/ck/tensor_operation/gpu/device
Matthias Gehre d43c474532 [CK] AIESW-32282: thread BElementwiseOperation dequant op down to ThreadwiseTensorSliceTransfer_v4 + bf16 truncate variant
Previously the wmma_cshuffle_v3 b_scale device-op's BElementwiseOperation
template parameter was carried as a struct member through the gridwise and
blockwise pipelines, but the per-nibble dequant call site in
ThreadwiseTensorSliceTransfer_v4::Run() hardcoded a local DequantPack8{} /
DequantPack8WithZp{} instance and ignored the template-supplied op.

This commit:

* Adds a new dedicated BDequantOp template parameter (defaulted to void
  for upstream compatibility) to:
    - device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp
    - grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp
    - grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp
    - block/blockwise_gemm_pipeline_wmma_selector.hpp
    - block/blockwise_gemm_pipeline_wmmaops_v1.hpp
  The new slot is separate from BElementwiseOperation because that one is
  also consumed at the B global->LDS copy via
  ThreadwiseTensorSliceTransfer_v3r1, which expects a 2-arg operator()
  while DequantPack8WithZp has 3/4-arg overloads only.

* Adds a DequantPolicyFor<> trait in
  element/unary_element_wise_operation.hpp that maps a "B dequant carrier"
  type to the (sym, asym) pair the v1 Interwave pipeline must compile with.
  Defaults to (void, void) so any non-dequant carrier (PassThrough
  included) lowers to existing behaviour.

* Updates ThreadwiseTensorSliceTransfer_v4::Run() (sym + asym overloads)
  to take a templated BElementOp / BElementOpAsym and instantiate it
  locally. Default arguments preserve the prior DequantPack8{} /
  DequantPack8WithZp{} behaviour bit-identically — existing CK callers
  (no BDequantOp passed) link unchanged.

* Adds bf16 truncate-via-bit-cast variants in
  element/unary_element_wise_operation.hpp:
    - fp32_to_bhalf_truncate(float)
    - i4_to_bhalf4_scale_truncate(int, bhalf2_t)
    - i4_to_bhalf4_zp_scale_truncate(int, bhalf2_t, bhalf2_t)
    - DequantPack8Truncate / DequantPack8WithZpTruncate element-ops
  These skip the IEEE round-to-nearest-even chain that
  type_convert<bhalf_t>(float) lowers to (v_add3_u32 +0x7fff bias +
  v_cmp_o_f32 + v_cndmask_b16 0x7fc0 NaN-quietening) — about ~1150 of
  the 3988 lines in the bf16 asym ISA dump. Worst-case error vs RTE
  is 0.5 ULP of bf16 = ~4e-3 relative, well inside the W4A16 op-test
  TOL_REL=5e-3. fp16 overloads of the truncate variants delegate to the
  non-truncate path (the fp16 bit-trick is already optimal, no rounding
  chain to remove).

  CK analog of the Triton-side optimization in vLLM PR ROCm/vllm#953.

End-to-end measurement on RedHatAI/Qwen3-8B-quantized.w4a16 (bf16, native
dtype, --num-prompts 10 --output-len 1 --input-len 3968):

  Triton baseline:  3030 ms TTFT
  CK with RTE:      3278 ms TTFT (+8.2% LOSS vs Triton)
  CK with TRUNC:    2796 ms TTFT (-7.7% WIN vs Triton)

The truncate variant closes the bf16 gap and overtakes Triton; RTE is a
regression on this hardware (gfx1151 / Strix Halo) because RDNA3 lacks
packed bf16 FMA and the rounding chain dominates the dequant pipeline.

Smoke test: all 16 (sym/asym x fp16/bf16 x G=32/G=128 x RTE/TRUNC)
combinations pass at TOL_REL=5e-3 in op_tests/test_gemm_w4a16.py.

Backward compatibility: with the new template arg left at its void
default, existing CK callers produce bit-identical code to before — the
threadwise transfer's defaulted template arg resolves to DequantPack8 /
DequantPack8WithZp via std::conditional_t.
2026-05-21 11:50:24 +02:00
..