mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
290 lines
9.1 KiB
Python
290 lines
9.1 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Wide-coverage benchmark data generator.
|
|
|
|
Generates benchmark results for hundreds of diverse GEMM shapes across all
|
|
corner cases: skinny M, tall N, deep K, M=1, prime dimensions, power-of-2,
|
|
LLM inference shapes, training shapes, and edge cases.
|
|
|
|
Runs all 4608 kernels in /workspace/ck_tile/bin/ against each shape and
|
|
writes streaming log output parseable by data_pipeline.py.
|
|
|
|
Usage:
|
|
python3 generate_wide_coverage.py --bin_dir /workspace/ck_tile/bin \
|
|
--out_dir data/ --batch_size 20 --warmup 3 --repeat 10
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
|
|
def generate_shape_list():
|
|
"""Generate a comprehensive list of (M, N, K) shapes covering all corner cases.
|
|
|
|
Categories:
|
|
1. M=1 (single token inference) -- the hardest case
|
|
2. Tiny M (2-16) -- small batch inference
|
|
3. Small M (32-128) -- medium batch
|
|
4. Medium M (256-2048) -- large batch / training
|
|
5. Large M (4096-20480) -- very large batch
|
|
6. Square shapes (powers of 2)
|
|
7. Skinny M, tall N (M << N)
|
|
8. Tall M, skinny N (M >> N)
|
|
9. Deep K (K >> M, N) -- compute-bound
|
|
10. Shallow K (K << M, N) -- memory-bound
|
|
11. Prime dimensions -- worst-case for tiling
|
|
12. LLM-specific shapes (DeepSeek, LLaMA, etc.)
|
|
13. Non-power-of-2 common sizes
|
|
"""
|
|
shapes = set()
|
|
|
|
# --- 1. M=1 (single token) across various N, K ---
|
|
for n in [512, 1024, 1536, 2048, 3072, 4096, 4608, 7168, 8192, 11008, 14336, 28672]:
|
|
for k in [256, 512, 1024, 1536, 2048, 2304, 4096, 7168, 8192]:
|
|
shapes.add((1, n, k))
|
|
|
|
# --- 2. Tiny M (2-16) ---
|
|
for m in [2, 4, 8, 16]:
|
|
for n in [512, 1536, 4096, 7168]:
|
|
for k in [256, 1024, 4096, 7168]:
|
|
shapes.add((m, n, k))
|
|
|
|
# --- 3. Small M (32-128) ---
|
|
for m in [32, 48, 64, 96, 128]:
|
|
for n in [512, 1536, 4096, 7168, 8192]:
|
|
for k in [256, 512, 2048, 4096, 7168]:
|
|
shapes.add((m, n, k))
|
|
|
|
# --- 4. Medium M (256-2048) ---
|
|
for m in [256, 384, 512, 768, 1024, 1536, 2048]:
|
|
for n in [512, 1536, 4096, 7168]:
|
|
for k in [256, 1024, 2048, 4096, 7168]:
|
|
shapes.add((m, n, k))
|
|
|
|
# --- 5. Large M (4096-20480) ---
|
|
for m in [4096, 8192, 12288, 16384, 20480]:
|
|
for n in [512, 1536, 4096, 7168]:
|
|
for k in [256, 1024, 2048, 7168]:
|
|
shapes.add((m, n, k))
|
|
|
|
# --- 6. Square shapes (powers of 2) ---
|
|
for p in range(5, 14): # 32 to 8192
|
|
d = 2**p
|
|
shapes.add((d, d, d))
|
|
|
|
# --- 7. Skinny M, tall N ---
|
|
for m in [1, 4, 16, 64]:
|
|
for n in [8192, 16384, 28672]:
|
|
for k in [1024, 4096, 8192]:
|
|
shapes.add((m, n, k))
|
|
|
|
# --- 8. Tall M, skinny N ---
|
|
for m in [4096, 8192, 16384]:
|
|
for n in [32, 64, 128, 256]:
|
|
for k in [1024, 4096]:
|
|
shapes.add((m, n, k))
|
|
|
|
# --- 9. Deep K (K >> M, N) ---
|
|
for m in [16, 64, 256]:
|
|
for n in [16, 64, 256]:
|
|
for k in [4096, 8192, 16384, 32768]:
|
|
shapes.add((m, n, k))
|
|
|
|
# --- 10. Shallow K (K << M, N) ---
|
|
for m in [1024, 4096, 8192]:
|
|
for n in [1024, 4096, 8192]:
|
|
for k in [16, 32, 64, 128]:
|
|
shapes.add((m, n, k))
|
|
|
|
# --- 11. Prime dimensions ---
|
|
primes = [17, 31, 37, 127, 251, 509, 1021, 2039, 4093]
|
|
for p in primes:
|
|
shapes.add((p, p, p))
|
|
for p in primes[:5]:
|
|
shapes.add((p, 4096, 4096))
|
|
shapes.add((4096, p, 4096))
|
|
shapes.add((4096, 4096, p))
|
|
|
|
# --- 12. LLM-specific shapes ---
|
|
llm_shapes = [
|
|
# DeepSeek MoE
|
|
(1, 1536, 7168),
|
|
(1, 4608, 7168),
|
|
(1, 7168, 2048),
|
|
(1, 7168, 2304),
|
|
(1, 7168, 256),
|
|
(1, 576, 7168),
|
|
(1, 512, 7168),
|
|
(1, 3072, 1536),
|
|
# LLaMA-7B
|
|
(1, 4096, 4096),
|
|
(32, 4096, 4096),
|
|
(128, 4096, 4096),
|
|
(1, 4096, 11008),
|
|
(32, 4096, 11008),
|
|
(1, 11008, 4096),
|
|
(32, 11008, 4096),
|
|
# LLaMA-70B
|
|
(1, 8192, 8192),
|
|
(32, 8192, 8192),
|
|
(128, 8192, 8192),
|
|
(1, 8192, 28672),
|
|
(32, 8192, 28672),
|
|
(1, 28672, 8192),
|
|
# GPT-style attention
|
|
(128, 128, 64),
|
|
(128, 128, 128),
|
|
(256, 256, 64),
|
|
(512, 512, 64),
|
|
(1024, 1024, 64),
|
|
(2048, 2048, 64),
|
|
]
|
|
for s in llm_shapes:
|
|
shapes.add(s)
|
|
|
|
# --- 13. Non-power-of-2 common sizes ---
|
|
for m in [48, 96, 192, 384, 576, 768, 1152, 1536, 2304, 3072, 4608, 6144]:
|
|
shapes.add((m, m, m))
|
|
shapes.add((m, 4096, 4096))
|
|
|
|
sorted_shapes = sorted(shapes)
|
|
return sorted_shapes
|
|
|
|
|
|
def run_shape_batch(bin_dir, shapes, out_file, warmup=3, repeat=10):
|
|
"""Run all kernels against a batch of shapes, writing streaming log output."""
|
|
executables = sorted(Path(bin_dir).glob("benchmark_gemm_universal_fp8_rcr_*"))
|
|
if not executables:
|
|
print(f"ERROR: No executables found in {bin_dir}", file=sys.stderr)
|
|
return 0
|
|
|
|
total_benchmarks = 0
|
|
|
|
for shape_idx, (m, n, k) in enumerate(shapes):
|
|
out_file.write("\n========================================\n")
|
|
out_file.write(
|
|
f"Shape {shape_idx + 1}: M={m} N={n} K={k} dtype=fp8 layout=rcr\n"
|
|
)
|
|
out_file.write("========================================\n")
|
|
out_file.write(f"Found {len(executables)} kernels\n")
|
|
out_file.flush()
|
|
|
|
for exe in executables:
|
|
try:
|
|
result = subprocess.run(
|
|
[
|
|
str(exe),
|
|
f"-m={m}",
|
|
f"-n={n}",
|
|
f"-k={k}",
|
|
f"-warmup={warmup}",
|
|
f"-repeat={repeat}",
|
|
"-verify=0",
|
|
],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=60,
|
|
)
|
|
output = result.stdout
|
|
# Extract JSON block from output
|
|
json_start = output.find("{")
|
|
json_end = output.rfind("}") + 1
|
|
if json_start >= 0 and json_end > json_start:
|
|
json_block = output[json_start:json_end]
|
|
try:
|
|
json.loads(json_block)
|
|
out_file.write(json_block + "\n")
|
|
total_benchmarks += 1
|
|
except json.JSONDecodeError:
|
|
pass
|
|
except (subprocess.TimeoutExpired, Exception):
|
|
pass
|
|
|
|
out_file.flush()
|
|
elapsed_kernels = len(executables)
|
|
print(
|
|
f" Shape {shape_idx + 1}/{len(shapes)}: M={m} N={n} K={k} "
|
|
f"({elapsed_kernels} kernels)",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
|
|
return total_benchmarks
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate wide-coverage benchmark data"
|
|
)
|
|
parser.add_argument(
|
|
"--bin_dir",
|
|
default="/workspace/ck_tile/bin",
|
|
help="Directory with benchmark executables",
|
|
)
|
|
parser.add_argument("--out_dir", default="data", help="Output directory")
|
|
parser.add_argument(
|
|
"--batch_size", type=int, default=25, help="Shapes per output file"
|
|
)
|
|
parser.add_argument("--warmup", type=int, default=3)
|
|
parser.add_argument("--repeat", type=int, default=10)
|
|
parser.add_argument(
|
|
"--max_shapes", type=int, default=None, help="Limit total shapes (for testing)"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
out_dir = Path(args.out_dir)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
shapes = generate_shape_list()
|
|
if args.max_shapes:
|
|
shapes = shapes[: args.max_shapes]
|
|
|
|
print(f"Generated {len(shapes)} unique shapes", file=sys.stderr, flush=True)
|
|
print(f"Bin dir: {args.bin_dir}", file=sys.stderr, flush=True)
|
|
print(f"Output dir: {args.out_dir}", file=sys.stderr, flush=True)
|
|
print(f"Batch size: {args.batch_size}", file=sys.stderr, flush=True)
|
|
|
|
total = 0
|
|
batch_idx = 0
|
|
for i in range(0, len(shapes), args.batch_size):
|
|
batch = shapes[i : i + args.batch_size]
|
|
batch_idx += 1
|
|
out_path = out_dir / f"wide_coverage_batch_{batch_idx:03d}.log"
|
|
|
|
print(
|
|
f"\nBatch {batch_idx}: shapes {i + 1}-{i + len(batch)} -> {out_path}",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
|
|
with open(out_path, "w") as f:
|
|
f.write(f"CK Tile Wide Coverage Benchmark Batch {batch_idx}\n")
|
|
f.write("GPU ID: 0\n")
|
|
f.write("Implementation: gemm_universal\n\n")
|
|
count = run_shape_batch(
|
|
args.bin_dir, batch, f, warmup=args.warmup, repeat=args.repeat
|
|
)
|
|
total += count
|
|
|
|
print(
|
|
f" Batch {batch_idx} complete: {count} benchmarks",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
|
|
print(
|
|
f"\nTotal: {total} benchmarks across {len(shapes)} shapes",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|