mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
193 lines
6.2 KiB
Python
193 lines
6.2 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Tests for search.py.
|
|
|
|
Covers: random search, DE search, config validity, result ordering,
|
|
budget compliance, and edge cases.
|
|
"""
|
|
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import lightgbm as lgb
|
|
import numpy as np
|
|
import pytest
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
|
|
from feature_engine import GemmUniversalFeatureEngine
|
|
from predict import Predictor
|
|
from search import SurrogateSearch
|
|
|
|
|
|
@pytest.fixture
|
|
def model_dir(tmp_path):
|
|
"""Create a minimal trained model."""
|
|
fe = GemmUniversalFeatureEngine()
|
|
n_features = len(fe.get_feature_names())
|
|
np.random.seed(42)
|
|
X = np.random.rand(200, n_features)
|
|
y = np.random.rand(200) * 500
|
|
model = lgb.LGBMRegressor(n_estimators=10, verbose=-1)
|
|
model.fit(X, y)
|
|
model.booster_.save_model(str(tmp_path / "model_tflops.lgbm"))
|
|
spec = {
|
|
"feature_names": fe.get_feature_names(),
|
|
"categorical_features": fe.get_categorical_features(),
|
|
}
|
|
with open(tmp_path / "feature_spec.json", "w") as f:
|
|
json.dump(spec, f)
|
|
return tmp_path
|
|
|
|
|
|
@pytest.fixture
|
|
def predictor(model_dir):
|
|
return Predictor(model_dir)
|
|
|
|
|
|
def _problem():
|
|
return {
|
|
"m": 1024,
|
|
"n": 1024,
|
|
"k": 1024,
|
|
"dtype": "fp8",
|
|
"layout": "rcr",
|
|
"split_k": 1,
|
|
}
|
|
|
|
|
|
class TestRandomSearch:
|
|
def test_returns_results(self, predictor):
|
|
searcher = SurrogateSearch(predictor, strategy="random")
|
|
results = searcher.search(_problem(), budget=50, top_k=5)
|
|
assert len(results) > 0
|
|
assert len(results) <= 5
|
|
|
|
def test_results_sorted_descending(self, predictor):
|
|
searcher = SurrogateSearch(predictor, strategy="random")
|
|
results = searcher.search(_problem(), budget=100, top_k=10)
|
|
scores = [s for _, s in results]
|
|
assert scores == sorted(scores, reverse=True)
|
|
|
|
def test_configs_are_valid(self, predictor):
|
|
fe = GemmUniversalFeatureEngine()
|
|
searcher = SurrogateSearch(predictor, feature_engine=fe, strategy="random")
|
|
results = searcher.search(_problem(), budget=50, top_k=5)
|
|
for cfg, _ in results:
|
|
ps = fe.get_parameter_space()
|
|
for k, v in cfg.items():
|
|
if k in ps:
|
|
assert v in ps[k], f"{k}={v} not in {ps[k]}"
|
|
|
|
def test_respects_top_k(self, predictor):
|
|
searcher = SurrogateSearch(predictor, strategy="random")
|
|
results = searcher.search(_problem(), budget=100, top_k=3)
|
|
assert len(results) <= 3
|
|
|
|
def test_different_problems_produce_results(self, predictor):
|
|
"""Both problem sizes should produce valid search results."""
|
|
searcher = SurrogateSearch(predictor, strategy="random", seed=42)
|
|
r1 = searcher.search(
|
|
{
|
|
"m": 16,
|
|
"n": 1536,
|
|
"k": 7168,
|
|
"dtype": "fp8",
|
|
"layout": "rcr",
|
|
"split_k": 1,
|
|
},
|
|
budget=50,
|
|
top_k=3,
|
|
)
|
|
searcher2 = SurrogateSearch(predictor, strategy="random", seed=42)
|
|
r2 = searcher2.search(
|
|
{
|
|
"m": 20480,
|
|
"n": 7168,
|
|
"k": 256,
|
|
"dtype": "fp8",
|
|
"layout": "rcr",
|
|
"split_k": 1,
|
|
},
|
|
budget=50,
|
|
top_k=3,
|
|
)
|
|
assert len(r1) > 0
|
|
assert len(r2) > 0
|
|
for _, score in r1 + r2:
|
|
assert np.isfinite(score)
|
|
|
|
def test_m1_corner_case(self, predictor):
|
|
searcher = SurrogateSearch(predictor, strategy="random")
|
|
results = searcher.search(
|
|
{
|
|
"m": 1,
|
|
"n": 4096,
|
|
"k": 4096,
|
|
"dtype": "fp8",
|
|
"layout": "rcr",
|
|
"split_k": 1,
|
|
},
|
|
budget=50,
|
|
top_k=5,
|
|
)
|
|
assert len(results) > 0
|
|
for _, score in results:
|
|
assert np.isfinite(score)
|
|
|
|
|
|
class TestDESearch:
|
|
def test_returns_results(self, predictor):
|
|
searcher = SurrogateSearch(predictor, strategy="de")
|
|
results = searcher.search(_problem(), budget=100, top_k=5)
|
|
assert len(results) > 0
|
|
|
|
def test_results_sorted_descending(self, predictor):
|
|
searcher = SurrogateSearch(predictor, strategy="de")
|
|
results = searcher.search(_problem(), budget=100, top_k=5)
|
|
scores = [s for _, s in results]
|
|
assert scores == sorted(scores, reverse=True)
|
|
|
|
def test_de_improves_over_initial(self, predictor):
|
|
"""DE should generally find at least as good as random initialization."""
|
|
searcher_r = SurrogateSearch(predictor, strategy="random", seed=42)
|
|
r_results = searcher_r.search(_problem(), budget=100, top_k=1)
|
|
searcher_d = SurrogateSearch(predictor, strategy="de", seed=42)
|
|
d_results = searcher_d.search(_problem(), budget=100, top_k=1)
|
|
if r_results and d_results:
|
|
assert d_results[0][1] >= r_results[0][1] * 0.9
|
|
|
|
def test_small_budget(self, predictor):
|
|
searcher = SurrogateSearch(predictor, strategy="de")
|
|
results = searcher.search(_problem(), budget=30, top_k=5)
|
|
assert len(results) > 0
|
|
|
|
|
|
class TestSearchEdgeCases:
|
|
def test_unknown_strategy_raises(self, predictor):
|
|
searcher = SurrogateSearch(predictor, strategy="unknown")
|
|
with pytest.raises(ValueError):
|
|
searcher.search(_problem(), budget=10)
|
|
|
|
def test_zero_budget(self, predictor):
|
|
searcher = SurrogateSearch(predictor, strategy="random")
|
|
results = searcher.search(_problem(), budget=0, top_k=5)
|
|
assert len(results) == 0
|
|
|
|
def test_deterministic_with_same_seed(self, predictor):
|
|
s1 = SurrogateSearch(predictor, strategy="random", seed=123)
|
|
s2 = SurrogateSearch(predictor, strategy="random", seed=123)
|
|
r1 = s1.search(_problem(), budget=50, top_k=5)
|
|
r2 = s2.search(_problem(), budget=50, top_k=5)
|
|
assert len(r1) == len(r2)
|
|
for (c1, s1_), (c2, s2_) in zip(r1, r2):
|
|
assert s1_ == pytest.approx(s2_)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|