mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
369 lines
11 KiB
Python
369 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Tests for data_pipeline.py.
|
|
|
|
Covers: kernel name parsing, layout derivation, streaming log parsing,
|
|
schema validation, and corner cases (empty logs, malformed JSON, single-shape).
|
|
"""
|
|
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import pandas as pd
|
|
import pytest
|
|
|
|
import sys
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
|
|
from data_pipeline import (
|
|
parse_kernel_name,
|
|
_layout_from_problem,
|
|
parse_streaming_log,
|
|
save_parquet,
|
|
load_parquet,
|
|
CANONICAL_COLUMNS,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# parse_kernel_name
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestParseKernelName:
|
|
def test_standard_name(self):
|
|
name = "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128"
|
|
result = parse_kernel_name(name)
|
|
assert result["dtype"] == "fp8"
|
|
assert result["layout"] == "rcr"
|
|
assert result["pipeline"] == "compv3"
|
|
assert result["epilogue"] == "cshuffle"
|
|
assert result["scheduler"] == "intrawave"
|
|
assert result["pad_m"] is False
|
|
assert result["pad_n"] is False
|
|
assert result["pad_k"] is False
|
|
assert result["persistent"] is False
|
|
assert result["tile_m"] == 128
|
|
assert result["tile_n"] == 128
|
|
assert result["tile_k"] == 128
|
|
assert result["warp_m"] == 1
|
|
assert result["warp_n"] == 4
|
|
assert result["warp_k"] == 1
|
|
assert result["warp_tile_m"] == 16
|
|
assert result["warp_tile_n"] == 16
|
|
assert result["warp_tile_k"] == 128
|
|
|
|
def test_with_padding_and_persistent(self):
|
|
name = "gemm_universal_fp16_rrr_compv4_default_interwave_True_True_True_True_256x256x64_2x2x1_32x32x16"
|
|
result = parse_kernel_name(name)
|
|
assert result["dtype"] == "fp16"
|
|
assert result["layout"] == "rrr"
|
|
assert result["pad_m"] is True
|
|
assert result["pad_n"] is True
|
|
assert result["pad_k"] is True
|
|
assert result["persistent"] is True
|
|
assert result["tile_m"] == 256
|
|
|
|
def test_empty_name(self):
|
|
assert parse_kernel_name("") == {}
|
|
|
|
def test_malformed_name(self):
|
|
assert parse_kernel_name("not_a_kernel_name") == {}
|
|
|
|
def test_partial_name(self):
|
|
result = parse_kernel_name("gemm_universal_fp8_rcr_compv3")
|
|
assert result.get("dtype") == "fp8"
|
|
assert result.get("layout") == "rcr"
|
|
assert "tile_m" not in result # not enough parts
|
|
|
|
def test_all_layouts(self):
|
|
for layout in ["rcr", "rrr", "crr", "ccr"]:
|
|
name = f"gemm_universal_fp8_{layout}_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128"
|
|
result = parse_kernel_name(name)
|
|
assert result["layout"] == layout
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _layout_from_problem
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestLayoutFromProblem:
|
|
def test_rcr(self):
|
|
assert (
|
|
_layout_from_problem(
|
|
{
|
|
"layout_a": "RowMajor",
|
|
"layout_b": "ColumnMajor",
|
|
"layout_c": "RowMajor",
|
|
}
|
|
)
|
|
== "rcr"
|
|
)
|
|
|
|
def test_rrr(self):
|
|
assert (
|
|
_layout_from_problem(
|
|
{"layout_a": "RowMajor", "layout_b": "RowMajor", "layout_c": "RowMajor"}
|
|
)
|
|
== "rrr"
|
|
)
|
|
|
|
def test_empty(self):
|
|
assert _layout_from_problem({}) == "???"
|
|
|
|
def test_case_insensitive(self):
|
|
assert (
|
|
_layout_from_problem(
|
|
{
|
|
"layout_a": "rowmajor",
|
|
"layout_b": "COLUMNMAJOR",
|
|
"layout_c": "RowMajor",
|
|
}
|
|
)
|
|
== "rcr"
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# parse_streaming_log
|
|
# ---------------------------------------------------------------------------
|
|
|
|
SAMPLE_LOG = """\
|
|
================================================================================
|
|
LOG FILE: test.log
|
|
================================================================================
|
|
CK Tile Profiling Run
|
|
GPU ID: 0
|
|
|
|
--- Running CK Tile benchmarks on GPU 0 ---
|
|
|
|
========================================
|
|
Shape 1: M=16 N=1536 K=7168 dtype=fp8 layout=rcr
|
|
========================================
|
|
Found 2 kernels
|
|
{
|
|
"name": "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128",
|
|
"problem": {
|
|
"split_k":1,
|
|
"m":16,
|
|
"n":1536,
|
|
"k":7168,
|
|
"stride_a":7168,
|
|
"stride_b":7168,
|
|
"stride_c":1536,
|
|
"dtype_a":"fp8",
|
|
"dtype_b":"fp8",
|
|
"dtype_acc":"fp32",
|
|
"dtype_c":"fp16",
|
|
"layout_a":"RowMajor",
|
|
"layout_b":"ColumnMajor",
|
|
"layout_c":"RowMajor",
|
|
"structured_sparsity":false
|
|
},
|
|
"perf_result": {
|
|
"latency(ms)": 0.04,
|
|
"tflops(TFlops)": 8.81,
|
|
"bandwidth(GB/s)": 279.51
|
|
}
|
|
}
|
|
{
|
|
"name": "gemm_universal_fp8_rcr_compv4_default_intrawave_False_False_False_False_128x128x64_2x2x1_32x32x16",
|
|
"problem": {
|
|
"split_k":1,
|
|
"m":16,
|
|
"n":1536,
|
|
"k":7168,
|
|
"stride_a":7168,
|
|
"stride_b":7168,
|
|
"stride_c":1536,
|
|
"dtype_a":"fp8",
|
|
"dtype_b":"fp8",
|
|
"dtype_acc":"fp32",
|
|
"dtype_c":"fp16",
|
|
"layout_a":"RowMajor",
|
|
"layout_b":"ColumnMajor",
|
|
"layout_c":"RowMajor",
|
|
"structured_sparsity":false
|
|
},
|
|
"perf_result": {
|
|
"latency(ms)": 0.05,
|
|
"tflops(TFlops)": 7.22,
|
|
"bandwidth(GB/s)": 228.85
|
|
}
|
|
}
|
|
|
|
========================================
|
|
Shape 2: M=20480 N=7168 K=256 dtype=fp8 layout=rcr
|
|
========================================
|
|
Found 1 kernels
|
|
{
|
|
"name": "gemm_universal_fp8_rcr_mem_cshuffle_intrawave_False_False_False_True_64x64x128_1x4x1_16x16x32",
|
|
"problem": {
|
|
"split_k":1,
|
|
"m":20480,
|
|
"n":7168,
|
|
"k":256,
|
|
"stride_a":256,
|
|
"stride_b":256,
|
|
"stride_c":7168,
|
|
"dtype_a":"fp8",
|
|
"dtype_b":"fp8",
|
|
"dtype_acc":"fp32",
|
|
"dtype_c":"fp16",
|
|
"layout_a":"RowMajor",
|
|
"layout_b":"ColumnMajor",
|
|
"layout_c":"RowMajor",
|
|
"structured_sparsity":false
|
|
},
|
|
"perf_result": {
|
|
"latency(ms)": 0.15,
|
|
"tflops(TFlops)": 505.00,
|
|
"bandwidth(GB/s)": 1200.50
|
|
}
|
|
}
|
|
"""
|
|
|
|
|
|
class TestParseStreamingLog:
|
|
def _write_log(self, content: str) -> Path:
|
|
f = tempfile.NamedTemporaryFile(mode="w", suffix=".log", delete=False)
|
|
f.write(content)
|
|
f.close()
|
|
return Path(f.name)
|
|
|
|
def test_basic_parse(self):
|
|
path = self._write_log(SAMPLE_LOG)
|
|
df = parse_streaming_log(path, arch="gfx950")
|
|
assert len(df) == 3
|
|
assert df["arch"].iloc[0] == "gfx950"
|
|
assert df["m"].tolist() == [16, 16, 20480]
|
|
assert df["n"].tolist() == [1536, 1536, 7168]
|
|
assert df["k"].tolist() == [7168, 7168, 256]
|
|
|
|
def test_tflops_values(self):
|
|
path = self._write_log(SAMPLE_LOG)
|
|
df = parse_streaming_log(path)
|
|
assert df["measured_tflops"].tolist() == pytest.approx([8.81, 7.22, 505.0])
|
|
|
|
def test_kernel_config_parsed(self):
|
|
path = self._write_log(SAMPLE_LOG)
|
|
df = parse_streaming_log(path)
|
|
assert df["tile_m"].iloc[0] == 128
|
|
assert df["pipeline"].iloc[0] == "compv3"
|
|
assert df["pipeline"].iloc[1] == "compv4"
|
|
|
|
def test_layout_derived_from_json(self):
|
|
path = self._write_log(SAMPLE_LOG)
|
|
df = parse_streaming_log(path)
|
|
assert all(df["layout"] == "rcr")
|
|
|
|
def test_empty_log(self):
|
|
path = self._write_log("No shapes here\nJust noise\n")
|
|
df = parse_streaming_log(path)
|
|
assert len(df) == 0
|
|
for col in CANONICAL_COLUMNS:
|
|
assert col in df.columns
|
|
|
|
def test_single_kernel(self):
|
|
log = """\
|
|
Shape 1: M=1 N=1 K=1 dtype=fp8 layout=rcr
|
|
{
|
|
"name": "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128",
|
|
"problem": {"split_k":1, "m":1, "n":1, "k":1, "dtype_a":"fp8", "dtype_b":"fp8", "layout_a":"RowMajor", "layout_b":"ColumnMajor", "layout_c":"RowMajor"},
|
|
"perf_result": {"latency(ms)": 0.001, "tflops(TFlops)": 0.002, "bandwidth(GB/s)": 0.01}
|
|
}
|
|
"""
|
|
path = self._write_log(log)
|
|
df = parse_streaming_log(path)
|
|
assert len(df) == 1
|
|
assert df["m"].iloc[0] == 1
|
|
assert bool(df["is_valid"].iloc[0]) is True
|
|
|
|
def test_zero_tflops_marked_invalid(self):
|
|
log = """\
|
|
Shape 1: M=16 N=16 K=16 dtype=fp8 layout=rcr
|
|
{
|
|
"name": "test_kernel",
|
|
"problem": {"split_k":1, "m":16, "n":16, "k":16, "dtype_a":"fp8"},
|
|
"perf_result": {"latency(ms)": 0.0, "tflops(TFlops)": 0.0, "bandwidth(GB/s)": 0.0}
|
|
}
|
|
"""
|
|
path = self._write_log(log)
|
|
df = parse_streaming_log(path)
|
|
assert len(df) == 1
|
|
assert bool(df["is_valid"].iloc[0]) is False
|
|
|
|
def test_malformed_json_skipped(self):
|
|
log = """\
|
|
Shape 1: M=16 N=16 K=16 dtype=fp8 layout=rcr
|
|
{
|
|
"name": "good_kernel",
|
|
"problem": {"split_k":1, "m":16, "n":16, "k":16, "dtype_a":"fp8"},
|
|
"perf_result": {"latency(ms)": 0.01, "tflops(TFlops)": 1.0, "bandwidth(GB/s)": 10.0}
|
|
}
|
|
{ this is not valid json }
|
|
{
|
|
"name": "another_good",
|
|
"problem": {"split_k":1, "m":16, "n":16, "k":16, "dtype_a":"fp8"},
|
|
"perf_result": {"latency(ms)": 0.02, "tflops(TFlops)": 2.0, "bandwidth(GB/s)": 20.0}
|
|
}
|
|
"""
|
|
path = self._write_log(log)
|
|
df = parse_streaming_log(path)
|
|
assert len(df) == 2
|
|
|
|
def test_extreme_shapes(self):
|
|
"""Tiny M=1 (single token) and very large M=20480."""
|
|
path = self._write_log(SAMPLE_LOG)
|
|
df = parse_streaming_log(path)
|
|
assert 1 not in df["m"].values # sample has M=16, M=20480
|
|
assert 16 in df["m"].values
|
|
assert 20480 in df["m"].values
|
|
|
|
def test_run_id_assigned(self):
|
|
path = self._write_log(SAMPLE_LOG)
|
|
df = parse_streaming_log(path, run_id="test_run_123")
|
|
assert all(df["run_id"] == "test_run_123")
|
|
|
|
def test_op_type_assigned(self):
|
|
path = self._write_log(SAMPLE_LOG)
|
|
df = parse_streaming_log(path, op_type="gemm_streamk")
|
|
assert all(df["op_type"] == "gemm_streamk")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Parquet round-trip
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestParquetIO:
|
|
def test_round_trip(self, tmp_path):
|
|
df = pd.DataFrame(
|
|
{
|
|
"m": [16, 32],
|
|
"n": [1536, 1536],
|
|
"k": [7168, 7168],
|
|
"measured_tflops": [8.81, 15.0],
|
|
}
|
|
)
|
|
path = tmp_path / "test.parquet"
|
|
save_parquet(df, path)
|
|
loaded = load_parquet(path)
|
|
assert len(loaded) == 2
|
|
assert loaded["m"].tolist() == [16, 32]
|
|
|
|
def test_creates_parent_dirs(self, tmp_path):
|
|
path = tmp_path / "sub" / "dir" / "test.parquet"
|
|
df = pd.DataFrame({"x": [1]})
|
|
save_parquet(df, path)
|
|
assert path.exists()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|