mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
[CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
306 lines
9.7 KiB
Python
306 lines
9.7 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Example 09: ML-Based Kernel Selection
|
|
|
|
Uses a trained LightGBM model to select the optimal kernel for each problem
|
|
size. The model predicts TFLOPS for every candidate in the kernel pool and
|
|
picks the highest-scoring one, which is then JIT-compiled and run.
|
|
|
|
This replaces the hand-crafted rules in 08_heuristics.py with a data-driven
|
|
approach achieving 97-98% of oracle-best TFLOPS efficiency.
|
|
|
|
Complexity: *****
|
|
|
|
Prerequisites:
|
|
- Trained model in dispatcher/heuristics/models/gemm_universal_fp8_gfx950/
|
|
- lightgbm, pandas, numpy, pyarrow installed
|
|
|
|
Usage:
|
|
python3 09_ml_heuristic.py
|
|
python3 09_ml_heuristic.py --dtype fp16 --arch gfx942
|
|
"""
|
|
|
|
import sys
|
|
import argparse
|
|
import time
|
|
from pathlib import Path
|
|
from dataclasses import dataclass
|
|
from typing import List
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "heuristics"))
|
|
|
|
import numpy as np
|
|
|
|
from ctypes_utils import (
|
|
KernelConfig,
|
|
setup_gemm_dispatcher,
|
|
cleanup_gemm,
|
|
)
|
|
|
|
from predict import Predictor
|
|
|
|
|
|
@dataclass
|
|
class KernelSpec:
|
|
"""Kernel specification -- same structure as 08_heuristics.py"""
|
|
|
|
name: str
|
|
tile_m: int
|
|
tile_n: int
|
|
tile_k: int
|
|
pipeline: str = "compv3"
|
|
scheduler: str = "intrawave"
|
|
wave_m: int = 2
|
|
wave_n: int = 2
|
|
wave_k: int = 1
|
|
warp_m: int = 32
|
|
warp_n: int = 32
|
|
warp_k: int = 16
|
|
|
|
|
|
# Kernel pool: representative configs spanning small to large tiles,
|
|
# compv3/compv4/mem pipelines, and intrawave/interwave schedulers.
|
|
KERNEL_POOL = [
|
|
# Small tiles
|
|
KernelSpec("s_64x64_k32_v3", 64, 64, 32, "compv3", warp_m=16, warp_n=16),
|
|
KernelSpec("s_64x64_k64_v3", 64, 64, 64, "compv3", warp_m=16, warp_n=16),
|
|
KernelSpec("s_64x64_k128_v3", 64, 64, 128, "compv3", warp_m=16, warp_n=16),
|
|
KernelSpec("s_64x64_k32_v4", 64, 64, 32, "compv4", warp_m=16, warp_n=16),
|
|
KernelSpec("s_64x64_k64_mem", 64, 64, 64, "mem", warp_m=16, warp_n=16),
|
|
KernelSpec("s_64x64_k128_mem", 64, 64, 128, "mem", warp_m=16, warp_n=16),
|
|
# Medium tiles
|
|
KernelSpec("m_128x128_k32_v3", 128, 128, 32, "compv3"),
|
|
KernelSpec("m_128x128_k64_v3", 128, 128, 64, "compv3"),
|
|
KernelSpec("m_128x128_k128_v3", 128, 128, 128, "compv3"),
|
|
KernelSpec("m_128x128_k32_v4", 128, 128, 32, "compv4"),
|
|
KernelSpec("m_128x128_k64_v4", 128, 128, 64, "compv4"),
|
|
KernelSpec("m_128x128_k64_mem", 128, 128, 64, "mem"),
|
|
KernelSpec("m_128x128_k128_mem", 128, 128, 128, "mem"),
|
|
# Rectangular medium
|
|
KernelSpec("r_64x128_k32", 64, 128, 32, "compv3", warp_m=16),
|
|
KernelSpec("r_128x64_k32", 128, 64, 32, "compv3", warp_n=16),
|
|
KernelSpec("r_64x128_k64", 64, 128, 64, "compv3", warp_m=16),
|
|
KernelSpec("r_128x64_k64", 128, 64, 64, "compv3", warp_n=16),
|
|
# Large tiles
|
|
KernelSpec("l_256x128_k32", 256, 128, 32, "compv3"),
|
|
KernelSpec("l_128x256_k32", 128, 256, 32, "compv3"),
|
|
KernelSpec("l_256x256_k32", 256, 256, 32, "compv3"),
|
|
KernelSpec("l_256x256_k64", 256, 256, 64, "compv3"),
|
|
# Interwave variants
|
|
KernelSpec("m_128x128_k64_iw", 128, 128, 64, "compv3", "interwave"),
|
|
KernelSpec("m_128x128_k64_mem_iw", 128, 128, 64, "mem", "interwave"),
|
|
]
|
|
|
|
|
|
def spec_to_feature_dict(spec: KernelSpec, dtype: str, layout: str) -> dict:
|
|
"""Convert a KernelSpec to the dict format the feature engine expects.
|
|
|
|
Note: pad_m/n/k default to True to match KernelConfig defaults and actual
|
|
compiled kernels. This ensures the ML model receives the correct padding
|
|
flags that will be used during JIT compilation.
|
|
"""
|
|
return {
|
|
"kernel_name": spec.name,
|
|
"tile_m": spec.tile_m,
|
|
"tile_n": spec.tile_n,
|
|
"tile_k": spec.tile_k,
|
|
"warp_m": spec.wave_m,
|
|
"warp_n": spec.wave_n,
|
|
"warp_k": spec.wave_k,
|
|
"warp_tile_m": spec.warp_m,
|
|
"warp_tile_n": spec.warp_n,
|
|
"warp_tile_k": spec.warp_k,
|
|
"pipeline": spec.pipeline,
|
|
"scheduler": spec.scheduler,
|
|
"epilogue": "cshuffle",
|
|
"pad_m": True, # Match KernelConfig default
|
|
"pad_n": True, # Match KernelConfig default
|
|
"pad_k": True, # Match KernelConfig default
|
|
"persistent": False,
|
|
"dtype": dtype,
|
|
"layout": layout,
|
|
}
|
|
|
|
|
|
def spec_to_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
|
|
"""Convert a KernelSpec to the dispatcher's KernelConfig for JIT compilation."""
|
|
return KernelConfig(
|
|
dtype_a=dtype,
|
|
dtype_b=dtype,
|
|
dtype_c=dtype,
|
|
dtype_acc="fp32",
|
|
layout_a="row",
|
|
layout_b="col",
|
|
layout_c="row",
|
|
tile_m=spec.tile_m,
|
|
tile_n=spec.tile_n,
|
|
tile_k=spec.tile_k,
|
|
wave_m=spec.wave_m,
|
|
wave_n=spec.wave_n,
|
|
wave_k=spec.wave_k,
|
|
warp_m=spec.warp_m,
|
|
warp_n=spec.warp_n,
|
|
warp_k=spec.warp_k,
|
|
pipeline=spec.pipeline,
|
|
scheduler=spec.scheduler,
|
|
epilogue="cshuffle",
|
|
gfx_arch=arch,
|
|
)
|
|
|
|
|
|
def ml_select_kernel(
|
|
predictor: Predictor,
|
|
pool: List[KernelSpec],
|
|
M: int,
|
|
N: int,
|
|
K: int,
|
|
dtype: str,
|
|
layout: str,
|
|
) -> tuple:
|
|
"""Score all kernels in the pool and return (best_spec, predicted_tflops)."""
|
|
problem = {"m": M, "n": N, "k": K, "dtype": dtype, "layout": layout, "split_k": 1}
|
|
kernel_dicts = [spec_to_feature_dict(s, dtype, layout) for s in pool]
|
|
|
|
ranked = predictor.rank_kernels(problem, kernel_dicts)
|
|
if not ranked:
|
|
return pool[0], 0.0
|
|
|
|
best_name, best_tflops = ranked[0]
|
|
best_spec = next((s for s in pool if s.name == best_name), pool[0])
|
|
return best_spec, best_tflops
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="ML-based kernel selection for GEMM")
|
|
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16", "fp8"])
|
|
parser.add_argument("--arch", default="gfx942")
|
|
parser.add_argument(
|
|
"--model_dir",
|
|
default=str(
|
|
Path(__file__).parent.parent.parent.parent
|
|
/ "heuristics"
|
|
/ "models"
|
|
/ "gemm_universal_fp8_gfx950"
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--no_run", action="store_true", help="Only predict, don't run GEMMs"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
print("=" * 75)
|
|
print(" Example 09: ML-Based Kernel Selection")
|
|
print("=" * 75)
|
|
print(f"\n Model: {args.model_dir}")
|
|
print(f" Dtype: {args.dtype}")
|
|
print(f" Arch: {args.arch}")
|
|
print(f" Pool: {len(KERNEL_POOL)} kernels")
|
|
|
|
predictor = Predictor(args.model_dir)
|
|
print(" Model loaded successfully")
|
|
|
|
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float16
|
|
|
|
test_sizes = [
|
|
(128, 128, 64),
|
|
(256, 256, 128),
|
|
(512, 512, 256),
|
|
(1024, 1024, 512),
|
|
(2048, 2048, 1024),
|
|
]
|
|
|
|
header = f"{'Shape':<20} {'Selected Kernel':<25} {'Pred TFLOPS':>12}"
|
|
if not args.no_run:
|
|
header += f" {'Time (ms)':>10} {'TFLOPS':>10} {'Status':<8}"
|
|
print(f"\n {header}")
|
|
print(" " + "-" * len(header))
|
|
|
|
results = []
|
|
|
|
for M, N, K in test_sizes:
|
|
t0 = time.time()
|
|
best_spec, pred_tflops = ml_select_kernel(
|
|
predictor, KERNEL_POOL, M, N, K, args.dtype, "rcr"
|
|
)
|
|
_ = (time.time() - t0) * 1000 # ML selection time (unused)
|
|
|
|
size_str = f"{M}x{N}x{K}"
|
|
line = f" {size_str:<20} {best_spec.name:<25} {pred_tflops:>12.2f}"
|
|
|
|
if args.no_run:
|
|
print(line)
|
|
results.append((size_str, best_spec.name, True, 0, pred_tflops))
|
|
continue
|
|
|
|
config = spec_to_kernel_config(best_spec, args.dtype, args.arch)
|
|
|
|
setup = setup_gemm_dispatcher(
|
|
config=config,
|
|
registry_name=f"ml_{best_spec.name}",
|
|
verbose=False,
|
|
auto_rebuild=True,
|
|
)
|
|
|
|
if not setup.success:
|
|
line += f" {'N/A':>10} {'N/A':>10} {'BUILD':>8}"
|
|
print(line)
|
|
results.append((size_str, best_spec.name, False, 0, 0))
|
|
cleanup_gemm()
|
|
continue
|
|
|
|
dispatcher = setup.dispatcher
|
|
if not dispatcher.is_supported(M, N, K):
|
|
line += f" {'N/A':>10} {'N/A':>10} {'UNSUP':>8}"
|
|
print(line)
|
|
results.append((size_str, best_spec.name, False, 0, 0))
|
|
cleanup_gemm()
|
|
continue
|
|
|
|
np.random.seed(42)
|
|
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
|
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
|
|
|
result = dispatcher.run(A, B, M, N, K)
|
|
|
|
if result.success:
|
|
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(
|
|
np_dtype
|
|
)
|
|
max_err = np.max(np.abs(result.output - C_ref))
|
|
passed = max_err < 1e-2
|
|
status = "PASS" if passed else "FAIL"
|
|
line += f" {result.time_ms:>10.4f} {result.tflops:>10.2f} {status:<8}"
|
|
results.append(
|
|
(size_str, best_spec.name, passed, result.time_ms, result.tflops)
|
|
)
|
|
else:
|
|
line += f" {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
|
results.append((size_str, best_spec.name, False, 0, 0))
|
|
|
|
print(line)
|
|
cleanup_gemm()
|
|
|
|
# Summary
|
|
print("\n" + "=" * 75)
|
|
print(" SUMMARY")
|
|
print("=" * 75)
|
|
passed = sum(1 for r in results if r[2])
|
|
print(f"\n Results: {passed}/{len(results)} tests passed")
|
|
valid = [r for r in results if r[2] and r[4] > 0]
|
|
if valid:
|
|
avg = sum(r[4] for r in valid) / len(valid)
|
|
print(f" Average TFLOPS: {avg:.2f}")
|
|
if passed == len(results):
|
|
print("\n *** ALL TESTS PASSED ***")
|
|
print("=" * 75)
|
|
return 0 if passed == len(results) else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|