Files
composable_kernel/dispatcher/heuristics/data_pipeline.py
Yaswanth Raparti c1127a36f5 [rocm-libraries] ROCm/rocm-libraries#5676 (commit 1d18339)
[CK][CK TILE]Autotuning heuristics infra for universal GEMM
 kernel selection (#5676)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

This PR adds ML-based kernel selection heuristics to the CK Tile
dispatcher, enabling fast and accurate automatic kernel selection for
Universal Gemm kernels. Instead of requiring exhaustive search through
4600+ kernel configurations (taking ~46 seconds per problem shape), the
ML heuristic predicts optimal kernels in microseconds while achieving
>98% of oracle-best performance.

## Technical Details

**ML infrastructure**

https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics
* Feature Engine
([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)):
55-feature extraction including problem dimensions, kernel
configuration, tile efficiency, and hardware profile
* Training Pipeline
([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)):
LightGBM regression with log-transform, GroupKFold cross-validation,
warm-start support
* Predictor
([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)):
Kernel ranking and TFLOPS prediction for problem shapes
* Evaluation
([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)):
Comprehensive metrics including efficiency, NDCG@k, shape family
analysis

**Data Generation Tools:**

*
[generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py):
Build and benchmark kernels across diverse problem shapes
*
[convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py):
Convert benchmark JSON to training-ready parquet format
*
[data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py):
Parse streaming benchmark logs into canonical datasets

**Examples**
*
[09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp):
C++ example demonstrating ML-based kernel selection
*
[09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py):
Python example with validation

**Pre-trained Models
(projects/composablekernel/dispatcher/heuristics/models/):**
* gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean
efficiency)
* gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean
efficiency)

## Test Plan

* Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8
* All shape families tested: tiny M (M<8), small M, medium M, large M
(M≥1024)
* All pipeline types: compv3, compv4, mem

## Test Result

**fp16 Model (gfx950, RCR layout)**
* Mean Efficiency: 99.36%
* P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of
oracle best)
* Min Efficiency: 95.45%

**fp8 Model (gfx950, RCR layout)**
* Mean Efficiency: 98.28% (original), 97.51% (wide coverage)
* P10 Efficiency: 94.64% (original), 93.89% (wide coverage)
* Min Efficiency: 84.5%

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-02 02:26:32 +00:00

395 lines
12 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Data pipeline for CK Tile heuristics.
Parses benchmark logs and structured JSON into a canonical parquet dataset.
Supports:
- Streaming log format (Shape N: headers + inline JSON) from ck_tile profiling runs
- Structured JSON from generate_benchmark_data.py
- Direct parquet passthrough
"""
import json
import re
import subprocess
import hashlib
from pathlib import Path
from typing import Optional
import pandas as pd
CANONICAL_COLUMNS = [
"op_type",
"dtype",
"layout",
"arch",
"kernel_name",
"m",
"n",
"k",
"split_k",
"measured_tflops",
"latency_ms",
"bandwidth_gb_s",
"is_valid",
"tile_m",
"tile_n",
"tile_k",
"warp_m",
"warp_n",
"warp_k",
"warp_tile_m",
"warp_tile_n",
"warp_tile_k",
"pipeline",
"scheduler",
"epilogue",
"pad_m",
"pad_n",
"pad_k",
"persistent",
"run_id",
]
def parse_kernel_name(name: str) -> dict:
"""Extract kernel config fields from a gemm_universal kernel name.
Name format:
gemm_universal_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}
_{padM}_{padN}_{padK}_{persistent}_{tileM}x{tileN}x{tileK}
_{warpM}x{warpN}x{warpK}_{warpTileM}x{warpTileN}x{warpTileK}
"""
result = {}
try:
prefix_match = re.match(
r"gemm_universal_(\w+?)_((?:rcr|rrr|crr|ccr))_(.*)", name
)
if not prefix_match:
return result
result["dtype"] = prefix_match.group(1)
result["layout"] = prefix_match.group(2)
remainder = prefix_match.group(3)
parts = remainder.split("_")
if len(parts) < 10:
return result
result["pipeline"] = parts[0]
result["epilogue"] = parts[1]
result["scheduler"] = parts[2]
result["pad_m"] = parts[3] == "True"
result["pad_n"] = parts[4] == "True"
result["pad_k"] = parts[5] == "True"
result["persistent"] = parts[6] == "True"
tile_dims = parts[7].split("x")
warp_dims = parts[8].split("x")
warp_tile_dims = parts[9].split("x")
result["tile_m"] = int(tile_dims[0])
result["tile_n"] = int(tile_dims[1])
result["tile_k"] = int(tile_dims[2])
result["warp_m"] = int(warp_dims[0])
result["warp_n"] = int(warp_dims[1])
result["warp_k"] = int(warp_dims[2])
result["warp_tile_m"] = int(warp_tile_dims[0])
result["warp_tile_n"] = int(warp_tile_dims[1])
result["warp_tile_k"] = int(warp_tile_dims[2])
except (IndexError, ValueError):
pass
return result
def _layout_from_problem(problem: dict) -> str:
"""Derive layout shorthand (rcr/rrr/etc.) from problem JSON fields."""
la = problem.get("layout_a", "")
lb = problem.get("layout_b", "")
lc = problem.get("layout_c", "")
def _tag(s):
s = s.lower()
if "row" in s:
return "r"
if "col" in s:
return "c"
return "?"
return _tag(la) + _tag(lb) + _tag(lc)
def parse_streaming_log(
path: str | Path,
arch: str = "unknown",
run_id: Optional[str] = None,
op_type: str = "gemm_universal",
) -> pd.DataFrame:
"""Parse a CK Tile streaming benchmark log into a canonical DataFrame.
The log alternates between shape headers and JSON result blocks:
Shape N: M=16 N=1536 K=7168 dtype=fp8 layout=rcr
{
"name": "gemm_universal_...",
"problem": { ... },
"perf_result": { "latency(ms)": ..., "tflops(TFlops)": ..., "bandwidth(GB/s)": ... }
}
"""
path = Path(path)
if run_id is None:
run_id = hashlib.md5(path.name.encode()).hexdigest()[:12]
shape_re = re.compile(
r"Shape\s+\d+:\s+M=(\d+)\s+N=(\d+)\s+K=(\d+)\s+dtype=(\w+)\s+layout=(\w+)"
)
rows = []
current_m, current_n, current_k = 0, 0, 0
current_dtype, current_layout = "", ""
json_buf = []
brace_depth = 0
with open(path, "r") as f:
for line in f:
stripped = line.strip()
shape_match = shape_re.search(stripped)
if shape_match:
current_m = int(shape_match.group(1))
current_n = int(shape_match.group(2))
current_k = int(shape_match.group(3))
current_dtype = shape_match.group(4)
current_layout = shape_match.group(5)
continue
if brace_depth == 0 and stripped.startswith("{"):
json_buf = [stripped]
brace_depth = stripped.count("{") - stripped.count("}")
if brace_depth == 0:
raw = "\n".join(json_buf)
try:
obj = json.loads(raw)
except json.JSONDecodeError:
continue
else:
continue
elif brace_depth > 0:
json_buf.append(stripped)
brace_depth += stripped.count("{") - stripped.count("}")
if brace_depth <= 0:
brace_depth = 0
raw = "\n".join(json_buf)
try:
obj = json.loads(raw)
except json.JSONDecodeError:
continue
else:
continue
else:
continue
# If we get here, obj was successfully parsed
kernel_name = obj.get("name", "")
problem = obj.get("problem", {})
perf = obj.get("perf_result", {})
m = problem.get("m", current_m)
n = problem.get("n", current_n)
k = problem.get("k", current_k)
split_k = problem.get("split_k", 1)
dtype = problem.get("dtype_a", current_dtype)
layout = (
_layout_from_problem(problem)
if problem.get("layout_a")
else current_layout
)
tflops = perf.get("tflops(TFlops)", 0.0)
latency = perf.get("latency(ms)", 0.0)
bandwidth = perf.get("bandwidth(GB/s)", 0.0)
kp = parse_kernel_name(kernel_name)
row = {
"op_type": op_type,
"dtype": dtype,
"layout": layout,
"arch": arch,
"kernel_name": kernel_name,
"m": m,
"n": n,
"k": k,
"split_k": split_k,
"measured_tflops": tflops,
"latency_ms": latency,
"bandwidth_gb_s": bandwidth,
"is_valid": tflops > 0 and latency > 0,
"run_id": run_id,
}
row.update(kp)
rows.append(row)
df = pd.DataFrame(rows)
for col in CANONICAL_COLUMNS:
if col not in df.columns:
df[col] = None
return df
def get_hardware_profile() -> dict:
"""Capture GPU hardware profile from rocminfo."""
profile = {}
try:
result = subprocess.run(
["rocminfo"], capture_output=True, text=True, timeout=30
)
output = result.stdout
gpu_section = False
for line in output.split("\n"):
line = line.strip()
if "Device Type:" in line and "GPU" in line:
gpu_section = True
continue
if gpu_section and "Device Type:" in line and "GPU" not in line:
break
if not gpu_section:
continue
if ":" not in line:
continue
key, _, val = line.partition(":")
key = key.strip()
val = val.strip()
mapping = {
"Name": "gfx_name",
"Marketing Name": "marketing_name",
"Compute Unit": "num_cus",
"SIMDs per CU": "simds_per_cu",
"Shader Engines": "shader_engines",
"Shader Arrs. per Eng.": "shader_arrays_per_engine",
"Max Clock Freq. (MHz)": "max_clock_mhz",
"Wavefront Size": "wavefront_size",
"Max Waves Per CU": "max_waves_per_cu",
"Chip ID": "chip_id",
}
if key in mapping:
raw = val.split("(")[0].strip()
try:
profile[mapping[key]] = int(raw)
except ValueError:
profile[mapping[key]] = raw
for line in output.split("\n"):
line = line.strip()
if line.startswith("L1:") and "num_cus" in profile:
raw = line.split(":")[1].strip().split("(")[0].strip()
try:
profile["l1_cache_kb"] = int(raw)
except ValueError:
pass
elif line.startswith("L2:"):
raw = line.split(":")[1].strip().split("(")[0].strip()
try:
profile["l2_cache_kb"] = int(raw)
except ValueError:
pass
elif line.startswith("L3:"):
raw = line.split(":")[1].strip().split("(")[0].strip()
try:
profile["l3_cache_kb"] = int(raw)
except ValueError:
pass
except (subprocess.TimeoutExpired, FileNotFoundError):
pass
return profile
def load_parquet(path: str | Path) -> pd.DataFrame:
"""Load a canonical parquet dataset."""
return pd.read_parquet(path)
def save_parquet(df: pd.DataFrame, path: str | Path):
"""Save a DataFrame in canonical parquet format."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(path, index=False, engine="pyarrow")
def build_training_dataset(
data_dir: str | Path,
op_type: str = "gemm_universal",
dtype: str = "fp8",
) -> pd.DataFrame:
"""Load and merge all parquet files matching the given op/dtype from a directory."""
data_dir = Path(data_dir)
frames = []
for f in sorted(data_dir.glob("*.parquet")):
df = pd.read_parquet(f)
if "op_type" in df.columns:
df = df[df["op_type"] == op_type]
if "dtype" in df.columns:
df = df[df["dtype"] == dtype]
if len(df) > 0:
frames.append(df)
if not frames:
raise FileNotFoundError(
f"No parquet files with op_type={op_type}, dtype={dtype} in {data_dir}"
)
return pd.concat(frames, ignore_index=True)
if __name__ == "__main__":
import argparse
import time
parser = argparse.ArgumentParser(description="Parse CK Tile benchmark data")
parser.add_argument("input", help="Input file (log or parquet)")
parser.add_argument("--output", "-o", required=True, help="Output parquet path")
parser.add_argument("--arch", default="gfx950", help="GPU architecture")
parser.add_argument("--op_type", default="gemm_universal", help="Operation type")
parser.add_argument(
"--capture_hw",
action="store_true",
help="Capture hardware profile from rocminfo",
)
args = parser.parse_args()
input_path = Path(args.input)
print(f"Parsing {input_path}...")
t0 = time.time()
if input_path.suffix == ".parquet":
df = load_parquet(input_path)
else:
df = parse_streaming_log(input_path, arch=args.arch, op_type=args.op_type)
elapsed = time.time() - t0
print(f"Parsed {len(df)} rows in {elapsed:.1f}s")
print(f" Unique shapes: {df.groupby(['m', 'n', 'k']).ngroups}")
print(f" Unique kernels: {df['kernel_name'].nunique()}")
print(f" Valid rows: {df['is_valid'].sum()} / {len(df)}")
if df["measured_tflops"].max() > 0:
print(
f" TFLOPS range: {df['measured_tflops'].min():.2f} - {df['measured_tflops'].max():.2f}"
)
if args.capture_hw:
hw = get_hardware_profile()
print(f" Hardware profile: {hw}")
for k, v in hw.items():
df[f"hw_{k}"] = v
save_parquet(df, args.output)
print(f"Saved to {args.output}")