Files
Yaswanth Raparti 6989cf800c [rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for
 grouped conv (#6327)

## Motivation
The ML heuristic in dispatcher does not support grouped-conv operator
yet. In this PR, the support for fwd, bdw-data, and bwd-weight
grouped-conv kernels have been added. A tile_engine utility has also
been added to compile and run any selected kernel configuration through
dispatcher infrastructure.

## Technical Details

1. Tile engine utility is added to benchmark each shape with all the
possible kernel+tile_size combinations here -
[https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url)
2. New LGBM regressor models for grouped conv are added to models
directory. We have 3 separate models for fwd, bwd-data, and bwd-weights
[https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url)
3. Implemented lazy GPU initialization (dispatcher/python)
- **Issue**: ProcessPoolExecutor fork() + GPU context caused memory
access faults
- **Solution**: Mirror FMHA pattern - defer GPU initialization until
first run()
  - **Changes**:
- setup_multiple_grouped_conv_dispatchers() returns List[Path], not
loaded libs
    - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL
    - Added _ensure_initialized() method for lazy GPU loading
    - GPU context created only on first run() call
  - **Benefit**: Parallel compilation now works without GPU conflicts
4. Addressed few miscellaneous issues such as:
  - Fixed BF16->FP16 naming bug in the dispatcher wrapper
- Added new tile sizes, and comp_v5 pipeline to the arch spec to expand
the kernel selection
- Added automatic padding support for unsupported shapes in dispatcher
runner
- Created a single source of truth between tile_engine and dispatcher
about the architecture and tile_size details
- Build a validation scripts to compare oracle_best vs ml_heuristic
comparison

## Test Plan

1. Validated fwd, bwd-data, and bwd-weight kernels with both known and
unseen data sets with up to 300 problems.
2. Ensured that test cases are added in both dispatcher and tile_engine
to validate the heuristic.

## Test Result
Results on Unseen shapes validated on gfx950
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.05%**
  - Median Efficiency: **96.8%**
  - P10 Efficiency: **79.9%**

#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.8%**
  - Median Efficiency: **96.5%**
  - P10 Efficiency: **82.9%**

#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **96.1%**
  - Median Efficiency: **99.2%**
  - P10 Efficiency: **89.4%**

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-08 20:48:42 +00:00
..

ML Heuristic Validation Tools

This directory contains validation scripts for testing ML-based kernel selection heuristics.

Directory Structure

validation/
├── README.md                          # This file
├── validate_ml_heuristic.py           # GEMM universal validation
└── grouped_conv/                      # Grouped convolution specific
    ├── validate_training_shapes.py    # Training data sanity check
    └── validate_backward_models.py    # Backward pass prediction quality

Scripts Overview

1. validate_ml_heuristic.py - GEMM Universal Validation

Purpose: Validate ML heuristic for GEMM universal operations (not grouped conv).

Usage:

python validate_ml_heuristic.py --dtype fp16 --layout rcr
python validate_ml_heuristic.py --dtype bf16 --model_dir models/gemm_universal_bf16_gfx950

What it does:

  • Loads benchmark data (oracle-best results for each GEMM shape)
  • Uses ML model to predict best kernel for each shape
  • Compares ML selection with oracle-best to compute efficiency
  • Outputs mean/median/P10/P90 efficiency statistics

When to use: Testing GEMM universal ML models on new training data or architectures.


Grouped Convolution Validation

2. grouped_conv/validate_training_shapes.py - Training Data Sanity Check

Purpose: Quick sanity check on shapes WITH multiple kernels in training data.

Usage:

cd dispatcher/heuristics/validation/grouped_conv
python validate_training_shapes.py

What it does:

  1. Selects 5 random training shapes with ≥5 kernels each
  2. For each shape:
    • Gets oracle-best from training data
    • Uses ML to predict best kernel
    • Builds BOTH kernels (oracle + ML)
    • Runs both on hardware
    • Compares actual TFLOPS

Output:

  • Per-shape efficiency (ML vs Oracle on hardware)
  • Prediction accuracy (ML predicted TFLOPS vs actual)
  • Mean efficiency across test shapes

Runtime: ~5-10 minutes (builds 10 kernels, runs on hardware)

When to use:

  • Quick sanity check after model training
  • Verify model isn't overfitting to training data
  • Debug prediction accuracy issues

3. grouped_conv/validate_backward_models.py - Backward Pass Prediction Quality

Purpose: Quick prediction quality check for bwd_data and bwd_weight ML models.

Usage:

cd dispatcher/heuristics/validation/grouped_conv
python validate_backward_models.py

What it does:

  1. Loads bwd_data and bwd_weight ML models
  2. Tests on 5-7 hardcoded representative problems
  3. For each problem:
    • Predicts TFLOPS for all backward kernels (compv3, mem pipelines)
    • Shows top-3 predicted kernels
    • Reports prediction statistics

Output:

  • Top-3 predicted kernels for each problem
  • Average predicted TFLOPS
  • Pipeline preference (compv3 vs mem)
  • Prediction confidence (gap between best and 3rd)

Runtime: <1 minute (NO hardware - prediction only)

When to use:

  • Quick check after training backward models
  • Verify model predictions are reasonable
  • Debug backward pass heuristic issues

Note: This does NOT run on hardware - it only checks prediction quality.


Comparison Matrix

Script Operation Hardware? Shapes Tested Runtime Use Case
validate_ml_heuristic.py GEMM universal All training <1 min GEMM model validation
validate_training_shapes.py Grouped conv fwd 5 training 5-10 min Quick sanity check
validate_backward_models.py Grouped conv bwd 5-7 hardcoded <1 min Backward prediction quality

Typical Workflow

  1. After training forward model:

    # Quick check
    python grouped_conv/validate_training_shapes.py
    
  2. After training backward models:

    python grouped_conv/validate_backward_models.py
    

Target Metrics

Forward Pass (Tier-1 Model)

  • Mean efficiency: >90% (currently 93.05%)
  • P10 efficiency: >75% (currently 79.21%)
  • Kernel match rate: >70%

Backward Pass

  • Mean efficiency: >85%
  • Prediction accuracy: >90%

Dependencies

All scripts require:

  • Trained ML models in ../models/
  • Training data in ../data/
  • Python packages: pandas, numpy, LightGBM, matplotlib (for plotting)

Grouped conv hardware validation scripts additionally require:

  • GPU hardware (gfx950 default)
  • Compiled kernels or JIT compilation support
  • tile_engine/ops/grouped_conv/ utilities