[CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
13 KiB
Data Generation Guide
This document explains how to build benchmark binaries from the CK Tile engine, generate benchmark datasets, and manage them for the ML kernel performance prediction system.
Overview
The ML heuristic needs benchmark data: measured TFLOPS, latency, and bandwidth for every (problem shape, kernel config) pair. The tile engine builds one executable per kernel configuration. Each executable benchmarks a single kernel on a given problem size and outputs JSON with performance metrics.
CK source --> CMake configure --> ninja build --> benchmark binaries
(4608 per op/dtype/layout)
benchmark binaries --> run on GPU --> streaming log --> parquet dataset
(per shape) (JSON blocks) (canonical schema)
Prerequisites
- ROCm: HIP >= 6.0.3 (for gfx950: HIP >= 6.0.4)
- Build tools: CMake >= 3.21, Ninja, HIP-aware clang compiler
- Python: 3.10+ with
pandas,pyarrow - GPU: ROCm-capable AMD GPU (MI250X, MI300X, MI355X, etc.)
Part 1: Building Benchmark Binaries from the Tile Engine
If you already have pre-built binaries (e.g., in /workspace/ck_tile/bin/),
skip to Part 2. This section explains how to build them from source.
Step 1: CMake Configure
From the CK repository root:
cmake -S /workspace/rocm-libraries/projects/composablekernel \
-B build \
-DCMAKE_BUILD_TYPE=Release \
-DGPU_TARGETS="gfx950" \
-DGEMM_UNIVERSAL_DATATYPE="fp8" \
-DGEMM_UNIVERSAL_LAYOUT="rcr" \
-G Ninja
Key CMake variables:
| Variable | Default | Description |
|---|---|---|
GPU_TARGETS |
(required) | Target GPU architectures. Supported: gfx90a, gfx942, gfx950, gfx1201. Semicolon-separated for multiple. |
GEMM_UNIVERSAL_DATATYPE |
"fp8;fp16" |
Data types to build. Options: fp8, fp16, bf16, bf8. Semicolon-separated. |
GEMM_UNIVERSAL_LAYOUT |
"rcr;rrr;crr;ccr" |
Layouts to build. Semicolon-separated. |
GEMM_UNIVERSAL_CONFIG_FILE |
"default_config.json" |
Kernel config file (in the configs/ directory). Controls which tile sizes, warp configs, pipelines, etc. are enumerated. |
ENABLE_CCACHE_GEMM_UNIVERSAL |
OFF |
Enable ccache for faster rebuilds. |
Example: build only fp8 RCR for gfx950 (fastest, ~4608 kernels):
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release \
-DGPU_TARGETS="gfx950" \
-DGEMM_UNIVERSAL_DATATYPE="fp8" \
-DGEMM_UNIVERSAL_LAYOUT="rcr" \
-G Ninja
Example: build all dtypes and layouts (slow, ~4608 * 4 * 4 = ~73K kernels):
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release \
-DGPU_TARGETS="gfx950" \
-DGEMM_UNIVERSAL_DATATYPE="fp8;fp16;bf16;bf8" \
-DGEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \
-G Ninja
What happens during configure
- CMake calls
gemm_universal_instance_builder.py --list_kernelsto enumerate all valid kernel configurations from the config JSON. - It writes
gemm_universal_kernel_list.txt(one kernel per line) andgemm_universal_kernel_count.txtto the build directory. - For each kernel, it creates a ninja build target.
Step 2: Build
# Build all benchmarks for the configured dtypes/layouts
ninja -C build benchmark_gemm_universal_all
# Or build a specific dtype/layout combo
ninja -C build benchmark_gemm_universal_fp8_rcr
# Or build by pipeline type
ninja -C build benchmark_gemm_universal_compv4_pipeline
ninja -C build benchmark_gemm_universal_mem_pipeline
# Or build a single specific kernel
ninja -C build benchmark_gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128
Build time estimates:
- ~4608 kernels (one dtype, one layout): 1-4 hours depending on CPU cores
- Use
-j <N>to control parallelism:ninja -C build -j 32 benchmark_gemm_universal_fp8_rcr
Step 3: Verify binaries
Binaries are placed in build/bin/:
ls build/bin/benchmark_gemm_universal_fp8_rcr_* | wc -l
# Expected: 4608 (for default config)
# Test one binary
./build/bin/benchmark_gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128 \
-m=1024 -n=1024 -k=1024 -warmup=3 -repeat=10 -verify=0
Kernel config files
The config files live in:
tile_engine/ops/gemm/gemm_universal/configs/
default_config.json # Default: full enumeration
default_ci_config.json # CI: reduced set for fast testing
user_provided_config.json # Custom: your own subset
To use a custom config:
cmake ... -DGEMM_UNIVERSAL_CONFIG_FILE="user_provided_config.json"
The config controls which tile sizes (e.g., 128x128x64, 256x256x32), warp configurations (e.g., 2x2x1, 1x4x1), pipelines (compv3, compv4, mem), schedulers, and other parameters are included in the kernel enumeration.
Building StreamK / other ops
The same pattern applies to other tile engine ops:
# StreamK
ninja -C build benchmark_gemm_streamk_fp8_rcr
# Grouped convolution
ninja -C build benchmark_grouped_conv_fwd_fp16_nhwgc
Each op has its own instance builder and config directory.
Part 2: Running Benchmarks and Generating Data
Quick Start
1. Run benchmarks for a set of shapes
Each binary accepts -m=, -n=, -k=, -warmup=, -repeat=, -verify= flags
and outputs JSON to stdout:
/workspace/ck_tile/bin/benchmark_gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128 \
-m=1024 -n=1024 -k=1024 -warmup=3 -repeat=10 -verify=0
Output:
{
"name": "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_...",
"problem": {
"split_k": 1, "m": 1024, "n": 1024, "k": 1024,
"dtype_a": "fp8", "dtype_b": "fp8", ...
},
"perf_result": {
"latency(ms)": 0.04,
"tflops(TFlops)": 204.60,
"bandwidth(GB/s)": 624.39
}
}
2. Batch generation using provided scripts
Wide coverage (diverse shapes across all regimes):
python3 generate_wide_coverage.py \
--bin_dir /workspace/ck_tile/bin \
--out_dir data/wide_coverage \
--batch_size 25 \
--warmup 3 --repeat 10
Edge-case dimensions (N=1, K=1, small N/K):
python3 generate_edge_dims.py
Both scripts write streaming log files that data_pipeline.py can parse.
3. Parse logs into parquet
python3 data_pipeline.py <log_file> \
-o data/my_dataset.parquet \
--arch gfx950 \
--capture_hw
The --capture_hw flag runs rocminfo once and injects the GPU hardware
profile (CU count, clock speed, cache sizes, etc.) into every row.
Canonical Data Schema
Every parquet file follows this schema:
| Column | Type | Description |
|---|---|---|
op_type |
str | gemm_universal, gemm_streamk, etc. |
dtype |
str | fp8, fp16, bf16, bf8 |
layout |
str | rcr, rrr, crr, ccr |
arch |
str | gfx942, gfx950, etc. |
kernel_name |
str | Full kernel identifier |
m, n, k |
int | Problem dimensions |
split_k |
int | Split-K factor (1 = standard) |
measured_tflops |
float | Ground-truth TFLOPS |
latency_ms |
float | Measured latency |
bandwidth_gb_s |
float | Measured bandwidth |
is_valid |
bool | True if tflops > 0 and latency > 0 |
tile_m, tile_n, tile_k |
int | Tile dimensions |
warp_m, warp_n, warp_k |
int | Warp config |
warp_tile_m/n/k |
int | Warp tile dimensions |
pipeline |
str | compv3, compv4, mem, etc. |
scheduler |
str | intrawave, interwave |
epilogue |
str | cshuffle, default |
pad_m, pad_n, pad_k |
bool | Padding flags |
persistent |
bool | Persistent kernel flag |
run_id |
str | Unique collection run identifier |
Shape Selection Guidelines
Good training data requires diverse shapes. Cover all of these regimes:
By M dimension (batch size / output rows)
- M=1: single-token inference (hardest case for tiling)
- Tiny M (2-16): small batch inference
- Small M (32-128): medium batch
- Medium M (256-2048): large batch / training
- Large M (4096-20480): very large batch
By N and K dimension
- N=1: vector-matrix multiply (degenerate)
- K=1: rank-1 update / outer product (degenerate)
- Small N or K (2-16): stress tile efficiency
- Deep K (K > 4096): compute-bound regime
- Shallow K (K < 256): memory-bound regime
By shape family
- Square: M ~ N ~ K (powers of 2)
- Tall: M >> N (tall output matrix)
- Wide: N >> M (wide output matrix)
- Deep-K: K >> M and K >> N
Special cases
- Prime dimensions: 17, 31, 127, 251, 509, 1021, 2039, 4093 (worst-case for tile alignment, tests padding logic)
- Non-power-of-2: 48, 96, 192, 384, 576, 768, 1536, 3072, 4608 (common in LLM architectures)
- LLM inference shapes: DeepSeek, LLaMA-7B, LLaMA-70B MLP/attention dims
Minimum recommended coverage
For a production-quality model, aim for:
- At least 200 unique (M, N, K) shapes
- At least 10 shapes per shape family
- All kernel configs (4608 for fp8 RCR) run against every shape
- Multiple layouts if training a cross-layout model
Benchmark Quality Guidelines
Warmup and repeat
- Minimum
warmup=3,repeat=10for fast iteration - Production quality:
warmup=5,repeat=20for stable measurements - The
perf_resultvalues are averaged overrepeatiterations
Noise handling
- Use median latency when aggregating multiple runs of the same benchmark
- Flag measurements where coefficient of variation exceeds 10%
- Avoid benchmarking under thermal throttling (check GPU temperature)
- Lock GPU clocks if possible for reproducibility
Environment metadata
Store with every dataset:
- GPU model and architecture (from
rocminfo) - ROCm driver version
- Clock mode (default / locked)
- Git hash of the CK tile engine build (if available)
- Timestamp
Adding Data for a New Op
To generate benchmark data for a new operation (e.g., gemm_streamk):
-
Build the binaries using the tile engine:
ninja -C build benchmark_gemm_streamk_fp8_rcr -
Write a generation script (or modify
generate_wide_coverage.py):- Change the executable glob pattern to match the new op
- Add any op-specific CLI flags the binaries need
-
Run and parse:
python3 data_pipeline.py my_streamk_run.log \ -o data/gemm_streamk_fp8_gfx950.parquet --arch gfx950 -
Train:
python3 train.py --op gemm_streamk --dtype fp8 --arch gfx950 \ --data_dir data/ --out_dir models/gemm_streamk_fp8_gfx950
Adding Data for a New Layout
Same binaries, same shapes -- just change the layout filter:
# Build rrr binaries
ninja -C build benchmark_gemm_universal_fp8_rrr
# Generate and parse
# ... (same flow, different bin_dir or executable glob)
# Train a cross-layout model by putting all layouts in the same data_dir
python3 train.py --data_dir data/ --out_dir models/gemm_universal_fp8_gfx950_all_layouts
The feature engine includes layout as a categorical feature, so one model
can handle all layouts.
Incremental Data Collection
When you have a trained model and want to add more data:
- Generate new data (new shapes, new layouts, etc.)
- Parse into parquet alongside existing data
- Warm-start from the previous model:
python3 train.py --data_dir data/ --out_dir models/v2 \ --warm_start models/v1 \ --warm_start_n_estimators 200
This adds 200 new trees on top of the existing model. The feature schema must match exactly (enforced automatically).
File Organization
Recommended directory structure:
heuristics/
data/
gemm_universal_fp8_rcr_gfx950.parquet # original 108 shapes
wide_coverage/ # batch log files
wide_coverage_batch_001.log
wide_coverage_batch_002.log
...
edge_dims/ # N=1, K=1 edge cases
edge_dims_batch_001.log
...
models/
gemm_universal_fp8_gfx950/ # trained model artifacts
model_tflops.lgbm
model_latency.lgbm
model_bandwidth.lgbm
feature_spec.json
train_manifest.json
cv_metrics_tflops.json
eval_report.json
...
Troubleshooting
Benchmark binary exits with non-zero code
Some kernel configs are invalid for certain problem sizes (e.g., tile_m=256
with M=16). The data pipeline marks these as is_valid=False and they are
filtered out during training. This is expected.
Edge dims produce very few results
N=1 and K=1 shapes are degenerate -- most kernel configurations have minimum dimension requirements and will fail or produce zero TFLOPS. The small number of valid results is still useful (it tells the model which configs work for these shapes).
Benchmarks are slow
Each shape requires running all 4608 kernel executables sequentially. At ~0.01s per kernel, that is ~46 seconds per shape. For 700 shapes, expect ~9 hours. Tips:
- Run on a dedicated GPU (no other workloads)
- Use
--batch_size 25to get incremental output - Parse and train on partial data while generation continues
Data from different GPUs / driver versions
Store run_id and hardware metadata with each dataset. Training on mixed
data is allowed but not recommended for production models. Filter to a
single run_id or arch for clean experiments.