Files
composable_kernel/dispatcher/heuristics/generate_wide_coverage.py
Yaswanth Raparti c1127a36f5 [rocm-libraries] ROCm/rocm-libraries#5676 (commit 1d18339)
[CK][CK TILE]Autotuning heuristics infra for universal GEMM
 kernel selection (#5676)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

This PR adds ML-based kernel selection heuristics to the CK Tile
dispatcher, enabling fast and accurate automatic kernel selection for
Universal Gemm kernels. Instead of requiring exhaustive search through
4600+ kernel configurations (taking ~46 seconds per problem shape), the
ML heuristic predicts optimal kernels in microseconds while achieving
>98% of oracle-best performance.

## Technical Details

**ML infrastructure**

https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics
* Feature Engine
([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)):
55-feature extraction including problem dimensions, kernel
configuration, tile efficiency, and hardware profile
* Training Pipeline
([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)):
LightGBM regression with log-transform, GroupKFold cross-validation,
warm-start support
* Predictor
([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)):
Kernel ranking and TFLOPS prediction for problem shapes
* Evaluation
([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)):
Comprehensive metrics including efficiency, NDCG@k, shape family
analysis

**Data Generation Tools:**

*
[generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py):
Build and benchmark kernels across diverse problem shapes
*
[convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py):
Convert benchmark JSON to training-ready parquet format
*
[data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py):
Parse streaming benchmark logs into canonical datasets

**Examples**
*
[09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp):
C++ example demonstrating ML-based kernel selection
*
[09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py):
Python example with validation

**Pre-trained Models
(projects/composablekernel/dispatcher/heuristics/models/):**
* gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean
efficiency)
* gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean
efficiency)

## Test Plan

* Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8
* All shape families tested: tiny M (M<8), small M, medium M, large M
(M≥1024)
* All pipeline types: compv3, compv4, mem

## Test Result

**fp16 Model (gfx950, RCR layout)**
* Mean Efficiency: 99.36%
* P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of
oracle best)
* Min Efficiency: 95.45%

**fp8 Model (gfx950, RCR layout)**
* Mean Efficiency: 98.28% (original), 97.51% (wide coverage)
* P10 Efficiency: 94.64% (original), 93.89% (wide coverage)
* Min Efficiency: 84.5%

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-02 02:26:32 +00:00

290 lines
9.1 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Wide-coverage benchmark data generator.
Generates benchmark results for hundreds of diverse GEMM shapes across all
corner cases: skinny M, tall N, deep K, M=1, prime dimensions, power-of-2,
LLM inference shapes, training shapes, and edge cases.
Runs all 4608 kernels in /workspace/ck_tile/bin/ against each shape and
writes streaming log output parseable by data_pipeline.py.
Usage:
python3 generate_wide_coverage.py --bin_dir /workspace/ck_tile/bin \
--out_dir data/ --batch_size 20 --warmup 3 --repeat 10
"""
import argparse
import json
import subprocess
import sys
from pathlib import Path
def generate_shape_list():
"""Generate a comprehensive list of (M, N, K) shapes covering all corner cases.
Categories:
1. M=1 (single token inference) -- the hardest case
2. Tiny M (2-16) -- small batch inference
3. Small M (32-128) -- medium batch
4. Medium M (256-2048) -- large batch / training
5. Large M (4096-20480) -- very large batch
6. Square shapes (powers of 2)
7. Skinny M, tall N (M << N)
8. Tall M, skinny N (M >> N)
9. Deep K (K >> M, N) -- compute-bound
10. Shallow K (K << M, N) -- memory-bound
11. Prime dimensions -- worst-case for tiling
12. LLM-specific shapes (DeepSeek, LLaMA, etc.)
13. Non-power-of-2 common sizes
"""
shapes = set()
# --- 1. M=1 (single token) across various N, K ---
for n in [512, 1024, 1536, 2048, 3072, 4096, 4608, 7168, 8192, 11008, 14336, 28672]:
for k in [256, 512, 1024, 1536, 2048, 2304, 4096, 7168, 8192]:
shapes.add((1, n, k))
# --- 2. Tiny M (2-16) ---
for m in [2, 4, 8, 16]:
for n in [512, 1536, 4096, 7168]:
for k in [256, 1024, 4096, 7168]:
shapes.add((m, n, k))
# --- 3. Small M (32-128) ---
for m in [32, 48, 64, 96, 128]:
for n in [512, 1536, 4096, 7168, 8192]:
for k in [256, 512, 2048, 4096, 7168]:
shapes.add((m, n, k))
# --- 4. Medium M (256-2048) ---
for m in [256, 384, 512, 768, 1024, 1536, 2048]:
for n in [512, 1536, 4096, 7168]:
for k in [256, 1024, 2048, 4096, 7168]:
shapes.add((m, n, k))
# --- 5. Large M (4096-20480) ---
for m in [4096, 8192, 12288, 16384, 20480]:
for n in [512, 1536, 4096, 7168]:
for k in [256, 1024, 2048, 7168]:
shapes.add((m, n, k))
# --- 6. Square shapes (powers of 2) ---
for p in range(5, 14): # 32 to 8192
d = 2**p
shapes.add((d, d, d))
# --- 7. Skinny M, tall N ---
for m in [1, 4, 16, 64]:
for n in [8192, 16384, 28672]:
for k in [1024, 4096, 8192]:
shapes.add((m, n, k))
# --- 8. Tall M, skinny N ---
for m in [4096, 8192, 16384]:
for n in [32, 64, 128, 256]:
for k in [1024, 4096]:
shapes.add((m, n, k))
# --- 9. Deep K (K >> M, N) ---
for m in [16, 64, 256]:
for n in [16, 64, 256]:
for k in [4096, 8192, 16384, 32768]:
shapes.add((m, n, k))
# --- 10. Shallow K (K << M, N) ---
for m in [1024, 4096, 8192]:
for n in [1024, 4096, 8192]:
for k in [16, 32, 64, 128]:
shapes.add((m, n, k))
# --- 11. Prime dimensions ---
primes = [17, 31, 37, 127, 251, 509, 1021, 2039, 4093]
for p in primes:
shapes.add((p, p, p))
for p in primes[:5]:
shapes.add((p, 4096, 4096))
shapes.add((4096, p, 4096))
shapes.add((4096, 4096, p))
# --- 12. LLM-specific shapes ---
llm_shapes = [
# DeepSeek MoE
(1, 1536, 7168),
(1, 4608, 7168),
(1, 7168, 2048),
(1, 7168, 2304),
(1, 7168, 256),
(1, 576, 7168),
(1, 512, 7168),
(1, 3072, 1536),
# LLaMA-7B
(1, 4096, 4096),
(32, 4096, 4096),
(128, 4096, 4096),
(1, 4096, 11008),
(32, 4096, 11008),
(1, 11008, 4096),
(32, 11008, 4096),
# LLaMA-70B
(1, 8192, 8192),
(32, 8192, 8192),
(128, 8192, 8192),
(1, 8192, 28672),
(32, 8192, 28672),
(1, 28672, 8192),
# GPT-style attention
(128, 128, 64),
(128, 128, 128),
(256, 256, 64),
(512, 512, 64),
(1024, 1024, 64),
(2048, 2048, 64),
]
for s in llm_shapes:
shapes.add(s)
# --- 13. Non-power-of-2 common sizes ---
for m in [48, 96, 192, 384, 576, 768, 1152, 1536, 2304, 3072, 4608, 6144]:
shapes.add((m, m, m))
shapes.add((m, 4096, 4096))
sorted_shapes = sorted(shapes)
return sorted_shapes
def run_shape_batch(bin_dir, shapes, out_file, warmup=3, repeat=10):
"""Run all kernels against a batch of shapes, writing streaming log output."""
executables = sorted(Path(bin_dir).glob("benchmark_gemm_universal_fp8_rcr_*"))
if not executables:
print(f"ERROR: No executables found in {bin_dir}", file=sys.stderr)
return 0
total_benchmarks = 0
for shape_idx, (m, n, k) in enumerate(shapes):
out_file.write("\n========================================\n")
out_file.write(
f"Shape {shape_idx + 1}: M={m} N={n} K={k} dtype=fp8 layout=rcr\n"
)
out_file.write("========================================\n")
out_file.write(f"Found {len(executables)} kernels\n")
out_file.flush()
for exe in executables:
try:
result = subprocess.run(
[
str(exe),
f"-m={m}",
f"-n={n}",
f"-k={k}",
f"-warmup={warmup}",
f"-repeat={repeat}",
"-verify=0",
],
capture_output=True,
text=True,
timeout=60,
)
output = result.stdout
# Extract JSON block from output
json_start = output.find("{")
json_end = output.rfind("}") + 1
if json_start >= 0 and json_end > json_start:
json_block = output[json_start:json_end]
try:
json.loads(json_block)
out_file.write(json_block + "\n")
total_benchmarks += 1
except json.JSONDecodeError:
pass
except (subprocess.TimeoutExpired, Exception):
pass
out_file.flush()
elapsed_kernels = len(executables)
print(
f" Shape {shape_idx + 1}/{len(shapes)}: M={m} N={n} K={k} "
f"({elapsed_kernels} kernels)",
file=sys.stderr,
flush=True,
)
return total_benchmarks
def main():
parser = argparse.ArgumentParser(
description="Generate wide-coverage benchmark data"
)
parser.add_argument(
"--bin_dir",
default="/workspace/ck_tile/bin",
help="Directory with benchmark executables",
)
parser.add_argument("--out_dir", default="data", help="Output directory")
parser.add_argument(
"--batch_size", type=int, default=25, help="Shapes per output file"
)
parser.add_argument("--warmup", type=int, default=3)
parser.add_argument("--repeat", type=int, default=10)
parser.add_argument(
"--max_shapes", type=int, default=None, help="Limit total shapes (for testing)"
)
args = parser.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
shapes = generate_shape_list()
if args.max_shapes:
shapes = shapes[: args.max_shapes]
print(f"Generated {len(shapes)} unique shapes", file=sys.stderr, flush=True)
print(f"Bin dir: {args.bin_dir}", file=sys.stderr, flush=True)
print(f"Output dir: {args.out_dir}", file=sys.stderr, flush=True)
print(f"Batch size: {args.batch_size}", file=sys.stderr, flush=True)
total = 0
batch_idx = 0
for i in range(0, len(shapes), args.batch_size):
batch = shapes[i : i + args.batch_size]
batch_idx += 1
out_path = out_dir / f"wide_coverage_batch_{batch_idx:03d}.log"
print(
f"\nBatch {batch_idx}: shapes {i + 1}-{i + len(batch)} -> {out_path}",
file=sys.stderr,
flush=True,
)
with open(out_path, "w") as f:
f.write(f"CK Tile Wide Coverage Benchmark Batch {batch_idx}\n")
f.write("GPU ID: 0\n")
f.write("Implementation: gemm_universal\n\n")
count = run_shape_batch(
args.bin_dir, batch, f, warmup=args.warmup, repeat=args.repeat
)
total += count
print(
f" Batch {batch_idx} complete: {count} benchmarks",
file=sys.stderr,
flush=True,
)
print(
f"\nTotal: {total} benchmarks across {len(shapes)} shapes",
file=sys.stderr,
flush=True,
)
if __name__ == "__main__":
main()