mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
756 lines
26 KiB
Python
756 lines
26 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Training script for CK Tile kernel performance prediction.
|
|
|
|
Trains LGBMRegressor models (TFLOPS, latency, bandwidth) with:
|
|
- Log-space regression (log1p transform) for scale-invariant accuracy
|
|
- GroupKFold cross-validation (operation-specific group keys)
|
|
- Iterative Hard Example Mining (IHEM)
|
|
- Model complexity bounds for C++ deployability
|
|
- Optional Optuna hyperparameter tuning
|
|
- Warm-start incremental training from a previous model via --warm_start
|
|
|
|
Supports multiple operation types:
|
|
- gemm_universal: GEMM operations (group by M, N, K)
|
|
- grouped_conv: Grouped convolution (group by problem config)
|
|
- fmha: Fused multi-head attention (future)
|
|
|
|
Log-transform rationale:
|
|
GEMM TFLOPS spans 5 orders of magnitude (0.02 for M=1 to 2230 for large
|
|
shapes). Raw regression optimizes for absolute RMSE, which means the model
|
|
spends all its capacity predicting large shapes accurately and ignores tiny
|
|
shapes where TFLOPS is < 10. Training on log1p(TFLOPS) puts all shapes on
|
|
equal footing, improving tiny_m efficiency from 84% to 96%.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import lightgbm as lgb
|
|
import numpy as np
|
|
import pandas as pd
|
|
from sklearn.model_selection import GroupKFold
|
|
|
|
from data_pipeline import build_training_dataset
|
|
|
|
|
|
# Operation-specific target column mappings
|
|
TARGET_COLUMNS = {
|
|
"gemm_universal": {
|
|
"tflops": "measured_tflops",
|
|
"latency": "latency_ms",
|
|
"bandwidth": "bandwidth_gb_s",
|
|
},
|
|
"grouped_conv": {
|
|
"tflops": "tflops",
|
|
"latency": "latency_ms",
|
|
"bandwidth": "bandwidth_gb_s",
|
|
},
|
|
"fmha": {
|
|
"tflops": "tflops",
|
|
"latency": "latency_ms",
|
|
"bandwidth": "bandwidth_gb_s",
|
|
},
|
|
}
|
|
|
|
# Targets where log1p transform is applied by default.
|
|
# TFLOPS and bandwidth span orders of magnitude; latency is already small-scale.
|
|
LOG_TARGETS = {"tflops", "bandwidth"}
|
|
|
|
DEFAULT_PARAMS = {
|
|
"objective": "regression",
|
|
"metric": ["rmse", "mae"],
|
|
"num_leaves": 255,
|
|
"max_depth": 15,
|
|
"n_estimators": 2000,
|
|
"learning_rate": 0.02,
|
|
"min_child_samples": 10,
|
|
"subsample": 0.85,
|
|
"colsample_bytree": 0.85,
|
|
"reg_alpha": 0.05,
|
|
"reg_lambda": 0.5,
|
|
"verbose": -1,
|
|
"n_jobs": 8,
|
|
"seed": 42,
|
|
}
|
|
|
|
MAX_ESTIMATORS = 5000
|
|
WARM_START_N_ESTIMATORS = 500
|
|
|
|
|
|
def get_feature_engine(operation: str, **hw_kwargs):
|
|
"""Get the appropriate feature engine for the operation type."""
|
|
if operation == "gemm_universal":
|
|
from feature_engine import GemmUniversalFeatureEngine
|
|
|
|
return GemmUniversalFeatureEngine(**hw_kwargs)
|
|
elif operation == "grouped_conv":
|
|
from feature_engine_grouped_conv import GroupedConvFeatureEngine
|
|
|
|
return GroupedConvFeatureEngine(**hw_kwargs)
|
|
elif operation == "fmha":
|
|
raise NotImplementedError("FMHA feature engine not yet implemented")
|
|
else:
|
|
raise ValueError(f"Unknown operation type: {operation}")
|
|
|
|
|
|
def check_feature_compatibility(
|
|
prev_model_dir: Path,
|
|
feature_engine,
|
|
) -> None:
|
|
"""Verify that the previous model's feature spec matches the current engine.
|
|
|
|
Raises ValueError with a detailed message on mismatch. This prevents silent
|
|
corruption when warm-starting from a model trained with a different feature
|
|
schema (e.g., after adding a new feature or changing an encoding).
|
|
|
|
Parameters
|
|
----------
|
|
prev_model_dir : Path
|
|
Directory containing the previous model
|
|
feature_engine : FeatureEngine
|
|
Current feature engine instance (any operation type)
|
|
"""
|
|
spec_path = prev_model_dir / "feature_spec.json"
|
|
if not spec_path.exists():
|
|
raise FileNotFoundError(
|
|
f"No feature_spec.json in {prev_model_dir}. "
|
|
"Cannot verify feature compatibility for warm start."
|
|
)
|
|
|
|
with open(spec_path) as f:
|
|
prev_spec = json.load(f)
|
|
|
|
prev_names = prev_spec.get("feature_names", [])
|
|
curr_names = feature_engine.get_feature_names()
|
|
if prev_names != curr_names:
|
|
added = set(curr_names) - set(prev_names)
|
|
removed = set(prev_names) - set(curr_names)
|
|
parts = ["Feature schema mismatch between previous model and current engine."]
|
|
if added:
|
|
parts.append(f" Added features: {sorted(added)}")
|
|
if removed:
|
|
parts.append(f" Removed features: {sorted(removed)}")
|
|
if not added and not removed:
|
|
parts.append(" Feature order changed (names match but order differs).")
|
|
raise ValueError("\n".join(parts))
|
|
|
|
prev_cats = prev_spec.get("categorical_features", [])
|
|
curr_cats = feature_engine.get_categorical_features()
|
|
if sorted(prev_cats) != sorted(curr_cats):
|
|
raise ValueError(
|
|
f"Categorical feature mismatch.\n"
|
|
f" Previous: {sorted(prev_cats)}\n"
|
|
f" Current: {sorted(curr_cats)}"
|
|
)
|
|
|
|
|
|
def load_warm_start_model(prev_model_dir: Path, target: str) -> str | None:
|
|
"""Load the path to a previous model file for warm-start, or None if absent.
|
|
|
|
Automatically decompresses .lgbm.gz files if the .lgbm file doesn't exist.
|
|
The decompressed file is cached to disk for subsequent loads.
|
|
|
|
Returns the string path (what LightGBM's init_model expects) rather than
|
|
a loaded Booster, because LGBMRegressor.fit(init_model=...) accepts both
|
|
path strings and Booster objects and path strings avoid keeping the old
|
|
model in memory.
|
|
"""
|
|
import gzip
|
|
|
|
model_path = prev_model_dir / f"model_{target}.lgbm"
|
|
gz_path = prev_model_dir / f"model_{target}.lgbm.gz"
|
|
|
|
# Auto-decompress if needed
|
|
if not model_path.exists() and gz_path.exists():
|
|
print(f" Decompressing {gz_path.name}...")
|
|
with gzip.open(gz_path, "rb") as f_in:
|
|
with open(model_path, "wb") as f_out:
|
|
f_out.write(f_in.read())
|
|
|
|
if not model_path.exists():
|
|
return None
|
|
return str(model_path)
|
|
|
|
|
|
def compute_group_keys(df: pd.DataFrame, operation: str) -> np.ndarray:
|
|
"""Create GroupKFold group keys based on operation type.
|
|
|
|
Parameters
|
|
----------
|
|
df : pd.DataFrame
|
|
Training data
|
|
operation : str
|
|
Operation type (gemm_universal, grouped_conv, fmha)
|
|
|
|
Returns
|
|
-------
|
|
np.ndarray
|
|
Group keys for GroupKFold cross-validation
|
|
"""
|
|
if operation == "gemm_universal":
|
|
# Group by (M, N, K)
|
|
return (
|
|
df["m"].astype(str) + "_" + df["n"].astype(str) + "_" + df["k"].astype(str)
|
|
).values
|
|
elif operation == "grouped_conv":
|
|
# Group by problem configuration (including 3D and dilation for FWD/BWD_DATA/BWD_WEIGHT)
|
|
return df.apply(
|
|
lambda r: f"{r['N']}_{r['C']}_{r['K']}_{r['G']}_{r['Hi']}_{r['Wi']}_{r['Y']}_{r['X']}_"
|
|
f"{r.get('Di', 1)}_{r.get('Z', 1)}_"
|
|
f"{r.get('dilation_h', 1)}_{r.get('dilation_w', 1)}",
|
|
axis=1,
|
|
).values
|
|
elif operation == "fmha":
|
|
raise NotImplementedError("FMHA group key computation not yet implemented")
|
|
else:
|
|
raise ValueError(f"Unknown operation type: {operation}")
|
|
|
|
|
|
def compute_tflops_efficiency(
|
|
df: pd.DataFrame, operation: str, pred_col: str = "pred_tflops"
|
|
) -> pd.DataFrame:
|
|
"""Compute per-shape efficiency: predicted-best TFLOPS / oracle-best TFLOPS.
|
|
|
|
Parameters
|
|
----------
|
|
df : pd.DataFrame
|
|
Dataframe with predictions and actual TFLOPS
|
|
operation : str
|
|
Operation type to determine grouping columns
|
|
pred_col : str
|
|
Column name for predicted TFLOPS
|
|
|
|
Returns
|
|
-------
|
|
pd.DataFrame
|
|
Per-shape efficiency metrics
|
|
"""
|
|
results = []
|
|
|
|
if operation == "gemm_universal":
|
|
groupby_cols = ["m", "n", "k"]
|
|
tflops_col = "measured_tflops"
|
|
elif operation == "grouped_conv":
|
|
# Group by all problem parameters including 3D and dilation
|
|
base_cols = ["N", "C", "K", "G", "Hi", "Wi", "Y", "X"]
|
|
optional_cols = ["Di", "Z", "dilation_h", "dilation_w"]
|
|
groupby_cols = base_cols + [col for col in optional_cols if col in df.columns]
|
|
tflops_col = "tflops"
|
|
elif operation == "fmha":
|
|
raise NotImplementedError("FMHA efficiency computation not yet implemented")
|
|
else:
|
|
raise ValueError(f"Unknown operation type: {operation}")
|
|
|
|
for shape_key, group in df.groupby(groupby_cols):
|
|
oracle_best = group[tflops_col].max()
|
|
if oracle_best <= 0:
|
|
continue
|
|
pred_best_idx = group[pred_col].idxmax()
|
|
selected_tflops = group.loc[pred_best_idx, tflops_col]
|
|
efficiency = selected_tflops / oracle_best
|
|
|
|
result = {
|
|
"oracle_best_tflops": oracle_best,
|
|
"selected_tflops": selected_tflops,
|
|
"efficiency": efficiency,
|
|
}
|
|
# Add shape-specific keys
|
|
if operation == "gemm_universal":
|
|
result.update({"m": shape_key[0], "n": shape_key[1], "k": shape_key[2]})
|
|
elif operation == "grouped_conv":
|
|
result.update(
|
|
{
|
|
"N": shape_key[0],
|
|
"C": shape_key[1],
|
|
"K": shape_key[2],
|
|
"G": shape_key[3],
|
|
"Hi": shape_key[4],
|
|
"Wi": shape_key[5],
|
|
"Y": shape_key[6],
|
|
"X": shape_key[7],
|
|
}
|
|
)
|
|
|
|
results.append(result)
|
|
|
|
return pd.DataFrame(results)
|
|
|
|
|
|
def train_single_target(
|
|
X_train,
|
|
y_train,
|
|
X_val,
|
|
y_val,
|
|
params: dict,
|
|
categorical_features: list[str],
|
|
feature_names: list[str],
|
|
init_model=None,
|
|
) -> lgb.LGBMRegressor:
|
|
"""Train a single LGBMRegressor with early stopping.
|
|
|
|
Parameters
|
|
----------
|
|
init_model : str, Path, lgb.Booster, lgb.LGBMModel, or None
|
|
If provided, training continues from this model (warm start).
|
|
Accepts a file path to a .lgbm file, a Booster instance, or an
|
|
LGBMModel instance. The new model adds n_estimators trees on top
|
|
of the existing ones.
|
|
"""
|
|
cat_indices = [
|
|
feature_names.index(c) for c in categorical_features if c in feature_names
|
|
]
|
|
|
|
model = lgb.LGBMRegressor(**params)
|
|
model.fit(
|
|
X_train,
|
|
y_train,
|
|
eval_set=[(X_val, y_val)],
|
|
eval_metric=["rmse"],
|
|
callbacks=[
|
|
lgb.early_stopping(50, verbose=False),
|
|
lgb.log_evaluation(0),
|
|
],
|
|
categorical_feature=cat_indices if cat_indices else "auto",
|
|
init_model=init_model,
|
|
)
|
|
return model
|
|
|
|
|
|
def run_cv(
|
|
df: pd.DataFrame,
|
|
feature_engine,
|
|
target: str,
|
|
params: dict,
|
|
operation: str,
|
|
n_splits: int = 5,
|
|
use_log: bool = True,
|
|
) -> dict:
|
|
"""Run GroupKFold cross-validation and return OOF predictions + metrics.
|
|
|
|
Parameters
|
|
----------
|
|
df : pd.DataFrame
|
|
Training data
|
|
feature_engine : FeatureEngine
|
|
Feature engine instance (operation-specific)
|
|
target : str
|
|
Target metric (tflops, latency, bandwidth)
|
|
params : dict
|
|
LightGBM parameters
|
|
operation : str
|
|
Operation type (gemm_universal, grouped_conv, fmha)
|
|
n_splits : int
|
|
Number of CV folds
|
|
use_log : bool
|
|
If True and target is in LOG_TARGETS, train on log1p(y) and invert
|
|
predictions with expm1 for efficiency calculation. This normalizes
|
|
the scale so that tiny-M shapes (TFLOPS ~ 1) get equal attention
|
|
as large-M shapes (TFLOPS ~ 2000).
|
|
"""
|
|
target_col = TARGET_COLUMNS[operation][target]
|
|
|
|
# Handle is_valid column (present in GEMM, not in grouped_conv)
|
|
if "is_valid" in df.columns:
|
|
valid_mask = df["is_valid"].fillna(False) & (df[target_col] > 0)
|
|
else:
|
|
valid_mask = df[target_col] > 0
|
|
|
|
df_valid = df[valid_mask].reset_index(drop=True)
|
|
|
|
apply_log = use_log and target in LOG_TARGETS
|
|
|
|
print(
|
|
f" Training on {len(df_valid)} valid rows for target={target}"
|
|
f"{' (log-space)' if apply_log else ''}"
|
|
)
|
|
|
|
X = feature_engine.extract_batch(df_valid)
|
|
y_raw = df_valid[target_col].values
|
|
y = np.log1p(y_raw) if apply_log else y_raw
|
|
groups = compute_group_keys(df_valid, operation)
|
|
feature_names = feature_engine.get_feature_names()
|
|
cat_features = feature_engine.get_categorical_features()
|
|
|
|
unique_groups = np.unique(groups)
|
|
actual_splits = min(n_splits, len(unique_groups))
|
|
if actual_splits < 2:
|
|
print(f" WARNING: Only {len(unique_groups)} unique groups, skipping CV")
|
|
return {}
|
|
|
|
gkf = GroupKFold(n_splits=actual_splits)
|
|
oof_preds = np.zeros(len(df_valid))
|
|
fold_metrics = []
|
|
|
|
for fold_idx, (train_idx, val_idx) in enumerate(gkf.split(X, y, groups)):
|
|
X_tr, X_val = X[train_idx], X[val_idx]
|
|
y_tr, y_val = y[train_idx], y[val_idx]
|
|
|
|
model = train_single_target(
|
|
X_tr, y_tr, X_val, y_val, params, cat_features, feature_names
|
|
)
|
|
preds = model.predict(X_val)
|
|
oof_preds[val_idx] = preds
|
|
|
|
rmse = np.sqrt(np.mean((preds - y_val) ** 2))
|
|
r2 = 1 - np.sum((preds - y_val) ** 2) / max(
|
|
np.sum((y_val - y_val.mean()) ** 2), 1e-10
|
|
)
|
|
|
|
if target == "tflops":
|
|
val_df = df_valid.iloc[val_idx].copy()
|
|
preds_raw = np.expm1(preds) if apply_log else preds
|
|
val_df["pred_tflops"] = preds_raw
|
|
eff_df = compute_tflops_efficiency(val_df, operation)
|
|
mean_eff = eff_df["efficiency"].mean() if len(eff_df) > 0 else 0
|
|
p10_eff = eff_df["efficiency"].quantile(0.1) if len(eff_df) > 0 else 0
|
|
else:
|
|
mean_eff, p10_eff = None, None
|
|
|
|
fold_metrics.append(
|
|
{
|
|
"fold": fold_idx,
|
|
"rmse": rmse,
|
|
"r2": r2,
|
|
"mean_efficiency": mean_eff,
|
|
"p10_efficiency": p10_eff,
|
|
"train_size": len(train_idx),
|
|
"val_size": len(val_idx),
|
|
"val_groups": len(np.unique(groups[val_idx])),
|
|
}
|
|
)
|
|
|
|
eff_str = (
|
|
f", eff={mean_eff:.4f}, p10={p10_eff:.4f}" if mean_eff is not None else ""
|
|
)
|
|
print(f" Fold {fold_idx}: RMSE={rmse:.4f}, R2={r2:.4f}{eff_str}")
|
|
|
|
df_valid[f"oof_pred_{target}"] = oof_preds
|
|
|
|
return {
|
|
"fold_metrics": fold_metrics,
|
|
"oof_df": df_valid,
|
|
"feature_names": feature_names,
|
|
"log_transform": apply_log,
|
|
}
|
|
|
|
|
|
def train_final_model(
|
|
df: pd.DataFrame,
|
|
feature_engine,
|
|
target: str,
|
|
params: dict,
|
|
operation: str,
|
|
init_model=None,
|
|
use_log: bool = True,
|
|
) -> lgb.LGBMRegressor:
|
|
"""Train the final model on all valid data.
|
|
|
|
Parameters
|
|
----------
|
|
df : pd.DataFrame
|
|
Training data
|
|
feature_engine : FeatureEngine
|
|
Feature engine instance (operation-specific)
|
|
target : str
|
|
Target metric (tflops, latency, bandwidth)
|
|
params : dict
|
|
LightGBM parameters
|
|
operation : str
|
|
Operation type (gemm_universal, grouped_conv, fmha)
|
|
init_model : str, Path, lgb.Booster, lgb.LGBMModel, or None
|
|
If provided, training continues from this model (warm start).
|
|
use_log : bool
|
|
If True and target is in LOG_TARGETS, train on log1p(y).
|
|
The saved model then predicts in log-space; callers must apply
|
|
expm1() to get raw values.
|
|
"""
|
|
target_col = TARGET_COLUMNS[operation][target]
|
|
|
|
# Handle is_valid column (present in GEMM, not in grouped_conv)
|
|
if "is_valid" in df.columns:
|
|
valid_mask = df["is_valid"].fillna(False) & (df[target_col] > 0)
|
|
else:
|
|
valid_mask = df[target_col] > 0
|
|
|
|
df_valid = df[valid_mask].reset_index(drop=True)
|
|
|
|
apply_log = use_log and target in LOG_TARGETS
|
|
|
|
X = feature_engine.extract_batch(df_valid)
|
|
y_raw = df_valid[target_col].values
|
|
y = np.log1p(y_raw) if apply_log else y_raw
|
|
feature_names = feature_engine.get_feature_names()
|
|
cat_features = feature_engine.get_categorical_features()
|
|
cat_indices = [feature_names.index(c) for c in cat_features if c in feature_names]
|
|
|
|
model = lgb.LGBMRegressor(**params)
|
|
model.fit(
|
|
X,
|
|
y,
|
|
categorical_feature=cat_indices if cat_indices else "auto",
|
|
init_model=init_model,
|
|
)
|
|
return model
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Train CK Tile kernel performance models (GEMM, Grouped Conv, FMHA)"
|
|
)
|
|
parser.add_argument(
|
|
"--data_dir", required=True, help="Directory with parquet files"
|
|
)
|
|
parser.add_argument("--out_dir", required=True, help="Output directory for models")
|
|
parser.add_argument(
|
|
"--operation",
|
|
default="gemm_universal",
|
|
choices=["gemm_universal", "grouped_conv", "fmha"],
|
|
help="Operation type (gemm_universal, grouped_conv, fmha)",
|
|
)
|
|
parser.add_argument(
|
|
"--op",
|
|
default=None,
|
|
help="Deprecated: use --operation instead. Kept for backward compatibility.",
|
|
)
|
|
parser.add_argument("--dtype", default="fp8", help="Data type filter")
|
|
parser.add_argument("--arch", default="gfx950", help="Architecture")
|
|
parser.add_argument(
|
|
"--targets", default="tflops,latency,bandwidth", help="Comma-separated targets"
|
|
)
|
|
parser.add_argument("--n_splits", type=int, default=5, help="Number of CV folds")
|
|
parser.add_argument(
|
|
"--tune", action="store_true", help="Run Optuna hyperparameter tuning"
|
|
)
|
|
parser.add_argument(
|
|
"--no_log_transform",
|
|
action="store_true",
|
|
help="Disable log1p transform on targets. By default, TFLOPS and bandwidth "
|
|
"are trained in log-space for scale-invariant accuracy across shape sizes.",
|
|
)
|
|
parser.add_argument(
|
|
"--warm_start",
|
|
default=None,
|
|
help="Path to previous model directory to continue training from. "
|
|
"Uses LightGBM's init_model to add new trees on top of the "
|
|
"existing model. Feature schemas must match exactly.",
|
|
)
|
|
parser.add_argument(
|
|
"--warm_start_n_estimators",
|
|
type=int,
|
|
default=WARM_START_N_ESTIMATORS,
|
|
help=f"Number of new trees to add when warm-starting (default: {WARM_START_N_ESTIMATORS}). "
|
|
"Lower than a full train since we're refining, not starting from scratch.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Handle backward compatibility for --op flag
|
|
operation = args.operation
|
|
if args.op is not None:
|
|
print("WARNING: --op is deprecated, use --operation instead")
|
|
operation = args.op
|
|
|
|
out_dir = Path(args.out_dir)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
targets = [t.strip() for t in args.targets.split(",")]
|
|
|
|
print(f"{'=' * 80}")
|
|
print(f"Training {operation} model")
|
|
print(f"{'=' * 80}")
|
|
print()
|
|
|
|
print(f"Loading data from {args.data_dir}...")
|
|
df = build_training_dataset(args.data_dir, op_type=operation, dtype=args.dtype)
|
|
print(f" Total rows: {len(df)}")
|
|
|
|
# Print unique shapes based on operation type
|
|
if operation == "gemm_universal":
|
|
print(f" Unique shapes: {df.groupby(['m', 'n', 'k']).ngroups}")
|
|
elif operation == "grouped_conv":
|
|
print(
|
|
f" Unique shapes: {df.groupby(['N', 'C', 'K', 'G', 'Hi', 'Wi', 'Y', 'X']).ngroups}"
|
|
)
|
|
|
|
print(f" Unique kernels: {df['kernel_name'].nunique()}")
|
|
print()
|
|
|
|
# Extract hardware parameters from data (if available)
|
|
hw_cols = [c for c in df.columns if c.startswith("hw_")]
|
|
hw_kwargs = {}
|
|
if hw_cols:
|
|
row0 = df.iloc[0]
|
|
if "hw_num_cus" in df.columns:
|
|
hw_kwargs["num_cus"] = int(row0.get("hw_num_cus", 256))
|
|
if "hw_max_clock_mhz" in df.columns:
|
|
hw_kwargs["max_clock_mhz"] = int(row0.get("hw_max_clock_mhz", 2400))
|
|
if "hw_simds_per_cu" in df.columns:
|
|
hw_kwargs["simds_per_cu"] = int(row0.get("hw_simds_per_cu", 4))
|
|
if "hw_shader_engines" in df.columns:
|
|
hw_kwargs["shader_engines"] = int(row0.get("hw_shader_engines", 32))
|
|
if "hw_max_waves_per_cu" in df.columns:
|
|
hw_kwargs["max_waves_per_cu"] = int(row0.get("hw_max_waves_per_cu", 32))
|
|
if "hw_wavefront_size" in df.columns:
|
|
hw_kwargs["wavefront_size"] = int(row0.get("hw_wavefront_size", 64))
|
|
if "hw_l1_cache_kb" in df.columns:
|
|
hw_kwargs["l1_cache_kb"] = int(row0.get("hw_l1_cache_kb", 32))
|
|
if "hw_l2_cache_kb" in df.columns:
|
|
hw_kwargs["l2_cache_kb"] = int(row0.get("hw_l2_cache_kb", 4096))
|
|
if "hw_l3_cache_kb" in df.columns:
|
|
hw_kwargs["l3_cache_kb"] = int(row0.get("hw_l3_cache_kb", 262144))
|
|
|
|
# Get operation-specific feature engine
|
|
print(f"Initializing {operation} feature engine...")
|
|
fe = get_feature_engine(operation, **hw_kwargs)
|
|
print(f" Feature count: {len(fe.get_feature_names())}")
|
|
print(f" Categorical features: {len(fe.get_categorical_features())}")
|
|
print()
|
|
|
|
params = dict(DEFAULT_PARAMS)
|
|
use_log = not args.no_log_transform
|
|
|
|
prev_model_dir = None
|
|
prev_manifest = {}
|
|
if args.warm_start:
|
|
prev_model_dir = Path(args.warm_start)
|
|
if not prev_model_dir.exists():
|
|
raise FileNotFoundError(f"Warm-start directory not found: {prev_model_dir}")
|
|
print(f" Warm-starting from {prev_model_dir}")
|
|
check_feature_compatibility(prev_model_dir, fe)
|
|
print(" Feature compatibility: OK")
|
|
params["n_estimators"] = args.warm_start_n_estimators
|
|
print(f" New trees to add: {args.warm_start_n_estimators}")
|
|
|
|
prev_manifest_path = prev_model_dir / "train_manifest.json"
|
|
if prev_manifest_path.exists():
|
|
with open(prev_manifest_path) as f:
|
|
prev_manifest = json.load(f)
|
|
|
|
all_cv_results = {}
|
|
for target in targets:
|
|
if target not in TARGET_COLUMNS[operation]:
|
|
print(f" Skipping unknown target: {target}")
|
|
continue
|
|
|
|
print(f"\n{'=' * 60}")
|
|
print(f"Training {target} model")
|
|
print(f"{'=' * 60}")
|
|
|
|
init_model_path = None
|
|
if prev_model_dir is not None:
|
|
init_model_path = load_warm_start_model(prev_model_dir, target)
|
|
if init_model_path:
|
|
print(f" Warm-starting from {init_model_path}")
|
|
else:
|
|
print(f" No previous {target} model found, training from scratch")
|
|
|
|
t0 = time.time()
|
|
cv_result = run_cv(
|
|
df, fe, target, params, operation, n_splits=args.n_splits, use_log=use_log
|
|
)
|
|
cv_time = time.time() - t0
|
|
|
|
if cv_result and cv_result["fold_metrics"]:
|
|
all_cv_results[target] = cv_result["fold_metrics"]
|
|
metrics_path = out_dir / f"cv_metrics_{target}.json"
|
|
with open(metrics_path, "w") as f:
|
|
json.dump(cv_result["fold_metrics"], f, indent=2)
|
|
print(f" CV completed in {cv_time:.1f}s, saved to {metrics_path}")
|
|
|
|
if target == "tflops" and cv_result.get("oof_df") is not None:
|
|
oof_df = cv_result["oof_df"]
|
|
oof_df.to_parquet(out_dir / "oof_predictions.parquet", index=False)
|
|
|
|
eff_df = compute_tflops_efficiency(oof_df, operation, "oof_pred_tflops")
|
|
if len(eff_df) > 0:
|
|
print("\n OOF TFLOPS Efficiency:")
|
|
print(f" Mean: {eff_df['efficiency'].mean():.4f}")
|
|
print(f" P10: {eff_df['efficiency'].quantile(0.1):.4f}")
|
|
print(f" P50: {eff_df['efficiency'].quantile(0.5):.4f}")
|
|
print(f" Min: {eff_df['efficiency'].min():.4f}")
|
|
|
|
print(f"\n Training final {target} model on all data...")
|
|
t0 = time.time()
|
|
model = train_final_model(
|
|
df,
|
|
fe,
|
|
target,
|
|
params,
|
|
operation,
|
|
init_model=init_model_path,
|
|
use_log=use_log,
|
|
)
|
|
train_time = time.time() - t0
|
|
|
|
model_path = out_dir / f"model_{target}.lgbm"
|
|
model.booster_.save_model(str(model_path))
|
|
print(f" Saved {model_path} ({train_time:.1f}s)")
|
|
|
|
importances = dict(
|
|
zip(
|
|
fe.get_feature_names(),
|
|
model.feature_importances_.tolist(),
|
|
)
|
|
)
|
|
imp_path = out_dir / f"feature_importances_{target}.json"
|
|
with open(imp_path, "w") as f:
|
|
json.dump(importances, f, indent=2)
|
|
|
|
log_targets_used = sorted(LOG_TARGETS & set(targets)) if use_log else []
|
|
spec = {
|
|
"op_type": operation,
|
|
"dtype": args.dtype,
|
|
"arch": args.arch,
|
|
"feature_names": fe.get_feature_names(),
|
|
"categorical_features": fe.get_categorical_features(),
|
|
"targets": targets,
|
|
"log_targets": log_targets_used,
|
|
"params": params,
|
|
}
|
|
with open(out_dir / "feature_spec.json", "w") as f:
|
|
json.dump(spec, f, indent=2)
|
|
|
|
# Compute unique shapes based on operation type
|
|
if operation == "gemm_universal":
|
|
unique_shapes = int(df.groupby(["m", "n", "k"]).ngroups)
|
|
elif operation == "grouped_conv":
|
|
unique_shapes = int(
|
|
df.groupby(["N", "C", "K", "G", "Hi", "Wi", "Y", "X"]).ngroups
|
|
)
|
|
else:
|
|
unique_shapes = 0 # Unknown operation
|
|
|
|
manifest = {
|
|
"warm_start_from": str(prev_model_dir) if prev_model_dir else None,
|
|
"prev_n_estimators": prev_manifest.get(
|
|
"total_n_estimators", params.get("n_estimators")
|
|
)
|
|
if prev_model_dir
|
|
else 0,
|
|
"new_n_estimators": params["n_estimators"],
|
|
"total_n_estimators": (
|
|
prev_manifest.get("total_n_estimators", 0) + params["n_estimators"]
|
|
if prev_model_dir
|
|
else params["n_estimators"]
|
|
),
|
|
"data_rows": len(df),
|
|
"valid_rows": int(df["is_valid"].fillna(False).sum()),
|
|
"unique_shapes": unique_shapes,
|
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
}
|
|
with open(out_dir / "train_manifest.json", "w") as f:
|
|
json.dump(manifest, f, indent=2)
|
|
|
|
print(f"\nAll models saved to {out_dir}")
|
|
if prev_model_dir:
|
|
print(f" Warm-started from: {prev_model_dir}")
|
|
print(f" Total estimators: {manifest['total_n_estimators']}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|