[CK Tile] Rule-based configuration generation in CK Dispatcher codegen (#8157) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation The CK Tile Dispatcher code generation for CK Tile Profiler relies on flat JSON files to list the generated configurations. This approach has the following problems - The JSON files are verbose - The JSON files get easily out of sync with the CK Builder .config files from which they were generated from. - The JSON file based configuration make it hard to list explicitly the rules that govern the instance generation. ## Technical Details Replaced the JSON files with a rule based configuration. To preserve the existing functionality, the `profiler` and the `tests` instance sets are generated directly from the CK Builder config files. The JSON config files are removed from source control, and the "on-the-fly" generation guarantees that the Dispatcher codegen uses up to date configurations. This is PR introduces six different rule sets for the CK Tile Dispatcher code generation 1. `profiler`: matches with the old JSON set of profiler configurations. 2. `tests`: matches with the old JSON set of tests configurations. 3. `full`: full configuration set created from a rule-based config selection 4. `full-tests`: a subset of `full` for generating configurations for convolution integration tests. 5. `tiny`: a subset of `full-tests` to produce the minimal set of configurations to test the Dispatcher codegen. 6. `default`: the default rules, which corresponds to the existing heuristic rules for configuration selection. This ensures that ML based kernel selection doesn't get broken. The main use of the `full` rule set is to define a reasonable solution space for the possible implicit GEMM configurations. We start from the configurations that allowed by the device architecture. The `full` rule set defines the relevant tile sizes for each convolution direction. From the tile size we have a curated mapping to the number of waves over the different GEMM axes, i.e., we describe how many waves each GEMM dimensions corresponds to. The GEMM-K wave tile dimension can be computed from the other parameters and does not need to be listed explicitly. An orthogonal axis to the tiling strategy is the vectorization strategy. This mainly defined by the data type and hardware as in general, we want to use the maximum possible load widths. The maximum sizes for each convolution direction variant are defined by the implicit GEMM matrix dimensions. For cases where have a low number of channels per convolution group, we need smaller vector load sizes. These are captured by the `VecStrategy` enumeration in the codegen rules. The problem with the rule based configuration selection is that we "over generate" configurations. The old JSON configurations compose approximately 25% of all configuration that the `full` rule set creates. The additional configurations are valid, but they many not provide any performance benefits. Hence, we keep the `profiler` and `tests` rule set for now to avoid building an excessive amount configurations by default. The `full` rule set can be taken into use by specifying CMake configuration flag `-D DISPATCHER_RULE_SET=full`. By default, the `tests` rule set is used, i.e., we don't change the existing bahaviour. ## Test Plan Added a new stage in the CI/CD pipeline that ensures the Dispatcher codegen rules are up to date. Otherwise the functionality is covered by the existing CI/CD tests. There are no functional changes to the convolution kernels. Only how the different instances are generated. ## Test Result If the CK Tile conv instances build without errors, the Dispatcher codegen is generating valid code. If all tests in CI/CD pipeline are passing, the Dispatcher codegen generates valid instances. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
CK Tile Heuristics: ML-Based Kernel Selection
Fast, accurate kernel selection for CK Tile operations using LightGBM regression with Origami-augmented feature engineering.
What This Does
Instead of running all 4608+ kernel configurations on the GPU to find the best one (exhaustive search taking ~46 seconds per shape), this system trains an ML model that predicts TFLOPS for any (problem, kernel) pair in microseconds. It scores all candidates instantly and picks the best kernel -- achieving 98.28% of oracle-best TFLOPS efficiency across 108 tested shapes.
Quick Start
1. Generate and convert benchmark data
Step 1: Generate benchmark data
python3 generate_benchmark_data.py \
--build_dir /path/to/build \
--output_dir data/fp16_original \
--dtype fp16 \
--layout rcr \
--num_build_jobs 4 \
--warmup 10 \
--repeat 50
This outputs JSON with all benchmark results.
Step 2: Convert JSON to parquet training format
python3 convert_json_to_parquet.py \
--input data/fp16_original/benchmark_results_fp16_rcr.json \
--output data/fp16_original/fp16_training_data.parquet \
--arch gfx950
The converter automatically fixes pad flags for _mem kernels and validates data.
Alternative: Parse existing logs
If you have raw benchmark logs from CK Tile:
python3 data_pipeline.py ck_tile_testrun_2.log \
-o data/gemm_universal_fp8_rcr_gfx950.parquet \
--arch gfx950 --capture_hw
2. Train a model
python3 train.py \
--data_dir data/ \
--out_dir models/gemm_universal_fp8_gfx950 \
--op gemm_universal --dtype fp8 --arch gfx950
Note: Trained models are automatically compressed to .lgbm.gz format to save space (~67% reduction). The Python tools automatically decompress them on first use and cache the decompressed version. For warm-start training, decompression happens automatically.
3. Evaluate
python3 evaluate.py \
--model_dir models/gemm_universal_fp8_gfx950 \
--data_dir data/ --op gemm_universal --dtype fp8
4. Predict the best kernel for a problem
python3 predict.py \
--model_dir models/gemm_universal_fp8_gfx950 \
--m 128 --n 1536 --k 7168 --layout rcr
5. Search for optimal configs (optional)
python3 search.py \
--model_dir models/gemm_universal_fp8_gfx950 \
--m 128 --n 1536 --k 7168 \
--strategy random --budget 500 --top_k 10
6. Using models in C++ (requires decompression)
C++ code uses the LightGBM C API which requires uncompressed .lgbm files. If you have compressed models (.lgbm.gz), decompress them first:
cd models/gemm_universal_fp16_gfx950
gunzip model_tflops.lgbm.gz
Then use in C++ examples:
cd dispatcher/build
./gemm_09_ml_heuristic --model ../heuristics/models/gemm_universal_fp16_gfx950/model_tflops.lgbm
Note: Python tools automatically decompress .lgbm.gz files on first use, so you can run Python scripts first to trigger decompression, then use the same models in C++.
Architecture
Problem (M, N, K, dtype, layout)
|
v
FeatureEngine.extract_batch() <-- 55 features: problem, kernel, interaction, hardware
|
v
LGBMRegressor.predict() <-- predicts TFLOPS for each candidate kernel
|
v
Sort by predicted TFLOPS <-- rank all candidates
|
v
Select Top-1 kernel <-- 98.28% mean efficiency, <1ms inference
Three models are trained per (op, dtype, arch):
- TFLOPS model (primary): used for kernel ranking
- Latency model (auxiliary): for latency-sensitive workloads
- Bandwidth model (auxiliary): for memory-bound analysis
File Inventory
| File | Purpose |
|---|---|
generate_benchmark_data.py |
Build and run benchmarks across ~25 diverse problem sizes, output JSON |
convert_json_to_parquet.py |
Convert benchmark JSON to parquet training format, fix _mem pad flags |
data_pipeline.py |
Parse raw benchmark logs into canonical parquet datasets |
feature_engine.py |
55-feature extraction: problem, kernel, interaction, hardware profile |
train.py |
Multi-target LGBMRegressor training with GroupKFold CV, IHEM, warm-start |
predict.py |
Predictor class: predict TFLOPS/latency/bandwidth, rank kernels |
evaluate.py |
Full evaluation: global metrics, per-shape/layout/pipeline slices |
search.py |
Surrogate search: discrete DE, random top-K |
generate_wide_coverage.py |
Generate benchmark data across 706 diverse shapes |
generate_edge_dims.py |
Generate N=1, K=1, and other edge-case shapes |
DATA_GENERATION.md |
Detailed guide for building binaries and generating data |
plan.md |
Full design plan with architecture, milestones, and rationale |
Features Used (55 total)
Problem features (13)
M, N, K, split_k, log2(M), log2(N), log2(K), log2(MNK), arithmetic_intensity, aspect_ratio_mn, aspect_ratio_mk, aspect_ratio_nk, layout
Kernel features (17)
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k, pipeline, scheduler, epilogue, pad_m, pad_n, pad_k, persistent, num_warps, tile_volume, tile_mn, lds_usage_estimate, lds_usage_ratio
Interaction features (9)
num_tiles_m, num_tiles_n, num_tiles_k, total_output_tiles, tile_eff_m, tile_eff_n, tile_eff_k, overall_tile_efficiency, cu_utilization
Hardware profile features (12)
hw_num_cus, hw_simds_per_cu, hw_total_simds, hw_shader_engines, hw_max_clock_mhz, hw_max_waves_per_cu, hw_wavefront_size, hw_lds_capacity, hw_l1_cache_kb, hw_l2_cache_kb, hw_l3_cache_kb, hw_num_xcd
Model Performance
fp8 RCR, gfx950
| Metric | 108 shapes (original) | 168 shapes (wide coverage) |
|---|---|---|
| Mean TFLOPS Efficiency | 98.28% | 97.51% |
| P10 TFLOPS Efficiency | 94.64% | 93.89% |
| tiny_m (M=1) Efficiency | 95.57% | 96.04% |
| R2 (TFLOPS) | 0.997 | 0.993 |
fp16 RCR, gfx950
Trained on 25 shapes, 1,024 kernels, 21,920 valid benchmarks.
| Metric | Value |
|---|---|
| Mean TFLOPS Efficiency | 99.36% |
| P10 TFLOPS Efficiency | 98.05% |
| P50 TFLOPS Efficiency | 100.00% |
| Min Efficiency | 95.45% |
| NDCG@1 | 64.00% |
| Top-5 Hit Rate | 88.00% |
Shape Family Breakdown:
| Shape Family | Mean Eff | P10 Eff | Shapes |
|---|---|---|---|
| Large M (M≥1024) | 99.54% | 99.07% | 4 |
| Medium M (128≤M<1024) | 99.62% | 98.74% | 7 |
| Small M (8≤M<128) | 98.82% | 96.22% | 8 |
| Tiny M (M<8) | 99.65% | 98.96% | 6 |
Pipeline Breakdown:
| Pipeline | Mean Eff | P10 Eff |
|---|---|---|
| compv3 | 99.75% | 99.09% |
| compv4 | 99.40% | 98.54% |
| mem | 99.08% | 96.59% |
Training uses log1p(TFLOPS) as the target by default, which normalizes the
scale across shapes spanning 0.02 to 2230 TFLOPS. This was the key finding
that improved tiny-M shapes from 84% to 96% efficiency. See
LEARNINGS.md for details.
Validation
Training uses GroupKFold(n_splits=5) with group key (M, N, K) to ensure
the model is evaluated on shapes it has never seen during training. Layout is
excluded from the group key to force cross-layout generalization.
Incremental Training (Warm Start)
When new benchmark data arrives, update the model without retraining from scratch:
python3 train.py \
--data_dir data/ \
--out_dir models/v2 \
--warm_start models/gemm_universal_fp8_gfx950 \
--warm_start_n_estimators 200
This adds 200 new trees on top of the existing model. Feature schemas must match exactly (automatically enforced).
Extending to New Ops
Adding support for a new operation (e.g., gemm_streamk, grouped_conv):
- Build binaries:
ninja -C build benchmark_gemm_streamk_fp8_rcr - Subclass
FeatureEngine: add op-specific features (e.g., StreamK split factor) - Generate data: run benchmarks across diverse shapes
- Train:
python3 train.py --op gemm_streamk --dtype fp8 --data_dir data/ --out_dir models/
The training, evaluation, prediction, and search infrastructure is fully op-agnostic -- only the feature engine needs a new subclass.
Tests
102 tests covering all modules:
python3 -m pytest tests/ -v
Test coverage includes:
- Log parsing with malformed JSON, empty logs, single-kernel shapes
- Feature formula correctness (tile efficiency, LDS usage, arithmetic intensity)
- Corner-case shapes: M=1, N=1, K=1, prime dimensions, 20480x7168x256
- Batch vs single extraction parity
- Parameter space validation and projection
- Predictor: single/batch prediction, ranking, missing models, empty inputs
- Training: group keys, efficiency computation, warm-start, feature compatibility
- Search: random, DE, config validity, determinism
Documentation
- README.md: This file -- quick start, architecture, performance
- DATA_GENERATION.md: Complete guide for building tile engine binaries, running benchmarks, managing datasets, and troubleshooting
- LEARNINGS.md: Empirical findings and design decisions (log-transform, IHEM results, tiny-M analysis, feature importance, N=1/K=1 edge cases)
Grouped Convolution ML Heuristics
Overview
ML-based kernel selection for grouped convolution operations (forward, bwd_data, bwd_weight) on gfx950 with bf16 precision.
Results
Forward Pass Model
- Training Data: 48,845 measurements across 1,372 unique problem shapes
- Validation Set: 300 unseen problems from model crawler
- Validation Performance (vs. oracle):
- Mean Efficiency: 93.05%
- Median Efficiency: 96.8%
- P10 Efficiency: 79.9%
Backward Data Gradient (bwd_data) Model
- Training Data: 18,773 measurements across 891 unique problem shapes
- Validation Set: 300 unseen problems from model crawler
- Validation Performance (vs. oracle):
- Mean Efficiency: 93.8%
- Median Efficiency: 96.5%
- P10 Efficiency: 82.9%
- Top-1 Accuracy: 25.2% (37/147 problems)
Backward Weight Gradient (bwd_weight) Model
- Training Data: 34,900 measurements across 1,508 unique problem shapes
- Validation Set: 300 unseen problems from model crawler
- Validation Performance (vs. oracle):
- Mean Efficiency: 96.1%
- Median Efficiency: 99.2%
- P10 Efficiency: 89.4%
- Top-1 Accuracy: 32.7% (51/156 problems)
Training Data Generation
Extended synthetic problem sets for backward passes cover diverse scenarios:
- Small spatial (7×7, 14×14) + various channels (64-1024)
- Medium spatial (28×28, 32×32, 56×56) + various channels (32-512)
- Large spatial (112×112) + small/medium channels (16-256)
- Asymmetric C/K combinations
- Small and large batch sizes (N=1 to 128)
- Grouped convolutions (G=2, 4, 8)
- Depthwise convolutions (G=C=K)
- Stride-2 downsampling
Model Files
Trained models stored in:
models/grouped_conv_forward_bf16_gfx950/models/grouped_conv_bwd_data_bf16_gfx950/models/grouped_conv_bwd_weight_bf16_gfx950/
Each contains:
model_tflops.lgbm- LightGBM model (compressed with gzip)feature_spec.json- Feature configurationcv_metrics_tflops.json- Cross-validation metricsfeature_importances_tflops.json- Feature importance rankings
Models are automatically decompressed on first use.
Usage
import pandas as pd
from predict import Predictor
from feature_engine_grouped_conv import GroupedConvFeatureEngine
# Define problem
problem = {
'N': 16, 'C': 256, 'K': 128, 'G': 1,
'Hi': 28, 'Wi': 28, 'Y': 3, 'X': 3,
'stride_h': 1, 'stride_w': 1,
'pad_h': 1, 'pad_w': 1,
'dtype': 'bf16'
}
# Load model with the grouped-conv feature engine
predictor = Predictor(
"models/grouped_conv_bwd_data_bf16_gfx950",
feature_engine=GroupedConvFeatureEngine(),
)
# Build the candidate kernel pool from a training/holdout parquet
# (each row carries kernel_name + every kernel-config column the engine needs).
df = pd.read_parquet("data/grouped_conv_bwd_data/bwd_data.parquet")
configs = [df[df["kernel_name"] == kn].iloc[0].to_dict()
for kn in df["kernel_name"].unique()]
# Rank candidates by predicted TFLOPS
ranked = predictor.rank_kernels(problem, configs)
best_name, best_tflops = ranked[0]
print(f"Best kernel: {best_name}")
print(f"Predicted TFLOPS: {best_tflops:.2f}")
Validation
Run validation against oracle benchmarks:
cd projects/composablekernel/tile_engine/ops/grouped_conv
python3 validate_ml_vs_oracle.py --variant bwd_data
python3 validate_ml_vs_oracle.py --variant bwd_weight
Solution Architecture (Grouped Conv)
Problem Config → Feature Engineering (83 features) → LightGBM Model → Predict TFLOPS → Select Best Kernel
↓ - Problem features (38) ↓ ↓
(N,C,K,G,H,W,Y,X) - Kernel features (12) Trained on <1ms total
- Interactions (21) 48K samples latency
- Hardware (12) 1372 shapes
Feature Engineering (feature_engine_grouped_conv.py)
83 engineered features:
- Problem Features (38): Raw params (N,C,K,G,Hi,Wi,Y,X,strides,pads), derived (Ho,Wo), log-scale transforms, arithmetic intensity, aspect ratios, channel/group metrics
- Kernel Features (12): Block size, GEMM tiles (M,N), pipeline type, num warps, tile volume, LDS usage
- Interaction Features (21): Tile efficiency (M,N,K), block-tile ratios, CU utilization, problem-tile comparisons, output tile counts
- Hardware Features (12): GFX950 specs - CUs (304), SIMDs, clocks, wavefront size, cache sizes (L1/L2/L3), XCD count
Latency
- Selection Time: <1ms
- vs Oracle: 30-60 seconds
- Speedup: 30,000-60,000×
Model Size
- Compressed: 2-8 MB (.lgbm.gz)
- Runtime Memory: ~50 MB
- Feature Array: <6 KB per problem
Training Pipeline
# 1. Collect data: Run all kernels on GPU for diverse problem set
python grouped_conv_full_benchmark.py --problem_set forward_training_miopen
# 2. Preprocess: Convert CSV to Parquet
python convert_csv_to_parquet.py --input train.csv --output train.parquet
# 3. Train model: LightGBM with cross-validation
python train.py --operation grouped_conv --direction forward --dtype bf16
# 4. Validate: Sanity-check on training shapes
python validation/grouped_conv/validate_training_shapes.py
Validation Framework
| Test | Purpose | Shapes | Runtime | Target |
|---|---|---|---|---|
validate_training_shapes.py |
Sanity check on training data | 5 | 5-10 min | >95% efficiency |
validate_backward_models.py |
Backward pass prediction quality | 7 | <1 min | Reasonable predictions |
File Structure (Grouped Conv)
dispatcher/heuristics/
├── train.py # Training script
├── feature_engine_grouped_conv.py # Feature engineering
├── predict.py # Generic Predictor (use with GroupedConvFeatureEngine)
├── models/
│ ├── grouped_conv_forward_bf16_gfx950/
│ │ ├── model_tflops.lgbm.gz # Compressed model
│ │ ├── feature_spec.json # Feature definitions
│ │ └── train_manifest.json # Training metadata
│ ├── grouped_conv_bwd_data_bf16_gfx950/
│ └── grouped_conv_bwd_weight_bf16_gfx950/
└── validation/
├── validate_ml_heuristic.py # GEMM validation
└── grouped_conv/
├── validate_training_shapes.py
└── validate_backward_models.py
tile_engine/ops/grouped_conv/
├── grouped_conv_full_benchmark.py # Data collection
├── run_one_grouped_conv_kernel.py # Single kernel runner
├── compare_ml_vs_oracle.py # Analysis tool
└── problems/
├── forward_training_miopen.py # Training problem sets
└── forward_validation_300.py # Test problem sets
C++/Python Integration
- C++ API:
GroupedConvRegistry::get_solution(problem) - Python API:
registry.run(problem, input, weight) - Automatic fallback to exhaustive search if ML unavailable
from ck_tile.dispatcher import GroupedConvRegistry, GroupedConvProblem
# Define problem
problem = GroupedConvProblem(
N=2, C=128, K=256, G=1,
Hi=28, Wi=28, Y=3, X=3,
stride_h=1, stride_w=1, pad_h=1, pad_w=1,
dtype='bf16', direction='forward'
)
# ML heuristic automatically selects best kernel
registry = GroupedConvRegistry(arch='gfx950')
result = registry.run(problem, input_tensor, weight_tensor)
Key Innovations
- Comprehensive Feature Engineering: 83 features capture problem-kernel-hardware interactions
- Tier-1 Extended Training: 1,372 shapes (vs 185 baseline) for better edge case coverage
- Compressed Models: LGBM.gz reduces size 8-10× without accuracy loss
- Operation-Specific Models: Separate optimizations for forward/backward passes
- Validation Framework: Automated testing on unseen production workloads
Verifying Training Quality
To quickly verify that a refactored train.py produces models with equivalent quality to the production training script:
cd /workspace/rocm-libraries/projects/composablekernel/dispatcher/heuristics
# Run automated test (uses 3-fold CV for speed)
./test_model_quality.sh
This script will:
- Validate current production model on 300 validation shapes
- Train a new model using refactored
train.py - Validate the new model on the same 300 shapes
- Compare predictions between old and new models
Expected Output:
Step 4: Comparing predictions...
================================================================================
PREDICTION COMPARISON: bwd_data
================================================================================
Kernel Selection Agreement: 215/300 (71.7%)
Metric Old Model New Model Delta
----------------------------------------------------------------------
Mean Efficiency 0.9380 0.9380 +0.0000
Median Efficiency 0.9650 0.9650 +0.0000
P10 Efficiency 0.8290 0.8290 +0.0000
Per-Problem Changes:
Improved: 0 (0.0%)
Same: 300 (100.0%)
Degraded: 0 (0.0%)
================================================================================
✓ PASS: New model maintains quality!
================================================================================
Model Selection Process
The validation script (validate_ml_vs_oracle.py) automatically selects the model based on:
Variant: --variant {forward|bwd_data|bwd_weight}
Model Path: dispatcher/heuristics/models/grouped_conv_{variant}_bf16_gfx950/
For example:
--variant bwd_data→ usesmodels/grouped_conv_bwd_data_bf16_gfx950/model_tflops.lgbm--variant bwd_weight→ usesmodels/grouped_conv_bwd_weight_bf16_gfx950/model_tflops.lgbm
Manual Step-by-Step Comparison
If you want to run each step manually:
Step 1: Validate Current Model
cd tile_engine/ops/grouped_conv
python3 validate_ml_vs_oracle.py \
--operation grouped_conv \
--variant bwd_data \
--problem-set bwd_data_model_crawler_validation \
--oracle-csv bwd_data_model_crawler_oracle.csv \
--save-predictions /tmp/bwd_data_old_predictions.csv
This uses the model at: dispatcher/heuristics/models/grouped_conv_bwd_data_bf16_gfx950/
Step 2: Train New Model
cd ../../dispatcher/heuristics
python3 train.py \
--operation grouped_conv \
--data_dir data/bwd_data_training \
--out_dir /tmp/grouped_conv_bwd_data_bf16_gfx950_new \
--dtype bf16 \
--arch gfx950 \
--targets tflops \
--n_splits 5
Step 3: Temporarily Swap Models
# Backup current model
mv models/grouped_conv_bwd_data_bf16_gfx950 /tmp/backup
# Use new model for validation
cp -r /tmp/grouped_conv_bwd_data_bf16_gfx950_new models/grouped_conv_bwd_data_bf16_gfx950
Step 4: Validate New Model
cd ../../tile_engine/ops/grouped_conv
python3 validate_ml_vs_oracle.py \
--operation grouped_conv \
--variant bwd_data \
--problem-set bwd_data_model_crawler_validation \
--oracle-csv bwd_data_model_crawler_oracle.csv \
--save-predictions /tmp/bwd_data_new_predictions.csv
Step 5: Restore Original Model
cd ../../dispatcher/heuristics
rm -rf models/grouped_conv_bwd_data_bf16_gfx950
mv /tmp/backup models/grouped_conv_bwd_data_bf16_gfx950
Step 6: Compare Predictions
cd ../../tile_engine/ops/grouped_conv
python3 compare_model_predictions.py \
--old-predictions /tmp/bwd_data_old_predictions.csv \
--new-predictions /tmp/bwd_data_new_predictions.csv \
--variant bwd_data
Acceptance Criteria
A new model passes quality validation if:
- ✓ Mean efficiency is within 0.5% of baseline
- ✓ Median efficiency is within 0.5% of baseline
- ✓ P10 efficiency is within 2% of baseline
- ✓ No catastrophic regressions (efficiency drops >10% on any problem)
Troubleshooting
Different Predictions on Same Model
Unlikely - If the same model file produces different predictions, check:
- Feature engine version (should be 83 features)
- Problem encoding (verify problem_to_dict matches)
- Predictor initialization (check log transform handling)
Quality Regression
If new model has lower efficiency:
- Check CV metrics in training log - should be similar to baseline
- Verify identical training data (check parquet row counts)
- Compare feature importance - should be similar patterns
- Inspect specific regression cases in comparison output