mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
318 lines
11 KiB
Python
318 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
ML Heuristic Validation: Test ML predictions against oracle-best from training data
|
|
|
|
This script validates ML-based kernel selection by:
|
|
1. Loading benchmark data (oracle-best results for each shape)
|
|
2. Using ML model to predict best kernel for each shape
|
|
3. Comparing ML selection with oracle-best to compute efficiency
|
|
|
|
Usage:
|
|
python validate_ml_heuristic.py --dtype fp16 --model_dir models/gemm_universal_fp16_gfx950
|
|
python validate_ml_heuristic.py --dtype fp8 --layout rcr
|
|
"""
|
|
|
|
import sys
|
|
import argparse
|
|
import pandas as pd
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
from predict import Predictor
|
|
|
|
|
|
def validate_ml_heuristic(dtype: str, layout: str, model_dir: str, data_dir: str):
|
|
"""Validate ML heuristic predictions against oracle-best"""
|
|
|
|
print("=" * 100)
|
|
print(f" ML Heuristic Validation: {dtype.upper()} {layout.upper()}")
|
|
print("=" * 100)
|
|
print()
|
|
|
|
# Load training data
|
|
print(f"Loading training data from {data_dir}...")
|
|
|
|
# Try dtype-specific parquet first, then fall back to combined
|
|
dtype_specific = (
|
|
Path(data_dir) / f"{dtype}_original" / f"{dtype}_training_data.parquet"
|
|
)
|
|
combined = Path(data_dir) / "all_training_data_fixed.parquet"
|
|
|
|
if dtype_specific.exists():
|
|
training_data = pd.read_parquet(dtype_specific)
|
|
print(f"✓ Loaded {len(training_data):,} benchmark runs from {dtype_specific}")
|
|
elif combined.exists():
|
|
training_data = pd.read_parquet(combined)
|
|
training_data = training_data[
|
|
(training_data["dtype"] == dtype) & (training_data["layout"] == layout)
|
|
]
|
|
print(f"✓ Loaded {len(training_data):,} benchmark runs from {combined}")
|
|
else:
|
|
print(f"❌ Error: No training data found at {dtype_specific} or {combined}")
|
|
return
|
|
|
|
if len(training_data) == 0:
|
|
print(f"❌ Error: No data found for dtype={dtype}, layout={layout}")
|
|
return
|
|
|
|
# Get unique shapes with oracle-best
|
|
shape_groups = training_data.groupby(["m", "n", "k"])
|
|
print(f"Unique shapes: {len(shape_groups)}")
|
|
print()
|
|
|
|
# Load ML predictor
|
|
print(f"Loading ML predictor from {model_dir}...")
|
|
try:
|
|
predictor = Predictor(model_dir)
|
|
print("✓ Loaded ML predictor")
|
|
print(f" Log targets: {predictor._log_targets}")
|
|
except Exception as e:
|
|
print(f"❌ Error loading model: {e}")
|
|
return
|
|
|
|
print()
|
|
print("=" * 100)
|
|
print(" Computing Oracle-Best Efficiency for Each Shape")
|
|
print("=" * 100)
|
|
print()
|
|
|
|
results = []
|
|
|
|
for shape_idx, ((m, n, k), group) in enumerate(shape_groups):
|
|
# Find oracle-best (max TFLOPS across all kernels tested)
|
|
oracle_best_row = group.loc[group["measured_tflops"].idxmax()]
|
|
oracle_best_tflops = oracle_best_row["measured_tflops"]
|
|
oracle_best_kernel = oracle_best_row["kernel_name"]
|
|
|
|
# Get all kernel configs tested for this shape
|
|
kernel_configs = []
|
|
for _, row in group.iterrows():
|
|
kernel_dict = {
|
|
"tile_m": row["tile_m"],
|
|
"tile_n": row["tile_n"],
|
|
"tile_k": row["tile_k"],
|
|
"warp_m": row["warp_m"],
|
|
"warp_n": row["warp_n"],
|
|
"warp_k": row["warp_k"],
|
|
"warp_tile_m": row["warp_tile_m"],
|
|
"warp_tile_n": row["warp_tile_n"],
|
|
"warp_tile_k": row["warp_tile_k"],
|
|
"pipeline": row["pipeline"],
|
|
"scheduler": row["scheduler"],
|
|
"epilogue": row["epilogue"],
|
|
"pad_m": row["pad_m"],
|
|
"pad_n": row["pad_n"],
|
|
"pad_k": row["pad_k"],
|
|
"persistent": row["persistent"],
|
|
"kernel_name": row["kernel_name"],
|
|
}
|
|
kernel_configs.append(kernel_dict)
|
|
|
|
# Use ML model to rank kernels
|
|
problem = {
|
|
"m": m,
|
|
"n": n,
|
|
"k": k,
|
|
"dtype": dtype,
|
|
"layout": layout,
|
|
"split_k": 1,
|
|
}
|
|
|
|
try:
|
|
ranked = predictor.rank_kernels(problem, kernel_configs)
|
|
|
|
if ranked:
|
|
ml_best_kernel, ml_predicted_tflops = ranked[0]
|
|
|
|
# Find actual TFLOPS for the ML-predicted kernel
|
|
ml_kernel_row = group[group["kernel_name"] == ml_best_kernel]
|
|
if len(ml_kernel_row) > 0:
|
|
ml_actual_tflops = ml_kernel_row["measured_tflops"].values[0]
|
|
|
|
# Calculate efficiency
|
|
efficiency_pct = 100.0 * (ml_actual_tflops / oracle_best_tflops)
|
|
|
|
# Determine if ML picked oracle-best
|
|
is_oracle_best = ml_best_kernel == oracle_best_kernel
|
|
|
|
results.append(
|
|
{
|
|
"m": m,
|
|
"n": n,
|
|
"k": k,
|
|
"oracle_best_tflops": oracle_best_tflops,
|
|
"oracle_best_kernel": oracle_best_kernel,
|
|
"ml_predicted_tflops": ml_predicted_tflops,
|
|
"ml_selected_kernel": ml_best_kernel,
|
|
"ml_actual_tflops": ml_actual_tflops,
|
|
"efficiency_pct": efficiency_pct,
|
|
"is_oracle_best": is_oracle_best,
|
|
"num_kernels": len(group),
|
|
}
|
|
)
|
|
|
|
if (shape_idx + 1) % 20 == 0:
|
|
status = "✓" if is_oracle_best else f"{efficiency_pct:.1f}%"
|
|
print(
|
|
f" [{shape_idx + 1:3d}/{len(shape_groups)}] "
|
|
f"M={m:4d} N={n:5d} K={k:5d}: {status}"
|
|
)
|
|
except Exception as e:
|
|
print(f" Error on shape M={m} N={n} K={k}: {e}")
|
|
continue
|
|
|
|
print()
|
|
print("=" * 100)
|
|
print(" Results Summary")
|
|
print("=" * 100)
|
|
print()
|
|
|
|
if results:
|
|
df_results = pd.DataFrame(results)
|
|
efficiencies = df_results["efficiency_pct"].values
|
|
oracle_matches = df_results["is_oracle_best"].sum()
|
|
|
|
print(f"Total shapes tested: {len(results)}")
|
|
print()
|
|
print("Efficiency Statistics (% of Oracle-Best TFLOPS):")
|
|
print(f" Mean: {np.mean(efficiencies):.2f}%")
|
|
print(f" Median: {np.median(efficiencies):.2f}%")
|
|
print(f" Min: {np.min(efficiencies):.2f}%")
|
|
print(f" Max: {np.max(efficiencies):.2f}%")
|
|
print(f" P10: {np.percentile(efficiencies, 10):.2f}%")
|
|
print(f" P50: {np.percentile(efficiencies, 50):.2f}%")
|
|
print(f" P90: {np.percentile(efficiencies, 90):.2f}%")
|
|
print()
|
|
print(
|
|
f"Oracle-best matches: {oracle_matches}/{len(results)} ({100 * oracle_matches / len(results):.1f}%)"
|
|
)
|
|
print()
|
|
|
|
# Classify by M size
|
|
df_results["m_class"] = pd.cut(
|
|
df_results["m"],
|
|
bins=[0, 8, 128, 1024, float("inf")],
|
|
labels=[
|
|
"Tiny (M<8)",
|
|
"Small (8≤M<128)",
|
|
"Medium (128≤M<1024)",
|
|
"Large (M≥1024)",
|
|
],
|
|
)
|
|
|
|
print("Efficiency by M size:")
|
|
for m_class in [
|
|
"Tiny (M<8)",
|
|
"Small (8≤M<128)",
|
|
"Medium (128≤M<1024)",
|
|
"Large (M≥1024)",
|
|
]:
|
|
subset = df_results[df_results["m_class"] == m_class]
|
|
if len(subset) > 0:
|
|
print(
|
|
f" {m_class:25s}: {subset['efficiency_pct'].mean():6.2f}% "
|
|
f"(n={len(subset)}, P10={subset['efficiency_pct'].quantile(0.1):.2f}%)"
|
|
)
|
|
|
|
print()
|
|
|
|
# Save results
|
|
output_file = f"validation_results_{dtype}_{layout}.csv"
|
|
df_results.to_csv(output_file, index=False)
|
|
print(f"✓ Results saved to {output_file}")
|
|
|
|
# Show best and worst shapes
|
|
print()
|
|
print("Top 5 shapes (best efficiency):")
|
|
top5 = df_results.nlargest(5, "efficiency_pct")[
|
|
["m", "n", "k", "efficiency_pct", "oracle_best_tflops", "is_oracle_best"]
|
|
]
|
|
for idx, row in top5.iterrows():
|
|
match = "✓" if row["is_oracle_best"] else " "
|
|
print(
|
|
f" {match} M={row['m']:5d} N={row['n']:5d} K={row['k']:5d}: "
|
|
f"{row['efficiency_pct']:.2f}% ({row['oracle_best_tflops']:.2f} TFLOPS)"
|
|
)
|
|
|
|
print()
|
|
print("Bottom 5 shapes (worst efficiency):")
|
|
bottom5 = df_results.nsmallest(5, "efficiency_pct")[
|
|
["m", "n", "k", "efficiency_pct", "oracle_best_tflops", "is_oracle_best"]
|
|
]
|
|
for idx, row in bottom5.iterrows():
|
|
match = "✓" if row["is_oracle_best"] else " "
|
|
print(
|
|
f" {match} M={row['m']:5d} N={row['n']:5d} K={row['k']:5d}: "
|
|
f"{row['efficiency_pct']:.2f}% ({row['oracle_best_tflops']:.2f} TFLOPS)"
|
|
)
|
|
|
|
else:
|
|
print("No results to display")
|
|
|
|
print()
|
|
print("=" * 100)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Validate ML heuristic predictions against oracle-best from training data"
|
|
)
|
|
parser.add_argument(
|
|
"--dtype",
|
|
default="fp16",
|
|
choices=["fp16", "bf16", "fp8"],
|
|
help="Data type to validate",
|
|
)
|
|
parser.add_argument(
|
|
"--layout",
|
|
default="rcr",
|
|
choices=["rcr", "rrr", "crr", "ccr"],
|
|
help="Matrix layout",
|
|
)
|
|
parser.add_argument(
|
|
"--model_dir",
|
|
default=None,
|
|
help="Path to model directory (auto-detect if not specified)",
|
|
)
|
|
parser.add_argument(
|
|
"--data_dir",
|
|
default=None,
|
|
help="Path to training data directory (auto-detect if not specified)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Auto-detect model directory if not specified
|
|
if args.model_dir is None:
|
|
heuristics_dir = Path(__file__).parent
|
|
model_candidates = [
|
|
heuristics_dir / "models" / f"gemm_universal_{args.dtype}_gfx950",
|
|
heuristics_dir / "models" / f"gemm_universal_{args.dtype}_gfx942",
|
|
]
|
|
for candidate in model_candidates:
|
|
if candidate.exists():
|
|
args.model_dir = str(candidate)
|
|
break
|
|
|
|
if args.model_dir is None:
|
|
print(f"❌ Error: Could not find model directory for {args.dtype}")
|
|
print(f" Searched: {[str(c) for c in model_candidates]}")
|
|
print(" Please specify --model_dir explicitly")
|
|
return 1
|
|
|
|
# Auto-detect data directory if not specified
|
|
if args.data_dir is None:
|
|
heuristics_dir = Path(__file__).parent
|
|
args.data_dir = str(heuristics_dir / "data")
|
|
|
|
validate_ml_heuristic(args.dtype, args.layout, args.model_dir, args.data_dir)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|