mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
[CK] Add split-K support for ABQuantGrouped in block_scale_gemm (#4816) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Changes ### Split-K support in `gemm_quant_kernel.hpp` - **`SplitKBatchOffset`**: Added `aq_group_offset` and `aq_k_split_offset` fields (mirroring the existing `bq_*` fields for B) to track each split-K batch's position within the AQ scale tensor. For `ABQuantGrouped`, both offsets are computed from `k_id * KRead` divided by `AQuantGroupSize::kK`. - **`MakeAQBlockWindow`**: Added an `aq_group_offset` parameter (defaulting to 0 for non-split-K paths) so the AQ tensor view's K-group dimension reflects only the remaining K-groups from the split-K offset, consistent with how `MakeBQBlockWindow` handles the BQ tensor. - **`RunGemm`**: Threads the `aq_k_split_offset` through to `MakeAQBlockWindow` when in split-K mode. ### Constraints in `IsSupportedArgument()` Four constraints gate split-K (`k_batch > 1`) for ABQuantGrouped: 1. **Mode check** — split-K is only allowed for `BQuantGrouped` (no preshuffle) or `ABQuantGrouped` (no `APreshuffleQuant`). Any other quant mode with `k_batch > 1` returns `false`. 2. **B quant group alignment** — `KRead` (per-batch K slice) must be divisible by `BQuantGroupSize::kK`. Each batch must operate on complete B quantization groups; a partial group would require splitting a scale value across batches. 3. **A quant group alignment** (new, ABQuantGrouped only) — `KRead` must also be divisible by `AQuantGroupSize::kK` for the same reason applied to the AQ scale tensor. 4. **Minimum 2 K-tile iterations per batch** (new) — The software-pipelined GEMM kernels (CompV3 family) prefetch one tile ahead, so they require `per_batch_num_loop = KRead / KPerBlock >= 2`. When `KRead == KPerBlock` (i.e. each batch is exactly one tile), the prefetch reads into the next batch's memory region and produces incorrect results. Configurations where `K == k_batch * KPerBlock` are therefore rejected. ### Example update (`run_gemm_quant_example.inc`) Updated the comment above the `IsSupportedArgument` call to document that split-K is now supported for both `BQuantGrouped` (no preshuffle) and `ABQuantGrouped` (no `APreshuffleQuant`). ## Unit Tests Two new test files covering decode and prefill tile shapes across a range of `k_batch` values (2–8), data types (FP8, BF8), and quantization group sizes (1×1×128 and 1×128×128 for B): - `test_gemm_quant_abquant_splitk_decode.cpp` — uses the decode tile shape (M=16, N=64, K_tile=256) - `test_gemm_quant_abquant_splitk_prefill.cpp` — uses the prefill tile shape (M=128, N=128, K_tile=128) Each test calls `run_test_with_validation` which runs the kernel and checks correctness against a CPU reference. Configurations excluded from tests are annotated with comments explaining which constraint they violate (typically the `per_batch_num_loop >= 2` requirement). ## Prerequisites This PR depends on #4429, which must be merged before this can be merged.