Files
composable_kernel/dispatcher/heuristics/tests/test_predict.py
Yaswanth Raparti c1127a36f5 [rocm-libraries] ROCm/rocm-libraries#5676 (commit 1d18339)
[CK][CK TILE]Autotuning heuristics infra for universal GEMM
 kernel selection (#5676)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

This PR adds ML-based kernel selection heuristics to the CK Tile
dispatcher, enabling fast and accurate automatic kernel selection for
Universal Gemm kernels. Instead of requiring exhaustive search through
4600+ kernel configurations (taking ~46 seconds per problem shape), the
ML heuristic predicts optimal kernels in microseconds while achieving
>98% of oracle-best performance.

## Technical Details

**ML infrastructure**

https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics
* Feature Engine
([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)):
55-feature extraction including problem dimensions, kernel
configuration, tile efficiency, and hardware profile
* Training Pipeline
([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)):
LightGBM regression with log-transform, GroupKFold cross-validation,
warm-start support
* Predictor
([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)):
Kernel ranking and TFLOPS prediction for problem shapes
* Evaluation
([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)):
Comprehensive metrics including efficiency, NDCG@k, shape family
analysis

**Data Generation Tools:**

*
[generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py):
Build and benchmark kernels across diverse problem shapes
*
[convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py):
Convert benchmark JSON to training-ready parquet format
*
[data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py):
Parse streaming benchmark logs into canonical datasets

**Examples**
*
[09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp):
C++ example demonstrating ML-based kernel selection
*
[09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py):
Python example with validation

**Pre-trained Models
(projects/composablekernel/dispatcher/heuristics/models/):**
* gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean
efficiency)
* gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean
efficiency)

## Test Plan

* Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8
* All shape families tested: tiny M (M<8), small M, medium M, large M
(M≥1024)
* All pipeline types: compv3, compv4, mem

## Test Result

**fp16 Model (gfx950, RCR layout)**
* Mean Efficiency: 99.36%
* P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of
oracle best)
* Min Efficiency: 95.45%

**fp8 Model (gfx950, RCR layout)**
* Mean Efficiency: 98.28% (original), 97.51% (wide coverage)
* P10 Efficiency: 94.64% (original), 93.89% (wide coverage)
* Min Efficiency: 84.5%

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-02 02:26:32 +00:00

182 lines
5.0 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Tests for predict.py.
Covers: Predictor initialization, single prediction, ranking, select_best,
missing model handling, and edge cases (single kernel, empty list).
"""
import json
import sys
from pathlib import Path
import lightgbm as lgb
import numpy as np
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from feature_engine import GemmUniversalFeatureEngine
from predict import Predictor
@pytest.fixture
def model_dir(tmp_path):
"""Create a minimal trained model for testing."""
fe = GemmUniversalFeatureEngine()
n_features = len(fe.get_feature_names())
np.random.seed(42)
X = np.random.rand(200, n_features)
y = np.random.rand(200) * 100
model = lgb.LGBMRegressor(n_estimators=10, verbose=-1)
model.fit(X, y)
model.booster_.save_model(str(tmp_path / "model_tflops.lgbm"))
y_lat = np.random.rand(200) * 0.1
model_lat = lgb.LGBMRegressor(n_estimators=10, verbose=-1)
model_lat.fit(X, y_lat)
model_lat.booster_.save_model(str(tmp_path / "model_latency.lgbm"))
spec = {
"feature_names": fe.get_feature_names(),
"categorical_features": fe.get_categorical_features(),
}
with open(tmp_path / "feature_spec.json", "w") as f:
json.dump(spec, f)
return tmp_path
@pytest.fixture
def predictor(model_dir):
return Predictor(model_dir)
def _problem():
return {
"m": 1024,
"n": 1024,
"k": 1024,
"dtype": "fp8",
"layout": "rcr",
"split_k": 1,
}
def _kernel(tile_m=128, pipeline="compv3"):
return {
"kernel_name": f"test_kernel_{tile_m}_{pipeline}",
"tile_m": tile_m,
"tile_n": 128,
"tile_k": 64,
"warp_m": 2,
"warp_n": 2,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 16,
"pipeline": pipeline,
"scheduler": "intrawave",
"epilogue": "cshuffle",
"pad_m": False,
"pad_n": False,
"pad_k": False,
"persistent": False,
}
class TestPredictor:
def test_predict_tflops_returns_float(self, predictor):
result = predictor.predict_tflops(_problem(), _kernel())
assert isinstance(result, float)
def test_predict_latency_returns_float(self, predictor):
result = predictor.predict_latency(_problem(), _kernel())
assert isinstance(result, float)
def test_predict_all_returns_dict(self, predictor):
result = predictor.predict_all(_problem(), _kernel())
assert "tflops" in result
assert "latency_ms" in result
def test_rank_kernels_sorted_descending(self, predictor):
kernels = [_kernel(64, "compv3"), _kernel(128, "compv4"), _kernel(256, "mem")]
ranked = predictor.rank_kernels(_problem(), kernels)
assert len(ranked) == 3
scores = [s for _, s in ranked]
assert scores == sorted(scores, reverse=True)
def test_select_best_returns_name(self, predictor):
kernels = [_kernel(64), _kernel(128)]
best = predictor.select_best(_problem(), kernels)
assert isinstance(best, str)
assert best in [k["kernel_name"] for k in kernels]
def test_single_kernel(self, predictor):
kernels = [_kernel(128)]
ranked = predictor.rank_kernels(_problem(), kernels)
assert len(ranked) == 1
def test_missing_bandwidth_model(self, model_dir):
pred = Predictor(model_dir)
with pytest.raises(FileNotFoundError):
pred.predict_bandwidth(_problem(), _kernel())
def test_empty_kernel_list(self, predictor):
with pytest.raises(ValueError):
predictor.select_best(_problem(), [])
def test_corner_case_m1(self, predictor):
prob = {
"m": 1,
"n": 4096,
"k": 4096,
"dtype": "fp8",
"layout": "rcr",
"split_k": 1,
}
result = predictor.predict_tflops(prob, _kernel())
assert np.isfinite(result)
def test_different_shapes_give_different_results(self, predictor):
k = _kernel()
r1 = predictor.predict_tflops(
{
"m": 16,
"n": 1536,
"k": 7168,
"dtype": "fp8",
"layout": "rcr",
"split_k": 1,
},
k,
)
r2 = predictor.predict_tflops(
{
"m": 20480,
"n": 7168,
"k": 256,
"dtype": "fp8",
"layout": "rcr",
"split_k": 1,
},
k,
)
assert r1 != r2
class TestPredictorEdgeCases:
def test_nonexistent_model_dir(self):
with pytest.raises(Exception):
pred = Predictor("/nonexistent/path")
pred.predict_tflops(_problem(), _kernel())
if __name__ == "__main__":
pytest.main([__file__, "-v"])