mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
[CK] Add split-K support for ABQuantGrouped in block_scale_gemm (#4816) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Changes ### Split-K support in `gemm_quant_kernel.hpp` - **`SplitKBatchOffset`**: Added `aq_group_offset` and `aq_k_split_offset` fields (mirroring the existing `bq_*` fields for B) to track each split-K batch's position within the AQ scale tensor. For `ABQuantGrouped`, both offsets are computed from `k_id * KRead` divided by `AQuantGroupSize::kK`. - **`MakeAQBlockWindow`**: Added an `aq_group_offset` parameter (defaulting to 0 for non-split-K paths) so the AQ tensor view's K-group dimension reflects only the remaining K-groups from the split-K offset, consistent with how `MakeBQBlockWindow` handles the BQ tensor. - **`RunGemm`**: Threads the `aq_k_split_offset` through to `MakeAQBlockWindow` when in split-K mode. ### Constraints in `IsSupportedArgument()` Four constraints gate split-K (`k_batch > 1`) for ABQuantGrouped: 1. **Mode check** — split-K is only allowed for `BQuantGrouped` (no preshuffle) or `ABQuantGrouped` (no `APreshuffleQuant`). Any other quant mode with `k_batch > 1` returns `false`. 2. **B quant group alignment** — `KRead` (per-batch K slice) must be divisible by `BQuantGroupSize::kK`. Each batch must operate on complete B quantization groups; a partial group would require splitting a scale value across batches. 3. **A quant group alignment** (new, ABQuantGrouped only) — `KRead` must also be divisible by `AQuantGroupSize::kK` for the same reason applied to the AQ scale tensor. 4. **Minimum 2 K-tile iterations per batch** (new) — The software-pipelined GEMM kernels (CompV3 family) prefetch one tile ahead, so they require `per_batch_num_loop = KRead / KPerBlock >= 2`. When `KRead == KPerBlock` (i.e. each batch is exactly one tile), the prefetch reads into the next batch's memory region and produces incorrect results. Configurations where `K == k_batch * KPerBlock` are therefore rejected. ### Example update (`run_gemm_quant_example.inc`) Updated the comment above the `IsSupportedArgument` call to document that split-K is now supported for both `BQuantGrouped` (no preshuffle) and `ABQuantGrouped` (no `APreshuffleQuant`). ## Unit Tests Two new test files covering decode and prefill tile shapes across a range of `k_batch` values (2–8), data types (FP8, BF8), and quantization group sizes (1×1×128 and 1×128×128 for B): - `test_gemm_quant_abquant_splitk_decode.cpp` — uses the decode tile shape (M=16, N=64, K_tile=256) - `test_gemm_quant_abquant_splitk_prefill.cpp` — uses the prefill tile shape (M=128, N=128, K_tile=128) Each test calls `run_test_with_validation` which runs the kernel and checks correctness against a CPU reference. Configurations excluded from tests are annotated with comments explaining which constraint they violate (typically the `per_batch_num_loop >= 2` requirement). ## Prerequisites This PR depends on #4429, which must be merged before this can be merged.
Quant GEMM Matrix Multiplication
This folder contains examples of quant GEMMs using the ck_tile tile-programming implementation.
- AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline
- BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline
- Row and Column-wise scaled: All of the row-wise elements in A Matrix and column-wise elements in B Matrix will share the same quantization element and the element-wise operation will complete in epilogue.
- Tensor-wise scaled: Share the same scalar scale across the whole tensor of A or B
Quantization Mode Comparison
| Quant Mode | A Matrix Organization | A Scale Shape | B Matrix Organization | B Scale Shape |
|---|---|---|---|---|
| AQuant | Blocks along K dimension Each M×GroupSize block shares one scale |
[M, K/GroupSize] |
Not quantized | N/A |
| BQuant | Not quantized | N/A | Blocks along K dimension Each GroupSize×N block shares one scale |
[K/GroupSize, N] |
| RowColQuant | Per-row quantization All K elements in each row share one scale |
[M, 1] |
Per-column quantization All K elements in each column share one scale |
[1, N] |
| TensorQuant | Tensor-wise quantization All M×K elements share one scale |
[1] |
Tensor-wise quantization All K×N elements share one scale |
[1] |
Features
- Preshuffled GEMM: Shuffle the GEMM of B (weight) matrix in the warp layout and bypass the shared memory to do the GEMM calculation. Best performance solution for GEMM.
- TransposeC: Transpose the C Matrix Output layout to have the best coalesced scale reading
- Preshuffled Quant: Preshuffle the input matrix to load multiple Quant warp blocks along the selected dimension.
- Precision: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix), uint8 (split into two fp4 in the pipeline (for B Matrix)).
- Validation: CPU/GPU validation and error tolerance options.
build
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
# Compile the quant kernels
make tile_example_gemm_quant -j
This will result in an executable build/bin/tile_example_gemm_quant
example
args:
-h Print help message (default:false)
-m m dimension (default:3840)
-n n dimension (default:4096)
-k k dimension (default:2048)
-a_layout A tensor data layout - R for Row or C for Column (default:R)
-b_layout B tensor data layout - R for Row or C for Column (default:C)
-bq_layout Bq tensor data layout - R for Row or C for Column (default:C)
-c_layout C tensor data layout - R for Row or C for Column (default:R)
-stride_a Tensor A stride (default:0)
-stride_q Tensor AQ stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, mxbf16bf16, mxbf16bf8 or mxbf16fp4 (default for both AQuant and Bquant: fp8)
-warmup Number of iterations before benchmarking the kernel (default:50)
-repeat Number of iterations to benchmark the kernel (default:1000)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-split_k SplitK value (default:1)
-device Device id that will be used to run the kernel (default:0)
-init 0:random, 1:linear, 2:constant(1) (default:0)
-flush_cache Flush cache before running the kernel (default:true)
-rotating_count Rotating count (default:1000)
-quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant)
-preshuffleb Enable preshuffle of tensor B (default:false)
-preshufflequant Enable preshuffle of quant tensor (default:false)
-group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128)
User need to select correct mapping of config for each quant mode:
| quant_mode as runtime argument | Corresponding cpp file | GemmConfig at the top of cpp file | |
|---|---|---|---|
| For selecting AQuant | aquant | gemm_aquant_quantgrouped.cpp | GemmConfigQuantDecode |
| For selecting AQuant with Preshuffle quant | aquant | gemm_aquant_quantgrouped_preshufflequant.cpp | GemmConfigPreshuffleQuantDecode |
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp | GemmConfigQuantDecode (or) GemmConfigQuantPrefill |
| For selecting BQuant with Preshuffle quant | bquant | gemm_bquant_quantgrouped_preshufflequant.cpp | GemmConfigPreshuffleQuantDecode (or) GemmConfigPreshuffleBQuantPrefill |
| For selecting PreShuffle B with BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb.cpp | GemmConfigPreshuffleB_BQuant_Decode (or) GemmConfigPreshuffleB_BQuant_Prefill |
| For selecting PreShuffle B with preshuffle BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp | GemmConfigPreshuffleB_PreshuffleBQuant_Decode (or) GemmConfigPreshuffleB_PreshuffleBQuant_Prefill |
| For selecting RowCol quant | rowcolquant | gemm_quant_rowcol | GemmConfigRowColQuant |