Files
composable_kernel/dispatcher/heuristics/dispatcher_integration.py
Yaswanth Raparti 91dbdfa476 [CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676)
## Motivation

This PR adds ML-based kernel selection heuristics to the CK Tile
dispatcher, enabling fast and accurate automatic kernel selection for
Universal Gemm kernels. Instead of requiring exhaustive search through
4600+ kernel configurations (taking ~46 seconds per problem shape), the
ML heuristic predicts optimal kernels in microseconds while achieving
>98% of oracle-best performance.

## Technical Details

**ML infrastructure** 

https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics
* Feature Engine
([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)):
55-feature extraction including problem dimensions, kernel
configuration, tile efficiency, and hardware profile
* Training Pipeline
([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)):
LightGBM regression with log-transform, GroupKFold cross-validation,
warm-start support
* Predictor
([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)):
Kernel ranking and TFLOPS prediction for problem shapes
* Evaluation
([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)):
Comprehensive metrics including efficiency, NDCG@k, shape family
analysis

**Data Generation Tools:**

*
[generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py):
Build and benchmark kernels across diverse problem shapes
*
[convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py):
Convert benchmark JSON to training-ready parquet format
*
[data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py):
Parse streaming benchmark logs into canonical datasets

**Examples**
*
[09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp):
C++ example demonstrating ML-based kernel selection
*
[09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py):
Python example with validation


**Pre-trained Models
(projects/composablekernel/dispatcher/heuristics/models/):**
* gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean
efficiency)
* gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean
efficiency)


## Test Plan

* Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8
* All shape families tested: tiny M (M<8), small M, medium M, large M
(M≥1024)
* All pipeline types: compv3, compv4, mem

## Test Result

**fp16 Model (gfx950, RCR layout)**
* Mean Efficiency: 99.36%
* P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of
oracle best)
* Min Efficiency: 95.45%

**fp8 Model (gfx950, RCR layout)**
* Mean Efficiency: 98.28% (original), 97.51% (wide coverage)
* P10 Efficiency: 94.64% (original), 93.89% (wide coverage)
* Min Efficiency: 84.5%

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Vidyasagar Ananthan <vidyasagar.ananthan@amd.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-01 19:25:55 -07:00

325 lines
10 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Dispatcher integration for ML-based kernel selection.
Bridges the trained LightGBM Predictor with the CK Tile dispatcher's
kernel selection flow. Provides heuristic functions compatible with
both the Python pre-selection pattern (08_heuristics.py style) and
the C++ HeuristicFunction signature.
Name mapping between feature engine and dispatcher KernelConfig:
Feature engine Dispatcher KernelConfig
--------------------- ----------------------
warp_m (warps/block) wave_m
warp_n wave_n
warp_k wave_k
warp_tile_m warp_m
warp_tile_n warp_n
warp_tile_k warp_k
Usage:
from dispatcher_integration import create_ml_heuristic
heuristic = create_ml_heuristic("models/gemm_universal_fp8_gfx950")
best_spec = heuristic(M=1024, N=1024, K=1024, kernel_pool=KERNEL_POOL)
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from data_pipeline import parse_kernel_name
from predict import Predictor
LAYOUT_TO_DISPATCHER = {
"rcr": ("row", "col", "row"),
"rrr": ("row", "row", "row"),
"crr": ("col", "row", "row"),
"ccr": ("col", "col", "row"),
}
DTYPE_TO_C_DTYPE = {
"fp8": "fp16",
"fp16": "fp16",
"bf16": "bf16",
"fp32": "fp32",
}
@dataclass
class MLKernelSpec:
"""Kernel spec returned by the ML heuristic, compatible with the dispatcher
example pattern. Carries both the feature-engine-space config and the
dispatcher-space KernelConfig fields."""
kernel_name: str
predicted_tflops: float
tile_m: int
tile_n: int
tile_k: int
wave_m: int
wave_n: int
wave_k: int
warp_m: int
warp_n: int
warp_k: int
pipeline: str
scheduler: str
epilogue: str
pad_m: bool
pad_n: bool
pad_k: bool
persistent: bool
def kernel_config_to_feature_dict(kernel_name: str) -> dict:
"""Parse a tile-engine kernel name into a feature-engine-compatible dict.
Returns a dict with fields matching what GemmUniversalFeatureEngine.extract()
expects for the kernel parameter: tile_m/n/k, warp_m/n/k (warps per block),
warp_tile_m/n/k, pipeline, scheduler, epilogue, pad_m/n/k, persistent.
"""
parsed = parse_kernel_name(kernel_name)
if not parsed:
return {}
parsed["kernel_name"] = kernel_name
return parsed
def feature_dict_to_dispatcher_config(
feat: dict, dtype: str = "fp8", arch: str = "gfx950"
) -> dict:
"""Convert a feature-engine kernel dict to dispatcher KernelConfig fields.
Handles the naming inversion:
feature engine warp_m -> KernelConfig wave_m (warps per block)
feature engine warp_tile_m -> KernelConfig warp_m (elements per warp)
"""
layout = feat.get("layout", "rcr")
la, lb, lc = LAYOUT_TO_DISPATCHER.get(layout, ("row", "col", "row"))
c_dtype = DTYPE_TO_C_DTYPE.get(dtype, dtype)
return {
"dtype_a": dtype,
"dtype_b": dtype,
"dtype_c": c_dtype,
"dtype_acc": "fp32",
"layout_a": la,
"layout_b": lb,
"layout_c": lc,
"tile_m": feat.get("tile_m", 128),
"tile_n": feat.get("tile_n", 128),
"tile_k": feat.get("tile_k", 64),
"wave_m": feat.get("warp_m", 2),
"wave_n": feat.get("warp_n", 2),
"wave_k": feat.get("warp_k", 1),
"warp_m": feat.get("warp_tile_m", 32),
"warp_n": feat.get("warp_tile_n", 32),
"warp_k": feat.get("warp_tile_k", 16),
"pipeline": feat.get("pipeline", "compv3"),
"scheduler": feat.get("scheduler", "intrawave"),
"epilogue": feat.get("epilogue", "cshuffle"),
"pad_m": feat.get("pad_m", True),
"pad_n": feat.get("pad_n", True),
"pad_k": feat.get("pad_k", True),
"gfx_arch": arch,
}
def feature_dict_to_ml_spec(feat: dict, predicted_tflops: float = 0.0) -> MLKernelSpec:
"""Convert a feature-engine kernel dict + prediction to an MLKernelSpec."""
return MLKernelSpec(
kernel_name=feat.get("kernel_name", "unknown"),
predicted_tflops=predicted_tflops,
tile_m=feat.get("tile_m", 128),
tile_n=feat.get("tile_n", 128),
tile_k=feat.get("tile_k", 64),
wave_m=feat.get("warp_m", 2),
wave_n=feat.get("warp_n", 2),
wave_k=feat.get("warp_k", 1),
warp_m=feat.get("warp_tile_m", 32),
warp_n=feat.get("warp_tile_n", 32),
warp_k=feat.get("warp_tile_k", 16),
pipeline=feat.get("pipeline", "compv3"),
scheduler=feat.get("scheduler", "intrawave"),
epilogue=feat.get("epilogue", "cshuffle"),
pad_m=feat.get("pad_m", False),
pad_n=feat.get("pad_n", False),
pad_k=feat.get("pad_k", False),
persistent=feat.get("persistent", False),
)
def load_kernel_pool_from_binaries(bin_dir: str | Path) -> list[dict]:
"""Discover benchmark executables and parse their names into feature dicts.
Each executable name encodes the full kernel config. This creates the
candidate pool for the ML heuristic without needing a registry JSON export.
"""
bin_dir = Path(bin_dir)
configs = []
for exe in sorted(bin_dir.glob("benchmark_gemm_universal_*")):
name = exe.stem.replace("benchmark_", "")
feat = kernel_config_to_feature_dict(name)
if feat and "tile_m" in feat:
configs.append(feat)
return configs
def create_ml_heuristic(
model_dir: str | Path,
dtype: str = "fp8",
arch: str = "gfx950",
layout: str = "rcr",
kernel_pool: Optional[list[dict]] = None,
bin_dir: Optional[str | Path] = None,
):
"""Create an ML heuristic function for kernel selection.
Returns a callable with signature:
(M: int, N: int, K: int) -> MLKernelSpec
The returned function scores all candidate kernels using the trained
LightGBM regressor and returns the best one as an MLKernelSpec.
Parameters
----------
model_dir : str or Path
Path to trained model directory (must contain model_tflops.lgbm or
model_tflops_log_big.lgbm and feature_spec.json).
dtype : str
Data type for the problem (fp8, fp16, bf16).
arch : str
GPU architecture (gfx942, gfx950).
layout : str
Matrix layout (rcr, rrr, crr, ccr).
kernel_pool : list of dict, optional
Pre-parsed kernel configs. If None, loads from bin_dir.
bin_dir : str or Path, optional
Directory with benchmark executables. Used to build kernel_pool if
kernel_pool is not provided. Defaults to /workspace/ck_tile/bin.
"""
model_dir = Path(model_dir)
predictor = Predictor(model_dir)
if kernel_pool is None:
if bin_dir is None:
bin_dir = Path("/workspace/ck_tile/bin")
kernel_pool = load_kernel_pool_from_binaries(bin_dir)
if not kernel_pool:
raise ValueError(
"No kernel configs found. Check bin_dir or provide kernel_pool."
)
def heuristic(M: int, N: int, K: int) -> MLKernelSpec:
problem = {
"m": M,
"n": N,
"k": K,
"dtype": dtype,
"layout": layout,
"split_k": 1,
}
ranked = predictor.rank_kernels(problem, kernel_pool)
if not ranked:
feat = kernel_pool[0]
return feature_dict_to_ml_spec(feat, 0.0)
best_name, best_tflops = ranked[0]
best_feat = next(
(kp for kp in kernel_pool if kp.get("kernel_name") == best_name),
kernel_pool[0],
)
return feature_dict_to_ml_spec(best_feat, best_tflops)
return heuristic
def create_ranked_heuristic(
model_dir: str | Path,
dtype: str = "fp8",
arch: str = "gfx950",
layout: str = "rcr",
kernel_pool: Optional[list[dict]] = None,
bin_dir: Optional[str | Path] = None,
top_k: int = 5,
):
"""Create an ML heuristic that returns the top-K ranked kernel specs.
Returns a callable with signature:
(M: int, N: int, K: int) -> list[MLKernelSpec]
Useful when you want fallback options if the top-1 kernel fails to build.
"""
model_dir = Path(model_dir)
predictor = Predictor(model_dir)
if kernel_pool is None:
if bin_dir is None:
bin_dir = Path("/workspace/ck_tile/bin")
kernel_pool = load_kernel_pool_from_binaries(bin_dir)
name_to_feat = {kp.get("kernel_name", ""): kp for kp in kernel_pool}
def heuristic(M: int, N: int, K: int) -> list[MLKernelSpec]:
problem = {
"m": M,
"n": N,
"k": K,
"dtype": dtype,
"layout": layout,
"split_k": 1,
}
ranked = predictor.rank_kernels(problem, kernel_pool)
results = []
for name, tflops in ranked[:top_k]:
feat = name_to_feat.get(name, kernel_pool[0])
results.append(feature_dict_to_ml_spec(feat, tflops))
return results
return heuristic
def ml_spec_to_dispatcher_config(
spec: MLKernelSpec, dtype: str = "fp8", arch: str = "gfx950"
) -> dict:
"""Convert an MLKernelSpec to a dict compatible with ctypes_utils.KernelConfig."""
layout_a, layout_b, layout_c = "row", "col", "row"
c_dtype = DTYPE_TO_C_DTYPE.get(dtype, dtype)
return {
"dtype_a": dtype,
"dtype_b": dtype,
"dtype_c": c_dtype,
"dtype_acc": "fp32",
"layout_a": layout_a,
"layout_b": layout_b,
"layout_c": layout_c,
"tile_m": spec.tile_m,
"tile_n": spec.tile_n,
"tile_k": spec.tile_k,
"wave_m": spec.wave_m,
"wave_n": spec.wave_n,
"wave_k": spec.wave_k,
"warp_m": spec.warp_m,
"warp_n": spec.warp_n,
"warp_k": spec.warp_k,
"pipeline": spec.pipeline,
"scheduler": spec.scheduler,
"epilogue": spec.epilogue,
"pad_m": spec.pad_m,
"pad_n": spec.pad_n,
"pad_k": spec.pad_k,
"gfx_arch": arch,
}