mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Vidyasagar Ananthan <vidyasagar.ananthan@amd.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
167 lines
5.4 KiB
Python
167 lines
5.4 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Supplementary edge-case benchmark generator for N=1 and K=1 dimensions.
|
|
|
|
These shapes represent vector-matrix multiply (N=1), rank-1 updates (K=1),
|
|
and other degenerate GEMM cases that stress tile efficiency and padding logic.
|
|
"""
|
|
|
|
import json
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
|
|
def generate_edge_shapes():
|
|
"""Generate shapes with N=1, K=1, and other single-dimension edge cases."""
|
|
shapes = set()
|
|
|
|
# --- N=1: vector-matrix multiply / single output column ---
|
|
for m in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
|
|
for k in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 7168, 8192]:
|
|
shapes.add((m, 1, k))
|
|
|
|
# --- K=1: rank-1 update / outer product ---
|
|
for m in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
|
|
for n in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 7168, 8192]:
|
|
shapes.add((m, n, 1))
|
|
|
|
# --- M=1, N=1: dot product ---
|
|
for k in [1, 16, 64, 256, 1024, 4096, 8192]:
|
|
shapes.add((1, 1, k))
|
|
|
|
# --- M=1, K=1: scalar-vector ---
|
|
for n in [1, 16, 64, 256, 1024, 4096, 8192]:
|
|
shapes.add((1, n, 1))
|
|
|
|
# --- N=1, K=1: scalar-vector ---
|
|
for m in [1, 16, 64, 256, 1024, 4096, 8192]:
|
|
shapes.add((m, 1, 1))
|
|
|
|
# --- All ones: 1x1x1 ---
|
|
shapes.add((1, 1, 1))
|
|
|
|
# --- Small N (2-16) ---
|
|
for m in [64, 256, 1024, 4096]:
|
|
for n in [2, 3, 4, 7, 8, 15, 16]:
|
|
for k in [64, 256, 1024, 4096]:
|
|
shapes.add((m, n, k))
|
|
|
|
# --- Small K (2-16) ---
|
|
for m in [64, 256, 1024, 4096]:
|
|
for n in [64, 256, 1024, 4096]:
|
|
for k in [2, 3, 4, 7, 8, 15, 16]:
|
|
shapes.add((m, n, k))
|
|
|
|
return sorted(shapes)
|
|
|
|
|
|
def run_shapes(bin_dir, shapes, out_file, warmup=3, repeat=10):
|
|
"""Run all kernels against shapes, writing streaming log."""
|
|
executables = sorted(Path(bin_dir).glob("benchmark_gemm_universal_fp8_rcr_*"))
|
|
if not executables:
|
|
print(f"ERROR: No executables found in {bin_dir}", file=sys.stderr)
|
|
return 0
|
|
|
|
total = 0
|
|
for idx, (m, n, k) in enumerate(shapes):
|
|
out_file.write("\n========================================\n")
|
|
out_file.write(f"Shape {idx + 1}: M={m} N={n} K={k} dtype=fp8 layout=rcr\n")
|
|
out_file.write("========================================\n")
|
|
out_file.write(f"Found {len(executables)} kernels\n")
|
|
out_file.flush()
|
|
|
|
for exe in executables:
|
|
try:
|
|
result = subprocess.run(
|
|
[
|
|
str(exe),
|
|
f"-m={m}",
|
|
f"-n={n}",
|
|
f"-k={k}",
|
|
f"-warmup={warmup}",
|
|
f"-repeat={repeat}",
|
|
"-verify=0",
|
|
],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=60,
|
|
)
|
|
output = result.stdout
|
|
json_start = output.find("{")
|
|
json_end = output.rfind("}") + 1
|
|
if json_start >= 0 and json_end > json_start:
|
|
json_block = output[json_start:json_end]
|
|
try:
|
|
json.loads(json_block)
|
|
out_file.write(json_block + "\n")
|
|
total += 1
|
|
except json.JSONDecodeError:
|
|
pass
|
|
except (subprocess.TimeoutExpired, Exception):
|
|
pass
|
|
|
|
out_file.flush()
|
|
print(
|
|
f" Shape {idx + 1}/{len(shapes)}: M={m} N={n} K={k}",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
|
|
return total
|
|
|
|
|
|
if __name__ == "__main__":
|
|
bin_dir = "/workspace/ck_tile/bin"
|
|
out_dir = Path("data/edge_dims")
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
shapes = generate_edge_shapes()
|
|
print(f"Generated {len(shapes)} edge-case shapes", file=sys.stderr, flush=True)
|
|
|
|
n1_count = sum(1 for m, n, k in shapes if n == 1)
|
|
k1_count = sum(1 for m, n, k in shapes if k == 1)
|
|
both1 = sum(1 for m, n, k in shapes if n == 1 and k == 1)
|
|
small_n = sum(1 for m, n, k in shapes if 2 <= n <= 16)
|
|
small_k = sum(1 for m, n, k in shapes if 2 <= k <= 16)
|
|
print(
|
|
f" N=1: {n1_count}, K=1: {k1_count}, both=1: {both1}",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
print(
|
|
f" Small N(2-16): {small_n}, Small K(2-16): {small_k}",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
|
|
batch_size = 25
|
|
total = 0
|
|
batch_idx = 0
|
|
for i in range(0, len(shapes), batch_size):
|
|
batch = shapes[i : i + batch_size]
|
|
batch_idx += 1
|
|
out_path = out_dir / f"edge_dims_batch_{batch_idx:03d}.log"
|
|
print(
|
|
f"\nBatch {batch_idx}: shapes {i + 1}-{i + len(batch)} -> {out_path}",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
|
|
with open(out_path, "w") as f:
|
|
f.write(f"CK Tile Edge Dims Benchmark Batch {batch_idx}\n")
|
|
f.write("GPU ID: 0\nImplementation: gemm_universal\n\n")
|
|
count = run_shapes(bin_dir, batch, f, warmup=3, repeat=10)
|
|
total += count
|
|
|
|
print(f" Batch {batch_idx} done: {count} results", file=sys.stderr, flush=True)
|
|
|
|
print(
|
|
f"\nTotal: {total} benchmarks across {len(shapes)} shapes",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|