mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Vidyasagar Ananthan <vidyasagar.ananthan@amd.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
61 lines
842 B
Plaintext
61 lines
842 B
Plaintext
# Python bytecode and caches
|
|
__pycache__/
|
|
*.pyc
|
|
*.pyo
|
|
*.pyd
|
|
.Python
|
|
|
|
# Jupyter notebooks
|
|
*.ipynb
|
|
.ipynb_checkpoints/
|
|
|
|
# Virtual environments
|
|
.venv/
|
|
venv/
|
|
ENV/
|
|
|
|
# IDE and editor files
|
|
.vscode/
|
|
.idea/
|
|
*.swp
|
|
*.swo
|
|
*~
|
|
|
|
# Test output and logs
|
|
*.log
|
|
test_output.log
|
|
custom_shapes_gpu_test.log
|
|
|
|
# Benchmark and analysis output files
|
|
*.csv
|
|
*.json
|
|
!models/*/feature_spec.json
|
|
!models/*/train_manifest.json
|
|
|
|
# Data files (parquet, arrow)
|
|
*.parquet
|
|
*.arrow
|
|
|
|
# Temporary and NFS files
|
|
.nfs*
|
|
*.tmp
|
|
*.bak
|
|
|
|
# Decompressed model files (compressed .lgbm.gz versions are tracked)
|
|
models/**/*.lgbm
|
|
|
|
# User-specific test and analysis scripts
|
|
test_*.py
|
|
!tests/test_*.py
|
|
find_*.py
|
|
oracle_*.json
|
|
validation_results_*.csv
|
|
custom_shapes_*.csv
|
|
fp16_bf16_*.csv
|
|
|
|
# Ignore all markdown files except tracked documentation
|
|
*.md
|
|
!DATA_GENERATION.md
|
|
!LEARNINGS.md
|
|
!README.md
|