mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
151 lines
4.4 KiB
Markdown
151 lines
4.4 KiB
Markdown
# ML Heuristic Validation Tools
|
|
|
|
This directory contains validation scripts for testing ML-based kernel selection heuristics.
|
|
|
|
## Directory Structure
|
|
|
|
```
|
|
validation/
|
|
├── README.md # This file
|
|
├── validate_ml_heuristic.py # GEMM universal validation
|
|
└── grouped_conv/ # Grouped convolution specific
|
|
├── validate_training_shapes.py # Training data sanity check
|
|
└── validate_backward_models.py # Backward pass prediction quality
|
|
```
|
|
|
|
## Scripts Overview
|
|
|
|
### 1. `validate_ml_heuristic.py` - GEMM Universal Validation
|
|
|
|
**Purpose**: Validate ML heuristic for GEMM universal operations (not grouped conv).
|
|
|
|
**Usage**:
|
|
```bash
|
|
python validate_ml_heuristic.py --dtype fp16 --layout rcr
|
|
python validate_ml_heuristic.py --dtype bf16 --model_dir models/gemm_universal_bf16_gfx950
|
|
```
|
|
|
|
**What it does**:
|
|
- Loads benchmark data (oracle-best results for each GEMM shape)
|
|
- Uses ML model to predict best kernel for each shape
|
|
- Compares ML selection with oracle-best to compute efficiency
|
|
- Outputs mean/median/P10/P90 efficiency statistics
|
|
|
|
**When to use**: Testing GEMM universal ML models on new training data or architectures.
|
|
|
|
---
|
|
|
|
## Grouped Convolution Validation
|
|
|
|
### 2. `grouped_conv/validate_training_shapes.py` - Training Data Sanity Check
|
|
|
|
**Purpose**: Quick sanity check on shapes WITH multiple kernels in training data.
|
|
|
|
**Usage**:
|
|
```bash
|
|
cd dispatcher/heuristics/validation/grouped_conv
|
|
python validate_training_shapes.py
|
|
```
|
|
|
|
**What it does**:
|
|
1. Selects 5 random training shapes with ≥5 kernels each
|
|
2. For each shape:
|
|
- Gets oracle-best from training data
|
|
- Uses ML to predict best kernel
|
|
- Builds BOTH kernels (oracle + ML)
|
|
- Runs both on hardware
|
|
- Compares actual TFLOPS
|
|
|
|
**Output**:
|
|
- Per-shape efficiency (ML vs Oracle on hardware)
|
|
- Prediction accuracy (ML predicted TFLOPS vs actual)
|
|
- Mean efficiency across test shapes
|
|
|
|
**Runtime**: ~5-10 minutes (builds 10 kernels, runs on hardware)
|
|
|
|
**When to use**:
|
|
- Quick sanity check after model training
|
|
- Verify model isn't overfitting to training data
|
|
- Debug prediction accuracy issues
|
|
|
|
---
|
|
|
|
### 3. `grouped_conv/validate_backward_models.py` - Backward Pass Prediction Quality
|
|
|
|
**Purpose**: Quick prediction quality check for bwd_data and bwd_weight ML models.
|
|
|
|
**Usage**:
|
|
```bash
|
|
cd dispatcher/heuristics/validation/grouped_conv
|
|
python validate_backward_models.py
|
|
```
|
|
|
|
**What it does**:
|
|
1. Loads bwd_data and bwd_weight ML models
|
|
2. Tests on 5-7 hardcoded representative problems
|
|
3. For each problem:
|
|
- Predicts TFLOPS for all backward kernels (compv3, mem pipelines)
|
|
- Shows top-3 predicted kernels
|
|
- Reports prediction statistics
|
|
|
|
**Output**:
|
|
- Top-3 predicted kernels for each problem
|
|
- Average predicted TFLOPS
|
|
- Pipeline preference (compv3 vs mem)
|
|
- Prediction confidence (gap between best and 3rd)
|
|
|
|
**Runtime**: <1 minute (NO hardware - prediction only)
|
|
|
|
**When to use**:
|
|
- Quick check after training backward models
|
|
- Verify model predictions are reasonable
|
|
- Debug backward pass heuristic issues
|
|
|
|
**Note**: This does NOT run on hardware - it only checks prediction quality.
|
|
|
|
---
|
|
|
|
## Comparison Matrix
|
|
|
|
| Script | Operation | Hardware? | Shapes Tested | Runtime | Use Case |
|
|
|--------|-----------|-----------|---------------|---------|----------|
|
|
| `validate_ml_heuristic.py` | GEMM universal | ✗ | All training | <1 min | GEMM model validation |
|
|
| `validate_training_shapes.py` | Grouped conv fwd | ✓ | 5 training | 5-10 min | Quick sanity check |
|
|
| `validate_backward_models.py` | Grouped conv bwd | ✗ | 5-7 hardcoded | <1 min | Backward prediction quality |
|
|
|
|
## Typical Workflow
|
|
|
|
1. **After training forward model**:
|
|
```bash
|
|
# Quick check
|
|
python grouped_conv/validate_training_shapes.py
|
|
```
|
|
|
|
2. **After training backward models**:
|
|
```bash
|
|
python grouped_conv/validate_backward_models.py
|
|
```
|
|
|
|
## Target Metrics
|
|
|
|
### Forward Pass (Tier-1 Model)
|
|
- **Mean efficiency**: >90% (currently 93.05%)
|
|
- **P10 efficiency**: >75% (currently 79.21%)
|
|
- **Kernel match rate**: >70%
|
|
|
|
### Backward Pass
|
|
- **Mean efficiency**: >85%
|
|
- **Prediction accuracy**: >90%
|
|
|
|
## Dependencies
|
|
|
|
All scripts require:
|
|
- Trained ML models in `../models/`
|
|
- Training data in `../data/`
|
|
- Python packages: pandas, numpy, LightGBM, matplotlib (for plotting)
|
|
|
|
Grouped conv hardware validation scripts additionally require:
|
|
- GPU hardware (gfx950 default)
|
|
- Compiled kernels or JIT compilation support
|
|
- `tile_engine/ops/grouped_conv/` utilities
|