Files
Yaswanth Raparti c1127a36f5 [rocm-libraries] ROCm/rocm-libraries#5676 (commit 1d18339)
[CK][CK TILE]Autotuning heuristics infra for universal GEMM
 kernel selection (#5676)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

This PR adds ML-based kernel selection heuristics to the CK Tile
dispatcher, enabling fast and accurate automatic kernel selection for
Universal Gemm kernels. Instead of requiring exhaustive search through
4600+ kernel configurations (taking ~46 seconds per problem shape), the
ML heuristic predicts optimal kernels in microseconds while achieving
>98% of oracle-best performance.

## Technical Details

**ML infrastructure**

https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics
* Feature Engine
([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)):
55-feature extraction including problem dimensions, kernel
configuration, tile efficiency, and hardware profile
* Training Pipeline
([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)):
LightGBM regression with log-transform, GroupKFold cross-validation,
warm-start support
* Predictor
([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)):
Kernel ranking and TFLOPS prediction for problem shapes
* Evaluation
([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)):
Comprehensive metrics including efficiency, NDCG@k, shape family
analysis

**Data Generation Tools:**

*
[generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py):
Build and benchmark kernels across diverse problem shapes
*
[convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py):
Convert benchmark JSON to training-ready parquet format
*
[data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py):
Parse streaming benchmark logs into canonical datasets

**Examples**
*
[09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp):
C++ example demonstrating ML-based kernel selection
*
[09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py):
Python example with validation

**Pre-trained Models
(projects/composablekernel/dispatcher/heuristics/models/):**
* gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean
efficiency)
* gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean
efficiency)

## Test Plan

* Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8
* All shape families tested: tiny M (M<8), small M, medium M, large M
(M≥1024)
* All pipeline types: compv3, compv4, mem

## Test Result

**fp16 Model (gfx950, RCR layout)**
* Mean Efficiency: 99.36%
* P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of
oracle best)
* Min Efficiency: 95.45%

**fp8 Model (gfx950, RCR layout)**
* Mean Efficiency: 98.28% (original), 97.51% (wide coverage)
* P10 Efficiency: 94.64% (original), 93.89% (wide coverage)
* Min Efficiency: 84.5%

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-02 02:26:32 +00:00

273 lines
9.6 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Surrogate search for CK Tile kernel configuration optimization.
Uses a trained LGBMRegressor as a cheap surrogate function to search the
discrete kernel parameter space (tile sizes, warp config, pipeline, etc.)
without running actual GPU benchmarks.
Strategies:
- 'random': Sample N random valid configs, score all, return top-K.
- 'de': Discrete Differential Evolution with mutation over valid parameter choices.
Usage:
from search import SurrogateSearch
from predict import Predictor
predictor = Predictor("models/gemm_universal_fp8_gfx950")
searcher = SurrogateSearch(predictor, strategy='random')
results = searcher.search(
problem={"m": 128, "n": 1536, "k": 7168, "dtype": "fp8", "layout": "rcr"},
budget=500,
)
# results: [(config_dict, predicted_tflops), ...] sorted descending
"""
import random
from typing import Optional
import numpy as np
from feature_engine import GemmUniversalFeatureEngine
from predict import Predictor
class SurrogateSearch:
"""Search kernel parameter space using ML regressor as surrogate objective.
Parameters
----------
predictor : Predictor
Trained predictor with a TFLOPS model.
feature_engine : GemmUniversalFeatureEngine, optional
Feature engine for parameter space and validation. If None, uses default.
strategy : str
Search strategy: 'random' or 'de' (Discrete Differential Evolution).
seed : int
Random seed for reproducibility.
"""
def __init__(
self,
predictor: Predictor,
feature_engine: Optional[GemmUniversalFeatureEngine] = None,
strategy: str = "random",
seed: int = 42,
):
self._predictor = predictor
self._fe = feature_engine or GemmUniversalFeatureEngine()
self._strategy = strategy
self._rng = random.Random(seed)
self._np_rng = np.random.RandomState(seed)
self._param_space = self._fe.get_parameter_space()
def _sample_random_config(self) -> dict:
"""Sample a single random config from the parameter space."""
config = {}
for param, values in self._param_space.items():
config[param] = self._rng.choice(values)
return config
def _sample_valid_config(self, max_attempts: int = 50) -> Optional[dict]:
"""Sample a random config that passes all validation constraints."""
for _ in range(max_attempts):
config = self._sample_random_config()
if self._fe.validate_config(config):
return config
return None
def _score_config(self, problem: dict, config: dict) -> float:
"""Score a config using the predictor."""
return self._predictor.predict_tflops(problem, config)
def _search_random(
self, problem: dict, budget: int, top_k: int
) -> list[tuple[dict, float]]:
"""Random search: sample valid configs, score all, return top-K."""
configs = []
for _ in range(budget):
cfg = self._sample_valid_config()
if cfg is not None:
configs.append(cfg)
if not configs:
return []
scored = []
for cfg in configs:
try:
score = self._score_config(problem, cfg)
scored.append((cfg, score))
except Exception:
continue
scored.sort(key=lambda x: -x[1])
return scored[:top_k]
def _search_de(
self,
problem: dict,
budget: int,
top_k: int,
pop_size: int = 20,
mutation_rate: float = 0.3,
crossover_rate: float = 0.7,
) -> list[tuple[dict, float]]:
"""Discrete Differential Evolution.
Uses discrete mutation: randomly swap parameters to other valid values
from the parameter space (no continuous relaxation + snap).
Each generation:
1. For each member of the population, create a trial vector by:
- Selecting 3 random other members (a, b, c)
- For each parameter, with probability mutation_rate, take the value
from a, b, or c (uniform choice among the three donors)
- With probability crossover_rate, take the trial value; otherwise keep original
2. Validate the trial; if invalid, resample that parameter from the space
3. Score the trial; if better than parent, replace
"""
param_names = list(self._param_space.keys())
population = []
for _ in range(pop_size):
cfg = self._sample_valid_config()
if cfg is not None:
score = self._score_config(problem, cfg)
population.append((cfg, score))
if len(population) < 4:
return self._search_random(problem, budget, top_k)
evals_used = len(population)
max_gens = (budget - evals_used) // pop_size
for gen in range(max_gens):
new_pop = []
for i, (parent, parent_score) in enumerate(population):
candidates = [j for j in range(len(population)) if j != i]
if len(candidates) < 3:
new_pop.append((parent, parent_score))
continue
a_idx, b_idx, c_idx = self._rng.sample(candidates, 3)
a, b, c = (
population[a_idx][0],
population[b_idx][0],
population[c_idx][0],
)
trial = dict(parent)
for param in param_names:
if self._rng.random() < mutation_rate:
donor = self._rng.choice([a, b, c])
trial[param] = donor.get(param, parent.get(param))
if self._rng.random() > crossover_rate:
trial[param] = parent.get(param)
if not self._fe.validate_config(trial):
for param in param_names:
if param in trial and trial[param] not in self._param_space.get(
param, [trial[param]]
):
trial[param] = self._rng.choice(self._param_space[param])
if not self._fe.validate_config(trial):
new_pop.append((parent, parent_score))
continue
try:
trial_score = self._score_config(problem, trial)
evals_used += 1
except Exception:
new_pop.append((parent, parent_score))
continue
if trial_score > parent_score:
new_pop.append((trial, trial_score))
else:
new_pop.append((parent, parent_score))
population = new_pop
population.sort(key=lambda x: -x[1])
return population[:top_k]
def search(
self,
problem: dict,
budget: int = 500,
top_k: int = 10,
**kwargs,
) -> list[tuple[dict, float]]:
"""Search the kernel parameter space for the best configuration.
Parameters
----------
problem : dict
Problem specification: m, n, k, dtype, layout, split_k.
budget : int
Maximum number of surrogate evaluations.
top_k : int
Number of top configurations to return.
**kwargs
Strategy-specific parameters (pop_size, mutation_rate, etc.).
Returns
-------
list of (config_dict, predicted_tflops), sorted descending by TFLOPS.
"""
if self._strategy == "random":
return self._search_random(problem, budget, top_k)
elif self._strategy == "de":
return self._search_de(problem, budget, top_k, **kwargs)
else:
raise ValueError(f"Unknown strategy: {self._strategy}")
if __name__ == "__main__":
import argparse
import time
parser = argparse.ArgumentParser(
description="Surrogate search for optimal kernel config"
)
parser.add_argument("--model_dir", required=True)
parser.add_argument("--m", type=int, required=True)
parser.add_argument("--n", type=int, required=True)
parser.add_argument("--k", type=int, required=True)
parser.add_argument("--dtype", default="fp8")
parser.add_argument("--layout", default="rcr")
parser.add_argument("--strategy", default="random", choices=["random", "de"])
parser.add_argument("--budget", type=int, default=500)
parser.add_argument("--top_k", type=int, default=10)
args = parser.parse_args()
predictor = Predictor(args.model_dir)
searcher = SurrogateSearch(predictor, strategy=args.strategy)
problem = {
"m": args.m,
"n": args.n,
"k": args.k,
"dtype": args.dtype,
"layout": args.layout,
"split_k": 1,
}
print(f"Searching with strategy={args.strategy}, budget={args.budget}...")
t0 = time.time()
results = searcher.search(problem, budget=args.budget, top_k=args.top_k)
elapsed = time.time() - t0
print(f"\nTop {len(results)} configs found in {elapsed * 1000:.1f}ms:")
for i, (cfg, tflops) in enumerate(results):
tile_str = f"{cfg.get('tile_m', '?')}x{cfg.get('tile_n', '?')}x{cfg.get('tile_k', '?')}"
warp_str = f"{cfg.get('warp_m', '?')}x{cfg.get('warp_n', '?')}x{cfg.get('warp_k', '?')}"
print(
f" #{i + 1}: {tflops:8.2f} TFLOPS tile={tile_str} warp={warp_str} "
f"pipeline={cfg.get('pipeline', '?')} scheduler={cfg.get('scheduler', '?')}"
)