mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
330 lines
12 KiB
Python
330 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Tests for train.py.
|
|
|
|
Covers: group key computation, TFLOPS efficiency calculation, edge cases
|
|
(single group, all-invalid data, tied predictions), and warm-start
|
|
incremental training (feature compat, lineage, quality).
|
|
"""
|
|
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pytest
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
|
|
from feature_engine import GemmUniversalFeatureEngine
|
|
from train import (
|
|
compute_group_keys,
|
|
compute_tflops_efficiency,
|
|
check_feature_compatibility,
|
|
load_warm_start_model,
|
|
train_final_model,
|
|
DEFAULT_PARAMS,
|
|
)
|
|
|
|
|
|
class TestComputeGroupKeys:
|
|
def test_basic(self):
|
|
df = pd.DataFrame(
|
|
{"m": [16, 16, 32], "n": [1536, 1536, 1536], "k": [7168, 7168, 7168]}
|
|
)
|
|
keys = compute_group_keys(df, "gemm_universal")
|
|
assert keys[0] == keys[1]
|
|
assert keys[0] != keys[2]
|
|
|
|
def test_unique_shapes(self):
|
|
df = pd.DataFrame({"m": [1, 2, 3], "n": [4, 5, 6], "k": [7, 8, 9]})
|
|
keys = compute_group_keys(df, "gemm_universal")
|
|
assert len(set(keys)) == 3
|
|
|
|
|
|
class TestComputeTflopsEfficiency:
|
|
def test_perfect_prediction(self):
|
|
"""Model predicts highest TFLOPS kernel => efficiency = 1.0."""
|
|
df = pd.DataFrame(
|
|
{
|
|
"m": [1024, 1024, 1024],
|
|
"n": [1024, 1024, 1024],
|
|
"k": [1024, 1024, 1024],
|
|
"measured_tflops": [100, 200, 150],
|
|
"pred_tflops": [50, 300, 100], # correctly ranks kernel 1 highest
|
|
}
|
|
)
|
|
eff = compute_tflops_efficiency(df, "gemm_universal", "pred_tflops")
|
|
assert len(eff) == 1
|
|
assert eff["efficiency"].iloc[0] == pytest.approx(1.0)
|
|
|
|
def test_worst_prediction(self):
|
|
"""Model picks the worst kernel."""
|
|
df = pd.DataFrame(
|
|
{
|
|
"m": [1024, 1024, 1024],
|
|
"n": [1024, 1024, 1024],
|
|
"k": [1024, 1024, 1024],
|
|
"measured_tflops": [100, 200, 150],
|
|
"pred_tflops": [999, 1, 1], # incorrectly ranks kernel 0 highest
|
|
}
|
|
)
|
|
eff = compute_tflops_efficiency(df, "gemm_universal", "pred_tflops")
|
|
assert eff["efficiency"].iloc[0] == pytest.approx(100 / 200)
|
|
|
|
def test_multiple_shapes(self):
|
|
df = pd.DataFrame(
|
|
{
|
|
"m": [16, 16, 32, 32],
|
|
"n": [1536, 1536, 1536, 1536],
|
|
"k": [7168, 7168, 7168, 7168],
|
|
"measured_tflops": [10, 20, 100, 200],
|
|
"pred_tflops": [5, 25, 150, 190],
|
|
}
|
|
)
|
|
eff = compute_tflops_efficiency(df, "gemm_universal", "pred_tflops")
|
|
assert len(eff) == 2
|
|
assert eff.iloc[0]["efficiency"] == pytest.approx(1.0)
|
|
assert eff.iloc[1]["efficiency"] == pytest.approx(1.0)
|
|
|
|
def test_zero_tflops_shape_skipped(self):
|
|
df = pd.DataFrame(
|
|
{
|
|
"m": [16, 16],
|
|
"n": [16, 16],
|
|
"k": [16, 16],
|
|
"measured_tflops": [0, 0],
|
|
"pred_tflops": [1, 2],
|
|
}
|
|
)
|
|
eff = compute_tflops_efficiency(df, "gemm_universal", "pred_tflops")
|
|
assert len(eff) == 0
|
|
|
|
def test_single_kernel_per_shape(self):
|
|
df = pd.DataFrame(
|
|
{
|
|
"m": [1024],
|
|
"n": [1024],
|
|
"k": [1024],
|
|
"measured_tflops": [150],
|
|
"pred_tflops": [100],
|
|
}
|
|
)
|
|
eff = compute_tflops_efficiency(df, "gemm_universal", "pred_tflops")
|
|
assert len(eff) == 1
|
|
assert eff["efficiency"].iloc[0] == pytest.approx(1.0)
|
|
|
|
def test_tied_predictions(self):
|
|
"""When multiple kernels have the same predicted TFLOPS, pandas idxmax picks the first."""
|
|
df = pd.DataFrame(
|
|
{
|
|
"m": [1024, 1024, 1024],
|
|
"n": [1024, 1024, 1024],
|
|
"k": [1024, 1024, 1024],
|
|
"measured_tflops": [100, 200, 200],
|
|
"pred_tflops": [50, 50, 50],
|
|
}
|
|
)
|
|
eff = compute_tflops_efficiency(df, "gemm_universal", "pred_tflops")
|
|
assert len(eff) == 1
|
|
assert eff["efficiency"].iloc[0] >= 0.5
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers for warm-start tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_dummy_data(n_rows=200, n_shapes=5):
|
|
"""Create a small synthetic benchmark DataFrame for testing training."""
|
|
rng = np.random.RandomState(42)
|
|
rows = []
|
|
for _ in range(n_rows):
|
|
m = rng.choice([64, 128, 256, 512, 1024])
|
|
n = rng.choice([64, 128, 256, 512, 1024])
|
|
k = rng.choice([64, 128, 256, 512, 1024])
|
|
rows.append(
|
|
{
|
|
"m": m,
|
|
"n": n,
|
|
"k": k,
|
|
"split_k": 1,
|
|
"dtype": "fp8",
|
|
"layout": "rcr",
|
|
"op_type": "gemm_universal",
|
|
"tile_m": rng.choice([64, 128, 256]),
|
|
"tile_n": rng.choice([64, 128, 256]),
|
|
"tile_k": rng.choice([32, 64, 128]),
|
|
"warp_m": rng.choice([1, 2, 4]),
|
|
"warp_n": rng.choice([1, 2, 4]),
|
|
"warp_k": 1,
|
|
"warp_tile_m": 32,
|
|
"warp_tile_n": 32,
|
|
"warp_tile_k": 16,
|
|
"pipeline": rng.choice(["compv3", "compv4", "mem"]),
|
|
"scheduler": rng.choice(["intrawave", "interwave"]),
|
|
"epilogue": "cshuffle",
|
|
"pad_m": False,
|
|
"pad_n": False,
|
|
"pad_k": False,
|
|
"persistent": False,
|
|
"measured_tflops": float(rng.uniform(10, 500)),
|
|
"latency_ms": float(rng.uniform(0.01, 1.0)),
|
|
"bandwidth_gb_s": float(rng.uniform(50, 1500)),
|
|
"is_valid": True,
|
|
"kernel_name": f"test_kernel_{rng.randint(0, 100)}",
|
|
}
|
|
)
|
|
return pd.DataFrame(rows)
|
|
|
|
|
|
def _save_feature_spec(model_dir, fe):
|
|
"""Save a feature_spec.json matching the given feature engine."""
|
|
spec = {
|
|
"feature_names": fe.get_feature_names(),
|
|
"categorical_features": fe.get_categorical_features(),
|
|
}
|
|
with open(model_dir / "feature_spec.json", "w") as f:
|
|
json.dump(spec, f)
|
|
|
|
|
|
def _train_and_save_base_model(model_dir, df, fe, target="tflops"):
|
|
"""Train a small base model and save it to model_dir."""
|
|
params = dict(DEFAULT_PARAMS)
|
|
params["n_estimators"] = 20
|
|
params["n_jobs"] = 1
|
|
model = train_final_model(df, fe, target, params, "gemm_universal")
|
|
model.booster_.save_model(str(model_dir / f"model_{target}.lgbm"))
|
|
_save_feature_spec(model_dir, fe)
|
|
return model
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Warm-start tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCheckFeatureCompatibility:
|
|
def test_compatible_passes(self, tmp_path):
|
|
fe = GemmUniversalFeatureEngine()
|
|
_save_feature_spec(tmp_path, fe)
|
|
check_feature_compatibility(tmp_path, fe)
|
|
|
|
def test_missing_spec_raises(self, tmp_path):
|
|
fe = GemmUniversalFeatureEngine()
|
|
with pytest.raises(FileNotFoundError, match="feature_spec.json"):
|
|
check_feature_compatibility(tmp_path, fe)
|
|
|
|
def test_added_feature_raises(self, tmp_path):
|
|
fe = GemmUniversalFeatureEngine()
|
|
spec = {
|
|
"feature_names": fe.get_feature_names()[:-1],
|
|
"categorical_features": fe.get_categorical_features(),
|
|
}
|
|
with open(tmp_path / "feature_spec.json", "w") as f:
|
|
json.dump(spec, f)
|
|
with pytest.raises(ValueError, match="Feature schema mismatch"):
|
|
check_feature_compatibility(tmp_path, fe)
|
|
|
|
def test_removed_feature_raises(self, tmp_path):
|
|
fe = GemmUniversalFeatureEngine()
|
|
spec = {
|
|
"feature_names": fe.get_feature_names() + ["extra_feature"],
|
|
"categorical_features": fe.get_categorical_features(),
|
|
}
|
|
with open(tmp_path / "feature_spec.json", "w") as f:
|
|
json.dump(spec, f)
|
|
with pytest.raises(ValueError, match="Feature schema mismatch"):
|
|
check_feature_compatibility(tmp_path, fe)
|
|
|
|
def test_categorical_mismatch_raises(self, tmp_path):
|
|
fe = GemmUniversalFeatureEngine()
|
|
spec = {
|
|
"feature_names": fe.get_feature_names(),
|
|
"categorical_features": ["layout", "pipeline"],
|
|
}
|
|
with open(tmp_path / "feature_spec.json", "w") as f:
|
|
json.dump(spec, f)
|
|
with pytest.raises(ValueError, match="Categorical feature mismatch"):
|
|
check_feature_compatibility(tmp_path, fe)
|
|
|
|
|
|
class TestLoadWarmStartModel:
|
|
def test_loads_existing_model(self, tmp_path):
|
|
fe = GemmUniversalFeatureEngine()
|
|
df = _make_dummy_data()
|
|
_train_and_save_base_model(tmp_path, df, fe)
|
|
path = load_warm_start_model(tmp_path, "tflops")
|
|
assert path is not None
|
|
assert Path(path).exists()
|
|
|
|
def test_returns_none_for_missing_target(self, tmp_path):
|
|
assert load_warm_start_model(tmp_path, "tflops") is None
|
|
|
|
def test_returns_none_for_wrong_target(self, tmp_path):
|
|
fe = GemmUniversalFeatureEngine()
|
|
df = _make_dummy_data()
|
|
_train_and_save_base_model(tmp_path, df, fe, target="tflops")
|
|
assert load_warm_start_model(tmp_path, "bandwidth") is None
|
|
|
|
|
|
class TestWarmStartTraining:
|
|
def test_warm_start_produces_more_trees(self, tmp_path):
|
|
"""A warm-started model should have more trees than the base."""
|
|
fe = GemmUniversalFeatureEngine()
|
|
df = _make_dummy_data(n_rows=300)
|
|
|
|
base_dir = tmp_path / "base"
|
|
base_dir.mkdir()
|
|
base_model = _train_and_save_base_model(base_dir, df, fe)
|
|
base_n_trees = base_model.booster_.num_trees()
|
|
|
|
init_model_path = load_warm_start_model(base_dir, "tflops")
|
|
params = dict(DEFAULT_PARAMS)
|
|
params["n_estimators"] = 15
|
|
params["n_jobs"] = 1
|
|
warm_model = train_final_model(
|
|
df, fe, "tflops", params, "gemm_universal", init_model=init_model_path
|
|
)
|
|
warm_n_trees = warm_model.booster_.num_trees()
|
|
|
|
assert warm_n_trees > base_n_trees
|
|
|
|
def test_warm_start_does_not_degrade(self, tmp_path):
|
|
"""Warm-started model on the same data should not be significantly worse."""
|
|
fe = GemmUniversalFeatureEngine()
|
|
df = _make_dummy_data(n_rows=300)
|
|
|
|
base_dir = tmp_path / "base"
|
|
base_dir.mkdir()
|
|
base_model = _train_and_save_base_model(base_dir, df, fe)
|
|
|
|
X = fe.extract_batch(df[df["is_valid"]].reset_index(drop=True))
|
|
y = df[df["is_valid"]]["measured_tflops"].values
|
|
base_rmse = np.sqrt(np.mean((base_model.predict(X) - y) ** 2))
|
|
|
|
init_model_path = load_warm_start_model(base_dir, "tflops")
|
|
params = dict(DEFAULT_PARAMS)
|
|
params["n_estimators"] = 15
|
|
params["n_jobs"] = 1
|
|
warm_model = train_final_model(
|
|
df, fe, "tflops", params, "gemm_universal", init_model=init_model_path
|
|
)
|
|
warm_rmse = np.sqrt(np.mean((warm_model.predict(X) - y) ** 2))
|
|
|
|
assert warm_rmse <= base_rmse * 1.1
|
|
|
|
def test_warm_start_from_nonexistent_dir(self):
|
|
with pytest.raises(FileNotFoundError):
|
|
check_feature_compatibility(
|
|
Path("/nonexistent/model/dir"), GemmUniversalFeatureEngine()
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|