Files
composable_kernel/example/ck_tile/38_block_scale_gemm
Cong Ma 6fd8ddabe7 [CK TILE GEMM] Refactor block_scale_gemm examples (#3181)
* [CK TILE GEMM] Refactor block_scale_gemm examples

- Split cpp file to reduce building time
- Support multiple GemmConfig

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Add support for rowcol and tensor GEMM operations

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update README

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Set quant group size to (1, 1, 64) for targets excluding gfx950, where warp tile size (16, 16, 128) is incompatible.
2025-11-12 23:43:40 -08:00
..

Quant GEMM Matrix Multiplication

This folder contains examples of quant GEMMs using the ck_tile tile-programming implementation.

  • AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline
  • BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline
  • Row and Column-wise scaled: All of the row-wise elements in A Matrix and column-wise elements in B Matrix will share the same quantization element and the element-wise operation will complete in epilogue.
  • Tensor-wise scaled: Share the same scalar scale across the whole tensor of A or B

Quantization Mode Comparison

Quant Mode A Matrix Organization A Scale Shape B Matrix Organization B Scale Shape
AQuant Blocks along K dimension
Each M×GroupSize block shares one scale
[M, K/GroupSize] Not quantized N/A
BQuant Not quantized N/A Blocks along K dimension
Each GroupSize×N block shares one scale
[K/GroupSize, N]
RowColQuant Per-row quantization
All K elements in each row share one scale
[M, 1] Per-column quantization
All K elements in each column share one scale
[1, N]
TensorQuant Tensor-wise quantization
All M×K elements share one scale
[1] Tensor-wise quantization
All K×N elements share one scale
[1]

Features

  • Preshuffled GEMM: Shuffle the GEMM of B (weight) matrix in the warp layout and bypass the shared memory to do the GEMM calculation. Best performance solution for GEMM.
  • TransposeC: Transpose the C Matrix Output layout to have the best coalesced scale reading
  • Preshuffled Quant: Preshuffle the input matrix to load multiple Quant warp blocks along the selected dimension.
  • Precision: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix).
  • Validation: CPU/GPU validation and error tolerance options.

build

# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx942) or leave it blank
../script/cmake-ck-dev.sh  ../ <arch>
# Compile the quant kernels
make tile_example_gemm_quant_basic -j

This will result in an executable build/bin/tile_example_gemm_quant_basic

example

args:
             -h    Print help message (default:false)
             -m    m dimension (default:3840)
             -n    n dimension (default:4096)
             -k    k dimension (default:2048)
      -a_layout    A tensor data layout - Row or Column (default:R)
      -b_layout    B tensor data layout - Row or Column (default:C)
     -bq_layout    Bq tensor data layout - Row or Column (default:C)
      -c_layout    C tensor data layout - Row or Column (default:R)
      -stride_a    Tensor A stride (default:0)
      -stride_q    Tensor AQ stride (default:0)
      -stride_b    Tensor B stride (default:0)
      -stride_c    Tensor C stride (default:0)
             -v    0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
          -prec    Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8;  for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8)
        -warmup    Number of iterations before benchmarking the kernel (default:50)
        -repeat    Number of iterations to benchmark the kernel (default:1000)
         -timer    gpu:gpu timer, cpu:cpu timer (default:gpu)
       -split_k    SplitK value (default:1)
        -device    Device id that will be used to run the kernel (default:0)
          -init    0:random, 1:linear, 2:constant(1) (default:0)
   -flush_cache    Flush cache before running the kernel (default:true)
-rotating_count    Rotating count (default:1000)
    -quant_mode    Choose aquant, bquant, tensor or rowcol (default:bquant)
   -preshuffleb    Enable preshuffle of tensor B (default:false)
    -group_size    Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128)

User need to select correct mapping of config for each quant mode:

quant_mode as runtime argument Config in cpp file
For selecting AQuant aquant GemmConfigQuant
For selecting Aquant with Preshuffle aquant GemmConfigPreshuffleQuant
For selecting BQuant bquant GemmConfigQuant
For selecting PreShuffle Weight matrix with Bquant bquant GemmConfigPreshuffleB_Bquant_decode (or) GemmConfigPreshuffleB_Bquant_prefill
For selecting RowCol quant rowcolquant GemmConfigRowColQuant