Files
Yaswanth Raparti 6989cf800c [rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for
 grouped conv (#6327)

## Motivation
The ML heuristic in dispatcher does not support grouped-conv operator
yet. In this PR, the support for fwd, bdw-data, and bwd-weight
grouped-conv kernels have been added. A tile_engine utility has also
been added to compile and run any selected kernel configuration through
dispatcher infrastructure.

## Technical Details

1. Tile engine utility is added to benchmark each shape with all the
possible kernel+tile_size combinations here -
[https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url)
2. New LGBM regressor models for grouped conv are added to models
directory. We have 3 separate models for fwd, bwd-data, and bwd-weights
[https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url)
3. Implemented lazy GPU initialization (dispatcher/python)
- **Issue**: ProcessPoolExecutor fork() + GPU context caused memory
access faults
- **Solution**: Mirror FMHA pattern - defer GPU initialization until
first run()
  - **Changes**:
- setup_multiple_grouped_conv_dispatchers() returns List[Path], not
loaded libs
    - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL
    - Added _ensure_initialized() method for lazy GPU loading
    - GPU context created only on first run() call
  - **Benefit**: Parallel compilation now works without GPU conflicts
4. Addressed few miscellaneous issues such as:
  - Fixed BF16->FP16 naming bug in the dispatcher wrapper
- Added new tile sizes, and comp_v5 pipeline to the arch spec to expand
the kernel selection
- Added automatic padding support for unsupported shapes in dispatcher
runner
- Created a single source of truth between tile_engine and dispatcher
about the architecture and tile_size details
- Build a validation scripts to compare oracle_best vs ml_heuristic
comparison

## Test Plan

1. Validated fwd, bwd-data, and bwd-weight kernels with both known and
unseen data sets with up to 300 problems.
2. Ensured that test cases are added in both dispatcher and tile_engine
to validate the heuristic.

## Test Result
Results on Unseen shapes validated on gfx950
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.05%**
  - Median Efficiency: **96.8%**
  - P10 Efficiency: **79.9%**

#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.8%**
  - Median Efficiency: **96.5%**
  - P10 Efficiency: **82.9%**

#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **96.1%**
  - Median Efficiency: **99.2%**
  - P10 Efficiency: **89.4%**

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-08 20:48:42 +00:00

151 lines
4.4 KiB
Markdown

# ML Heuristic Validation Tools
This directory contains validation scripts for testing ML-based kernel selection heuristics.
## Directory Structure
```
validation/
├── README.md # This file
├── validate_ml_heuristic.py # GEMM universal validation
└── grouped_conv/ # Grouped convolution specific
├── validate_training_shapes.py # Training data sanity check
└── validate_backward_models.py # Backward pass prediction quality
```
## Scripts Overview
### 1. `validate_ml_heuristic.py` - GEMM Universal Validation
**Purpose**: Validate ML heuristic for GEMM universal operations (not grouped conv).
**Usage**:
```bash
python validate_ml_heuristic.py --dtype fp16 --layout rcr
python validate_ml_heuristic.py --dtype bf16 --model_dir models/gemm_universal_bf16_gfx950
```
**What it does**:
- Loads benchmark data (oracle-best results for each GEMM shape)
- Uses ML model to predict best kernel for each shape
- Compares ML selection with oracle-best to compute efficiency
- Outputs mean/median/P10/P90 efficiency statistics
**When to use**: Testing GEMM universal ML models on new training data or architectures.
---
## Grouped Convolution Validation
### 2. `grouped_conv/validate_training_shapes.py` - Training Data Sanity Check
**Purpose**: Quick sanity check on shapes WITH multiple kernels in training data.
**Usage**:
```bash
cd dispatcher/heuristics/validation/grouped_conv
python validate_training_shapes.py
```
**What it does**:
1. Selects 5 random training shapes with ≥5 kernels each
2. For each shape:
- Gets oracle-best from training data
- Uses ML to predict best kernel
- Builds BOTH kernels (oracle + ML)
- Runs both on hardware
- Compares actual TFLOPS
**Output**:
- Per-shape efficiency (ML vs Oracle on hardware)
- Prediction accuracy (ML predicted TFLOPS vs actual)
- Mean efficiency across test shapes
**Runtime**: ~5-10 minutes (builds 10 kernels, runs on hardware)
**When to use**:
- Quick sanity check after model training
- Verify model isn't overfitting to training data
- Debug prediction accuracy issues
---
### 3. `grouped_conv/validate_backward_models.py` - Backward Pass Prediction Quality
**Purpose**: Quick prediction quality check for bwd_data and bwd_weight ML models.
**Usage**:
```bash
cd dispatcher/heuristics/validation/grouped_conv
python validate_backward_models.py
```
**What it does**:
1. Loads bwd_data and bwd_weight ML models
2. Tests on 5-7 hardcoded representative problems
3. For each problem:
- Predicts TFLOPS for all backward kernels (compv3, mem pipelines)
- Shows top-3 predicted kernels
- Reports prediction statistics
**Output**:
- Top-3 predicted kernels for each problem
- Average predicted TFLOPS
- Pipeline preference (compv3 vs mem)
- Prediction confidence (gap between best and 3rd)
**Runtime**: <1 minute (NO hardware - prediction only)
**When to use**:
- Quick check after training backward models
- Verify model predictions are reasonable
- Debug backward pass heuristic issues
**Note**: This does NOT run on hardware - it only checks prediction quality.
---
## Comparison Matrix
| Script | Operation | Hardware? | Shapes Tested | Runtime | Use Case |
|--------|-----------|-----------|---------------|---------|----------|
| `validate_ml_heuristic.py` | GEMM universal | ✗ | All training | <1 min | GEMM model validation |
| `validate_training_shapes.py` | Grouped conv fwd | ✓ | 5 training | 5-10 min | Quick sanity check |
| `validate_backward_models.py` | Grouped conv bwd | ✗ | 5-7 hardcoded | <1 min | Backward prediction quality |
## Typical Workflow
1. **After training forward model**:
```bash
# Quick check
python grouped_conv/validate_training_shapes.py
```
2. **After training backward models**:
```bash
python grouped_conv/validate_backward_models.py
```
## Target Metrics
### Forward Pass (Tier-1 Model)
- **Mean efficiency**: >90% (currently 93.05%)
- **P10 efficiency**: >75% (currently 79.21%)
- **Kernel match rate**: >70%
### Backward Pass
- **Mean efficiency**: >85%
- **Prediction accuracy**: >90%
## Dependencies
All scripts require:
- Trained ML models in `../models/`
- Training data in `../data/`
- Python packages: pandas, numpy, LightGBM, matplotlib (for plotting)
Grouped conv hardware validation scripts additionally require:
- GPU hardware (gfx950 default)
- Compiled kernels or JIT compilation support
- `tile_engine/ops/grouped_conv/` utilities