mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
Previously the wmma_cshuffle_v3 b_scale device-op's BElementwiseOperation
template parameter was carried as a struct member through the gridwise and
blockwise pipelines, but the per-nibble dequant call site in
ThreadwiseTensorSliceTransfer_v4::Run() hardcoded a local DequantPack8{} /
DequantPack8WithZp{} instance and ignored the template-supplied op.
This commit:
* Adds a new dedicated BDequantOp template parameter (defaulted to void
for upstream compatibility) to:
- device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp
- grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp
- grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp
- block/blockwise_gemm_pipeline_wmma_selector.hpp
- block/blockwise_gemm_pipeline_wmmaops_v1.hpp
The new slot is separate from BElementwiseOperation because that one is
also consumed at the B global->LDS copy via
ThreadwiseTensorSliceTransfer_v3r1, which expects a 2-arg operator()
while DequantPack8WithZp has 3/4-arg overloads only.
* Adds a DequantPolicyFor<> trait in
element/unary_element_wise_operation.hpp that maps a "B dequant carrier"
type to the (sym, asym) pair the v1 Interwave pipeline must compile with.
Defaults to (void, void) so any non-dequant carrier (PassThrough
included) lowers to existing behaviour.
* Updates ThreadwiseTensorSliceTransfer_v4::Run() (sym + asym overloads)
to take a templated BElementOp / BElementOpAsym and instantiate it
locally. Default arguments preserve the prior DequantPack8{} /
DequantPack8WithZp{} behaviour bit-identically — existing CK callers
(no BDequantOp passed) link unchanged.
* Adds bf16 truncate-via-bit-cast variants in
element/unary_element_wise_operation.hpp:
- fp32_to_bhalf_truncate(float)
- i4_to_bhalf4_scale_truncate(int, bhalf2_t)
- i4_to_bhalf4_zp_scale_truncate(int, bhalf2_t, bhalf2_t)
- DequantPack8Truncate / DequantPack8WithZpTruncate element-ops
These skip the IEEE round-to-nearest-even chain that
type_convert<bhalf_t>(float) lowers to (v_add3_u32 +0x7fff bias +
v_cmp_o_f32 + v_cndmask_b16 0x7fc0 NaN-quietening) — about ~1150 of
the 3988 lines in the bf16 asym ISA dump. Worst-case error vs RTE
is 0.5 ULP of bf16 = ~4e-3 relative, well inside the W4A16 op-test
TOL_REL=5e-3. fp16 overloads of the truncate variants delegate to the
non-truncate path (the fp16 bit-trick is already optimal, no rounding
chain to remove).
CK analog of the Triton-side optimization in vLLM PR ROCm/vllm#953.
End-to-end measurement on RedHatAI/Qwen3-8B-quantized.w4a16 (bf16, native
dtype, --num-prompts 10 --output-len 1 --input-len 3968):
Triton baseline: 3030 ms TTFT
CK with RTE: 3278 ms TTFT (+8.2% LOSS vs Triton)
CK with TRUNC: 2796 ms TTFT (-7.7% WIN vs Triton)
The truncate variant closes the bf16 gap and overtakes Triton; RTE is a
regression on this hardware (gfx1151 / Strix Halo) because RDNA3 lacks
packed bf16 FMA and the rounding chain dominates the dequant pipeline.
Smoke test: all 16 (sym/asym x fp16/bf16 x G=32/G=128 x RTE/TRUNC)
combinations pass at TOL_REL=5e-3 in op_tests/test_gemm_w4a16.py.
Backward compatibility: with the new template arg left at its void
default, existing CK callers produce bit-identical code to before — the
threadwise transfer's defaulted template arg resolves to DequantPack8 /
DequantPack8WithZp via std::conditional_t.