Files
composable_kernel/dispatcher/heuristics/tests/test_train.py
Yaswanth Raparti c1127a36f5 [rocm-libraries] ROCm/rocm-libraries#5676 (commit 1d18339)
[CK][CK TILE]Autotuning heuristics infra for universal GEMM
 kernel selection (#5676)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

This PR adds ML-based kernel selection heuristics to the CK Tile
dispatcher, enabling fast and accurate automatic kernel selection for
Universal Gemm kernels. Instead of requiring exhaustive search through
4600+ kernel configurations (taking ~46 seconds per problem shape), the
ML heuristic predicts optimal kernels in microseconds while achieving
>98% of oracle-best performance.

## Technical Details

**ML infrastructure**

https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics
* Feature Engine
([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)):
55-feature extraction including problem dimensions, kernel
configuration, tile efficiency, and hardware profile
* Training Pipeline
([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)):
LightGBM regression with log-transform, GroupKFold cross-validation,
warm-start support
* Predictor
([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)):
Kernel ranking and TFLOPS prediction for problem shapes
* Evaluation
([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)):
Comprehensive metrics including efficiency, NDCG@k, shape family
analysis

**Data Generation Tools:**

*
[generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py):
Build and benchmark kernels across diverse problem shapes
*
[convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py):
Convert benchmark JSON to training-ready parquet format
*
[data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py):
Parse streaming benchmark logs into canonical datasets

**Examples**
*
[09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp):
C++ example demonstrating ML-based kernel selection
*
[09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py):
Python example with validation

**Pre-trained Models
(projects/composablekernel/dispatcher/heuristics/models/):**
* gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean
efficiency)
* gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean
efficiency)

## Test Plan

* Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8
* All shape families tested: tiny M (M<8), small M, medium M, large M
(M≥1024)
* All pipeline types: compv3, compv4, mem

## Test Result

**fp16 Model (gfx950, RCR layout)**
* Mean Efficiency: 99.36%
* P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of
oracle best)
* Min Efficiency: 95.45%

**fp8 Model (gfx950, RCR layout)**
* Mean Efficiency: 98.28% (original), 97.51% (wide coverage)
* P10 Efficiency: 94.64% (original), 93.89% (wide coverage)
* Min Efficiency: 84.5%

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-02 02:26:32 +00:00

330 lines
11 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Tests for train.py.
Covers: group key computation, TFLOPS efficiency calculation, edge cases
(single group, all-invalid data, tied predictions), and warm-start
incremental training (feature compat, lineage, quality).
"""
import json
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from feature_engine import GemmUniversalFeatureEngine
from train import (
compute_group_keys,
compute_tflops_efficiency,
check_feature_compatibility,
load_warm_start_model,
train_final_model,
DEFAULT_PARAMS,
)
class TestComputeGroupKeys:
def test_basic(self):
df = pd.DataFrame(
{"m": [16, 16, 32], "n": [1536, 1536, 1536], "k": [7168, 7168, 7168]}
)
keys = compute_group_keys(df)
assert keys[0] == keys[1]
assert keys[0] != keys[2]
def test_unique_shapes(self):
df = pd.DataFrame({"m": [1, 2, 3], "n": [4, 5, 6], "k": [7, 8, 9]})
keys = compute_group_keys(df)
assert len(set(keys)) == 3
class TestComputeTflopsEfficiency:
def test_perfect_prediction(self):
"""Model predicts highest TFLOPS kernel => efficiency = 1.0."""
df = pd.DataFrame(
{
"m": [1024, 1024, 1024],
"n": [1024, 1024, 1024],
"k": [1024, 1024, 1024],
"measured_tflops": [100, 200, 150],
"pred_tflops": [50, 300, 100], # correctly ranks kernel 1 highest
}
)
eff = compute_tflops_efficiency(df, "pred_tflops")
assert len(eff) == 1
assert eff["efficiency"].iloc[0] == pytest.approx(1.0)
def test_worst_prediction(self):
"""Model picks the worst kernel."""
df = pd.DataFrame(
{
"m": [1024, 1024, 1024],
"n": [1024, 1024, 1024],
"k": [1024, 1024, 1024],
"measured_tflops": [100, 200, 150],
"pred_tflops": [999, 1, 1], # incorrectly ranks kernel 0 highest
}
)
eff = compute_tflops_efficiency(df, "pred_tflops")
assert eff["efficiency"].iloc[0] == pytest.approx(100 / 200)
def test_multiple_shapes(self):
df = pd.DataFrame(
{
"m": [16, 16, 32, 32],
"n": [1536, 1536, 1536, 1536],
"k": [7168, 7168, 7168, 7168],
"measured_tflops": [10, 20, 100, 200],
"pred_tflops": [5, 25, 150, 190],
}
)
eff = compute_tflops_efficiency(df, "pred_tflops")
assert len(eff) == 2
assert eff.iloc[0]["efficiency"] == pytest.approx(1.0)
assert eff.iloc[1]["efficiency"] == pytest.approx(1.0)
def test_zero_tflops_shape_skipped(self):
df = pd.DataFrame(
{
"m": [16, 16],
"n": [16, 16],
"k": [16, 16],
"measured_tflops": [0, 0],
"pred_tflops": [1, 2],
}
)
eff = compute_tflops_efficiency(df, "pred_tflops")
assert len(eff) == 0
def test_single_kernel_per_shape(self):
df = pd.DataFrame(
{
"m": [1024],
"n": [1024],
"k": [1024],
"measured_tflops": [150],
"pred_tflops": [100],
}
)
eff = compute_tflops_efficiency(df, "pred_tflops")
assert len(eff) == 1
assert eff["efficiency"].iloc[0] == pytest.approx(1.0)
def test_tied_predictions(self):
"""When multiple kernels have the same predicted TFLOPS, pandas idxmax picks the first."""
df = pd.DataFrame(
{
"m": [1024, 1024, 1024],
"n": [1024, 1024, 1024],
"k": [1024, 1024, 1024],
"measured_tflops": [100, 200, 200],
"pred_tflops": [50, 50, 50],
}
)
eff = compute_tflops_efficiency(df, "pred_tflops")
assert len(eff) == 1
assert eff["efficiency"].iloc[0] >= 0.5
# ---------------------------------------------------------------------------
# Helpers for warm-start tests
# ---------------------------------------------------------------------------
def _make_dummy_data(n_rows=200, n_shapes=5):
"""Create a small synthetic benchmark DataFrame for testing training."""
rng = np.random.RandomState(42)
rows = []
for _ in range(n_rows):
m = rng.choice([64, 128, 256, 512, 1024])
n = rng.choice([64, 128, 256, 512, 1024])
k = rng.choice([64, 128, 256, 512, 1024])
rows.append(
{
"m": m,
"n": n,
"k": k,
"split_k": 1,
"dtype": "fp8",
"layout": "rcr",
"op_type": "gemm_universal",
"tile_m": rng.choice([64, 128, 256]),
"tile_n": rng.choice([64, 128, 256]),
"tile_k": rng.choice([32, 64, 128]),
"warp_m": rng.choice([1, 2, 4]),
"warp_n": rng.choice([1, 2, 4]),
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 16,
"pipeline": rng.choice(["compv3", "compv4", "mem"]),
"scheduler": rng.choice(["intrawave", "interwave"]),
"epilogue": "cshuffle",
"pad_m": False,
"pad_n": False,
"pad_k": False,
"persistent": False,
"measured_tflops": float(rng.uniform(10, 500)),
"latency_ms": float(rng.uniform(0.01, 1.0)),
"bandwidth_gb_s": float(rng.uniform(50, 1500)),
"is_valid": True,
"kernel_name": f"test_kernel_{rng.randint(0, 100)}",
}
)
return pd.DataFrame(rows)
def _save_feature_spec(model_dir, fe):
"""Save a feature_spec.json matching the given feature engine."""
spec = {
"feature_names": fe.get_feature_names(),
"categorical_features": fe.get_categorical_features(),
}
with open(model_dir / "feature_spec.json", "w") as f:
json.dump(spec, f)
def _train_and_save_base_model(model_dir, df, fe, target="tflops"):
"""Train a small base model and save it to model_dir."""
params = dict(DEFAULT_PARAMS)
params["n_estimators"] = 20
params["n_jobs"] = 1
model = train_final_model(df, fe, target, params)
model.booster_.save_model(str(model_dir / f"model_{target}.lgbm"))
_save_feature_spec(model_dir, fe)
return model
# ---------------------------------------------------------------------------
# Warm-start tests
# ---------------------------------------------------------------------------
class TestCheckFeatureCompatibility:
def test_compatible_passes(self, tmp_path):
fe = GemmUniversalFeatureEngine()
_save_feature_spec(tmp_path, fe)
check_feature_compatibility(tmp_path, fe)
def test_missing_spec_raises(self, tmp_path):
fe = GemmUniversalFeatureEngine()
with pytest.raises(FileNotFoundError, match="feature_spec.json"):
check_feature_compatibility(tmp_path, fe)
def test_added_feature_raises(self, tmp_path):
fe = GemmUniversalFeatureEngine()
spec = {
"feature_names": fe.get_feature_names()[:-1],
"categorical_features": fe.get_categorical_features(),
}
with open(tmp_path / "feature_spec.json", "w") as f:
json.dump(spec, f)
with pytest.raises(ValueError, match="Feature schema mismatch"):
check_feature_compatibility(tmp_path, fe)
def test_removed_feature_raises(self, tmp_path):
fe = GemmUniversalFeatureEngine()
spec = {
"feature_names": fe.get_feature_names() + ["extra_feature"],
"categorical_features": fe.get_categorical_features(),
}
with open(tmp_path / "feature_spec.json", "w") as f:
json.dump(spec, f)
with pytest.raises(ValueError, match="Feature schema mismatch"):
check_feature_compatibility(tmp_path, fe)
def test_categorical_mismatch_raises(self, tmp_path):
fe = GemmUniversalFeatureEngine()
spec = {
"feature_names": fe.get_feature_names(),
"categorical_features": ["layout", "pipeline"],
}
with open(tmp_path / "feature_spec.json", "w") as f:
json.dump(spec, f)
with pytest.raises(ValueError, match="Categorical feature mismatch"):
check_feature_compatibility(tmp_path, fe)
class TestLoadWarmStartModel:
def test_loads_existing_model(self, tmp_path):
fe = GemmUniversalFeatureEngine()
df = _make_dummy_data()
_train_and_save_base_model(tmp_path, df, fe)
path = load_warm_start_model(tmp_path, "tflops")
assert path is not None
assert Path(path).exists()
def test_returns_none_for_missing_target(self, tmp_path):
assert load_warm_start_model(tmp_path, "tflops") is None
def test_returns_none_for_wrong_target(self, tmp_path):
fe = GemmUniversalFeatureEngine()
df = _make_dummy_data()
_train_and_save_base_model(tmp_path, df, fe, target="tflops")
assert load_warm_start_model(tmp_path, "bandwidth") is None
class TestWarmStartTraining:
def test_warm_start_produces_more_trees(self, tmp_path):
"""A warm-started model should have more trees than the base."""
fe = GemmUniversalFeatureEngine()
df = _make_dummy_data(n_rows=300)
base_dir = tmp_path / "base"
base_dir.mkdir()
base_model = _train_and_save_base_model(base_dir, df, fe)
base_n_trees = base_model.booster_.num_trees()
init_model_path = load_warm_start_model(base_dir, "tflops")
params = dict(DEFAULT_PARAMS)
params["n_estimators"] = 15
params["n_jobs"] = 1
warm_model = train_final_model(
df, fe, "tflops", params, init_model=init_model_path
)
warm_n_trees = warm_model.booster_.num_trees()
assert warm_n_trees > base_n_trees
def test_warm_start_does_not_degrade(self, tmp_path):
"""Warm-started model on the same data should not be significantly worse."""
fe = GemmUniversalFeatureEngine()
df = _make_dummy_data(n_rows=300)
base_dir = tmp_path / "base"
base_dir.mkdir()
base_model = _train_and_save_base_model(base_dir, df, fe)
X = fe.extract_batch(df[df["is_valid"]].reset_index(drop=True))
y = df[df["is_valid"]]["measured_tflops"].values
base_rmse = np.sqrt(np.mean((base_model.predict(X) - y) ** 2))
init_model_path = load_warm_start_model(base_dir, "tflops")
params = dict(DEFAULT_PARAMS)
params["n_estimators"] = 15
params["n_jobs"] = 1
warm_model = train_final_model(
df, fe, "tflops", params, init_model=init_model_path
)
warm_rmse = np.sqrt(np.mean((warm_model.predict(X) - y) ** 2))
assert warm_rmse <= base_rmse * 1.1
def test_warm_start_from_nonexistent_dir(self):
with pytest.raises(FileNotFoundError):
check_feature_compatibility(
Path("/nonexistent/model/dir"), GemmUniversalFeatureEngine()
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])