Files
Yaswanth Raparti 6989cf800c [rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for
 grouped conv (#6327)

## Motivation
The ML heuristic in dispatcher does not support grouped-conv operator
yet. In this PR, the support for fwd, bdw-data, and bwd-weight
grouped-conv kernels have been added. A tile_engine utility has also
been added to compile and run any selected kernel configuration through
dispatcher infrastructure.

## Technical Details

1. Tile engine utility is added to benchmark each shape with all the
possible kernel+tile_size combinations here -
[https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url)
2. New LGBM regressor models for grouped conv are added to models
directory. We have 3 separate models for fwd, bwd-data, and bwd-weights
[https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url)
3. Implemented lazy GPU initialization (dispatcher/python)
- **Issue**: ProcessPoolExecutor fork() + GPU context caused memory
access faults
- **Solution**: Mirror FMHA pattern - defer GPU initialization until
first run()
  - **Changes**:
- setup_multiple_grouped_conv_dispatchers() returns List[Path], not
loaded libs
    - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL
    - Added _ensure_initialized() method for lazy GPU loading
    - GPU context created only on first run() call
  - **Benefit**: Parallel compilation now works without GPU conflicts
4. Addressed few miscellaneous issues such as:
  - Fixed BF16->FP16 naming bug in the dispatcher wrapper
- Added new tile sizes, and comp_v5 pipeline to the arch spec to expand
the kernel selection
- Added automatic padding support for unsupported shapes in dispatcher
runner
- Created a single source of truth between tile_engine and dispatcher
about the architecture and tile_size details
- Build a validation scripts to compare oracle_best vs ml_heuristic
comparison

## Test Plan

1. Validated fwd, bwd-data, and bwd-weight kernels with both known and
unseen data sets with up to 300 problems.
2. Ensured that test cases are added in both dispatcher and tile_engine
to validate the heuristic.

## Test Result
Results on Unseen shapes validated on gfx950
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.05%**
  - Median Efficiency: **96.8%**
  - P10 Efficiency: **79.9%**

#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.8%**
  - Median Efficiency: **96.5%**
  - P10 Efficiency: **82.9%**

#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **96.1%**
  - Median Efficiency: **99.2%**
  - P10 Efficiency: **89.4%**

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-08 20:48:42 +00:00

330 lines
12 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Tests for train.py.
Covers: group key computation, TFLOPS efficiency calculation, edge cases
(single group, all-invalid data, tied predictions), and warm-start
incremental training (feature compat, lineage, quality).
"""
import json
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from feature_engine import GemmUniversalFeatureEngine
from train import (
compute_group_keys,
compute_tflops_efficiency,
check_feature_compatibility,
load_warm_start_model,
train_final_model,
DEFAULT_PARAMS,
)
class TestComputeGroupKeys:
def test_basic(self):
df = pd.DataFrame(
{"m": [16, 16, 32], "n": [1536, 1536, 1536], "k": [7168, 7168, 7168]}
)
keys = compute_group_keys(df, "gemm_universal")
assert keys[0] == keys[1]
assert keys[0] != keys[2]
def test_unique_shapes(self):
df = pd.DataFrame({"m": [1, 2, 3], "n": [4, 5, 6], "k": [7, 8, 9]})
keys = compute_group_keys(df, "gemm_universal")
assert len(set(keys)) == 3
class TestComputeTflopsEfficiency:
def test_perfect_prediction(self):
"""Model predicts highest TFLOPS kernel => efficiency = 1.0."""
df = pd.DataFrame(
{
"m": [1024, 1024, 1024],
"n": [1024, 1024, 1024],
"k": [1024, 1024, 1024],
"measured_tflops": [100, 200, 150],
"pred_tflops": [50, 300, 100], # correctly ranks kernel 1 highest
}
)
eff = compute_tflops_efficiency(df, "gemm_universal", "pred_tflops")
assert len(eff) == 1
assert eff["efficiency"].iloc[0] == pytest.approx(1.0)
def test_worst_prediction(self):
"""Model picks the worst kernel."""
df = pd.DataFrame(
{
"m": [1024, 1024, 1024],
"n": [1024, 1024, 1024],
"k": [1024, 1024, 1024],
"measured_tflops": [100, 200, 150],
"pred_tflops": [999, 1, 1], # incorrectly ranks kernel 0 highest
}
)
eff = compute_tflops_efficiency(df, "gemm_universal", "pred_tflops")
assert eff["efficiency"].iloc[0] == pytest.approx(100 / 200)
def test_multiple_shapes(self):
df = pd.DataFrame(
{
"m": [16, 16, 32, 32],
"n": [1536, 1536, 1536, 1536],
"k": [7168, 7168, 7168, 7168],
"measured_tflops": [10, 20, 100, 200],
"pred_tflops": [5, 25, 150, 190],
}
)
eff = compute_tflops_efficiency(df, "gemm_universal", "pred_tflops")
assert len(eff) == 2
assert eff.iloc[0]["efficiency"] == pytest.approx(1.0)
assert eff.iloc[1]["efficiency"] == pytest.approx(1.0)
def test_zero_tflops_shape_skipped(self):
df = pd.DataFrame(
{
"m": [16, 16],
"n": [16, 16],
"k": [16, 16],
"measured_tflops": [0, 0],
"pred_tflops": [1, 2],
}
)
eff = compute_tflops_efficiency(df, "gemm_universal", "pred_tflops")
assert len(eff) == 0
def test_single_kernel_per_shape(self):
df = pd.DataFrame(
{
"m": [1024],
"n": [1024],
"k": [1024],
"measured_tflops": [150],
"pred_tflops": [100],
}
)
eff = compute_tflops_efficiency(df, "gemm_universal", "pred_tflops")
assert len(eff) == 1
assert eff["efficiency"].iloc[0] == pytest.approx(1.0)
def test_tied_predictions(self):
"""When multiple kernels have the same predicted TFLOPS, pandas idxmax picks the first."""
df = pd.DataFrame(
{
"m": [1024, 1024, 1024],
"n": [1024, 1024, 1024],
"k": [1024, 1024, 1024],
"measured_tflops": [100, 200, 200],
"pred_tflops": [50, 50, 50],
}
)
eff = compute_tflops_efficiency(df, "gemm_universal", "pred_tflops")
assert len(eff) == 1
assert eff["efficiency"].iloc[0] >= 0.5
# ---------------------------------------------------------------------------
# Helpers for warm-start tests
# ---------------------------------------------------------------------------
def _make_dummy_data(n_rows=200, n_shapes=5):
"""Create a small synthetic benchmark DataFrame for testing training."""
rng = np.random.RandomState(42)
rows = []
for _ in range(n_rows):
m = rng.choice([64, 128, 256, 512, 1024])
n = rng.choice([64, 128, 256, 512, 1024])
k = rng.choice([64, 128, 256, 512, 1024])
rows.append(
{
"m": m,
"n": n,
"k": k,
"split_k": 1,
"dtype": "fp8",
"layout": "rcr",
"op_type": "gemm_universal",
"tile_m": rng.choice([64, 128, 256]),
"tile_n": rng.choice([64, 128, 256]),
"tile_k": rng.choice([32, 64, 128]),
"warp_m": rng.choice([1, 2, 4]),
"warp_n": rng.choice([1, 2, 4]),
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 16,
"pipeline": rng.choice(["compv3", "compv4", "mem"]),
"scheduler": rng.choice(["intrawave", "interwave"]),
"epilogue": "cshuffle",
"pad_m": False,
"pad_n": False,
"pad_k": False,
"persistent": False,
"measured_tflops": float(rng.uniform(10, 500)),
"latency_ms": float(rng.uniform(0.01, 1.0)),
"bandwidth_gb_s": float(rng.uniform(50, 1500)),
"is_valid": True,
"kernel_name": f"test_kernel_{rng.randint(0, 100)}",
}
)
return pd.DataFrame(rows)
def _save_feature_spec(model_dir, fe):
"""Save a feature_spec.json matching the given feature engine."""
spec = {
"feature_names": fe.get_feature_names(),
"categorical_features": fe.get_categorical_features(),
}
with open(model_dir / "feature_spec.json", "w") as f:
json.dump(spec, f)
def _train_and_save_base_model(model_dir, df, fe, target="tflops"):
"""Train a small base model and save it to model_dir."""
params = dict(DEFAULT_PARAMS)
params["n_estimators"] = 20
params["n_jobs"] = 1
model = train_final_model(df, fe, target, params, "gemm_universal")
model.booster_.save_model(str(model_dir / f"model_{target}.lgbm"))
_save_feature_spec(model_dir, fe)
return model
# ---------------------------------------------------------------------------
# Warm-start tests
# ---------------------------------------------------------------------------
class TestCheckFeatureCompatibility:
def test_compatible_passes(self, tmp_path):
fe = GemmUniversalFeatureEngine()
_save_feature_spec(tmp_path, fe)
check_feature_compatibility(tmp_path, fe)
def test_missing_spec_raises(self, tmp_path):
fe = GemmUniversalFeatureEngine()
with pytest.raises(FileNotFoundError, match="feature_spec.json"):
check_feature_compatibility(tmp_path, fe)
def test_added_feature_raises(self, tmp_path):
fe = GemmUniversalFeatureEngine()
spec = {
"feature_names": fe.get_feature_names()[:-1],
"categorical_features": fe.get_categorical_features(),
}
with open(tmp_path / "feature_spec.json", "w") as f:
json.dump(spec, f)
with pytest.raises(ValueError, match="Feature schema mismatch"):
check_feature_compatibility(tmp_path, fe)
def test_removed_feature_raises(self, tmp_path):
fe = GemmUniversalFeatureEngine()
spec = {
"feature_names": fe.get_feature_names() + ["extra_feature"],
"categorical_features": fe.get_categorical_features(),
}
with open(tmp_path / "feature_spec.json", "w") as f:
json.dump(spec, f)
with pytest.raises(ValueError, match="Feature schema mismatch"):
check_feature_compatibility(tmp_path, fe)
def test_categorical_mismatch_raises(self, tmp_path):
fe = GemmUniversalFeatureEngine()
spec = {
"feature_names": fe.get_feature_names(),
"categorical_features": ["layout", "pipeline"],
}
with open(tmp_path / "feature_spec.json", "w") as f:
json.dump(spec, f)
with pytest.raises(ValueError, match="Categorical feature mismatch"):
check_feature_compatibility(tmp_path, fe)
class TestLoadWarmStartModel:
def test_loads_existing_model(self, tmp_path):
fe = GemmUniversalFeatureEngine()
df = _make_dummy_data()
_train_and_save_base_model(tmp_path, df, fe)
path = load_warm_start_model(tmp_path, "tflops")
assert path is not None
assert Path(path).exists()
def test_returns_none_for_missing_target(self, tmp_path):
assert load_warm_start_model(tmp_path, "tflops") is None
def test_returns_none_for_wrong_target(self, tmp_path):
fe = GemmUniversalFeatureEngine()
df = _make_dummy_data()
_train_and_save_base_model(tmp_path, df, fe, target="tflops")
assert load_warm_start_model(tmp_path, "bandwidth") is None
class TestWarmStartTraining:
def test_warm_start_produces_more_trees(self, tmp_path):
"""A warm-started model should have more trees than the base."""
fe = GemmUniversalFeatureEngine()
df = _make_dummy_data(n_rows=300)
base_dir = tmp_path / "base"
base_dir.mkdir()
base_model = _train_and_save_base_model(base_dir, df, fe)
base_n_trees = base_model.booster_.num_trees()
init_model_path = load_warm_start_model(base_dir, "tflops")
params = dict(DEFAULT_PARAMS)
params["n_estimators"] = 15
params["n_jobs"] = 1
warm_model = train_final_model(
df, fe, "tflops", params, "gemm_universal", init_model=init_model_path
)
warm_n_trees = warm_model.booster_.num_trees()
assert warm_n_trees > base_n_trees
def test_warm_start_does_not_degrade(self, tmp_path):
"""Warm-started model on the same data should not be significantly worse."""
fe = GemmUniversalFeatureEngine()
df = _make_dummy_data(n_rows=300)
base_dir = tmp_path / "base"
base_dir.mkdir()
base_model = _train_and_save_base_model(base_dir, df, fe)
X = fe.extract_batch(df[df["is_valid"]].reset_index(drop=True))
y = df[df["is_valid"]]["measured_tflops"].values
base_rmse = np.sqrt(np.mean((base_model.predict(X) - y) ** 2))
init_model_path = load_warm_start_model(base_dir, "tflops")
params = dict(DEFAULT_PARAMS)
params["n_estimators"] = 15
params["n_jobs"] = 1
warm_model = train_final_model(
df, fe, "tflops", params, "gemm_universal", init_model=init_model_path
)
warm_rmse = np.sqrt(np.mean((warm_model.predict(X) - y) ** 2))
assert warm_rmse <= base_rmse * 1.1
def test_warm_start_from_nonexistent_dir(self):
with pytest.raises(FileNotFoundError):
check_feature_compatibility(
Path("/nonexistent/model/dir"), GemmUniversalFeatureEngine()
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])