mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Vidyasagar Ananthan <vidyasagar.ananthan@amd.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
244 lines
8.7 KiB
Python
244 lines
8.7 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Predictor for CK Tile kernel performance.
|
|
|
|
Loads trained LightGBM models and provides:
|
|
- predict_tflops(): predicted TFLOPS for a single (problem, kernel) pair
|
|
- predict_latency(): predicted latency in ms
|
|
- predict_bandwidth(): predicted bandwidth in GB/s
|
|
- predict_all(): all three predictions at once
|
|
- rank_kernels(): rank all candidate kernels by predicted TFLOPS
|
|
- select_best(): return the best kernel ID
|
|
|
|
Usage:
|
|
predictor = Predictor("models/gemm_universal_fp8_gfx950")
|
|
best_kernel = predictor.select_best(
|
|
problem={"m": 128, "n": 1536, "k": 7168, "dtype": "fp8", "layout": "rcr"},
|
|
kernel_configs=[...],
|
|
)
|
|
"""
|
|
|
|
import gzip
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import lightgbm as lgb
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from feature_engine import GemmUniversalFeatureEngine
|
|
|
|
|
|
class Predictor:
|
|
"""Loads trained models and feature spec for kernel performance prediction.
|
|
|
|
Parameters
|
|
----------
|
|
model_dir : str or Path
|
|
Directory containing model artifacts:
|
|
- model_tflops.lgbm (required)
|
|
- model_latency.lgbm (optional)
|
|
- model_bandwidth.lgbm (optional)
|
|
- feature_spec.json (required)
|
|
|
|
feature_engine : FeatureEngine, optional
|
|
Override the feature engine. If None, constructs one from feature_spec.json.
|
|
"""
|
|
|
|
def __init__(self, model_dir: str | Path, feature_engine=None):
|
|
self._model_dir = Path(model_dir)
|
|
self._models: dict[str, lgb.Booster] = {}
|
|
|
|
spec_path = self._model_dir / "feature_spec.json"
|
|
if spec_path.exists():
|
|
with open(spec_path) as f:
|
|
self._spec = json.load(f)
|
|
else:
|
|
self._spec = {}
|
|
|
|
self._log_targets = set(self._spec.get("log_targets", []))
|
|
|
|
if feature_engine is not None:
|
|
self._feature_engine = feature_engine
|
|
else:
|
|
self._feature_engine = GemmUniversalFeatureEngine()
|
|
|
|
def _load_model(self, target: str) -> Optional[lgb.Booster]:
|
|
"""Lazy-load a model for the given target.
|
|
|
|
Automatically decompresses .lgbm.gz files if the .lgbm file doesn't exist.
|
|
The decompressed file is cached to disk for subsequent loads.
|
|
"""
|
|
if target in self._models:
|
|
return self._models[target]
|
|
|
|
path = self._model_dir / f"model_{target}.lgbm"
|
|
gz_path = self._model_dir / f"model_{target}.lgbm.gz"
|
|
|
|
# Auto-decompress if needed
|
|
if not path.exists() and gz_path.exists():
|
|
with gzip.open(gz_path, 'rb') as f_in:
|
|
with open(path, 'wb') as f_out:
|
|
f_out.write(f_in.read())
|
|
|
|
if not path.exists():
|
|
return None
|
|
|
|
model = lgb.Booster(model_file=str(path))
|
|
self._models[target] = model
|
|
return model
|
|
|
|
def _predict_single(self, target: str, problem: dict, kernel_config: dict) -> float:
|
|
"""Predict a single target value, applying inverse log transform if needed."""
|
|
model = self._load_model(target)
|
|
if model is None:
|
|
raise FileNotFoundError(f"No model_{target}.lgbm in {self._model_dir}")
|
|
features = self._feature_engine.extract(problem, kernel_config)
|
|
raw = float(model.predict(features.reshape(1, -1))[0])
|
|
if target in self._log_targets:
|
|
return float(np.expm1(raw))
|
|
# Clamp to non-negative even for non-log models
|
|
return float(max(0.0, raw))
|
|
|
|
def predict_tflops(self, problem: dict, kernel_config: dict) -> float:
|
|
"""Predict TFLOPS for a single (problem, kernel) pair.
|
|
|
|
Returns a real TFLOPS estimate (interpretable, usable as DE surrogate).
|
|
If the model was trained in log-space, the inverse transform is applied
|
|
automatically.
|
|
"""
|
|
return self._predict_single("tflops", problem, kernel_config)
|
|
|
|
def predict_latency(self, problem: dict, kernel_config: dict) -> float:
|
|
"""Predict latency in milliseconds for a single (problem, kernel) pair."""
|
|
return self._predict_single("latency", problem, kernel_config)
|
|
|
|
def predict_bandwidth(self, problem: dict, kernel_config: dict) -> float:
|
|
"""Predict bandwidth in GB/s for a single (problem, kernel) pair."""
|
|
return self._predict_single("bandwidth", problem, kernel_config)
|
|
|
|
def predict_all(self, problem: dict, kernel_config: dict) -> dict[str, float]:
|
|
"""Predict all available targets for a single (problem, kernel) pair.
|
|
|
|
Returns dict with keys 'tflops', 'latency_ms', 'bandwidth_gb_s' (if models exist).
|
|
|
|
Note: Applies inverse log transform for targets in log_targets and clamps
|
|
negatives to 0.0, consistent with _predict_single().
|
|
"""
|
|
features = self._feature_engine.extract(problem, kernel_config).reshape(1, -1)
|
|
result = {}
|
|
for target, key in [
|
|
("tflops", "tflops"),
|
|
("latency", "latency_ms"),
|
|
("bandwidth", "bandwidth_gb_s"),
|
|
]:
|
|
model = self._load_model(target)
|
|
if model is not None:
|
|
raw = float(model.predict(features)[0])
|
|
# Apply inverse log transform if model was trained in log-space
|
|
if target in self._log_targets:
|
|
result[key] = float(np.expm1(raw))
|
|
else:
|
|
# Clamp to non-negative even for non-log models
|
|
result[key] = float(max(0.0, raw))
|
|
return result
|
|
|
|
def rank_kernels(
|
|
self, problem: dict, kernel_configs: list[dict]
|
|
) -> list[tuple[str, float]]:
|
|
"""Rank candidate kernels by predicted TFLOPS (descending).
|
|
|
|
Parameters
|
|
----------
|
|
problem : dict
|
|
Problem specification with keys: m, n, k, dtype, layout, split_k.
|
|
kernel_configs : list of dict
|
|
Each dict must have a 'kernel_name' key plus kernel parameters.
|
|
|
|
Returns
|
|
-------
|
|
list of (kernel_name, predicted_tflops) tuples, sorted descending.
|
|
"""
|
|
if not kernel_configs:
|
|
return []
|
|
|
|
model = self._load_model("tflops")
|
|
if model is None:
|
|
raise FileNotFoundError(f"No model_tflops.lgbm in {self._model_dir}")
|
|
|
|
rows = []
|
|
for kc in kernel_configs:
|
|
merged = {**problem, **kc}
|
|
rows.append(merged)
|
|
|
|
df = pd.DataFrame(rows)
|
|
X = self._feature_engine.extract_batch(df)
|
|
preds = model.predict(X)
|
|
if "tflops" in self._log_targets:
|
|
preds = np.expm1(preds)
|
|
|
|
results = []
|
|
for i, kc in enumerate(kernel_configs):
|
|
name = kc.get("kernel_name", f"kernel_{i}")
|
|
results.append((name, float(preds[i])))
|
|
|
|
results.sort(key=lambda x: -x[1])
|
|
return results
|
|
|
|
def select_best(self, problem: dict, kernel_configs: list[dict]) -> str:
|
|
"""Return the kernel_name of the best predicted kernel."""
|
|
ranked = self.rank_kernels(problem, kernel_configs)
|
|
if not ranked:
|
|
raise ValueError("No kernel configs provided")
|
|
return ranked[0][0]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Predict kernel performance")
|
|
parser.add_argument(
|
|
"--model_dir", required=True, help="Directory with trained models"
|
|
)
|
|
parser.add_argument("--m", type=int, required=True)
|
|
parser.add_argument("--n", type=int, required=True)
|
|
parser.add_argument("--k", type=int, required=True)
|
|
parser.add_argument("--layout", default="rcr")
|
|
parser.add_argument("--dtype", default="fp8")
|
|
args = parser.parse_args()
|
|
|
|
predictor = Predictor(args.model_dir)
|
|
problem = {
|
|
"m": args.m,
|
|
"n": args.n,
|
|
"k": args.k,
|
|
"dtype": args.dtype,
|
|
"layout": args.layout,
|
|
"split_k": 1,
|
|
}
|
|
|
|
print(f"Loading models from {args.model_dir}...")
|
|
print(
|
|
f"Problem: M={args.m} N={args.n} K={args.k} dtype={args.dtype} layout={args.layout}"
|
|
)
|
|
|
|
data_dir = Path(args.model_dir).parent.parent / "data"
|
|
if data_dir.exists():
|
|
for pq in data_dir.glob("*.parquet"):
|
|
df = pd.read_parquet(pq)
|
|
kernel_names = df["kernel_name"].unique()
|
|
configs = []
|
|
for kn in kernel_names[:10]:
|
|
row = df[df["kernel_name"] == kn].iloc[0]
|
|
configs.append(row.to_dict())
|
|
if configs:
|
|
ranked = predictor.rank_kernels(problem, configs)
|
|
print(f"\nTop 5 kernels (from {len(configs)} candidates):")
|
|
for name, tflops in ranked[:5]:
|
|
print(f" {tflops:8.2f} TFLOPS {name}")
|
|
break
|