mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 04:37:02 +00:00
Adds an optional asymmetric int4 dequant path to the wmma_cshuffle_v3
b_scale GEMM device-op. The existing path computes (nibble - 8) * scale
(symmetric uint4b8); the new path computes (nibble - zp) * scale per
group (matches AWQ / per-group-zero-point checkpoints). Implemented
via the algebraic identity:
(nibble - zp) * scale == (nibble - 8) * scale - (zp - 8) * scale
Callers precompute scaled_zp = (zp - 8) * scale once at weight load and
pass the [N, K/G] tensor alongside the existing scale tensor. The hot
loop pays one extra fp16 vector subtract per dequant pack.
Changes are additive and gated on optional `p_b_zero_point` being non-null
at the device-op level (defaults to nullptr) and `BZeroPointStruct = Empty`
at the gridwise/blockwise level (defaults to Empty). Symmetric callers see
no signature change and compile bit-identically.
Files:
* element/unary_element_wise_operation.hpp
+ `i4_to_half4_zp_scale()`: pk_i4 -> half4 dequant with per-group
scaled_zp subtract (mirrors `i4_to_half4_scale`).
+ `DequantPack8WithZp`: element-op wrapping the above; mirrors
`DequantPack8`. `is_pack8_invocable = true`.
* thread/threadwise_tensor_slice_transfer.hpp
+ `ThreadwiseTensorSliceTransfer_v4::Run()` overload that takes
`scale + scaled_zp` instead of just `scale` and dispatches to
`DequantPack8WithZp` on the pk_i4 path. Non-pk_i4 SrcData falls
back to the plain dequant path with scaled_zp ignored.
* block/blockwise_gemm_pipeline_wmmaops_v1.hpp
+ `BZeroPointStruct = Empty` defaulted template arg on Run().
+ `if constexpr(HasBZp)` branch picks the new threadwise Run() overload
and runs an extra GlobalLoad cycle for the zero-point struct,
synchronized with the existing BScaleStruct cadence.
* block/blockwise_gemm_pipeline_wmmaops_v3.hpp
+ `BZeroPointStruct = Empty` defaulted template arg on Run() for
signature compatibility with v1. v3 (Intrawave) does not implement
the asymmetric path; passing a non-Empty BZeroPointStruct would
compile but silently ignore it. Caller is responsible for routing
asymmetric configs to v1 (Interwave).
* grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp
+ Forward `b_zero_point_struct` (defaulted Empty) from Run() to the
blockwise pipeline's Run().
* grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp
+ `p_b_zero_point_grid` field on Argument (defaulted nullptr).
+ Run() builds a BZeroPoint struct from `p_b_zero_point_grid` when
non-null (mirroring the BScale struct construction) and threads it
through to Base::Run; symmetric path stays bit-identical when
p_b_zero_point_grid is nullptr.
* device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp
+ Optional `p_b_zero_point = nullptr` trailing arg on `MakeArgument`.
Stored on the Argument and propagated through the Run path. Existing
symmetric callers see no signature change.
Backward compatibility verified: with p_b_zero_point=nullptr / BZeroPointStruct=Empty
the gridwise + blockwise dispatch produces bit-identical results to the
symmetric pre-change path. Tested via the standalone
example_gemm_wmma_fp16_pk_i4_v3_b_scale runner (CPU-reference verify
mode) on M=2048 N=19456 K=2560 (gate_up_proj prefill shape) and on
M=2048 across qkv (N=6144), o_proj (N=2560 K=4096), and down_proj
(N=2560 K=9728) on Strix Halo (gfx1151) — all pass.
Use case: vLLM W4A16 prefill GEMM dispatch on Strix Halo (gfx1151)
for AWQ-quantized models (e.g. Qwen/Qwen3-4B-AWQ).
JIRA: AIESW-32176.