Files
composable_kernel/dispatcher/examples/gemm/python/09_ml_heuristic.py
Yaswanth Raparti c1127a36f5 [rocm-libraries] ROCm/rocm-libraries#5676 (commit 1d18339)
[CK][CK TILE]Autotuning heuristics infra for universal GEMM
 kernel selection (#5676)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

This PR adds ML-based kernel selection heuristics to the CK Tile
dispatcher, enabling fast and accurate automatic kernel selection for
Universal Gemm kernels. Instead of requiring exhaustive search through
4600+ kernel configurations (taking ~46 seconds per problem shape), the
ML heuristic predicts optimal kernels in microseconds while achieving
>98% of oracle-best performance.

## Technical Details

**ML infrastructure**

https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics
* Feature Engine
([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)):
55-feature extraction including problem dimensions, kernel
configuration, tile efficiency, and hardware profile
* Training Pipeline
([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)):
LightGBM regression with log-transform, GroupKFold cross-validation,
warm-start support
* Predictor
([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)):
Kernel ranking and TFLOPS prediction for problem shapes
* Evaluation
([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)):
Comprehensive metrics including efficiency, NDCG@k, shape family
analysis

**Data Generation Tools:**

*
[generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py):
Build and benchmark kernels across diverse problem shapes
*
[convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py):
Convert benchmark JSON to training-ready parquet format
*
[data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py):
Parse streaming benchmark logs into canonical datasets

**Examples**
*
[09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp):
C++ example demonstrating ML-based kernel selection
*
[09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py):
Python example with validation

**Pre-trained Models
(projects/composablekernel/dispatcher/heuristics/models/):**
* gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean
efficiency)
* gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean
efficiency)

## Test Plan

* Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8
* All shape families tested: tiny M (M<8), small M, medium M, large M
(M≥1024)
* All pipeline types: compv3, compv4, mem

## Test Result

**fp16 Model (gfx950, RCR layout)**
* Mean Efficiency: 99.36%
* P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of
oracle best)
* Min Efficiency: 95.45%

**fp8 Model (gfx950, RCR layout)**
* Mean Efficiency: 98.28% (original), 97.51% (wide coverage)
* P10 Efficiency: 94.64% (original), 93.89% (wide coverage)
* Min Efficiency: 84.5%

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-02 02:26:32 +00:00

306 lines
9.7 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 09: ML-Based Kernel Selection
Uses a trained LightGBM model to select the optimal kernel for each problem
size. The model predicts TFLOPS for every candidate in the kernel pool and
picks the highest-scoring one, which is then JIT-compiled and run.
This replaces the hand-crafted rules in 08_heuristics.py with a data-driven
approach achieving 97-98% of oracle-best TFLOPS efficiency.
Complexity: *****
Prerequisites:
- Trained model in dispatcher/heuristics/models/gemm_universal_fp8_gfx950/
- lightgbm, pandas, numpy, pyarrow installed
Usage:
python3 09_ml_heuristic.py
python3 09_ml_heuristic.py --dtype fp16 --arch gfx942
"""
import sys
import argparse
import time
from pathlib import Path
from dataclasses import dataclass
from typing import List
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "heuristics"))
import numpy as np
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
)
from predict import Predictor
@dataclass
class KernelSpec:
"""Kernel specification -- same structure as 08_heuristics.py"""
name: str
tile_m: int
tile_n: int
tile_k: int
pipeline: str = "compv3"
scheduler: str = "intrawave"
wave_m: int = 2
wave_n: int = 2
wave_k: int = 1
warp_m: int = 32
warp_n: int = 32
warp_k: int = 16
# Kernel pool: representative configs spanning small to large tiles,
# compv3/compv4/mem pipelines, and intrawave/interwave schedulers.
KERNEL_POOL = [
# Small tiles
KernelSpec("s_64x64_k32_v3", 64, 64, 32, "compv3", warp_m=16, warp_n=16),
KernelSpec("s_64x64_k64_v3", 64, 64, 64, "compv3", warp_m=16, warp_n=16),
KernelSpec("s_64x64_k128_v3", 64, 64, 128, "compv3", warp_m=16, warp_n=16),
KernelSpec("s_64x64_k32_v4", 64, 64, 32, "compv4", warp_m=16, warp_n=16),
KernelSpec("s_64x64_k64_mem", 64, 64, 64, "mem", warp_m=16, warp_n=16),
KernelSpec("s_64x64_k128_mem", 64, 64, 128, "mem", warp_m=16, warp_n=16),
# Medium tiles
KernelSpec("m_128x128_k32_v3", 128, 128, 32, "compv3"),
KernelSpec("m_128x128_k64_v3", 128, 128, 64, "compv3"),
KernelSpec("m_128x128_k128_v3", 128, 128, 128, "compv3"),
KernelSpec("m_128x128_k32_v4", 128, 128, 32, "compv4"),
KernelSpec("m_128x128_k64_v4", 128, 128, 64, "compv4"),
KernelSpec("m_128x128_k64_mem", 128, 128, 64, "mem"),
KernelSpec("m_128x128_k128_mem", 128, 128, 128, "mem"),
# Rectangular medium
KernelSpec("r_64x128_k32", 64, 128, 32, "compv3", warp_m=16),
KernelSpec("r_128x64_k32", 128, 64, 32, "compv3", warp_n=16),
KernelSpec("r_64x128_k64", 64, 128, 64, "compv3", warp_m=16),
KernelSpec("r_128x64_k64", 128, 64, 64, "compv3", warp_n=16),
# Large tiles
KernelSpec("l_256x128_k32", 256, 128, 32, "compv3"),
KernelSpec("l_128x256_k32", 128, 256, 32, "compv3"),
KernelSpec("l_256x256_k32", 256, 256, 32, "compv3"),
KernelSpec("l_256x256_k64", 256, 256, 64, "compv3"),
# Interwave variants
KernelSpec("m_128x128_k64_iw", 128, 128, 64, "compv3", "interwave"),
KernelSpec("m_128x128_k64_mem_iw", 128, 128, 64, "mem", "interwave"),
]
def spec_to_feature_dict(spec: KernelSpec, dtype: str, layout: str) -> dict:
"""Convert a KernelSpec to the dict format the feature engine expects.
Note: pad_m/n/k default to True to match KernelConfig defaults and actual
compiled kernels. This ensures the ML model receives the correct padding
flags that will be used during JIT compilation.
"""
return {
"kernel_name": spec.name,
"tile_m": spec.tile_m,
"tile_n": spec.tile_n,
"tile_k": spec.tile_k,
"warp_m": spec.wave_m,
"warp_n": spec.wave_n,
"warp_k": spec.wave_k,
"warp_tile_m": spec.warp_m,
"warp_tile_n": spec.warp_n,
"warp_tile_k": spec.warp_k,
"pipeline": spec.pipeline,
"scheduler": spec.scheduler,
"epilogue": "cshuffle",
"pad_m": True, # Match KernelConfig default
"pad_n": True, # Match KernelConfig default
"pad_k": True, # Match KernelConfig default
"persistent": False,
"dtype": dtype,
"layout": layout,
}
def spec_to_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
"""Convert a KernelSpec to the dispatcher's KernelConfig for JIT compilation."""
return KernelConfig(
dtype_a=dtype,
dtype_b=dtype,
dtype_c=dtype,
dtype_acc="fp32",
layout_a="row",
layout_b="col",
layout_c="row",
tile_m=spec.tile_m,
tile_n=spec.tile_n,
tile_k=spec.tile_k,
wave_m=spec.wave_m,
wave_n=spec.wave_n,
wave_k=spec.wave_k,
warp_m=spec.warp_m,
warp_n=spec.warp_n,
warp_k=spec.warp_k,
pipeline=spec.pipeline,
scheduler=spec.scheduler,
epilogue="cshuffle",
gfx_arch=arch,
)
def ml_select_kernel(
predictor: Predictor,
pool: List[KernelSpec],
M: int,
N: int,
K: int,
dtype: str,
layout: str,
) -> tuple:
"""Score all kernels in the pool and return (best_spec, predicted_tflops)."""
problem = {"m": M, "n": N, "k": K, "dtype": dtype, "layout": layout, "split_k": 1}
kernel_dicts = [spec_to_feature_dict(s, dtype, layout) for s in pool]
ranked = predictor.rank_kernels(problem, kernel_dicts)
if not ranked:
return pool[0], 0.0
best_name, best_tflops = ranked[0]
best_spec = next((s for s in pool if s.name == best_name), pool[0])
return best_spec, best_tflops
def main():
parser = argparse.ArgumentParser(description="ML-based kernel selection for GEMM")
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16", "fp8"])
parser.add_argument("--arch", default="gfx942")
parser.add_argument(
"--model_dir",
default=str(
Path(__file__).parent.parent.parent.parent
/ "heuristics"
/ "models"
/ "gemm_universal_fp8_gfx950"
),
)
parser.add_argument(
"--no_run", action="store_true", help="Only predict, don't run GEMMs"
)
args = parser.parse_args()
print("=" * 75)
print(" Example 09: ML-Based Kernel Selection")
print("=" * 75)
print(f"\n Model: {args.model_dir}")
print(f" Dtype: {args.dtype}")
print(f" Arch: {args.arch}")
print(f" Pool: {len(KERNEL_POOL)} kernels")
predictor = Predictor(args.model_dir)
print(" Model loaded successfully")
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float16
test_sizes = [
(128, 128, 64),
(256, 256, 128),
(512, 512, 256),
(1024, 1024, 512),
(2048, 2048, 1024),
]
header = f"{'Shape':<20} {'Selected Kernel':<25} {'Pred TFLOPS':>12}"
if not args.no_run:
header += f" {'Time (ms)':>10} {'TFLOPS':>10} {'Status':<8}"
print(f"\n {header}")
print(" " + "-" * len(header))
results = []
for M, N, K in test_sizes:
t0 = time.time()
best_spec, pred_tflops = ml_select_kernel(
predictor, KERNEL_POOL, M, N, K, args.dtype, "rcr"
)
_ = (time.time() - t0) * 1000 # ML selection time (unused)
size_str = f"{M}x{N}x{K}"
line = f" {size_str:<20} {best_spec.name:<25} {pred_tflops:>12.2f}"
if args.no_run:
print(line)
results.append((size_str, best_spec.name, True, 0, pred_tflops))
continue
config = spec_to_kernel_config(best_spec, args.dtype, args.arch)
setup = setup_gemm_dispatcher(
config=config,
registry_name=f"ml_{best_spec.name}",
verbose=False,
auto_rebuild=True,
)
if not setup.success:
line += f" {'N/A':>10} {'N/A':>10} {'BUILD':>8}"
print(line)
results.append((size_str, best_spec.name, False, 0, 0))
cleanup_gemm()
continue
dispatcher = setup.dispatcher
if not dispatcher.is_supported(M, N, K):
line += f" {'N/A':>10} {'N/A':>10} {'UNSUP':>8}"
print(line)
results.append((size_str, best_spec.name, False, 0, 0))
cleanup_gemm()
continue
np.random.seed(42)
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
result = dispatcher.run(A, B, M, N, K)
if result.success:
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(
np_dtype
)
max_err = np.max(np.abs(result.output - C_ref))
passed = max_err < 1e-2
status = "PASS" if passed else "FAIL"
line += f" {result.time_ms:>10.4f} {result.tflops:>10.2f} {status:<8}"
results.append(
(size_str, best_spec.name, passed, result.time_ms, result.tflops)
)
else:
line += f" {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
results.append((size_str, best_spec.name, False, 0, 0))
print(line)
cleanup_gemm()
# Summary
print("\n" + "=" * 75)
print(" SUMMARY")
print("=" * 75)
passed = sum(1 for r in results if r[2])
print(f"\n Results: {passed}/{len(results)} tests passed")
valid = [r for r in results if r[2] and r[4] > 0]
if valid:
avg = sum(r[4] for r in valid) / len(valid)
print(f" Average TFLOPS: {avg:.2f}")
if passed == len(results):
print("\n *** ALL TESTS PASSED ***")
print("=" * 75)
return 0 if passed == len(results) else 1
if __name__ == "__main__":
sys.exit(main())