mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
182 lines
5.0 KiB
Python
182 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Tests for predict.py.
|
|
|
|
Covers: Predictor initialization, single prediction, ranking, select_best,
|
|
missing model handling, and edge cases (single kernel, empty list).
|
|
"""
|
|
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import lightgbm as lgb
|
|
import numpy as np
|
|
import pytest
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
|
|
from feature_engine import GemmUniversalFeatureEngine
|
|
from predict import Predictor
|
|
|
|
|
|
@pytest.fixture
|
|
def model_dir(tmp_path):
|
|
"""Create a minimal trained model for testing."""
|
|
fe = GemmUniversalFeatureEngine()
|
|
n_features = len(fe.get_feature_names())
|
|
|
|
np.random.seed(42)
|
|
X = np.random.rand(200, n_features)
|
|
y = np.random.rand(200) * 100
|
|
|
|
model = lgb.LGBMRegressor(n_estimators=10, verbose=-1)
|
|
model.fit(X, y)
|
|
model.booster_.save_model(str(tmp_path / "model_tflops.lgbm"))
|
|
|
|
y_lat = np.random.rand(200) * 0.1
|
|
model_lat = lgb.LGBMRegressor(n_estimators=10, verbose=-1)
|
|
model_lat.fit(X, y_lat)
|
|
model_lat.booster_.save_model(str(tmp_path / "model_latency.lgbm"))
|
|
|
|
spec = {
|
|
"feature_names": fe.get_feature_names(),
|
|
"categorical_features": fe.get_categorical_features(),
|
|
}
|
|
with open(tmp_path / "feature_spec.json", "w") as f:
|
|
json.dump(spec, f)
|
|
|
|
return tmp_path
|
|
|
|
|
|
@pytest.fixture
|
|
def predictor(model_dir):
|
|
return Predictor(model_dir)
|
|
|
|
|
|
def _problem():
|
|
return {
|
|
"m": 1024,
|
|
"n": 1024,
|
|
"k": 1024,
|
|
"dtype": "fp8",
|
|
"layout": "rcr",
|
|
"split_k": 1,
|
|
}
|
|
|
|
|
|
def _kernel(tile_m=128, pipeline="compv3"):
|
|
return {
|
|
"kernel_name": f"test_kernel_{tile_m}_{pipeline}",
|
|
"tile_m": tile_m,
|
|
"tile_n": 128,
|
|
"tile_k": 64,
|
|
"warp_m": 2,
|
|
"warp_n": 2,
|
|
"warp_k": 1,
|
|
"warp_tile_m": 32,
|
|
"warp_tile_n": 32,
|
|
"warp_tile_k": 16,
|
|
"pipeline": pipeline,
|
|
"scheduler": "intrawave",
|
|
"epilogue": "cshuffle",
|
|
"pad_m": False,
|
|
"pad_n": False,
|
|
"pad_k": False,
|
|
"persistent": False,
|
|
}
|
|
|
|
|
|
class TestPredictor:
|
|
def test_predict_tflops_returns_float(self, predictor):
|
|
result = predictor.predict_tflops(_problem(), _kernel())
|
|
assert isinstance(result, float)
|
|
|
|
def test_predict_latency_returns_float(self, predictor):
|
|
result = predictor.predict_latency(_problem(), _kernel())
|
|
assert isinstance(result, float)
|
|
|
|
def test_predict_all_returns_dict(self, predictor):
|
|
result = predictor.predict_all(_problem(), _kernel())
|
|
assert "tflops" in result
|
|
assert "latency_ms" in result
|
|
|
|
def test_rank_kernels_sorted_descending(self, predictor):
|
|
kernels = [_kernel(64, "compv3"), _kernel(128, "compv4"), _kernel(256, "mem")]
|
|
ranked = predictor.rank_kernels(_problem(), kernels)
|
|
assert len(ranked) == 3
|
|
scores = [s for _, s in ranked]
|
|
assert scores == sorted(scores, reverse=True)
|
|
|
|
def test_select_best_returns_name(self, predictor):
|
|
kernels = [_kernel(64), _kernel(128)]
|
|
best = predictor.select_best(_problem(), kernels)
|
|
assert isinstance(best, str)
|
|
assert best in [k["kernel_name"] for k in kernels]
|
|
|
|
def test_single_kernel(self, predictor):
|
|
kernels = [_kernel(128)]
|
|
ranked = predictor.rank_kernels(_problem(), kernels)
|
|
assert len(ranked) == 1
|
|
|
|
def test_missing_bandwidth_model(self, model_dir):
|
|
pred = Predictor(model_dir)
|
|
with pytest.raises(FileNotFoundError):
|
|
pred.predict_bandwidth(_problem(), _kernel())
|
|
|
|
def test_empty_kernel_list(self, predictor):
|
|
with pytest.raises(ValueError):
|
|
predictor.select_best(_problem(), [])
|
|
|
|
def test_corner_case_m1(self, predictor):
|
|
prob = {
|
|
"m": 1,
|
|
"n": 4096,
|
|
"k": 4096,
|
|
"dtype": "fp8",
|
|
"layout": "rcr",
|
|
"split_k": 1,
|
|
}
|
|
result = predictor.predict_tflops(prob, _kernel())
|
|
assert np.isfinite(result)
|
|
|
|
def test_different_shapes_give_different_results(self, predictor):
|
|
k = _kernel()
|
|
r1 = predictor.predict_tflops(
|
|
{
|
|
"m": 16,
|
|
"n": 1536,
|
|
"k": 7168,
|
|
"dtype": "fp8",
|
|
"layout": "rcr",
|
|
"split_k": 1,
|
|
},
|
|
k,
|
|
)
|
|
r2 = predictor.predict_tflops(
|
|
{
|
|
"m": 20480,
|
|
"n": 7168,
|
|
"k": 256,
|
|
"dtype": "fp8",
|
|
"layout": "rcr",
|
|
"split_k": 1,
|
|
},
|
|
k,
|
|
)
|
|
assert r1 != r2
|
|
|
|
|
|
class TestPredictorEdgeCases:
|
|
def test_nonexistent_model_dir(self):
|
|
with pytest.raises(Exception):
|
|
pred = Predictor("/nonexistent/path")
|
|
pred.predict_tflops(_problem(), _kernel())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|