Files
Yaswanth Raparti 6989cf800c [rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for
 grouped conv (#6327)

## Motivation
The ML heuristic in dispatcher does not support grouped-conv operator
yet. In this PR, the support for fwd, bdw-data, and bwd-weight
grouped-conv kernels have been added. A tile_engine utility has also
been added to compile and run any selected kernel configuration through
dispatcher infrastructure.

## Technical Details

1. Tile engine utility is added to benchmark each shape with all the
possible kernel+tile_size combinations here -
[https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url)
2. New LGBM regressor models for grouped conv are added to models
directory. We have 3 separate models for fwd, bwd-data, and bwd-weights
[https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url)
3. Implemented lazy GPU initialization (dispatcher/python)
- **Issue**: ProcessPoolExecutor fork() + GPU context caused memory
access faults
- **Solution**: Mirror FMHA pattern - defer GPU initialization until
first run()
  - **Changes**:
- setup_multiple_grouped_conv_dispatchers() returns List[Path], not
loaded libs
    - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL
    - Added _ensure_initialized() method for lazy GPU loading
    - GPU context created only on first run() call
  - **Benefit**: Parallel compilation now works without GPU conflicts
4. Addressed few miscellaneous issues such as:
  - Fixed BF16->FP16 naming bug in the dispatcher wrapper
- Added new tile sizes, and comp_v5 pipeline to the arch spec to expand
the kernel selection
- Added automatic padding support for unsupported shapes in dispatcher
runner
- Created a single source of truth between tile_engine and dispatcher
about the architecture and tile_size details
- Build a validation scripts to compare oracle_best vs ml_heuristic
comparison

## Test Plan

1. Validated fwd, bwd-data, and bwd-weight kernels with both known and
unseen data sets with up to 300 problems.
2. Ensured that test cases are added in both dispatcher and tile_engine
to validate the heuristic.

## Test Result
Results on Unseen shapes validated on gfx950
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.05%**
  - Median Efficiency: **96.8%**
  - P10 Efficiency: **79.9%**

#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.8%**
  - Median Efficiency: **96.5%**
  - P10 Efficiency: **82.9%**

#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **96.1%**
  - Median Efficiency: **99.2%**
  - P10 Efficiency: **89.4%**

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-08 20:48:42 +00:00

647 lines
21 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# CK Tile Heuristics: ML-Based Kernel Selection
Fast, accurate kernel selection for CK Tile operations using LightGBM regression
with Origami-augmented feature engineering.
## What This Does
Instead of running all 4608+ kernel configurations on the GPU to find the best
one (exhaustive search taking ~46 seconds per shape), this system trains an ML
model that predicts TFLOPS for any (problem, kernel) pair in microseconds. It
scores all candidates instantly and picks the best kernel -- achieving 98.28%
of oracle-best TFLOPS efficiency across 108 tested shapes.
## Quick Start
### 1. Generate and convert benchmark data
**Step 1: Generate benchmark data**
```bash
python3 generate_benchmark_data.py \
--build_dir /path/to/build \
--output_dir data/fp16_original \
--dtype fp16 \
--layout rcr \
--num_build_jobs 4 \
--warmup 10 \
--repeat 50
```
This outputs JSON with all benchmark results.
**Step 2: Convert JSON to parquet training format**
```bash
python3 convert_json_to_parquet.py \
--input data/fp16_original/benchmark_results_fp16_rcr.json \
--output data/fp16_original/fp16_training_data.parquet \
--arch gfx950
```
The converter automatically fixes pad flags for `_mem` kernels and validates data.
**Alternative: Parse existing logs**
If you have raw benchmark logs from CK Tile:
```bash
python3 data_pipeline.py ck_tile_testrun_2.log \
-o data/gemm_universal_fp8_rcr_gfx950.parquet \
--arch gfx950 --capture_hw
```
### 2. Train a model
```bash
python3 train.py \
--data_dir data/ \
--out_dir models/gemm_universal_fp8_gfx950 \
--op gemm_universal --dtype fp8 --arch gfx950
```
**Note**: Trained models are automatically compressed to `.lgbm.gz` format to save space (~67% reduction). The Python tools automatically decompress them on first use and cache the decompressed version. For warm-start training, decompression happens automatically.
### 3. Evaluate
```bash
python3 evaluate.py \
--model_dir models/gemm_universal_fp8_gfx950 \
--data_dir data/ --op gemm_universal --dtype fp8
```
### 4. Predict the best kernel for a problem
```bash
python3 predict.py \
--model_dir models/gemm_universal_fp8_gfx950 \
--m 128 --n 1536 --k 7168 --layout rcr
```
### 5. Search for optimal configs (optional)
```bash
python3 search.py \
--model_dir models/gemm_universal_fp8_gfx950 \
--m 128 --n 1536 --k 7168 \
--strategy random --budget 500 --top_k 10
```
### 6. Using models in C++ (requires decompression)
C++ code uses the LightGBM C API which requires uncompressed `.lgbm` files. If you have compressed models (`.lgbm.gz`), decompress them first:
```bash
cd models/gemm_universal_fp16_gfx950
gunzip model_tflops.lgbm.gz
```
Then use in C++ examples:
```bash
cd dispatcher/build
./gemm_09_ml_heuristic --model ../heuristics/models/gemm_universal_fp16_gfx950/model_tflops.lgbm
```
**Note**: Python tools automatically decompress `.lgbm.gz` files on first use, so you can run Python scripts first to trigger decompression, then use the same models in C++.
## Architecture
```
Problem (M, N, K, dtype, layout)
|
v
FeatureEngine.extract_batch() <-- 55 features: problem, kernel, interaction, hardware
|
v
LGBMRegressor.predict() <-- predicts TFLOPS for each candidate kernel
|
v
Sort by predicted TFLOPS <-- rank all candidates
|
v
Select Top-1 kernel <-- 98.28% mean efficiency, <1ms inference
```
Three models are trained per (op, dtype, arch):
- **TFLOPS model** (primary): used for kernel ranking
- **Latency model** (auxiliary): for latency-sensitive workloads
- **Bandwidth model** (auxiliary): for memory-bound analysis
## File Inventory
| File | Purpose |
|---|---|
| `generate_benchmark_data.py` | Build and run benchmarks across ~25 diverse problem sizes, output JSON |
| `convert_json_to_parquet.py` | Convert benchmark JSON to parquet training format, fix `_mem` pad flags |
| `data_pipeline.py` | Parse raw benchmark logs into canonical parquet datasets |
| `feature_engine.py` | 55-feature extraction: problem, kernel, interaction, hardware profile |
| `train.py` | Multi-target LGBMRegressor training with GroupKFold CV, IHEM, warm-start |
| `predict.py` | Predictor class: predict TFLOPS/latency/bandwidth, rank kernels |
| `evaluate.py` | Full evaluation: global metrics, per-shape/layout/pipeline slices |
| `search.py` | Surrogate search: discrete DE, random top-K |
| `generate_wide_coverage.py` | Generate benchmark data across 706 diverse shapes |
| `generate_edge_dims.py` | Generate N=1, K=1, and other edge-case shapes |
| `DATA_GENERATION.md` | Detailed guide for building binaries and generating data |
| `plan.md` | Full design plan with architecture, milestones, and rationale |
## Features Used (55 total)
### Problem features (13)
`M, N, K, split_k, log2(M), log2(N), log2(K), log2(MNK),
arithmetic_intensity, aspect_ratio_mn, aspect_ratio_mk, aspect_ratio_nk, layout`
### Kernel features (17)
`tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n,
warp_tile_k, pipeline, scheduler, epilogue, pad_m, pad_n, pad_k, persistent,
num_warps, tile_volume, tile_mn, lds_usage_estimate, lds_usage_ratio`
### Interaction features (9)
`num_tiles_m, num_tiles_n, num_tiles_k, total_output_tiles,
tile_eff_m, tile_eff_n, tile_eff_k, overall_tile_efficiency, cu_utilization`
### Hardware profile features (12)
`hw_num_cus, hw_simds_per_cu, hw_total_simds, hw_shader_engines,
hw_max_clock_mhz, hw_max_waves_per_cu, hw_wavefront_size, hw_lds_capacity,
hw_l1_cache_kb, hw_l2_cache_kb, hw_l3_cache_kb, hw_num_xcd`
## Model Performance
### fp8 RCR, gfx950
| Metric | 108 shapes (original) | 168 shapes (wide coverage) |
|---|---|---|
| Mean TFLOPS Efficiency | 98.28% | 97.51% |
| P10 TFLOPS Efficiency | 94.64% | 93.89% |
| tiny_m (M=1) Efficiency | 95.57% | 96.04% |
| R2 (TFLOPS) | 0.997 | 0.993 |
### fp16 RCR, gfx950
Trained on 25 shapes, 1,024 kernels, 21,920 valid benchmarks.
| Metric | Value |
|---|---|
| Mean TFLOPS Efficiency | 99.36% |
| P10 TFLOPS Efficiency | 98.05% |
| P50 TFLOPS Efficiency | 100.00% |
| Min Efficiency | 95.45% |
| NDCG@1 | 64.00% |
| Top-5 Hit Rate | 88.00% |
**Shape Family Breakdown:**
| Shape Family | Mean Eff | P10 Eff | Shapes |
|---|---|---|---|
| Large M (M≥1024) | 99.54% | 99.07% | 4 |
| Medium M (128≤M<1024) | 99.62% | 98.74% | 7 |
| Small M (8≤M<128) | 98.82% | 96.22% | 8 |
| Tiny M (M<8) | 99.65% | 98.96% | 6 |
**Pipeline Breakdown:**
| Pipeline | Mean Eff | P10 Eff |
|---|---|---|
| compv3 | 99.75% | 99.09% |
| compv4 | 99.40% | 98.54% |
| mem | 99.08% | 96.59% |
Training uses `log1p(TFLOPS)` as the target by default, which normalizes the
scale across shapes spanning 0.02 to 2230 TFLOPS. This was the key finding
that improved tiny-M shapes from 84% to 96% efficiency. See
[LEARNINGS.md](LEARNINGS.md) for details.
## Validation
Training uses `GroupKFold(n_splits=5)` with group key `(M, N, K)` to ensure
the model is evaluated on shapes it has never seen during training. Layout is
excluded from the group key to force cross-layout generalization.
## Incremental Training (Warm Start)
When new benchmark data arrives, update the model without retraining from scratch:
```bash
python3 train.py \
--data_dir data/ \
--out_dir models/v2 \
--warm_start models/gemm_universal_fp8_gfx950 \
--warm_start_n_estimators 200
```
This adds 200 new trees on top of the existing model. Feature schemas must
match exactly (automatically enforced).
## Extending to New Ops
Adding support for a new operation (e.g., `gemm_streamk`, `grouped_conv`):
1. **Build binaries**: `ninja -C build benchmark_gemm_streamk_fp8_rcr`
2. **Subclass `FeatureEngine`**: add op-specific features (e.g., StreamK split factor)
3. **Generate data**: run benchmarks across diverse shapes
4. **Train**: `python3 train.py --op gemm_streamk --dtype fp8 --data_dir data/ --out_dir models/`
The training, evaluation, prediction, and search infrastructure is fully
op-agnostic -- only the feature engine needs a new subclass.
## Tests
102 tests covering all modules:
```bash
python3 -m pytest tests/ -v
```
Test coverage includes:
- Log parsing with malformed JSON, empty logs, single-kernel shapes
- Feature formula correctness (tile efficiency, LDS usage, arithmetic intensity)
- Corner-case shapes: M=1, N=1, K=1, prime dimensions, 20480x7168x256
- Batch vs single extraction parity
- Parameter space validation and projection
- Predictor: single/batch prediction, ranking, missing models, empty inputs
- Training: group keys, efficiency computation, warm-start, feature compatibility
- Search: random, DE, config validity, determinism
## Documentation
- **[README.md](README.md)**: This file -- quick start, architecture, performance
- **[DATA_GENERATION.md](DATA_GENERATION.md)**: Complete guide for building tile engine
binaries, running benchmarks, managing datasets, and troubleshooting
- **[LEARNINGS.md](LEARNINGS.md)**: Empirical findings and design decisions (log-transform,
IHEM results, tiny-M analysis, feature importance, N=1/K=1 edge cases)
## Grouped Convolution ML Heuristics
### Overview
ML-based kernel selection for grouped convolution operations (forward, bwd_data, bwd_weight) on gfx950 with bf16 precision.
### Results
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
- Mean Efficiency: **93.05%**
- Median Efficiency: **96.8%**
- P10 Efficiency: **79.9%**
#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
- Mean Efficiency: **93.8%**
- Median Efficiency: **96.5%**
- P10 Efficiency: **82.9%**
- Top-1 Accuracy: **25.2%** (37/147 problems)
#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
- Mean Efficiency: **96.1%**
- Median Efficiency: **99.2%**
- P10 Efficiency: **89.4%**
- Top-1 Accuracy: **32.7%** (51/156 problems)
### Training Data Generation
Extended synthetic problem sets for backward passes cover diverse scenarios:
- Small spatial (7×7, 14×14) + various channels (64-1024)
- Medium spatial (28×28, 32×32, 56×56) + various channels (32-512)
- Large spatial (112×112) + small/medium channels (16-256)
- Asymmetric C/K combinations
- Small and large batch sizes (N=1 to 128)
- Grouped convolutions (G=2, 4, 8)
- Depthwise convolutions (G=C=K)
- Stride-2 downsampling
### Model Files
Trained models stored in:
- `models/grouped_conv_forward_bf16_gfx950/`
- `models/grouped_conv_bwd_data_bf16_gfx950/`
- `models/grouped_conv_bwd_weight_bf16_gfx950/`
Each contains:
- `model_tflops.lgbm` - LightGBM model (compressed with gzip)
- `feature_spec.json` - Feature configuration
- `cv_metrics_tflops.json` - Cross-validation metrics
- `feature_importances_tflops.json` - Feature importance rankings
Models are automatically decompressed on first use.
### Usage
```python
import pandas as pd
from predict import Predictor
from feature_engine_grouped_conv import GroupedConvFeatureEngine
# Define problem
problem = {
'N': 16, 'C': 256, 'K': 128, 'G': 1,
'Hi': 28, 'Wi': 28, 'Y': 3, 'X': 3,
'stride_h': 1, 'stride_w': 1,
'pad_h': 1, 'pad_w': 1,
'dtype': 'bf16'
}
# Load model with the grouped-conv feature engine
predictor = Predictor(
"models/grouped_conv_bwd_data_bf16_gfx950",
feature_engine=GroupedConvFeatureEngine(),
)
# Build the candidate kernel pool from a training/holdout parquet
# (each row carries kernel_name + every kernel-config column the engine needs).
df = pd.read_parquet("data/grouped_conv_bwd_data/bwd_data.parquet")
configs = [df[df["kernel_name"] == kn].iloc[0].to_dict()
for kn in df["kernel_name"].unique()]
# Rank candidates by predicted TFLOPS
ranked = predictor.rank_kernels(problem, configs)
best_name, best_tflops = ranked[0]
print(f"Best kernel: {best_name}")
print(f"Predicted TFLOPS: {best_tflops:.2f}")
```
### Validation
Run validation against oracle benchmarks:
```bash
cd projects/composablekernel/tile_engine/ops/grouped_conv
python3 validate_ml_vs_oracle.py --variant bwd_data
python3 validate_ml_vs_oracle.py --variant bwd_weight
```
### Solution Architecture (Grouped Conv)
```
Problem Config → Feature Engineering (83 features) → LightGBM Model → Predict TFLOPS → Select Best Kernel
↓ - Problem features (38) ↓ ↓
(N,C,K,G,H,W,Y,X) - Kernel features (12) Trained on <1ms total
- Interactions (21) 48K samples latency
- Hardware (12) 1372 shapes
```
### Feature Engineering (`feature_engine_grouped_conv.py`)
**83 engineered features**:
- **Problem Features (38)**: Raw params (N,C,K,G,Hi,Wi,Y,X,strides,pads), derived (Ho,Wo), log-scale transforms, arithmetic intensity, aspect ratios, channel/group metrics
- **Kernel Features (12)**: Block size, GEMM tiles (M,N), pipeline type, num warps, tile volume, LDS usage
- **Interaction Features (21)**: Tile efficiency (M,N,K), block-tile ratios, CU utilization, problem-tile comparisons, output tile counts
- **Hardware Features (12)**: GFX950 specs - CUs (304), SIMDs, clocks, wavefront size, cache sizes (L1/L2/L3), XCD count
### Latency
- **Selection Time**: <1ms
- **vs Oracle**: 30-60 seconds
- **Speedup**: 30,000-60,000×
### Model Size
- **Compressed**: 2-8 MB (.lgbm.gz)
- **Runtime Memory**: ~50 MB
- **Feature Array**: <6 KB per problem
### Training Pipeline
```bash
# 1. Collect data: Run all kernels on GPU for diverse problem set
python grouped_conv_full_benchmark.py --problem_set forward_training_miopen
# 2. Preprocess: Convert CSV to Parquet
python convert_csv_to_parquet.py --input train.csv --output train.parquet
# 3. Train model: LightGBM with cross-validation
python train.py --operation grouped_conv --direction forward --dtype bf16
# 4. Validate: Sanity-check on training shapes
python validation/grouped_conv/validate_training_shapes.py
```
### Validation Framework
| Test | Purpose | Shapes | Runtime | Target |
|------|---------|--------|---------|--------|
| `validate_training_shapes.py` | Sanity check on training data | 5 | 5-10 min | >95% efficiency |
| `validate_backward_models.py` | Backward pass prediction quality | 7 | <1 min | Reasonable predictions |
### File Structure (Grouped Conv)
```
dispatcher/heuristics/
├── train.py # Training script
├── feature_engine_grouped_conv.py # Feature engineering
├── predict.py # Generic Predictor (use with GroupedConvFeatureEngine)
├── models/
│ ├── grouped_conv_forward_bf16_gfx950/
│ │ ├── model_tflops.lgbm.gz # Compressed model
│ │ ├── feature_spec.json # Feature definitions
│ │ └── train_manifest.json # Training metadata
│ ├── grouped_conv_bwd_data_bf16_gfx950/
│ └── grouped_conv_bwd_weight_bf16_gfx950/
└── validation/
├── validate_ml_heuristic.py # GEMM validation
└── grouped_conv/
├── validate_training_shapes.py
└── validate_backward_models.py
tile_engine/ops/grouped_conv/
├── grouped_conv_full_benchmark.py # Data collection
├── run_one_grouped_conv_kernel.py # Single kernel runner
├── compare_ml_vs_oracle.py # Analysis tool
└── problems/
├── forward_training_miopen.py # Training problem sets
└── forward_validation_300.py # Test problem sets
```
### C++/Python Integration
- **C++ API**: `GroupedConvRegistry::get_solution(problem)`
- **Python API**: `registry.run(problem, input, weight)`
- Automatic fallback to exhaustive search if ML unavailable
```python
from ck_tile.dispatcher import GroupedConvRegistry, GroupedConvProblem
# Define problem
problem = GroupedConvProblem(
N=2, C=128, K=256, G=1,
Hi=28, Wi=28, Y=3, X=3,
stride_h=1, stride_w=1, pad_h=1, pad_w=1,
dtype='bf16', direction='forward'
)
# ML heuristic automatically selects best kernel
registry = GroupedConvRegistry(arch='gfx950')
result = registry.run(problem, input_tensor, weight_tensor)
```
### Key Innovations
1. **Comprehensive Feature Engineering**: 83 features capture problem-kernel-hardware interactions
2. **Tier-1 Extended Training**: 1,372 shapes (vs 185 baseline) for better edge case coverage
3. **Compressed Models**: LGBM.gz reduces size 8-10× without accuracy loss
4. **Operation-Specific Models**: Separate optimizations for forward/backward passes
5. **Validation Framework**: Automated testing on unseen production workloads
## Verifying Training Quality
To quickly verify that a refactored `train.py` produces models with equivalent quality to the production training script:
```bash
cd /workspace/rocm-libraries/projects/composablekernel/dispatcher/heuristics
# Run automated test (uses 3-fold CV for speed)
./test_model_quality.sh
```
This script will:
1. Validate current production model on 300 validation shapes
2. Train a new model using refactored `train.py`
3. Validate the new model on the same 300 shapes
4. Compare predictions between old and new models
**Expected Output:**
```
Step 4: Comparing predictions...
================================================================================
PREDICTION COMPARISON: bwd_data
================================================================================
Kernel Selection Agreement: 215/300 (71.7%)
Metric Old Model New Model Delta
----------------------------------------------------------------------
Mean Efficiency 0.9380 0.9380 +0.0000
Median Efficiency 0.9650 0.9650 +0.0000
P10 Efficiency 0.8290 0.8290 +0.0000
Per-Problem Changes:
Improved: 0 (0.0%)
Same: 300 (100.0%)
Degraded: 0 (0.0%)
================================================================================
✓ PASS: New model maintains quality!
================================================================================
```
### Model Selection Process
The validation script (`validate_ml_vs_oracle.py`) automatically selects the model based on:
**Variant:** `--variant {forward|bwd_data|bwd_weight}`
**Model Path:** `dispatcher/heuristics/models/grouped_conv_{variant}_bf16_gfx950/`
For example:
- `--variant bwd_data` → uses `models/grouped_conv_bwd_data_bf16_gfx950/model_tflops.lgbm`
- `--variant bwd_weight` → uses `models/grouped_conv_bwd_weight_bf16_gfx950/model_tflops.lgbm`
### Manual Step-by-Step Comparison
If you want to run each step manually:
#### Step 1: Validate Current Model
```bash
cd tile_engine/ops/grouped_conv
python3 validate_ml_vs_oracle.py \
--operation grouped_conv \
--variant bwd_data \
--problem-set bwd_data_model_crawler_validation \
--oracle-csv bwd_data_model_crawler_oracle.csv \
--save-predictions /tmp/bwd_data_old_predictions.csv
```
This uses the model at: `dispatcher/heuristics/models/grouped_conv_bwd_data_bf16_gfx950/`
#### Step 2: Train New Model
```bash
cd ../../dispatcher/heuristics
python3 train.py \
--operation grouped_conv \
--data_dir data/bwd_data_training \
--out_dir /tmp/grouped_conv_bwd_data_bf16_gfx950_new \
--dtype bf16 \
--arch gfx950 \
--targets tflops \
--n_splits 5
```
#### Step 3: Temporarily Swap Models
```bash
# Backup current model
mv models/grouped_conv_bwd_data_bf16_gfx950 /tmp/backup
# Use new model for validation
cp -r /tmp/grouped_conv_bwd_data_bf16_gfx950_new models/grouped_conv_bwd_data_bf16_gfx950
```
#### Step 4: Validate New Model
```bash
cd ../../tile_engine/ops/grouped_conv
python3 validate_ml_vs_oracle.py \
--operation grouped_conv \
--variant bwd_data \
--problem-set bwd_data_model_crawler_validation \
--oracle-csv bwd_data_model_crawler_oracle.csv \
--save-predictions /tmp/bwd_data_new_predictions.csv
```
#### Step 5: Restore Original Model
```bash
cd ../../dispatcher/heuristics
rm -rf models/grouped_conv_bwd_data_bf16_gfx950
mv /tmp/backup models/grouped_conv_bwd_data_bf16_gfx950
```
#### Step 6: Compare Predictions
```bash
cd ../../tile_engine/ops/grouped_conv
python3 compare_model_predictions.py \
--old-predictions /tmp/bwd_data_old_predictions.csv \
--new-predictions /tmp/bwd_data_new_predictions.csv \
--variant bwd_data
```
### Acceptance Criteria
A new model passes quality validation if:
1. ✓ Mean efficiency is within 0.5% of baseline
2. ✓ Median efficiency is within 0.5% of baseline
3. ✓ P10 efficiency is within 2% of baseline
4. ✓ No catastrophic regressions (efficiency drops >10% on any problem)
### Troubleshooting
#### Different Predictions on Same Model
**Unlikely** - If the same model file produces different predictions, check:
- Feature engine version (should be 83 features)
- Problem encoding (verify problem_to_dict matches)
- Predictor initialization (check log transform handling)
#### Quality Regression
If new model has lower efficiency:
1. Check CV metrics in training log - should be similar to baseline
2. Verify identical training data (check parquet row counts)
3. Compare feature importance - should be similar patterns
4. Inspect specific regression cases in comparison output