Files
Yaswanth Raparti 6989cf800c [rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for
 grouped conv (#6327)

## Motivation
The ML heuristic in dispatcher does not support grouped-conv operator
yet. In this PR, the support for fwd, bdw-data, and bwd-weight
grouped-conv kernels have been added. A tile_engine utility has also
been added to compile and run any selected kernel configuration through
dispatcher infrastructure.

## Technical Details

1. Tile engine utility is added to benchmark each shape with all the
possible kernel+tile_size combinations here -
[https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url)
2. New LGBM regressor models for grouped conv are added to models
directory. We have 3 separate models for fwd, bwd-data, and bwd-weights
[https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url)
3. Implemented lazy GPU initialization (dispatcher/python)
- **Issue**: ProcessPoolExecutor fork() + GPU context caused memory
access faults
- **Solution**: Mirror FMHA pattern - defer GPU initialization until
first run()
  - **Changes**:
- setup_multiple_grouped_conv_dispatchers() returns List[Path], not
loaded libs
    - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL
    - Added _ensure_initialized() method for lazy GPU loading
    - GPU context created only on first run() call
  - **Benefit**: Parallel compilation now works without GPU conflicts
4. Addressed few miscellaneous issues such as:
  - Fixed BF16->FP16 naming bug in the dispatcher wrapper
- Added new tile sizes, and comp_v5 pipeline to the arch spec to expand
the kernel selection
- Added automatic padding support for unsupported shapes in dispatcher
runner
- Created a single source of truth between tile_engine and dispatcher
about the architecture and tile_size details
- Build a validation scripts to compare oracle_best vs ml_heuristic
comparison

## Test Plan

1. Validated fwd, bwd-data, and bwd-weight kernels with both known and
unseen data sets with up to 300 problems.
2. Ensured that test cases are added in both dispatcher and tile_engine
to validate the heuristic.

## Test Result
Results on Unseen shapes validated on gfx950
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.05%**
  - Median Efficiency: **96.8%**
  - P10 Efficiency: **79.9%**

#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.8%**
  - Median Efficiency: **96.5%**
  - P10 Efficiency: **82.9%**

#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **96.1%**
  - Median Efficiency: **99.2%**
  - P10 Efficiency: **89.4%**

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-08 20:48:42 +00:00

756 lines
26 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Training script for CK Tile kernel performance prediction.
Trains LGBMRegressor models (TFLOPS, latency, bandwidth) with:
- Log-space regression (log1p transform) for scale-invariant accuracy
- GroupKFold cross-validation (operation-specific group keys)
- Iterative Hard Example Mining (IHEM)
- Model complexity bounds for C++ deployability
- Optional Optuna hyperparameter tuning
- Warm-start incremental training from a previous model via --warm_start
Supports multiple operation types:
- gemm_universal: GEMM operations (group by M, N, K)
- grouped_conv: Grouped convolution (group by problem config)
- fmha: Fused multi-head attention (future)
Log-transform rationale:
GEMM TFLOPS spans 5 orders of magnitude (0.02 for M=1 to 2230 for large
shapes). Raw regression optimizes for absolute RMSE, which means the model
spends all its capacity predicting large shapes accurately and ignores tiny
shapes where TFLOPS is < 10. Training on log1p(TFLOPS) puts all shapes on
equal footing, improving tiny_m efficiency from 84% to 96%.
"""
import argparse
import json
import time
from pathlib import Path
import lightgbm as lgb
import numpy as np
import pandas as pd
from sklearn.model_selection import GroupKFold
from data_pipeline import build_training_dataset
# Operation-specific target column mappings
TARGET_COLUMNS = {
"gemm_universal": {
"tflops": "measured_tflops",
"latency": "latency_ms",
"bandwidth": "bandwidth_gb_s",
},
"grouped_conv": {
"tflops": "tflops",
"latency": "latency_ms",
"bandwidth": "bandwidth_gb_s",
},
"fmha": {
"tflops": "tflops",
"latency": "latency_ms",
"bandwidth": "bandwidth_gb_s",
},
}
# Targets where log1p transform is applied by default.
# TFLOPS and bandwidth span orders of magnitude; latency is already small-scale.
LOG_TARGETS = {"tflops", "bandwidth"}
DEFAULT_PARAMS = {
"objective": "regression",
"metric": ["rmse", "mae"],
"num_leaves": 255,
"max_depth": 15,
"n_estimators": 2000,
"learning_rate": 0.02,
"min_child_samples": 10,
"subsample": 0.85,
"colsample_bytree": 0.85,
"reg_alpha": 0.05,
"reg_lambda": 0.5,
"verbose": -1,
"n_jobs": 8,
"seed": 42,
}
MAX_ESTIMATORS = 5000
WARM_START_N_ESTIMATORS = 500
def get_feature_engine(operation: str, **hw_kwargs):
"""Get the appropriate feature engine for the operation type."""
if operation == "gemm_universal":
from feature_engine import GemmUniversalFeatureEngine
return GemmUniversalFeatureEngine(**hw_kwargs)
elif operation == "grouped_conv":
from feature_engine_grouped_conv import GroupedConvFeatureEngine
return GroupedConvFeatureEngine(**hw_kwargs)
elif operation == "fmha":
raise NotImplementedError("FMHA feature engine not yet implemented")
else:
raise ValueError(f"Unknown operation type: {operation}")
def check_feature_compatibility(
prev_model_dir: Path,
feature_engine,
) -> None:
"""Verify that the previous model's feature spec matches the current engine.
Raises ValueError with a detailed message on mismatch. This prevents silent
corruption when warm-starting from a model trained with a different feature
schema (e.g., after adding a new feature or changing an encoding).
Parameters
----------
prev_model_dir : Path
Directory containing the previous model
feature_engine : FeatureEngine
Current feature engine instance (any operation type)
"""
spec_path = prev_model_dir / "feature_spec.json"
if not spec_path.exists():
raise FileNotFoundError(
f"No feature_spec.json in {prev_model_dir}. "
"Cannot verify feature compatibility for warm start."
)
with open(spec_path) as f:
prev_spec = json.load(f)
prev_names = prev_spec.get("feature_names", [])
curr_names = feature_engine.get_feature_names()
if prev_names != curr_names:
added = set(curr_names) - set(prev_names)
removed = set(prev_names) - set(curr_names)
parts = ["Feature schema mismatch between previous model and current engine."]
if added:
parts.append(f" Added features: {sorted(added)}")
if removed:
parts.append(f" Removed features: {sorted(removed)}")
if not added and not removed:
parts.append(" Feature order changed (names match but order differs).")
raise ValueError("\n".join(parts))
prev_cats = prev_spec.get("categorical_features", [])
curr_cats = feature_engine.get_categorical_features()
if sorted(prev_cats) != sorted(curr_cats):
raise ValueError(
f"Categorical feature mismatch.\n"
f" Previous: {sorted(prev_cats)}\n"
f" Current: {sorted(curr_cats)}"
)
def load_warm_start_model(prev_model_dir: Path, target: str) -> str | None:
"""Load the path to a previous model file for warm-start, or None if absent.
Automatically decompresses .lgbm.gz files if the .lgbm file doesn't exist.
The decompressed file is cached to disk for subsequent loads.
Returns the string path (what LightGBM's init_model expects) rather than
a loaded Booster, because LGBMRegressor.fit(init_model=...) accepts both
path strings and Booster objects and path strings avoid keeping the old
model in memory.
"""
import gzip
model_path = prev_model_dir / f"model_{target}.lgbm"
gz_path = prev_model_dir / f"model_{target}.lgbm.gz"
# Auto-decompress if needed
if not model_path.exists() and gz_path.exists():
print(f" Decompressing {gz_path.name}...")
with gzip.open(gz_path, "rb") as f_in:
with open(model_path, "wb") as f_out:
f_out.write(f_in.read())
if not model_path.exists():
return None
return str(model_path)
def compute_group_keys(df: pd.DataFrame, operation: str) -> np.ndarray:
"""Create GroupKFold group keys based on operation type.
Parameters
----------
df : pd.DataFrame
Training data
operation : str
Operation type (gemm_universal, grouped_conv, fmha)
Returns
-------
np.ndarray
Group keys for GroupKFold cross-validation
"""
if operation == "gemm_universal":
# Group by (M, N, K)
return (
df["m"].astype(str) + "_" + df["n"].astype(str) + "_" + df["k"].astype(str)
).values
elif operation == "grouped_conv":
# Group by problem configuration (including 3D and dilation for FWD/BWD_DATA/BWD_WEIGHT)
return df.apply(
lambda r: f"{r['N']}_{r['C']}_{r['K']}_{r['G']}_{r['Hi']}_{r['Wi']}_{r['Y']}_{r['X']}_"
f"{r.get('Di', 1)}_{r.get('Z', 1)}_"
f"{r.get('dilation_h', 1)}_{r.get('dilation_w', 1)}",
axis=1,
).values
elif operation == "fmha":
raise NotImplementedError("FMHA group key computation not yet implemented")
else:
raise ValueError(f"Unknown operation type: {operation}")
def compute_tflops_efficiency(
df: pd.DataFrame, operation: str, pred_col: str = "pred_tflops"
) -> pd.DataFrame:
"""Compute per-shape efficiency: predicted-best TFLOPS / oracle-best TFLOPS.
Parameters
----------
df : pd.DataFrame
Dataframe with predictions and actual TFLOPS
operation : str
Operation type to determine grouping columns
pred_col : str
Column name for predicted TFLOPS
Returns
-------
pd.DataFrame
Per-shape efficiency metrics
"""
results = []
if operation == "gemm_universal":
groupby_cols = ["m", "n", "k"]
tflops_col = "measured_tflops"
elif operation == "grouped_conv":
# Group by all problem parameters including 3D and dilation
base_cols = ["N", "C", "K", "G", "Hi", "Wi", "Y", "X"]
optional_cols = ["Di", "Z", "dilation_h", "dilation_w"]
groupby_cols = base_cols + [col for col in optional_cols if col in df.columns]
tflops_col = "tflops"
elif operation == "fmha":
raise NotImplementedError("FMHA efficiency computation not yet implemented")
else:
raise ValueError(f"Unknown operation type: {operation}")
for shape_key, group in df.groupby(groupby_cols):
oracle_best = group[tflops_col].max()
if oracle_best <= 0:
continue
pred_best_idx = group[pred_col].idxmax()
selected_tflops = group.loc[pred_best_idx, tflops_col]
efficiency = selected_tflops / oracle_best
result = {
"oracle_best_tflops": oracle_best,
"selected_tflops": selected_tflops,
"efficiency": efficiency,
}
# Add shape-specific keys
if operation == "gemm_universal":
result.update({"m": shape_key[0], "n": shape_key[1], "k": shape_key[2]})
elif operation == "grouped_conv":
result.update(
{
"N": shape_key[0],
"C": shape_key[1],
"K": shape_key[2],
"G": shape_key[3],
"Hi": shape_key[4],
"Wi": shape_key[5],
"Y": shape_key[6],
"X": shape_key[7],
}
)
results.append(result)
return pd.DataFrame(results)
def train_single_target(
X_train,
y_train,
X_val,
y_val,
params: dict,
categorical_features: list[str],
feature_names: list[str],
init_model=None,
) -> lgb.LGBMRegressor:
"""Train a single LGBMRegressor with early stopping.
Parameters
----------
init_model : str, Path, lgb.Booster, lgb.LGBMModel, or None
If provided, training continues from this model (warm start).
Accepts a file path to a .lgbm file, a Booster instance, or an
LGBMModel instance. The new model adds n_estimators trees on top
of the existing ones.
"""
cat_indices = [
feature_names.index(c) for c in categorical_features if c in feature_names
]
model = lgb.LGBMRegressor(**params)
model.fit(
X_train,
y_train,
eval_set=[(X_val, y_val)],
eval_metric=["rmse"],
callbacks=[
lgb.early_stopping(50, verbose=False),
lgb.log_evaluation(0),
],
categorical_feature=cat_indices if cat_indices else "auto",
init_model=init_model,
)
return model
def run_cv(
df: pd.DataFrame,
feature_engine,
target: str,
params: dict,
operation: str,
n_splits: int = 5,
use_log: bool = True,
) -> dict:
"""Run GroupKFold cross-validation and return OOF predictions + metrics.
Parameters
----------
df : pd.DataFrame
Training data
feature_engine : FeatureEngine
Feature engine instance (operation-specific)
target : str
Target metric (tflops, latency, bandwidth)
params : dict
LightGBM parameters
operation : str
Operation type (gemm_universal, grouped_conv, fmha)
n_splits : int
Number of CV folds
use_log : bool
If True and target is in LOG_TARGETS, train on log1p(y) and invert
predictions with expm1 for efficiency calculation. This normalizes
the scale so that tiny-M shapes (TFLOPS ~ 1) get equal attention
as large-M shapes (TFLOPS ~ 2000).
"""
target_col = TARGET_COLUMNS[operation][target]
# Handle is_valid column (present in GEMM, not in grouped_conv)
if "is_valid" in df.columns:
valid_mask = df["is_valid"].fillna(False) & (df[target_col] > 0)
else:
valid_mask = df[target_col] > 0
df_valid = df[valid_mask].reset_index(drop=True)
apply_log = use_log and target in LOG_TARGETS
print(
f" Training on {len(df_valid)} valid rows for target={target}"
f"{' (log-space)' if apply_log else ''}"
)
X = feature_engine.extract_batch(df_valid)
y_raw = df_valid[target_col].values
y = np.log1p(y_raw) if apply_log else y_raw
groups = compute_group_keys(df_valid, operation)
feature_names = feature_engine.get_feature_names()
cat_features = feature_engine.get_categorical_features()
unique_groups = np.unique(groups)
actual_splits = min(n_splits, len(unique_groups))
if actual_splits < 2:
print(f" WARNING: Only {len(unique_groups)} unique groups, skipping CV")
return {}
gkf = GroupKFold(n_splits=actual_splits)
oof_preds = np.zeros(len(df_valid))
fold_metrics = []
for fold_idx, (train_idx, val_idx) in enumerate(gkf.split(X, y, groups)):
X_tr, X_val = X[train_idx], X[val_idx]
y_tr, y_val = y[train_idx], y[val_idx]
model = train_single_target(
X_tr, y_tr, X_val, y_val, params, cat_features, feature_names
)
preds = model.predict(X_val)
oof_preds[val_idx] = preds
rmse = np.sqrt(np.mean((preds - y_val) ** 2))
r2 = 1 - np.sum((preds - y_val) ** 2) / max(
np.sum((y_val - y_val.mean()) ** 2), 1e-10
)
if target == "tflops":
val_df = df_valid.iloc[val_idx].copy()
preds_raw = np.expm1(preds) if apply_log else preds
val_df["pred_tflops"] = preds_raw
eff_df = compute_tflops_efficiency(val_df, operation)
mean_eff = eff_df["efficiency"].mean() if len(eff_df) > 0 else 0
p10_eff = eff_df["efficiency"].quantile(0.1) if len(eff_df) > 0 else 0
else:
mean_eff, p10_eff = None, None
fold_metrics.append(
{
"fold": fold_idx,
"rmse": rmse,
"r2": r2,
"mean_efficiency": mean_eff,
"p10_efficiency": p10_eff,
"train_size": len(train_idx),
"val_size": len(val_idx),
"val_groups": len(np.unique(groups[val_idx])),
}
)
eff_str = (
f", eff={mean_eff:.4f}, p10={p10_eff:.4f}" if mean_eff is not None else ""
)
print(f" Fold {fold_idx}: RMSE={rmse:.4f}, R2={r2:.4f}{eff_str}")
df_valid[f"oof_pred_{target}"] = oof_preds
return {
"fold_metrics": fold_metrics,
"oof_df": df_valid,
"feature_names": feature_names,
"log_transform": apply_log,
}
def train_final_model(
df: pd.DataFrame,
feature_engine,
target: str,
params: dict,
operation: str,
init_model=None,
use_log: bool = True,
) -> lgb.LGBMRegressor:
"""Train the final model on all valid data.
Parameters
----------
df : pd.DataFrame
Training data
feature_engine : FeatureEngine
Feature engine instance (operation-specific)
target : str
Target metric (tflops, latency, bandwidth)
params : dict
LightGBM parameters
operation : str
Operation type (gemm_universal, grouped_conv, fmha)
init_model : str, Path, lgb.Booster, lgb.LGBMModel, or None
If provided, training continues from this model (warm start).
use_log : bool
If True and target is in LOG_TARGETS, train on log1p(y).
The saved model then predicts in log-space; callers must apply
expm1() to get raw values.
"""
target_col = TARGET_COLUMNS[operation][target]
# Handle is_valid column (present in GEMM, not in grouped_conv)
if "is_valid" in df.columns:
valid_mask = df["is_valid"].fillna(False) & (df[target_col] > 0)
else:
valid_mask = df[target_col] > 0
df_valid = df[valid_mask].reset_index(drop=True)
apply_log = use_log and target in LOG_TARGETS
X = feature_engine.extract_batch(df_valid)
y_raw = df_valid[target_col].values
y = np.log1p(y_raw) if apply_log else y_raw
feature_names = feature_engine.get_feature_names()
cat_features = feature_engine.get_categorical_features()
cat_indices = [feature_names.index(c) for c in cat_features if c in feature_names]
model = lgb.LGBMRegressor(**params)
model.fit(
X,
y,
categorical_feature=cat_indices if cat_indices else "auto",
init_model=init_model,
)
return model
def main():
parser = argparse.ArgumentParser(
description="Train CK Tile kernel performance models (GEMM, Grouped Conv, FMHA)"
)
parser.add_argument(
"--data_dir", required=True, help="Directory with parquet files"
)
parser.add_argument("--out_dir", required=True, help="Output directory for models")
parser.add_argument(
"--operation",
default="gemm_universal",
choices=["gemm_universal", "grouped_conv", "fmha"],
help="Operation type (gemm_universal, grouped_conv, fmha)",
)
parser.add_argument(
"--op",
default=None,
help="Deprecated: use --operation instead. Kept for backward compatibility.",
)
parser.add_argument("--dtype", default="fp8", help="Data type filter")
parser.add_argument("--arch", default="gfx950", help="Architecture")
parser.add_argument(
"--targets", default="tflops,latency,bandwidth", help="Comma-separated targets"
)
parser.add_argument("--n_splits", type=int, default=5, help="Number of CV folds")
parser.add_argument(
"--tune", action="store_true", help="Run Optuna hyperparameter tuning"
)
parser.add_argument(
"--no_log_transform",
action="store_true",
help="Disable log1p transform on targets. By default, TFLOPS and bandwidth "
"are trained in log-space for scale-invariant accuracy across shape sizes.",
)
parser.add_argument(
"--warm_start",
default=None,
help="Path to previous model directory to continue training from. "
"Uses LightGBM's init_model to add new trees on top of the "
"existing model. Feature schemas must match exactly.",
)
parser.add_argument(
"--warm_start_n_estimators",
type=int,
default=WARM_START_N_ESTIMATORS,
help=f"Number of new trees to add when warm-starting (default: {WARM_START_N_ESTIMATORS}). "
"Lower than a full train since we're refining, not starting from scratch.",
)
args = parser.parse_args()
# Handle backward compatibility for --op flag
operation = args.operation
if args.op is not None:
print("WARNING: --op is deprecated, use --operation instead")
operation = args.op
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
targets = [t.strip() for t in args.targets.split(",")]
print(f"{'=' * 80}")
print(f"Training {operation} model")
print(f"{'=' * 80}")
print()
print(f"Loading data from {args.data_dir}...")
df = build_training_dataset(args.data_dir, op_type=operation, dtype=args.dtype)
print(f" Total rows: {len(df)}")
# Print unique shapes based on operation type
if operation == "gemm_universal":
print(f" Unique shapes: {df.groupby(['m', 'n', 'k']).ngroups}")
elif operation == "grouped_conv":
print(
f" Unique shapes: {df.groupby(['N', 'C', 'K', 'G', 'Hi', 'Wi', 'Y', 'X']).ngroups}"
)
print(f" Unique kernels: {df['kernel_name'].nunique()}")
print()
# Extract hardware parameters from data (if available)
hw_cols = [c for c in df.columns if c.startswith("hw_")]
hw_kwargs = {}
if hw_cols:
row0 = df.iloc[0]
if "hw_num_cus" in df.columns:
hw_kwargs["num_cus"] = int(row0.get("hw_num_cus", 256))
if "hw_max_clock_mhz" in df.columns:
hw_kwargs["max_clock_mhz"] = int(row0.get("hw_max_clock_mhz", 2400))
if "hw_simds_per_cu" in df.columns:
hw_kwargs["simds_per_cu"] = int(row0.get("hw_simds_per_cu", 4))
if "hw_shader_engines" in df.columns:
hw_kwargs["shader_engines"] = int(row0.get("hw_shader_engines", 32))
if "hw_max_waves_per_cu" in df.columns:
hw_kwargs["max_waves_per_cu"] = int(row0.get("hw_max_waves_per_cu", 32))
if "hw_wavefront_size" in df.columns:
hw_kwargs["wavefront_size"] = int(row0.get("hw_wavefront_size", 64))
if "hw_l1_cache_kb" in df.columns:
hw_kwargs["l1_cache_kb"] = int(row0.get("hw_l1_cache_kb", 32))
if "hw_l2_cache_kb" in df.columns:
hw_kwargs["l2_cache_kb"] = int(row0.get("hw_l2_cache_kb", 4096))
if "hw_l3_cache_kb" in df.columns:
hw_kwargs["l3_cache_kb"] = int(row0.get("hw_l3_cache_kb", 262144))
# Get operation-specific feature engine
print(f"Initializing {operation} feature engine...")
fe = get_feature_engine(operation, **hw_kwargs)
print(f" Feature count: {len(fe.get_feature_names())}")
print(f" Categorical features: {len(fe.get_categorical_features())}")
print()
params = dict(DEFAULT_PARAMS)
use_log = not args.no_log_transform
prev_model_dir = None
prev_manifest = {}
if args.warm_start:
prev_model_dir = Path(args.warm_start)
if not prev_model_dir.exists():
raise FileNotFoundError(f"Warm-start directory not found: {prev_model_dir}")
print(f" Warm-starting from {prev_model_dir}")
check_feature_compatibility(prev_model_dir, fe)
print(" Feature compatibility: OK")
params["n_estimators"] = args.warm_start_n_estimators
print(f" New trees to add: {args.warm_start_n_estimators}")
prev_manifest_path = prev_model_dir / "train_manifest.json"
if prev_manifest_path.exists():
with open(prev_manifest_path) as f:
prev_manifest = json.load(f)
all_cv_results = {}
for target in targets:
if target not in TARGET_COLUMNS[operation]:
print(f" Skipping unknown target: {target}")
continue
print(f"\n{'=' * 60}")
print(f"Training {target} model")
print(f"{'=' * 60}")
init_model_path = None
if prev_model_dir is not None:
init_model_path = load_warm_start_model(prev_model_dir, target)
if init_model_path:
print(f" Warm-starting from {init_model_path}")
else:
print(f" No previous {target} model found, training from scratch")
t0 = time.time()
cv_result = run_cv(
df, fe, target, params, operation, n_splits=args.n_splits, use_log=use_log
)
cv_time = time.time() - t0
if cv_result and cv_result["fold_metrics"]:
all_cv_results[target] = cv_result["fold_metrics"]
metrics_path = out_dir / f"cv_metrics_{target}.json"
with open(metrics_path, "w") as f:
json.dump(cv_result["fold_metrics"], f, indent=2)
print(f" CV completed in {cv_time:.1f}s, saved to {metrics_path}")
if target == "tflops" and cv_result.get("oof_df") is not None:
oof_df = cv_result["oof_df"]
oof_df.to_parquet(out_dir / "oof_predictions.parquet", index=False)
eff_df = compute_tflops_efficiency(oof_df, operation, "oof_pred_tflops")
if len(eff_df) > 0:
print("\n OOF TFLOPS Efficiency:")
print(f" Mean: {eff_df['efficiency'].mean():.4f}")
print(f" P10: {eff_df['efficiency'].quantile(0.1):.4f}")
print(f" P50: {eff_df['efficiency'].quantile(0.5):.4f}")
print(f" Min: {eff_df['efficiency'].min():.4f}")
print(f"\n Training final {target} model on all data...")
t0 = time.time()
model = train_final_model(
df,
fe,
target,
params,
operation,
init_model=init_model_path,
use_log=use_log,
)
train_time = time.time() - t0
model_path = out_dir / f"model_{target}.lgbm"
model.booster_.save_model(str(model_path))
print(f" Saved {model_path} ({train_time:.1f}s)")
importances = dict(
zip(
fe.get_feature_names(),
model.feature_importances_.tolist(),
)
)
imp_path = out_dir / f"feature_importances_{target}.json"
with open(imp_path, "w") as f:
json.dump(importances, f, indent=2)
log_targets_used = sorted(LOG_TARGETS & set(targets)) if use_log else []
spec = {
"op_type": operation,
"dtype": args.dtype,
"arch": args.arch,
"feature_names": fe.get_feature_names(),
"categorical_features": fe.get_categorical_features(),
"targets": targets,
"log_targets": log_targets_used,
"params": params,
}
with open(out_dir / "feature_spec.json", "w") as f:
json.dump(spec, f, indent=2)
# Compute unique shapes based on operation type
if operation == "gemm_universal":
unique_shapes = int(df.groupby(["m", "n", "k"]).ngroups)
elif operation == "grouped_conv":
unique_shapes = int(
df.groupby(["N", "C", "K", "G", "Hi", "Wi", "Y", "X"]).ngroups
)
else:
unique_shapes = 0 # Unknown operation
manifest = {
"warm_start_from": str(prev_model_dir) if prev_model_dir else None,
"prev_n_estimators": prev_manifest.get(
"total_n_estimators", params.get("n_estimators")
)
if prev_model_dir
else 0,
"new_n_estimators": params["n_estimators"],
"total_n_estimators": (
prev_manifest.get("total_n_estimators", 0) + params["n_estimators"]
if prev_model_dir
else params["n_estimators"]
),
"data_rows": len(df),
"valid_rows": int(df["is_valid"].fillna(False).sum()),
"unique_shapes": unique_shapes,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
}
with open(out_dir / "train_manifest.json", "w") as f:
json.dump(manifest, f, indent=2)
print(f"\nAll models saved to {out_dir}")
if prev_model_dir:
print(f" Warm-started from: {prev_model_dir}")
print(f" Total estimators: {manifest['total_n_estimators']}")
if __name__ == "__main__":
main()