[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
ML Heuristic Validation Tools
This directory contains validation scripts for testing ML-based kernel selection heuristics.
Directory Structure
validation/
├── README.md # This file
├── validate_ml_heuristic.py # GEMM universal validation
└── grouped_conv/ # Grouped convolution specific
├── validate_training_shapes.py # Training data sanity check
└── validate_backward_models.py # Backward pass prediction quality
Scripts Overview
1. validate_ml_heuristic.py - GEMM Universal Validation
Purpose: Validate ML heuristic for GEMM universal operations (not grouped conv).
Usage:
python validate_ml_heuristic.py --dtype fp16 --layout rcr
python validate_ml_heuristic.py --dtype bf16 --model_dir models/gemm_universal_bf16_gfx950
What it does:
- Loads benchmark data (oracle-best results for each GEMM shape)
- Uses ML model to predict best kernel for each shape
- Compares ML selection with oracle-best to compute efficiency
- Outputs mean/median/P10/P90 efficiency statistics
When to use: Testing GEMM universal ML models on new training data or architectures.
Grouped Convolution Validation
2. grouped_conv/validate_training_shapes.py - Training Data Sanity Check
Purpose: Quick sanity check on shapes WITH multiple kernels in training data.
Usage:
cd dispatcher/heuristics/validation/grouped_conv
python validate_training_shapes.py
What it does:
- Selects 5 random training shapes with ≥5 kernels each
- For each shape:
- Gets oracle-best from training data
- Uses ML to predict best kernel
- Builds BOTH kernels (oracle + ML)
- Runs both on hardware
- Compares actual TFLOPS
Output:
- Per-shape efficiency (ML vs Oracle on hardware)
- Prediction accuracy (ML predicted TFLOPS vs actual)
- Mean efficiency across test shapes
Runtime: ~5-10 minutes (builds 10 kernels, runs on hardware)
When to use:
- Quick sanity check after model training
- Verify model isn't overfitting to training data
- Debug prediction accuracy issues
3. grouped_conv/validate_backward_models.py - Backward Pass Prediction Quality
Purpose: Quick prediction quality check for bwd_data and bwd_weight ML models.
Usage:
cd dispatcher/heuristics/validation/grouped_conv
python validate_backward_models.py
What it does:
- Loads bwd_data and bwd_weight ML models
- Tests on 5-7 hardcoded representative problems
- For each problem:
- Predicts TFLOPS for all backward kernels (compv3, mem pipelines)
- Shows top-3 predicted kernels
- Reports prediction statistics
Output:
- Top-3 predicted kernels for each problem
- Average predicted TFLOPS
- Pipeline preference (compv3 vs mem)
- Prediction confidence (gap between best and 3rd)
Runtime: <1 minute (NO hardware - prediction only)
When to use:
- Quick check after training backward models
- Verify model predictions are reasonable
- Debug backward pass heuristic issues
Note: This does NOT run on hardware - it only checks prediction quality.
Comparison Matrix
| Script | Operation | Hardware? | Shapes Tested | Runtime | Use Case |
|---|---|---|---|---|---|
validate_ml_heuristic.py |
GEMM universal | ✗ | All training | <1 min | GEMM model validation |
validate_training_shapes.py |
Grouped conv fwd | ✓ | 5 training | 5-10 min | Quick sanity check |
validate_backward_models.py |
Grouped conv bwd | ✗ | 5-7 hardcoded | <1 min | Backward prediction quality |
Typical Workflow
-
After training forward model:
# Quick check python grouped_conv/validate_training_shapes.py -
After training backward models:
python grouped_conv/validate_backward_models.py
Target Metrics
Forward Pass (Tier-1 Model)
- Mean efficiency: >90% (currently 93.05%)
- P10 efficiency: >75% (currently 79.21%)
- Kernel match rate: >70%
Backward Pass
- Mean efficiency: >85%
- Prediction accuracy: >90%
Dependencies
All scripts require:
- Trained ML models in
../models/ - Training data in
../data/ - Python packages: pandas, numpy, LightGBM, matplotlib (for plotting)
Grouped conv hardware validation scripts additionally require:
- GPU hardware (gfx950 default)
- Compiled kernels or JIT compilation support
tile_engine/ops/grouped_conv/utilities