Files
composable_kernel/example/ck_tile/38_block_scale_gemm
Sami Remes ad4e2e7624 [rocm-libraries] ROCm/rocm-libraries#7199 (commit 23f7320)
[CK_TILE] [QuantGEMM] Fix SplitK tail handling and other
 improvements (#7199)

This pull request introduces improved and more robust split-K support
for quantized GEMM. The main changes add runtime validation, utility
functions for split-K batch calculations, pointer offset handling for
split-K in grouped kernels, and enhanced support for various tensor
layouts. The changes also improve error handling and provide more
flexibility for runtime tail handling in split-K pipelines.

**Split-K Support and Validation Enhancements:**

* Added runtime validation to ensure `k_batch` is a positive integer and
that split-K configurations do not produce empty final batches or
mismatched pipeline tails, with detailed error messages and logging for
misconfiguration.
[[1]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871R1184-R1211)
[[2]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871L1161-R1250)
* Introduced utility functions `get_splitk_batch_k_read` and
`get_splitk_last_batch_k` to compute per-batch K read sizes and handle
split rounding, ensuring correct and consistent split-K batch
partitioning.
[[1]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871R206-R234)
[[2]](diffhunk://#diff-635b89bdffa96b2b42f1632520cde36701d7d631e864185591f6b32f7645cf47L104-R107)
[[3]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871L388-R417)
[[4]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871L1161-R1250)
* Changed the default value of `k_batch` in `QuantGemmHostArgs` to 1 (no
split-K) for safer default behavior.

**Pointer Offsets and Grouped Kernel Handling:**

* Updated `QuantGroupedGemmKernel` to apply split-K per-batch offsets to
all input pointers, mirroring the behavior of non-grouped kernels and
ensuring correctness for split-K launches.
* Modified AQ tensor view handling to correctly reflect the remaining
K-groups from the split-K batch's offset position, improving accuracy
for split-K in grouped kernels.

**Pipeline and Layout Flexibility:**

* Added support for runtime selection of split-K tail handling via a new
template parameter `RuntimeSplitKTail_`, with new helper methods to
dispatch GEMM pipelines accordingly.
[[1]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871R273)
[[2]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871R1496-R1567)
[[3]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871L1427)
[[4]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871L1447-R1629)
[[5]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871L1459-R1641)
* Improved handling for tensor layout cases, including preshuffled B and
both row-major and column-major AQ layouts, ensuring correct pointer
arithmetic and compatibility checks.
[[1]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871R438-R454)
[[2]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871L464-R516)
[[3]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871R1184-R1211)
2026-06-05 11:41:49 +00:00
..

Quant GEMM Matrix Multiplication

This folder contains examples of quant GEMMs using the ck_tile tile-programming implementation.

  • AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline
  • BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline
  • Row and Column-wise scaled: All of the row-wise elements in A Matrix and column-wise elements in B Matrix will share the same quantization element and the element-wise operation will complete in epilogue.
  • Tensor-wise scaled: Share the same scalar scale across the whole tensor of A or B

Quantization Mode Comparison

Quant Mode A Matrix Organization A Scale Shape B Matrix Organization B Scale Shape
AQuant Blocks along K dimension
Each M×GroupSize block shares one scale
[M, K/GroupSize] Not quantized N/A
BQuant Not quantized N/A Blocks along K dimension
Each GroupSize×N block shares one scale
[K/GroupSize, N]
RowColQuant Per-row quantization
All K elements in each row share one scale
[M, 1] Per-column quantization
All K elements in each column share one scale
[1, N]
TensorQuant Tensor-wise quantization
All M×K elements share one scale
[1] Tensor-wise quantization
All K×N elements share one scale
[1]

Features

  • Preshuffled GEMM: Shuffle the GEMM of B (weight) matrix in the warp layout and bypass the shared memory to do the GEMM calculation. Best performance solution for GEMM.
  • TransposeC: Transpose the C Matrix Output layout to have the best coalesced scale reading
  • Preshuffled Quant: Preshuffle the input matrix to load multiple Quant warp blocks along the selected dimension.
  • Precision: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix), uint8 (split into two fp4 in the pipeline (for B Matrix)).
  • Validation: CPU/GPU validation and error tolerance options.

build

# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx942) or leave it blank
../script/cmake-ck-dev.sh  ../ <arch>
# Compile the quant kernels
make tile_example_gemm_quant -j

This will result in an executable build/bin/tile_example_gemm_quant

example

args:
               -h    Print help message (default:false)
               -m    m dimension (default:3840)
               -n    n dimension (default:4096)
               -k    k dimension (default:2048)
        -a_layout    A tensor data layout - R for Row or C for Column (default:R)
        -b_layout    B tensor data layout - R for Row or C for Column (default:C)
       -bq_layout    Bq tensor data layout - R for Row or C for Column (default:C)
        -c_layout    C tensor data layout - R for Row or C for Column (default:R)
        -stride_a    Tensor A stride (default:0)
        -stride_q    Tensor AQ stride (default:0)
        -stride_b    Tensor B stride (default:0)
        -stride_c    Tensor C stride (default:0)
               -v    0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
            -prec    Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8;  for Bquant: fp8, bf8, fp8i4, bf8i4, mxbf16bf16, mxbf16bf8 or mxbf16fp4 (default for both AQuant and Bquant: fp8)
          -warmup    Number of iterations before benchmarking the kernel (default:50)
          -repeat    Number of iterations to benchmark the kernel (default:1000)
           -timer    gpu:gpu timer, cpu:cpu timer (default:gpu)
         -split_k    SplitK value (default:1)
          -device    Device id that will be used to run the kernel (default:0)
            -init    0:random, 1:linear, 2:constant(1) (default:0)
     -flush_cache    Flush cache before running the kernel (default:true)
  -rotating_count    Rotating count (default:1000)
      -quant_mode    Choose aquant, bquant, tensor or rowcol (default:bquant)
     -preshuffleb    Enable preshuffle of tensor B (default:false)
 -preshufflequant   Enable preshuffle of quant tensor (default:false)
      -group_size    Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128)

User need to select correct mapping of config for each quant mode:

quant_mode as runtime argument Corresponding cpp file GemmConfig at the top of cpp file
For selecting AQuant aquant gemm_aquant_quantgrouped.cpp GemmConfigQuantDecode
For selecting AQuant with Preshuffle quant aquant gemm_aquant_quantgrouped_preshufflequant.cpp GemmConfigPreshuffleQuantDecode
For selecting BQuant bquant gemm_bquant_quantgrouped_<prec_type>.cpp GemmConfigQuantDecode (or) GemmConfigQuantPrefill
For selecting BQuant with Preshuffle quant bquant gemm_bquant_quantgrouped_preshufflequant.cpp GemmConfigPreshuffleQuantDecode (or) GemmConfigPreshuffleBQuantPrefill
For selecting PreShuffle B with BQuant bquant gemm_bquant_quantgrouped_preshuffleb.cpp GemmConfigPreshuffleB_BQuant_Decode (or) GemmConfigPreshuffleB_BQuant_Prefill
For selecting PreShuffle B with preshuffle BQuant bquant gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp GemmConfigPreshuffleB_PreshuffleBQuant_Decode (or) GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
For selecting RowCol quant rowcolquant gemm_quant_rowcol GemmConfigRowColQuant