mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
142 lines
2.7 KiB
Plaintext
142 lines
2.7 KiB
Plaintext
# Compiled Object files
|
|
*.slo
|
|
*.lo
|
|
*.o
|
|
*.obj
|
|
|
|
# Precompiled Headers
|
|
*.gch
|
|
*.pch
|
|
*.ipch
|
|
|
|
# Compiled Dynamic libraries
|
|
*.so
|
|
*.dylib
|
|
*.dll
|
|
|
|
# Fortran module files
|
|
*.mod
|
|
|
|
# Compiled Static libraries
|
|
*.lai
|
|
*.la
|
|
*.a
|
|
*.lib
|
|
|
|
# Executables
|
|
*.exe
|
|
*.out
|
|
*.app
|
|
|
|
# vim tags
|
|
tags
|
|
.tags
|
|
.*.swp
|
|
|
|
# Editors
|
|
.vscode
|
|
|
|
# CMake formatting configuration (local)
|
|
.cmake-format.yaml
|
|
|
|
# Cline
|
|
.cline*
|
|
|
|
# build-in-source directory (see exceptions below)
|
|
build*
|
|
|
|
# emacs temporary/backup files
|
|
.\#*
|
|
\#*\#
|
|
*~
|
|
|
|
# GDB temporary files
|
|
.gdb_history
|
|
install.dir*
|
|
|
|
# documentation artifacts
|
|
_build/
|
|
_images/
|
|
_static/
|
|
_templates/
|
|
_toc.yml
|
|
_doxygen/
|
|
docs/doxygen/html
|
|
docs/doxygen/xml
|
|
|
|
# JetBrains IDE (see build* exceptions below)
|
|
.idea/
|
|
cmake-build*/
|
|
build*/
|
|
|
|
# LSP configuration
|
|
.clangd
|
|
|
|
# User-defined CMake presets
|
|
CMakeUserPresets.json
|
|
|
|
# Python virtualenv
|
|
.venv/
|
|
|
|
# Python cache
|
|
__pycache__/
|
|
|
|
# Cache directories
|
|
.cache/
|
|
.ck_tile_cache/
|
|
ck_tile_cache/
|
|
**/kernel_cache/
|
|
**/.kernel_cache/
|
|
|
|
# Dispatcher kernel cache (user-generated, can be large)
|
|
dispatcher/**/kernel_cache/
|
|
dispatcher/**/.kernel_cache/
|
|
dispatcher/**/cached_kernels/
|
|
dispatcher/**/*.hsaco
|
|
dispatcher/**/*.co
|
|
|
|
# Dispatcher generated JSON exports
|
|
dispatcher/**/*_kernels.json
|
|
dispatcher/**/dispatcher_kernels.json
|
|
|
|
# Generated test data
|
|
test_data/*
|
|
!test_data/*.py
|
|
!test_data/*.sh
|
|
!test_data/requirements.txt
|
|
|
|
# Exceptions to build* patterns above
|
|
# The experimental/builder directory should be tracked despite matching build*
|
|
!experimental/builder
|
|
!experimental/builder/**
|
|
experimental/grouped_convolution_tile_instances/instances/*
|
|
!experimental/grouped_convolution_tile_instances/instances/*.in
|
|
!experimental/grouped_convolution_tile_instances/instances/*.inc
|
|
!experimental/grouped_convolution_tile_instances/instances/*.hpp
|
|
experimental/grouped_convolution_tile_instances/*.inc
|
|
# Heuristics: benchmark data (never in git)
|
|
dispatcher/heuristics/data/
|
|
|
|
# Heuristics: experimental/training artifacts (exclude from git)
|
|
dispatcher/heuristics/models/**/oof_predictions.parquet
|
|
dispatcher/heuristics/models/**/cv_metrics_*.json
|
|
dispatcher/heuristics/models/**/eval_report.json
|
|
dispatcher/heuristics/models/**/feature_importances_*.json
|
|
dispatcher/heuristics/models/**/model_tflops_ihem.lgbm
|
|
dispatcher/heuristics/models/**/model_tflops_log.lgbm
|
|
dispatcher/heuristics/models/**/model_tflops_log_big.lgbm
|
|
|
|
# Heuristics: keep in git (production model files):
|
|
# models/{op}_{dtype}_{arch}/model_tflops.lgbm
|
|
# models/{op}_{dtype}_{arch}/model_latency.lgbm
|
|
# models/{op}_{dtype}_{arch}/model_bandwidth.lgbm
|
|
# models/{op}_{dtype}_{arch}/feature_spec.json
|
|
# models/{op}_{dtype}_{arch}/train_manifest.json
|
|
|
|
# Heuristics: logs and caches
|
|
dispatcher/heuristics/*.log
|
|
dispatcher/heuristics/__pycache__/
|
|
dispatcher/heuristics/tests/__pycache__/
|
|
dispatcher/heuristics/.pytest_cache/
|
|
|