Files
composable_kernel/dispatcher/heuristics/tests/test_feature_parity.py
Yaswanth Raparti 91dbdfa476 [CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676)
## Motivation

This PR adds ML-based kernel selection heuristics to the CK Tile
dispatcher, enabling fast and accurate automatic kernel selection for
Universal Gemm kernels. Instead of requiring exhaustive search through
4600+ kernel configurations (taking ~46 seconds per problem shape), the
ML heuristic predicts optimal kernels in microseconds while achieving
>98% of oracle-best performance.

## Technical Details

**ML infrastructure** 

https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics
* Feature Engine
([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)):
55-feature extraction including problem dimensions, kernel
configuration, tile efficiency, and hardware profile
* Training Pipeline
([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)):
LightGBM regression with log-transform, GroupKFold cross-validation,
warm-start support
* Predictor
([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)):
Kernel ranking and TFLOPS prediction for problem shapes
* Evaluation
([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)):
Comprehensive metrics including efficiency, NDCG@k, shape family
analysis

**Data Generation Tools:**

*
[generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py):
Build and benchmark kernels across diverse problem shapes
*
[convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py):
Convert benchmark JSON to training-ready parquet format
*
[data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py):
Parse streaming benchmark logs into canonical datasets

**Examples**
*
[09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp):
C++ example demonstrating ML-based kernel selection
*
[09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py):
Python example with validation


**Pre-trained Models
(projects/composablekernel/dispatcher/heuristics/models/):**
* gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean
efficiency)
* gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean
efficiency)


## Test Plan

* Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8
* All shape families tested: tiny M (M<8), small M, medium M, large M
(M≥1024)
* All pipeline types: compv3, compv4, mem

## Test Result

**fp16 Model (gfx950, RCR layout)**
* Mean Efficiency: 99.36%
* P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of
oracle best)
* Min Efficiency: 95.45%

**fp8 Model (gfx950, RCR layout)**
* Mean Efficiency: 98.28% (original), 97.51% (wide coverage)
* P10 Efficiency: 94.64% (original), 93.89% (wide coverage)
* Min Efficiency: 84.5%

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Vidyasagar Ananthan <vidyasagar.ananthan@amd.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-01 19:25:55 -07:00

358 lines
12 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Test that the C++ extract_features() in ml_heuristic.hpp produces identical
values to the Python GemmUniversalFeatureEngine.extract().
This test uses ctypes to call the C++ feature extraction compiled into a
small shared library, then compares against Python output. If compilation
fails (no HIP/ROCm), it falls back to verifying the Python feature engine
against manually computed expected values for specific test cases.
"""
import math
import sys
from pathlib import Path
import numpy as np
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from feature_engine import (
GemmUniversalFeatureEngine,
PIPELINE_MAP,
SCHEDULER_MAP,
EPILOGUE_MAP,
LAYOUT_MAP,
)
def _compute_features_manually(
M,
N,
K,
split_k,
dtype,
layout,
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
pipeline,
scheduler,
epilogue,
pad_m,
pad_n,
pad_k,
persistent,
hw,
):
"""Recompute features independently to verify the Python engine."""
bpe_map = {"fp8": 1.0, "fp16": 2.0, "bf16": 2.0, "fp32": 4.0}
bpe = bpe_map.get(dtype, 1.0)
log2_M = math.log2(max(M, 1))
log2_N = math.log2(max(N, 1))
log2_K = math.log2(max(K, 1))
log2_MNK = math.log2(max(M * N * K, 1))
mem = (M * K + K * N + M * N) * bpe
ai = (2.0 * M * N * K) / max(mem, 1)
lds_est = (tile_m * tile_k + tile_n * tile_k) * bpe
lds_cap = 32768 if pipeline == "compv4" else hw["lds_capacity"]
ntm = math.ceil(M / max(tile_m, 1))
ntn = math.ceil(N / max(tile_n, 1))
ntk = math.ceil(K / max(tile_k, 1))
def eff(d, t):
if t <= 0:
return 1.0
r = d % t
return r / t if r > 0 else 1.0
# Problem-to-tile ratios
ratio_M_to_tile_m = M / max(tile_m, 1)
ratio_N_to_tile_n = N / max(tile_n, 1)
ratio_K_to_tile_k = K / max(tile_k, 1)
# Binary features: problem smaller than tile
problem_smaller_than_tile_m = float(M < tile_m)
problem_smaller_than_tile_n = float(N < tile_n)
problem_smaller_than_tile_k = float(K < tile_k)
any_dim_too_small = float((M < tile_m) or (N < tile_n) or (K < tile_k))
# Padding requirement features
needs_padding_m = float(tile_m > 0 and M % tile_m != 0)
needs_padding_n = float(tile_n > 0 and N % tile_n != 0)
needs_padding_k = float(tile_k > 0 and K % tile_k != 0)
# Interaction features
has_padding_when_needed_m = float(needs_padding_m and pad_m)
has_padding_when_needed_n = float(needs_padding_n and pad_n)
has_padding_when_needed_k = float(needs_padding_k and pad_k)
# Missing padding features
missing_required_padding_m = float(needs_padding_m and not pad_m)
missing_required_padding_n = float(needs_padding_n and not pad_n)
missing_required_padding_k = float(needs_padding_k and not pad_k)
missing_any_required_padding = float(
missing_required_padding_m or missing_required_padding_n or missing_required_padding_k
)
return [
M, # 0
N, # 1
K, # 2
split_k, # 3
log2_M, # 4
log2_N, # 5
log2_K, # 6
log2_MNK, # 7
ai, # 8
M / max(N, 1), # 9 (aspect_ratio_mn)
M / max(K, 1), # 10 (aspect_ratio_mk)
N / max(K, 1), # 11 (aspect_ratio_nk)
LAYOUT_MAP.get(layout, 0), # 12
tile_m, # 13
tile_n, # 14
tile_k, # 15
warp_m, # 16
warp_n, # 17
warp_k, # 18
warp_tile_m, # 19
warp_tile_n, # 20
warp_tile_k, # 21
PIPELINE_MAP.get(pipeline, 0), # 22
SCHEDULER_MAP.get(scheduler, 0), # 23
EPILOGUE_MAP.get(epilogue, 0), # 24
float(pad_m), # 25
float(pad_n), # 26
float(pad_k), # 27
float(persistent), # 28
warp_m * warp_n * warp_k, # 29 (num_warps)
tile_m * tile_n * tile_k, # 30 (tile_volume)
tile_m * tile_n, # 31 (tile_mn)
lds_est, # 32 (lds_usage_estimate)
lds_est / max(lds_cap, 1), # 33 (lds_usage_ratio)
ntm, # 34 (num_tiles_m)
ntn, # 35 (num_tiles_n)
ntk, # 36 (num_tiles_k)
ntm * ntn, # 37 (total_output_tiles)
eff(M, tile_m), # 38 (tile_eff_m)
eff(N, tile_n), # 39 (tile_eff_n)
eff(K, tile_k), # 40 (tile_eff_k)
eff(M, tile_m) * eff(N, tile_n) * eff(K, tile_k), # 41 (overall_tile_efficiency)
ntm * ntn / max(hw["num_cus"], 1), # 42 (cu_utilization)
ratio_M_to_tile_m, # 43
ratio_N_to_tile_n, # 44
ratio_K_to_tile_k, # 45
problem_smaller_than_tile_m, # 46
problem_smaller_than_tile_n, # 47
problem_smaller_than_tile_k, # 48
any_dim_too_small, # 49
needs_padding_m, # 50
needs_padding_n, # 51
needs_padding_k, # 52
has_padding_when_needed_m, # 53
has_padding_when_needed_n, # 54
has_padding_when_needed_k, # 55
missing_required_padding_m, # 56
missing_required_padding_n, # 57
missing_required_padding_k, # 58
missing_any_required_padding, # 59
hw["num_cus"], # 60
hw["simds_per_cu"], # 61
hw["num_cus"] * hw["simds_per_cu"], # 62 (total_simds)
hw["shader_engines"], # 63
hw["max_clock_mhz"], # 64
hw["max_waves_per_cu"], # 65
hw["wavefront_size"], # 66
hw["lds_capacity"], # 67
hw["l1_cache_kb"], # 68
hw["l2_cache_kb"], # 69
hw["l3_cache_kb"], # 70
hw["num_xcd"], # 71
]
TEST_CASES = [
{
"problem": {
"m": 1024,
"n": 1024,
"k": 1024,
"split_k": 1,
"dtype": "fp8",
"layout": "rcr",
},
"kernel": {
"tile_m": 128,
"tile_n": 128,
"tile_k": 64,
"warp_m": 2,
"warp_n": 2,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 16,
"pipeline": "compv3",
"scheduler": "intrawave",
"epilogue": "cshuffle",
"pad_m": False,
"pad_n": False,
"pad_k": False,
"persistent": False,
},
},
{
"problem": {
"m": 1,
"n": 4096,
"k": 4096,
"split_k": 1,
"dtype": "fp8",
"layout": "rcr",
},
"kernel": {
"tile_m": 64,
"tile_n": 64,
"tile_k": 128,
"warp_m": 1,
"warp_n": 4,
"warp_k": 1,
"warp_tile_m": 16,
"warp_tile_n": 16,
"warp_tile_k": 128,
"pipeline": "compv4",
"scheduler": "interwave",
"epilogue": "default",
"pad_m": True,
"pad_n": True,
"pad_k": True,
"persistent": True,
},
},
{
"problem": {
"m": 20480,
"n": 7168,
"k": 256,
"split_k": 1,
"dtype": "fp16",
"layout": "rrr",
},
"kernel": {
"tile_m": 256,
"tile_n": 256,
"tile_k": 32,
"warp_m": 4,
"warp_n": 1,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 16,
"pipeline": "mem",
"scheduler": "interwave",
"epilogue": "cshuffle",
"pad_m": False,
"pad_n": False,
"pad_k": False,
"persistent": False,
},
},
]
HW = {
"num_cus": 256,
"simds_per_cu": 4,
"shader_engines": 32,
"max_clock_mhz": 2400,
"max_waves_per_cu": 32,
"wavefront_size": 64,
"lds_capacity": 65536,
"l1_cache_kb": 32,
"l2_cache_kb": 4096,
"l3_cache_kb": 262144,
"num_xcd": 8,
}
class TestFeatureParity:
"""Verify Python feature engine matches manual computation (C++ uses same logic)."""
@pytest.fixture
def fe(self):
return GemmUniversalFeatureEngine(**HW)
@pytest.mark.parametrize("case_idx", range(len(TEST_CASES)))
def test_python_matches_manual(self, fe, case_idx):
case = TEST_CASES[case_idx]
prob = case["problem"]
kern = case["kernel"]
py_features = fe.extract(prob, kern)
manual = _compute_features_manually(
prob["m"],
prob["n"],
prob["k"],
prob["split_k"],
prob["dtype"],
prob["layout"],
kern["tile_m"],
kern["tile_n"],
kern["tile_k"],
kern["warp_m"],
kern["warp_n"],
kern["warp_k"],
kern["warp_tile_m"],
kern["warp_tile_n"],
kern["warp_tile_k"],
kern["pipeline"],
kern["scheduler"],
kern["epilogue"],
kern["pad_m"],
kern["pad_n"],
kern["pad_k"],
kern["persistent"],
HW,
)
manual_arr = np.array(manual, dtype=np.float64)
assert len(py_features) == len(manual_arr) == 72
for i in range(72):
assert py_features[i] == pytest.approx(
manual_arr[i], rel=1e-10, abs=1e-15
), (
f"Feature {i} ({fe.get_feature_names()[i]}): Python={py_features[i]}, Manual={manual_arr[i]}"
)
def test_feature_count(self, fe):
assert len(fe.get_feature_names()) == 72
def test_encoding_maps_match_cpp(self):
"""The C++ encode_* functions must use the same mapping as Python."""
assert PIPELINE_MAP == {
"compv3": 0,
"compv4": 1,
"compv5": 2,
"mem": 3,
"preshufflev2": 4,
}
assert SCHEDULER_MAP == {"intrawave": 0, "interwave": 1}
assert EPILOGUE_MAP == {"default": 0, "cshuffle": 1}
assert LAYOUT_MAP == {"rcr": 0, "rrr": 1, "crr": 2, "ccr": 3}
if __name__ == "__main__":
pytest.main([__file__, "-v"])