mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Vidyasagar Ananthan <vidyasagar.ananthan@amd.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
273 lines
9.6 KiB
Python
273 lines
9.6 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Surrogate search for CK Tile kernel configuration optimization.
|
|
|
|
Uses a trained LGBMRegressor as a cheap surrogate function to search the
|
|
discrete kernel parameter space (tile sizes, warp config, pipeline, etc.)
|
|
without running actual GPU benchmarks.
|
|
|
|
Strategies:
|
|
- 'random': Sample N random valid configs, score all, return top-K.
|
|
- 'de': Discrete Differential Evolution with mutation over valid parameter choices.
|
|
|
|
Usage:
|
|
from search import SurrogateSearch
|
|
from predict import Predictor
|
|
|
|
predictor = Predictor("models/gemm_universal_fp8_gfx950")
|
|
searcher = SurrogateSearch(predictor, strategy='random')
|
|
results = searcher.search(
|
|
problem={"m": 128, "n": 1536, "k": 7168, "dtype": "fp8", "layout": "rcr"},
|
|
budget=500,
|
|
)
|
|
# results: [(config_dict, predicted_tflops), ...] sorted descending
|
|
"""
|
|
|
|
import random
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
|
|
from feature_engine import GemmUniversalFeatureEngine
|
|
from predict import Predictor
|
|
|
|
|
|
class SurrogateSearch:
|
|
"""Search kernel parameter space using ML regressor as surrogate objective.
|
|
|
|
Parameters
|
|
----------
|
|
predictor : Predictor
|
|
Trained predictor with a TFLOPS model.
|
|
feature_engine : GemmUniversalFeatureEngine, optional
|
|
Feature engine for parameter space and validation. If None, uses default.
|
|
strategy : str
|
|
Search strategy: 'random' or 'de' (Discrete Differential Evolution).
|
|
seed : int
|
|
Random seed for reproducibility.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
predictor: Predictor,
|
|
feature_engine: Optional[GemmUniversalFeatureEngine] = None,
|
|
strategy: str = "random",
|
|
seed: int = 42,
|
|
):
|
|
self._predictor = predictor
|
|
self._fe = feature_engine or GemmUniversalFeatureEngine()
|
|
self._strategy = strategy
|
|
self._rng = random.Random(seed)
|
|
self._np_rng = np.random.RandomState(seed)
|
|
self._param_space = self._fe.get_parameter_space()
|
|
|
|
def _sample_random_config(self) -> dict:
|
|
"""Sample a single random config from the parameter space."""
|
|
config = {}
|
|
for param, values in self._param_space.items():
|
|
config[param] = self._rng.choice(values)
|
|
return config
|
|
|
|
def _sample_valid_config(self, max_attempts: int = 50) -> Optional[dict]:
|
|
"""Sample a random config that passes all validation constraints."""
|
|
for _ in range(max_attempts):
|
|
config = self._sample_random_config()
|
|
if self._fe.validate_config(config):
|
|
return config
|
|
return None
|
|
|
|
def _score_config(self, problem: dict, config: dict) -> float:
|
|
"""Score a config using the predictor."""
|
|
return self._predictor.predict_tflops(problem, config)
|
|
|
|
def _search_random(
|
|
self, problem: dict, budget: int, top_k: int
|
|
) -> list[tuple[dict, float]]:
|
|
"""Random search: sample valid configs, score all, return top-K."""
|
|
configs = []
|
|
for _ in range(budget):
|
|
cfg = self._sample_valid_config()
|
|
if cfg is not None:
|
|
configs.append(cfg)
|
|
|
|
if not configs:
|
|
return []
|
|
|
|
scored = []
|
|
for cfg in configs:
|
|
try:
|
|
score = self._score_config(problem, cfg)
|
|
scored.append((cfg, score))
|
|
except Exception:
|
|
continue
|
|
|
|
scored.sort(key=lambda x: -x[1])
|
|
return scored[:top_k]
|
|
|
|
def _search_de(
|
|
self,
|
|
problem: dict,
|
|
budget: int,
|
|
top_k: int,
|
|
pop_size: int = 20,
|
|
mutation_rate: float = 0.3,
|
|
crossover_rate: float = 0.7,
|
|
) -> list[tuple[dict, float]]:
|
|
"""Discrete Differential Evolution.
|
|
|
|
Uses discrete mutation: randomly swap parameters to other valid values
|
|
from the parameter space (no continuous relaxation + snap).
|
|
|
|
Each generation:
|
|
1. For each member of the population, create a trial vector by:
|
|
- Selecting 3 random other members (a, b, c)
|
|
- For each parameter, with probability mutation_rate, take the value
|
|
from a, b, or c (uniform choice among the three donors)
|
|
- With probability crossover_rate, take the trial value; otherwise keep original
|
|
2. Validate the trial; if invalid, resample that parameter from the space
|
|
3. Score the trial; if better than parent, replace
|
|
"""
|
|
param_names = list(self._param_space.keys())
|
|
|
|
population = []
|
|
for _ in range(pop_size):
|
|
cfg = self._sample_valid_config()
|
|
if cfg is not None:
|
|
score = self._score_config(problem, cfg)
|
|
population.append((cfg, score))
|
|
|
|
if len(population) < 4:
|
|
return self._search_random(problem, budget, top_k)
|
|
|
|
evals_used = len(population)
|
|
max_gens = (budget - evals_used) // pop_size
|
|
|
|
for gen in range(max_gens):
|
|
new_pop = []
|
|
for i, (parent, parent_score) in enumerate(population):
|
|
candidates = [j for j in range(len(population)) if j != i]
|
|
if len(candidates) < 3:
|
|
new_pop.append((parent, parent_score))
|
|
continue
|
|
|
|
a_idx, b_idx, c_idx = self._rng.sample(candidates, 3)
|
|
a, b, c = (
|
|
population[a_idx][0],
|
|
population[b_idx][0],
|
|
population[c_idx][0],
|
|
)
|
|
|
|
trial = dict(parent)
|
|
for param in param_names:
|
|
if self._rng.random() < mutation_rate:
|
|
donor = self._rng.choice([a, b, c])
|
|
trial[param] = donor.get(param, parent.get(param))
|
|
|
|
if self._rng.random() > crossover_rate:
|
|
trial[param] = parent.get(param)
|
|
|
|
if not self._fe.validate_config(trial):
|
|
for param in param_names:
|
|
if param in trial and trial[param] not in self._param_space.get(
|
|
param, [trial[param]]
|
|
):
|
|
trial[param] = self._rng.choice(self._param_space[param])
|
|
if not self._fe.validate_config(trial):
|
|
new_pop.append((parent, parent_score))
|
|
continue
|
|
|
|
try:
|
|
trial_score = self._score_config(problem, trial)
|
|
evals_used += 1
|
|
except Exception:
|
|
new_pop.append((parent, parent_score))
|
|
continue
|
|
|
|
if trial_score > parent_score:
|
|
new_pop.append((trial, trial_score))
|
|
else:
|
|
new_pop.append((parent, parent_score))
|
|
|
|
population = new_pop
|
|
|
|
population.sort(key=lambda x: -x[1])
|
|
return population[:top_k]
|
|
|
|
def search(
|
|
self,
|
|
problem: dict,
|
|
budget: int = 500,
|
|
top_k: int = 10,
|
|
**kwargs,
|
|
) -> list[tuple[dict, float]]:
|
|
"""Search the kernel parameter space for the best configuration.
|
|
|
|
Parameters
|
|
----------
|
|
problem : dict
|
|
Problem specification: m, n, k, dtype, layout, split_k.
|
|
budget : int
|
|
Maximum number of surrogate evaluations.
|
|
top_k : int
|
|
Number of top configurations to return.
|
|
**kwargs
|
|
Strategy-specific parameters (pop_size, mutation_rate, etc.).
|
|
|
|
Returns
|
|
-------
|
|
list of (config_dict, predicted_tflops), sorted descending by TFLOPS.
|
|
"""
|
|
if self._strategy == "random":
|
|
return self._search_random(problem, budget, top_k)
|
|
elif self._strategy == "de":
|
|
return self._search_de(problem, budget, top_k, **kwargs)
|
|
else:
|
|
raise ValueError(f"Unknown strategy: {self._strategy}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
import time
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Surrogate search for optimal kernel config"
|
|
)
|
|
parser.add_argument("--model_dir", required=True)
|
|
parser.add_argument("--m", type=int, required=True)
|
|
parser.add_argument("--n", type=int, required=True)
|
|
parser.add_argument("--k", type=int, required=True)
|
|
parser.add_argument("--dtype", default="fp8")
|
|
parser.add_argument("--layout", default="rcr")
|
|
parser.add_argument("--strategy", default="random", choices=["random", "de"])
|
|
parser.add_argument("--budget", type=int, default=500)
|
|
parser.add_argument("--top_k", type=int, default=10)
|
|
args = parser.parse_args()
|
|
|
|
predictor = Predictor(args.model_dir)
|
|
searcher = SurrogateSearch(predictor, strategy=args.strategy)
|
|
problem = {
|
|
"m": args.m,
|
|
"n": args.n,
|
|
"k": args.k,
|
|
"dtype": args.dtype,
|
|
"layout": args.layout,
|
|
"split_k": 1,
|
|
}
|
|
|
|
print(f"Searching with strategy={args.strategy}, budget={args.budget}...")
|
|
t0 = time.time()
|
|
results = searcher.search(problem, budget=args.budget, top_k=args.top_k)
|
|
elapsed = time.time() - t0
|
|
|
|
print(f"\nTop {len(results)} configs found in {elapsed * 1000:.1f}ms:")
|
|
for i, (cfg, tflops) in enumerate(results):
|
|
tile_str = f"{cfg.get('tile_m', '?')}x{cfg.get('tile_n', '?')}x{cfg.get('tile_k', '?')}"
|
|
warp_str = f"{cfg.get('warp_m', '?')}x{cfg.get('warp_n', '?')}x{cfg.get('warp_k', '?')}"
|
|
print(
|
|
f" #{i + 1}: {tflops:8.2f} TFLOPS tile={tile_str} warp={warp_str} "
|
|
f"pipeline={cfg.get('pipeline', '?')} scheduler={cfg.get('scheduler', '?')}"
|
|
)
|