mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
[CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676)
## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Vidyasagar Ananthan <vidyasagar.ananthan@amd.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
305
dispatcher/examples/gemm/python/09_ml_heuristic.py
Normal file
305
dispatcher/examples/gemm/python/09_ml_heuristic.py
Normal file
@@ -0,0 +1,305 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 09: ML-Based Kernel Selection
|
||||
|
||||
Uses a trained LightGBM model to select the optimal kernel for each problem
|
||||
size. The model predicts TFLOPS for every candidate in the kernel pool and
|
||||
picks the highest-scoring one, which is then JIT-compiled and run.
|
||||
|
||||
This replaces the hand-crafted rules in 08_heuristics.py with a data-driven
|
||||
approach achieving 97-98% of oracle-best TFLOPS efficiency.
|
||||
|
||||
Complexity: *****
|
||||
|
||||
Prerequisites:
|
||||
- Trained model in dispatcher/heuristics/models/gemm_universal_fp8_gfx950/
|
||||
- lightgbm, pandas, numpy, pyarrow installed
|
||||
|
||||
Usage:
|
||||
python3 09_ml_heuristic.py
|
||||
python3 09_ml_heuristic.py --dtype fp16 --arch gfx942
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "heuristics"))
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
)
|
||||
|
||||
from predict import Predictor
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelSpec:
|
||||
"""Kernel specification -- same structure as 08_heuristics.py"""
|
||||
|
||||
name: str
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
pipeline: str = "compv3"
|
||||
scheduler: str = "intrawave"
|
||||
wave_m: int = 2
|
||||
wave_n: int = 2
|
||||
wave_k: int = 1
|
||||
warp_m: int = 32
|
||||
warp_n: int = 32
|
||||
warp_k: int = 16
|
||||
|
||||
|
||||
# Kernel pool: representative configs spanning small to large tiles,
|
||||
# compv3/compv4/mem pipelines, and intrawave/interwave schedulers.
|
||||
KERNEL_POOL = [
|
||||
# Small tiles
|
||||
KernelSpec("s_64x64_k32_v3", 64, 64, 32, "compv3", warp_m=16, warp_n=16),
|
||||
KernelSpec("s_64x64_k64_v3", 64, 64, 64, "compv3", warp_m=16, warp_n=16),
|
||||
KernelSpec("s_64x64_k128_v3", 64, 64, 128, "compv3", warp_m=16, warp_n=16),
|
||||
KernelSpec("s_64x64_k32_v4", 64, 64, 32, "compv4", warp_m=16, warp_n=16),
|
||||
KernelSpec("s_64x64_k64_mem", 64, 64, 64, "mem", warp_m=16, warp_n=16),
|
||||
KernelSpec("s_64x64_k128_mem", 64, 64, 128, "mem", warp_m=16, warp_n=16),
|
||||
# Medium tiles
|
||||
KernelSpec("m_128x128_k32_v3", 128, 128, 32, "compv3"),
|
||||
KernelSpec("m_128x128_k64_v3", 128, 128, 64, "compv3"),
|
||||
KernelSpec("m_128x128_k128_v3", 128, 128, 128, "compv3"),
|
||||
KernelSpec("m_128x128_k32_v4", 128, 128, 32, "compv4"),
|
||||
KernelSpec("m_128x128_k64_v4", 128, 128, 64, "compv4"),
|
||||
KernelSpec("m_128x128_k64_mem", 128, 128, 64, "mem"),
|
||||
KernelSpec("m_128x128_k128_mem", 128, 128, 128, "mem"),
|
||||
# Rectangular medium
|
||||
KernelSpec("r_64x128_k32", 64, 128, 32, "compv3", warp_m=16),
|
||||
KernelSpec("r_128x64_k32", 128, 64, 32, "compv3", warp_n=16),
|
||||
KernelSpec("r_64x128_k64", 64, 128, 64, "compv3", warp_m=16),
|
||||
KernelSpec("r_128x64_k64", 128, 64, 64, "compv3", warp_n=16),
|
||||
# Large tiles
|
||||
KernelSpec("l_256x128_k32", 256, 128, 32, "compv3"),
|
||||
KernelSpec("l_128x256_k32", 128, 256, 32, "compv3"),
|
||||
KernelSpec("l_256x256_k32", 256, 256, 32, "compv3"),
|
||||
KernelSpec("l_256x256_k64", 256, 256, 64, "compv3"),
|
||||
# Interwave variants
|
||||
KernelSpec("m_128x128_k64_iw", 128, 128, 64, "compv3", "interwave"),
|
||||
KernelSpec("m_128x128_k64_mem_iw", 128, 128, 64, "mem", "interwave"),
|
||||
]
|
||||
|
||||
|
||||
def spec_to_feature_dict(spec: KernelSpec, dtype: str, layout: str) -> dict:
|
||||
"""Convert a KernelSpec to the dict format the feature engine expects.
|
||||
|
||||
Note: pad_m/n/k default to True to match KernelConfig defaults and actual
|
||||
compiled kernels. This ensures the ML model receives the correct padding
|
||||
flags that will be used during JIT compilation.
|
||||
"""
|
||||
return {
|
||||
"kernel_name": spec.name,
|
||||
"tile_m": spec.tile_m,
|
||||
"tile_n": spec.tile_n,
|
||||
"tile_k": spec.tile_k,
|
||||
"warp_m": spec.wave_m,
|
||||
"warp_n": spec.wave_n,
|
||||
"warp_k": spec.wave_k,
|
||||
"warp_tile_m": spec.warp_m,
|
||||
"warp_tile_n": spec.warp_n,
|
||||
"warp_tile_k": spec.warp_k,
|
||||
"pipeline": spec.pipeline,
|
||||
"scheduler": spec.scheduler,
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": True, # Match KernelConfig default
|
||||
"pad_n": True, # Match KernelConfig default
|
||||
"pad_k": True, # Match KernelConfig default
|
||||
"persistent": False,
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
}
|
||||
|
||||
|
||||
def spec_to_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
|
||||
"""Convert a KernelSpec to the dispatcher's KernelConfig for JIT compilation."""
|
||||
return KernelConfig(
|
||||
dtype_a=dtype,
|
||||
dtype_b=dtype,
|
||||
dtype_c=dtype,
|
||||
dtype_acc="fp32",
|
||||
layout_a="row",
|
||||
layout_b="col",
|
||||
layout_c="row",
|
||||
tile_m=spec.tile_m,
|
||||
tile_n=spec.tile_n,
|
||||
tile_k=spec.tile_k,
|
||||
wave_m=spec.wave_m,
|
||||
wave_n=spec.wave_n,
|
||||
wave_k=spec.wave_k,
|
||||
warp_m=spec.warp_m,
|
||||
warp_n=spec.warp_n,
|
||||
warp_k=spec.warp_k,
|
||||
pipeline=spec.pipeline,
|
||||
scheduler=spec.scheduler,
|
||||
epilogue="cshuffle",
|
||||
gfx_arch=arch,
|
||||
)
|
||||
|
||||
|
||||
def ml_select_kernel(
|
||||
predictor: Predictor,
|
||||
pool: List[KernelSpec],
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
dtype: str,
|
||||
layout: str,
|
||||
) -> tuple:
|
||||
"""Score all kernels in the pool and return (best_spec, predicted_tflops)."""
|
||||
problem = {"m": M, "n": N, "k": K, "dtype": dtype, "layout": layout, "split_k": 1}
|
||||
kernel_dicts = [spec_to_feature_dict(s, dtype, layout) for s in pool]
|
||||
|
||||
ranked = predictor.rank_kernels(problem, kernel_dicts)
|
||||
if not ranked:
|
||||
return pool[0], 0.0
|
||||
|
||||
best_name, best_tflops = ranked[0]
|
||||
best_spec = next((s for s in pool if s.name == best_name), pool[0])
|
||||
return best_spec, best_tflops
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ML-based kernel selection for GEMM")
|
||||
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16", "fp8"])
|
||||
parser.add_argument("--arch", default="gfx942")
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
default=str(
|
||||
Path(__file__).parent.parent.parent.parent
|
||||
/ "heuristics"
|
||||
/ "models"
|
||||
/ "gemm_universal_fp8_gfx950"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_run", action="store_true", help="Only predict, don't run GEMMs"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 75)
|
||||
print(" Example 09: ML-Based Kernel Selection")
|
||||
print("=" * 75)
|
||||
print(f"\n Model: {args.model_dir}")
|
||||
print(f" Dtype: {args.dtype}")
|
||||
print(f" Arch: {args.arch}")
|
||||
print(f" Pool: {len(KERNEL_POOL)} kernels")
|
||||
|
||||
predictor = Predictor(args.model_dir)
|
||||
print(" Model loaded successfully")
|
||||
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float16
|
||||
|
||||
test_sizes = [
|
||||
(128, 128, 64),
|
||||
(256, 256, 128),
|
||||
(512, 512, 256),
|
||||
(1024, 1024, 512),
|
||||
(2048, 2048, 1024),
|
||||
]
|
||||
|
||||
header = f"{'Shape':<20} {'Selected Kernel':<25} {'Pred TFLOPS':>12}"
|
||||
if not args.no_run:
|
||||
header += f" {'Time (ms)':>10} {'TFLOPS':>10} {'Status':<8}"
|
||||
print(f"\n {header}")
|
||||
print(" " + "-" * len(header))
|
||||
|
||||
results = []
|
||||
|
||||
for M, N, K in test_sizes:
|
||||
t0 = time.time()
|
||||
best_spec, pred_tflops = ml_select_kernel(
|
||||
predictor, KERNEL_POOL, M, N, K, args.dtype, "rcr"
|
||||
)
|
||||
_ = (time.time() - t0) * 1000 # ML selection time (unused)
|
||||
|
||||
size_str = f"{M}x{N}x{K}"
|
||||
line = f" {size_str:<20} {best_spec.name:<25} {pred_tflops:>12.2f}"
|
||||
|
||||
if args.no_run:
|
||||
print(line)
|
||||
results.append((size_str, best_spec.name, True, 0, pred_tflops))
|
||||
continue
|
||||
|
||||
config = spec_to_kernel_config(best_spec, args.dtype, args.arch)
|
||||
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=config,
|
||||
registry_name=f"ml_{best_spec.name}",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
if not setup.success:
|
||||
line += f" {'N/A':>10} {'N/A':>10} {'BUILD':>8}"
|
||||
print(line)
|
||||
results.append((size_str, best_spec.name, False, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
line += f" {'N/A':>10} {'N/A':>10} {'UNSUP':>8}"
|
||||
print(line)
|
||||
results.append((size_str, best_spec.name, False, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
np.random.seed(42)
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if result.success:
|
||||
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(
|
||||
np_dtype
|
||||
)
|
||||
max_err = np.max(np.abs(result.output - C_ref))
|
||||
passed = max_err < 1e-2
|
||||
status = "PASS" if passed else "FAIL"
|
||||
line += f" {result.time_ms:>10.4f} {result.tflops:>10.2f} {status:<8}"
|
||||
results.append(
|
||||
(size_str, best_spec.name, passed, result.time_ms, result.tflops)
|
||||
)
|
||||
else:
|
||||
line += f" {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
||||
results.append((size_str, best_spec.name, False, 0, 0))
|
||||
|
||||
print(line)
|
||||
cleanup_gemm()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 75)
|
||||
print(" SUMMARY")
|
||||
print("=" * 75)
|
||||
passed = sum(1 for r in results if r[2])
|
||||
print(f"\n Results: {passed}/{len(results)} tests passed")
|
||||
valid = [r for r in results if r[2] and r[4] > 0]
|
||||
if valid:
|
||||
avg = sum(r[4] for r in valid) / len(valid)
|
||||
print(f" Average TFLOPS: {avg:.2f}")
|
||||
if passed == len(results):
|
||||
print("\n *** ALL TESTS PASSED ***")
|
||||
print("=" * 75)
|
||||
return 0 if passed == len(results) else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user