mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
#!/usr/bin/env python3
|
|
"""Test that compressed models can be loaded and used."""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
|
|
from predict import Predictor
|
|
|
|
|
|
def test_fp16_model_decompression():
|
|
"""Test that fp16 model is auto-decompressed and usable."""
|
|
model_dir = Path(__file__).parent.parent / "models" / "gemm_universal_fp16_gfx950"
|
|
|
|
# Ensure .lgbm.gz exists
|
|
gz_file = model_dir / "model_tflops.lgbm.gz"
|
|
|
|
assert gz_file.exists(), f"Compressed model not found: {gz_file}"
|
|
|
|
# Load predictor - should auto-decompress
|
|
predictor = Predictor(model_dir)
|
|
|
|
# Test prediction
|
|
problem = {"m": 128, "n": 1536, "k": 7168, "dtype": "fp16", "layout": "rcr"}
|
|
kernel_config = {
|
|
"tile_shape": {"m0": 128, "n0": 128, "k0": 16},
|
|
"wave_shape": {"m1": 2, "n1": 2, "k1": 1},
|
|
"warp_tile": {"m2": 32, "n2": 32, "k2": 8},
|
|
}
|
|
|
|
tflops = predictor.predict_tflops(problem, kernel_config)
|
|
|
|
assert isinstance(tflops, float), f"Expected float, got {type(tflops)}"
|
|
assert tflops > 0, f"Expected positive TFLOPS, got {tflops}"
|
|
|
|
# Verify decompressed file was created
|
|
lgbm_file = model_dir / "model_tflops.lgbm"
|
|
assert lgbm_file.exists(), "Model should have been decompressed"
|
|
|
|
print(f"✅ FP16 model decompression test passed")
|
|
print(f" Predicted TFLOPS: {tflops:.2f}")
|
|
print(f" Decompressed to: {lgbm_file}")
|
|
return True
|
|
|
|
|
|
def test_fp8_model_decompression():
|
|
"""Test that fp8 model is auto-decompressed and usable."""
|
|
model_dir = Path(__file__).parent.parent / "models" / "gemm_universal_fp8_gfx950"
|
|
|
|
# Ensure .lgbm.gz exists
|
|
gz_file = model_dir / "model_tflops.lgbm.gz"
|
|
|
|
assert gz_file.exists(), f"Compressed model not found: {gz_file}"
|
|
|
|
# Load predictor - should auto-decompress
|
|
predictor = Predictor(model_dir)
|
|
|
|
# Test prediction
|
|
problem = {"m": 2048, "n": 2048, "k": 2048, "dtype": "fp8", "layout": "rcr"}
|
|
kernel_config = {
|
|
"tile_shape": {"m0": 256, "n0": 256, "k0": 64},
|
|
"wave_shape": {"m1": 2, "n1": 2, "k1": 1},
|
|
"warp_tile": {"m2": 32, "n2": 32, "k2": 16},
|
|
}
|
|
|
|
tflops = predictor.predict_tflops(problem, kernel_config)
|
|
|
|
assert isinstance(tflops, float), f"Expected float, got {type(tflops)}"
|
|
assert tflops > 0, f"Expected positive TFLOPS, got {tflops}"
|
|
|
|
# Verify decompressed file was created
|
|
lgbm_file = model_dir / "model_tflops.lgbm"
|
|
assert lgbm_file.exists(), "Model should have been decompressed"
|
|
|
|
print(f"✅ FP8 model decompression test passed")
|
|
print(f" Predicted TFLOPS: {tflops:.2f}")
|
|
print(f" Decompressed to: {lgbm_file}")
|
|
return True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("Testing compressed model auto-decompression...")
|
|
print()
|
|
|
|
test_fp16_model_decompression()
|
|
print()
|
|
test_fp8_model_decompression()
|
|
print()
|
|
print("✅ All model compression tests passed!")
|