mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
[CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
CK Tile Dispatcher Examples
Comprehensive examples for GEMM operations with GPU execution.
Note
: Convolution examples have been moved to
ck-2/conv_archive/for reference.
Quick Start
Step 1: Build
cd /path/to/composable_kernel/dispatcher
mkdir -p build && cd build
cmake .. \
-DCMAKE_PREFIX_PATH=/opt/rocm \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_BUILD_TYPE=Release \
-DGPU_TARGETS="gfx942" \
-DBUILD_DISPATCHER_EXAMPLES=ON
# Build everything (C++ examples + Python libraries)
make -j$(nproc)
# Or build ONLY Python libraries (faster)
make python_libs -j$(nproc)
Step 2: Run C++ Examples
cd build/examples
# GEMM
./gemm_01_basic
./gemm_02_multi_size
./gemm_03_benchmark_validation
./gemm_04_heuristics
./gemm_05_json_export
./gemm_06_multi_registry
Step 3: Run Python Examples
cd /path/to/composable_kernel/dispatcher
# GEMM
python3 examples/gemm/python/01_basic_gemm.py
python3 examples/gemm/python/04_validation.py
python3 examples/gemm/python/07_stress_test.py
python3 examples/gemm/python/08_heuristics.py
Directory Structure
examples/
├── gemm/
│ ├── cpp/ # 6 C++ GEMM examples
│ └── python/ # 11 Python GEMM examples
│
└── README.md
GEMM Examples
C++ Examples
| # | Example | Description |
|---|---|---|
| 01 | gemm_01_basic |
Basic GEMM with declarative API, autofill, autocorrect |
| 02 | gemm_02_multi_size |
Wildcard expansion for multiple configurations |
| 03 | gemm_03_benchmark_validation |
Performance benchmarking with CPU/GPU validation |
| 04 | gemm_04_heuristics |
Heuristic-based kernel selection |
| 05 | gemm_05_json_export |
Registry JSON export for external tools |
| 06 | gemm_06_multi_registry |
Multiple registries with named kernel sets |
Details: gemm/cpp/README.md
Python Examples
| # | Example | Description |
|---|---|---|
| 01 | 01_basic_gemm.py |
Basic GEMM with multi-kernel support |
| 02 | 02_batch_gemm.py |
Batched GEMM operations |
| 03 | 03_benchmark.py |
Performance benchmarking |
| 04 | 04_validation.py |
CPU reference validation |
| 05 | 05_numpy_integration.py |
NumPy array integration |
| 06 | 06_json_export.py |
Registry JSON export |
| 07 | 07_stress_test.py |
Multi-kernel stress testing (48 configs) |
| 08 | 08_heuristics.py |
Heuristic-based kernel selection (24 configs) |
| 09 | 09_multi_registry.py |
Multiple registries |
| 10 | 10_advanced_benchmark.py |
Advanced benchmark with full control |
| 11 | 11_json_import.py |
Import kernels from JSON |
Details: gemm/python/README.md
Key Features
Declarative Kernel API
Both C++ and Python examples use a declarative approach:
C++ (DECL_KERNEL_SET macro):
DECL_KERNEL_SET(my_kernels,
.add(
Signature().dtype("fp16").layout("rcr"),
Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16)
.pipeline("compv4").scheduler("intrawave"),
"gfx942"
)
);
Python (KernelConfig):
config = KernelConfig(
tile_m=256, tile_n=256, tile_k=32,
wave_m=2, wave_n=2, wave_k=1,
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
pipeline="compv4", scheduler="intrawave"
)
Autofill and Autocorrect
The build system automatically:
- Autofills missing parameters with sensible defaults
- Autocorrects invalid parameters based on architecture constraints
- Expands wildcards (
*,-1,ANY_INT) to all valid configurations
Architecture Filtering
Kernel configurations are validated against GPU architecture constraints:
- Tile divisibility requirements
- Warp tile constraints
- Pipeline compatibility
Invalid configurations are automatically pruned during code generation.
Validation Examples
C++ Validation
./gemm_03_benchmark_validation --verify 1 # GEMM with CPU reference
./gemm_03_benchmark_validation --verify 2 # GEMM with GPU reference
Python Validation
python3 examples/gemm/python/04_validation.py
python3 examples/gemm/python/07_stress_test.py # Multi-kernel validation
Troubleshooting
Python: Library not found
# Run from dispatcher directory
cd /path/to/composable_kernel/dispatcher
python3 examples/gemm/python/01_basic_gemm.py
C++: Executables not found
# Build with examples enabled
cmake .. -DBUILD_DISPATCHER_EXAMPLES=ON
make -j$(nproc)
# Run from build/examples
cd build/examples
./gemm_01_basic
GPU not detected
rocminfo | grep "Name:"
# Should show: gfx942, gfx90a, etc.
Archived Examples
Convolution examples have been archived to ck-2/conv_archive/dispatcher/:
examples/conv/cpp/- 11 C++ convolution examplesexamples/conv/python/- 14 Python convolution examples
See the archive for convolution functionality reference.