[CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
9.2 KiB
Learnings and Design Decisions
Empirical findings from building the CK Tile kernel performance prediction system. These inform the current defaults and explain why certain approaches were chosen.
1. Log-Transform is Essential for Cross-Scale Accuracy
Problem: GEMM TFLOPS spans 5 orders of magnitude across different problem sizes. When training on raw TFLOPS, the regression loss (RMSE) is dominated by large shapes where absolute errors are biggest. The model learns to predict large shapes accurately but ignores tiny shapes where the TFLOPS values are much lower.
Evidence (168 shapes, 626K rows, 5-fold GroupKFold CV):
| Model | Mean Eff | P10 Eff | tiny_m Eff | Min Eff |
|---|---|---|---|---|
| Raw TFLOPS (500 trees) | 92.73% | 80.24% | 84.55% | 4.26% |
| log1p(TFLOPS) (500 trees) | 96.92% | 94.34% | 94.89% | 60.27% |
| log1p(TFLOPS) (2000 trees) | 97.51% | 93.89% | 96.04% | 63.56% |
Solution: Train on log1p(measured_tflops) and apply expm1() to
predictions. This is now the default in train.py. Pass --no_log_transform
to revert to raw regression (not recommended).
Why log1p, not log: log1p(x) = log(1 + x) handles zero and near-zero
TFLOPS gracefully, whereas log(x) produces -inf for x=0.
2. Tiny-M Shapes are the Hardest Case
M=1 (single-token inference) shapes are fundamentally different from batch shapes:
- Most kernel configurations produce very low TFLOPS
- The "best" kernel is often only marginally better than the rest
- The oracle performance itself is very low, so any prediction error tanks efficiency
- Many kernels fail outright (tile_m=128 with M=1 wastes 127/128 of the tile)
The bottom shapes in our evaluation are all M=1, with efficiencies in the 63-70% range. These shapes have such low absolute performance that the model's noise floor exceeds the performance difference between kernels.
Mitigation: Log-transform helps significantly (tiny_m improved from 84% to 96%). For production use with M=1, consider a dedicated fallback (e.g., hardcoded kernel selection for M < 4 based on known-good configs).
3. IHEM (Hard Example Mining) Hurts When Scale is the Issue
We tried Iterative Hard Example Mining with sample reweighting (2x-5x weight on hard shapes). Result: it made things worse, degrading mean efficiency from 94.31% to 92.90% over 3 iterations.
Why: The hard shapes are hard because of scale mismatch, not because the model lacks capacity. Reweighting amplifies the small-TFLOPS rows, which distorts the learned relationship between features and performance for the majority of shapes. The log-transform was the correct fix -- it addresses the root cause (scale) rather than the symptom (bad predictions on tiny shapes).
Lesson: IHEM is useful when the model has capacity gaps (e.g., certain pipeline types are underrepresented). It is counterproductive when the issue is target-variable scale. Always try target transforms before reweighting.
4. GroupKFold Key = (M, N, K) Forces Generalization
The validation uses GroupKFold where the group key is (M, N, K) -- all
kernels for the same shape go to the same fold. This means:
- The model is always evaluated on shapes it has never seen during training
- Layout is excluded from the key, forcing the model to generalize across layouts
- Since models are per-arch,
archis implicit (constant within one training run)
This is much stricter than random row splitting, where the model would see some kernels for each shape during training. Our efficiency numbers are conservative estimates of real-world performance on unseen shapes.
5. Model Size vs Accuracy Tradeoff
| Config | Trees | Leaves | LR | Mean Eff | P10 Eff | Train Time |
|---|---|---|---|---|---|---|
| Small (default v1) | 500 | 127 | 0.05 | 96.92% | 94.34% | ~20s |
| Big (current) | 2000 | 255 | 0.02 | 97.51% | 93.89% | ~25s/fold |
The bigger model improved mean efficiency by 0.6% but P10 didn't improve (actually slightly worse). The extra capacity helps on medium shapes but doesn't crack the tiny-M floor. This suggests the feature set, not model capacity, is the limiting factor for the hardest shapes.
For C++ deployment, the bigger model (2000 trees, 255 leaves) is still fast enough -- LightGBM inference is O(trees * log(leaves)) per sample, which is ~microseconds even at 2000 trees.
6. N=1 and K=1 Shapes are Degenerate
We generated benchmark data for 546 edge-case shapes (N=1, K=1, small N/K). Result: zero valid kernel results across 94 shapes. All 4608 kernels either fail or produce 0 TFLOPS for these degenerate dimensions.
This means:
- The tile engine kernels have hard minimum dimension requirements
- N=1 / K=1 shapes cannot be handled by the current kernel set
- These shapes need dedicated kernels (e.g., BLAS-1/BLAS-2 fallbacks)
- The ML model should not be expected to handle them -- they should be filtered out before reaching the heuristic
7. Feature Engineering Insights
From LightGBM feature importances on the log-target model:
Top features (by split count):
M, N, K-- raw dimensions are always the most importanttile_m, tile_n, tile_k-- the tile shape is the primary kernel differentiatoroverall_tile_efficiency-- how well the shape fits the tile (the interaction)num_tiles_m, total_output_tiles-- work decompositionarithmetic_intensity-- compute vs memory bound regimepipeline-- pipeline type (compv3 vs compv4 vs mem) significantly affects perf
Low-importance features:
- Hardware constants (CUs, clock, caches) -- they're constant within one arch model, so they provide no discriminative signal. They'll become important when training cross-arch models.
split_k-- always 1 in current datapersistent-- rarely True in current kernel set
8. Warm-Start Works for Incremental Updates
LightGBM's init_model parameter successfully continues training from an
existing model. New trees are added on top of existing ones. Key considerations:
- Feature schema must match exactly (enforced by
check_feature_compatibility) - Use fewer new trees (200-500) since we're refining, not starting fresh
- The
train_manifest.jsontracks the full lineage (total trees, data sizes) - Quality should be at least as good as the base model (tested)
9. Data Volume Matters More Than Model Complexity
| Dataset | Shapes | Rows | Mean Eff (log, 500 trees) |
|---|---|---|---|
| Original (DeepSeek only) | 108 | 418K | 98.28% (on seen distribution) |
+ Wide coverage M=1 distribution. Adding 60 diverse shapes (many M=1) exposed the model's weakness on tiny shapes. More diverse training data is always better than a bigger model on narrow data.Summary of DefaultsBased on these findings, the current defaults in train.py are:- Target transform: log1p for TFLOPS and bandwidth (scale normalization)- Model: 2000 trees, 255 leaves, max depth 15, LR 0.02- Validation: 5-fold GroupKFold, key = (M, N, K)- Early stopping: patience 100 (let trees fully converge)- Warm start: 500 new trees (was 200, increased for bigger base model) |
168 | 626K | 96.92% (harder distribution) |
The original 108-shape model looked great (98.28%) but was overfitting to the DeepSeek LLM inference