Files
composable_kernel/dispatcher/heuristics/tests/test_model_compression.py
Yaswanth Raparti c1127a36f5 [rocm-libraries] ROCm/rocm-libraries#5676 (commit 1d18339)
[CK][CK TILE]Autotuning heuristics infra for universal GEMM
 kernel selection (#5676)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

This PR adds ML-based kernel selection heuristics to the CK Tile
dispatcher, enabling fast and accurate automatic kernel selection for
Universal Gemm kernels. Instead of requiring exhaustive search through
4600+ kernel configurations (taking ~46 seconds per problem shape), the
ML heuristic predicts optimal kernels in microseconds while achieving
>98% of oracle-best performance.

## Technical Details

**ML infrastructure**

https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics
* Feature Engine
([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)):
55-feature extraction including problem dimensions, kernel
configuration, tile efficiency, and hardware profile
* Training Pipeline
([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)):
LightGBM regression with log-transform, GroupKFold cross-validation,
warm-start support
* Predictor
([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)):
Kernel ranking and TFLOPS prediction for problem shapes
* Evaluation
([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)):
Comprehensive metrics including efficiency, NDCG@k, shape family
analysis

**Data Generation Tools:**

*
[generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py):
Build and benchmark kernels across diverse problem shapes
*
[convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py):
Convert benchmark JSON to training-ready parquet format
*
[data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py):
Parse streaming benchmark logs into canonical datasets

**Examples**
*
[09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp):
C++ example demonstrating ML-based kernel selection
*
[09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py):
Python example with validation

**Pre-trained Models
(projects/composablekernel/dispatcher/heuristics/models/):**
* gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean
efficiency)
* gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean
efficiency)

## Test Plan

* Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8
* All shape families tested: tiny M (M<8), small M, medium M, large M
(M≥1024)
* All pipeline types: compv3, compv4, mem

## Test Result

**fp16 Model (gfx950, RCR layout)**
* Mean Efficiency: 99.36%
* P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of
oracle best)
* Min Efficiency: 95.45%

**fp8 Model (gfx950, RCR layout)**
* Mean Efficiency: 98.28% (original), 97.51% (wide coverage)
* P10 Efficiency: 94.64% (original), 93.89% (wide coverage)
* Min Efficiency: 84.5%

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-02 02:26:32 +00:00

91 lines
2.9 KiB
Python

#!/usr/bin/env python3
"""Test that compressed models can be loaded and used."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from predict import Predictor
def test_fp16_model_decompression():
"""Test that fp16 model is auto-decompressed and usable."""
model_dir = Path(__file__).parent.parent / "models" / "gemm_universal_fp16_gfx950"
# Ensure .lgbm.gz exists
gz_file = model_dir / "model_tflops.lgbm.gz"
assert gz_file.exists(), f"Compressed model not found: {gz_file}"
# Load predictor - should auto-decompress
predictor = Predictor(model_dir)
# Test prediction
problem = {"m": 128, "n": 1536, "k": 7168, "dtype": "fp16", "layout": "rcr"}
kernel_config = {
"tile_shape": {"m0": 128, "n0": 128, "k0": 16},
"wave_shape": {"m1": 2, "n1": 2, "k1": 1},
"warp_tile": {"m2": 32, "n2": 32, "k2": 8},
}
tflops = predictor.predict_tflops(problem, kernel_config)
assert isinstance(tflops, float), f"Expected float, got {type(tflops)}"
assert tflops > 0, f"Expected positive TFLOPS, got {tflops}"
# Verify decompressed file was created
lgbm_file = model_dir / "model_tflops.lgbm"
assert lgbm_file.exists(), "Model should have been decompressed"
print(f"✅ FP16 model decompression test passed")
print(f" Predicted TFLOPS: {tflops:.2f}")
print(f" Decompressed to: {lgbm_file}")
return True
def test_fp8_model_decompression():
"""Test that fp8 model is auto-decompressed and usable."""
model_dir = Path(__file__).parent.parent / "models" / "gemm_universal_fp8_gfx950"
# Ensure .lgbm.gz exists
gz_file = model_dir / "model_tflops.lgbm.gz"
assert gz_file.exists(), f"Compressed model not found: {gz_file}"
# Load predictor - should auto-decompress
predictor = Predictor(model_dir)
# Test prediction
problem = {"m": 2048, "n": 2048, "k": 2048, "dtype": "fp8", "layout": "rcr"}
kernel_config = {
"tile_shape": {"m0": 256, "n0": 256, "k0": 64},
"wave_shape": {"m1": 2, "n1": 2, "k1": 1},
"warp_tile": {"m2": 32, "n2": 32, "k2": 16},
}
tflops = predictor.predict_tflops(problem, kernel_config)
assert isinstance(tflops, float), f"Expected float, got {type(tflops)}"
assert tflops > 0, f"Expected positive TFLOPS, got {tflops}"
# Verify decompressed file was created
lgbm_file = model_dir / "model_tflops.lgbm"
assert lgbm_file.exists(), "Model should have been decompressed"
print(f"✅ FP8 model decompression test passed")
print(f" Predicted TFLOPS: {tflops:.2f}")
print(f" Decompressed to: {lgbm_file}")
return True
if __name__ == "__main__":
print("Testing compressed model auto-decompression...")
print()
test_fp16_model_decompression()
print()
test_fp8_model_decompression()
print()
print("✅ All model compression tests passed!")