Files
composable_kernel/dispatcher/heuristics/validation/grouped_conv/validate_backward_models.py
Yaswanth Raparti 6989cf800c [rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for
 grouped conv (#6327)

## Motivation
The ML heuristic in dispatcher does not support grouped-conv operator
yet. In this PR, the support for fwd, bdw-data, and bwd-weight
grouped-conv kernels have been added. A tile_engine utility has also
been added to compile and run any selected kernel configuration through
dispatcher infrastructure.

## Technical Details

1. Tile engine utility is added to benchmark each shape with all the
possible kernel+tile_size combinations here -
[https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url)
2. New LGBM regressor models for grouped conv are added to models
directory. We have 3 separate models for fwd, bwd-data, and bwd-weights
[https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url)
3. Implemented lazy GPU initialization (dispatcher/python)
- **Issue**: ProcessPoolExecutor fork() + GPU context caused memory
access faults
- **Solution**: Mirror FMHA pattern - defer GPU initialization until
first run()
  - **Changes**:
- setup_multiple_grouped_conv_dispatchers() returns List[Path], not
loaded libs
    - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL
    - Added _ensure_initialized() method for lazy GPU loading
    - GPU context created only on first run() call
  - **Benefit**: Parallel compilation now works without GPU conflicts
4. Addressed few miscellaneous issues such as:
  - Fixed BF16->FP16 naming bug in the dispatcher wrapper
- Added new tile sizes, and comp_v5 pipeline to the arch spec to expand
the kernel selection
- Added automatic padding support for unsupported shapes in dispatcher
runner
- Created a single source of truth between tile_engine and dispatcher
about the architecture and tile_size details
- Build a validation scripts to compare oracle_best vs ml_heuristic
comparison

## Test Plan

1. Validated fwd, bwd-data, and bwd-weight kernels with both known and
unseen data sets with up to 300 problems.
2. Ensured that test cases are added in both dispatcher and tile_engine
to validate the heuristic.

## Test Result
Results on Unseen shapes validated on gfx950
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.05%**
  - Median Efficiency: **96.8%**
  - P10 Efficiency: **79.9%**

#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.8%**
  - Median Efficiency: **96.5%**
  - P10 Efficiency: **82.9%**

#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **96.1%**
  - Median Efficiency: **99.2%**
  - P10 Efficiency: **89.4%**

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-08 20:48:42 +00:00

182 lines
7.8 KiB
Python

#!/usr/bin/env python3
"""
Validate backward pass ML models using actual training problem shapes.
Tests prediction quality on representative problems from the training set.
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent)) # heuristics
from predict import Predictor
from feature_engine_grouped_conv import GroupedConvFeatureEngine
# Representative test problems from training sets
BWD_DATA_TEST_PROBLEMS = [
# Small problems (from bwd_data_training.py)
{'N': 32, 'C': 1, 'K': 1, 'G': 1, 'Hi': 5, 'Wi': 5, 'Y': 3, 'X': 3, 'stride_h': 1, 'stride_w': 1, 'pad_h': 0, 'pad_w': 0},
{'N': 64, 'C': 1, 'K': 1, 'G': 1, 'Hi': 5, 'Wi': 5, 'Y': 3, 'X': 3, 'stride_h': 1, 'stride_w': 1, 'pad_h': 0, 'pad_w': 0},
{'N': 128, 'C': 256, 'K': 128, 'G': 1, 'Hi': 32, 'Wi': 32, 'Y': 3, 'X': 3, 'stride_h': 1, 'stride_w': 1, 'pad_h': 1, 'pad_w': 1},
{'N': 2, 'C': 128, 'K': 256, 'G': 1, 'Hi': 32, 'Wi': 32, 'Y': 3, 'X': 3, 'stride_h': 1, 'stride_w': 1, 'pad_h': 1, 'pad_w': 1},
{'N': 2, 'C': 256, 'K': 256, 'G': 1, 'Hi': 14, 'Wi': 14, 'Y': 1, 'X': 1, 'stride_h': 1, 'stride_w': 1, 'pad_h': 0, 'pad_w': 0},
]
BWD_WEIGHT_TEST_PROBLEMS = [
# Small problems (from bwd_weight_synthetic.py)
{'N': 1, 'C': 64, 'K': 64, 'G': 1, 'Hi': 7, 'Wi': 7, 'Y': 1, 'X': 1, 'stride_h': 1, 'stride_w': 1, 'pad_h': 0, 'pad_w': 0},
{'N': 2, 'C': 64, 'K': 128, 'G': 1, 'Hi': 14, 'Wi': 14, 'Y': 1, 'X': 1, 'stride_h': 1, 'stride_w': 1, 'pad_h': 0, 'pad_w': 0},
{'N': 8, 'C': 128, 'K': 128, 'G': 1, 'Hi': 28, 'Wi': 28, 'Y': 3, 'X': 3, 'stride_h': 1, 'stride_w': 1, 'pad_h': 1, 'pad_w': 1},
# Medium problems
{'N': 16, 'C': 128, 'K': 256, 'G': 1, 'Hi': 14, 'Wi': 14, 'Y': 3, 'X': 3, 'stride_h': 1, 'stride_w': 1, 'pad_h': 1, 'pad_w': 1},
{'N': 32, 'C': 256, 'K': 512, 'G': 1, 'Hi': 28, 'Wi': 28, 'Y': 3, 'X': 3, 'stride_h': 1, 'stride_w': 1, 'pad_h': 1, 'pad_w': 1},
# Large problems
{'N': 64, 'C': 512, 'K': 1024, 'G': 1, 'Hi': 14, 'Wi': 14, 'Y': 3, 'X': 3, 'stride_h': 2, 'stride_w': 2, 'pad_h': 1, 'pad_w': 1},
{'N': 128, 'C': 1024, 'K': 2048, 'G': 1, 'Hi': 28, 'Wi': 28, 'Y': 5, 'X': 5, 'stride_h': 1, 'stride_w': 1, 'pad_h': 2, 'pad_w': 2},
]
# Backward kernel configurations (compv3, mem)
BACKWARD_KERNELS = [
{'block_size': 16, 'gemm_m_per_block': 64, 'gemm_n_per_block': 64, 'pipeline': 'compv3'},
{'block_size': 16, 'gemm_m_per_block': 64, 'gemm_n_per_block': 64, 'pipeline': 'mem'},
{'block_size': 32, 'gemm_m_per_block': 64, 'gemm_n_per_block': 64, 'pipeline': 'compv3'},
{'block_size': 32, 'gemm_m_per_block': 64, 'gemm_n_per_block': 64, 'pipeline': 'mem'},
{'block_size': 32, 'gemm_m_per_block': 128, 'gemm_n_per_block': 64, 'pipeline': 'compv3'},
{'block_size': 32, 'gemm_m_per_block': 128, 'gemm_n_per_block': 64, 'pipeline': 'mem'},
{'block_size': 64, 'gemm_m_per_block': 64, 'gemm_n_per_block': 64, 'pipeline': 'compv3'},
{'block_size': 64, 'gemm_m_per_block': 64, 'gemm_n_per_block': 64, 'pipeline': 'mem'},
{'block_size': 64, 'gemm_m_per_block': 128, 'gemm_n_per_block': 64, 'pipeline': 'compv3'},
{'block_size': 64, 'gemm_m_per_block': 128, 'gemm_n_per_block': 64, 'pipeline': 'mem'},
{'block_size': 128, 'gemm_m_per_block': 128, 'gemm_n_per_block': 64, 'pipeline': 'compv3'},
{'block_size': 128, 'gemm_m_per_block': 128, 'gemm_n_per_block': 64, 'pipeline': 'mem'},
]
def format_problem(p):
"""Format problem for display."""
Ho = (p['Hi'] + 2*p['pad_h'] - p['Y']) // p['stride_h'] + 1
Wo = (p['Wi'] + 2*p['pad_w'] - p['X']) // p['stride_w'] + 1
return f"N={p['N']:3d} C={p['C']:4d} K={p['K']:4d} {p['Hi']:2d}x{p['Wi']:2d}{Ho:2d}x{Wo:2d} f{p['Y']}x{p['X']}"
def validate_variant(variant, test_problems, model_dir):
"""Validate a specific variant (bwd_data or bwd_weight)."""
print("=" * 100)
print(f" VALIDATING {variant.upper()} MODEL")
print("=" * 100)
print(f" Model: {model_dir}")
print(f" Problems: {len(test_problems)}")
print()
# Load model
feature_engine = GroupedConvFeatureEngine()
predictor = Predictor(model_dir, feature_engine=feature_engine)
print(" ✓ Model loaded successfully")
print()
# Test each problem
print(f" {'Problem':<45} {'Best Kernel':<25} {'Pred TFLOPS':>12} {'Top-3 Kernels':<35}")
print(" " + "-" * 117)
all_predictions = []
for problem in test_problems:
# Add dtype
problem_with_dtype = {**problem, 'dtype': 'bf16'}
# Predict for all kernels
predictions = []
for kernel in BACKWARD_KERNELS:
tflops = predictor.predict_tflops(problem_with_dtype, kernel)
predictions.append({
'tflops': tflops,
'kernel': f"{kernel['block_size']}x{kernel['gemm_m_per_block']}x{kernel['gemm_n_per_block']}_{kernel['pipeline']}",
'pipeline': kernel['pipeline']
})
# Sort by TFLOPS
predictions.sort(key=lambda x: x['tflops'], reverse=True)
all_predictions.append(predictions)
# Format output
prob_str = format_problem(problem)
best = predictions[0]
top3_str = f"{predictions[0]['kernel'][:18]}, {predictions[1]['kernel'][:18]}, {predictions[2]['kernel'][:18]}"
print(f" {prob_str:<45} {best['kernel']:<25} {best['tflops']:>12.2f} {top3_str:<35}")
print()
print(" " + "=" * 117)
# Summary statistics
print()
print(" SUMMARY STATISTICS:")
print(f" {'Metric':<30} {'Value':>15}")
print(" " + "-" * 47)
# Average predicted TFLOPS
avg_best_tflops = sum(p[0]['tflops'] for p in all_predictions) / len(all_predictions)
print(f" {'Avg Best Predicted TFLOPS':<30} {avg_best_tflops:>15.2f}")
# Min/max predicted TFLOPS
min_tflops = min(p[0]['tflops'] for p in all_predictions)
max_tflops = max(p[0]['tflops'] for p in all_predictions)
print(f" {'Min Predicted TFLOPS':<30} {min_tflops:>15.2f}")
print(f" {'Max Predicted TFLOPS':<30} {max_tflops:>15.2f}")
# Pipeline preference (how often each pipeline is selected)
compv3_count = sum(1 for p in all_predictions if p[0]['pipeline'] == 'compv3')
mem_count = sum(1 for p in all_predictions if p[0]['pipeline'] == 'mem')
print(f" {'Best pipeline: compv3':<30} {compv3_count:>15} ({100*compv3_count/len(all_predictions):.1f}%)")
print(f" {'Best pipeline: mem':<30} {mem_count:>15} ({100*mem_count/len(all_predictions):.1f}%)")
# Top-3 accuracy approximation (how often best kernel is significantly better than 2nd/3rd)
gaps = []
for preds in all_predictions:
gap = (preds[0]['tflops'] - preds[2]['tflops']) / preds[0]['tflops'] * 100
gaps.append(gap)
avg_gap = sum(gaps) / len(gaps)
print(f" {'Avg gap: best vs 3rd (%)':<30} {avg_gap:>15.1f}%")
print()
def main():
print()
print("=" * 100)
print(" BACKWARD PASS ML MODEL VALIDATION")
print(" Testing predictions on training problem shapes")
print("=" * 100)
print()
# Model directory is in heuristics/models/, not validation/grouped_conv/models/
heuristics_dir = Path(__file__).parent.parent.parent # Go up from validation/grouped_conv/ to heuristics/
# Validate bwd_data
bwd_data_model = heuristics_dir / "models" / "grouped_conv_bwd_data_bf16_gfx950"
if bwd_data_model.exists():
validate_variant("bwd_data", BWD_DATA_TEST_PROBLEMS, bwd_data_model)
else:
print(f" ⚠ BWD_DATA model not found: {bwd_data_model}")
print()
# Validate bwd_weight
bwd_weight_model = heuristics_dir / "models" / "grouped_conv_bwd_weight_bf16_gfx950"
if bwd_weight_model.exists():
validate_variant("bwd_weight", BWD_WEIGHT_TEST_PROBLEMS, bwd_weight_model)
else:
print(f" ⚠ BWD_WEIGHT model not found: {bwd_weight_model}")
print()
print("=" * 100)
print(" VALIDATION COMPLETE")
print("=" * 100)
print()
if __name__ == "__main__":
main()