Files
composable_kernel/dispatcher/heuristics/LEARNINGS.md
Yaswanth Raparti 91dbdfa476 [CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676)
## Motivation

This PR adds ML-based kernel selection heuristics to the CK Tile
dispatcher, enabling fast and accurate automatic kernel selection for
Universal Gemm kernels. Instead of requiring exhaustive search through
4600+ kernel configurations (taking ~46 seconds per problem shape), the
ML heuristic predicts optimal kernels in microseconds while achieving
>98% of oracle-best performance.

## Technical Details

**ML infrastructure** 

https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics
* Feature Engine
([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)):
55-feature extraction including problem dimensions, kernel
configuration, tile efficiency, and hardware profile
* Training Pipeline
([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)):
LightGBM regression with log-transform, GroupKFold cross-validation,
warm-start support
* Predictor
([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)):
Kernel ranking and TFLOPS prediction for problem shapes
* Evaluation
([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)):
Comprehensive metrics including efficiency, NDCG@k, shape family
analysis

**Data Generation Tools:**

*
[generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py):
Build and benchmark kernels across diverse problem shapes
*
[convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py):
Convert benchmark JSON to training-ready parquet format
*
[data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py):
Parse streaming benchmark logs into canonical datasets

**Examples**
*
[09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp):
C++ example demonstrating ML-based kernel selection
*
[09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py):
Python example with validation


**Pre-trained Models
(projects/composablekernel/dispatcher/heuristics/models/):**
* gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean
efficiency)
* gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean
efficiency)


## Test Plan

* Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8
* All shape families tested: tiny M (M<8), small M, medium M, large M
(M≥1024)
* All pipeline types: compv3, compv4, mem

## Test Result

**fp16 Model (gfx950, RCR layout)**
* Mean Efficiency: 99.36%
* P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of
oracle best)
* Min Efficiency: 95.45%

**fp8 Model (gfx950, RCR layout)**
* Mean Efficiency: 98.28% (original), 97.51% (wide coverage)
* P10 Efficiency: 94.64% (original), 93.89% (wide coverage)
* Min Efficiency: 84.5%

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Vidyasagar Ananthan <vidyasagar.ananthan@amd.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-01 19:25:55 -07:00

152 lines
9.2 KiB
Markdown
Raw Permalink Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Learnings and Design Decisions
Empirical findings from building the CK Tile kernel performance prediction system.
These inform the current defaults and explain why certain approaches were chosen.
## 1. Log-Transform is Essential for Cross-Scale Accuracy
**Problem**: GEMM TFLOPS spans 5 orders of magnitude across different problem
sizes. When training on raw TFLOPS, the regression loss (RMSE) is dominated by
large shapes where absolute errors are biggest. The model learns to predict
large shapes accurately but ignores tiny shapes where the TFLOPS values are
much lower.
**Evidence** (168 shapes, 626K rows, 5-fold GroupKFold CV):
| Model | Mean Eff | P10 Eff | tiny_m Eff | Min Eff |
| ----------------------------- | ---------- | ---------- | ---------- | ---------- |
| Raw TFLOPS (500 trees) | 92.73% | 80.24% | 84.55% | 4.26% |
| **log1p(TFLOPS)** (500 trees) | **96.92%** | **94.34%** | **94.89%** | **60.27%** |
| log1p(TFLOPS) (2000 trees) | 97.51% | 93.89% | 96.04% | 63.56% |
**Solution**: Train on `log1p(measured_tflops)` and apply `expm1()` to
predictions. This is now the default in `train.py`. Pass `--no_log_transform`
to revert to raw regression (not recommended).
**Why log1p, not log**: `log1p(x) = log(1 + x)` handles zero and near-zero
TFLOPS gracefully, whereas `log(x)` produces -inf for x=0.
## 2. Tiny-M Shapes are the Hardest Case
M=1 (single-token inference) shapes are fundamentally different from batch shapes:
- Most kernel configurations produce very low TFLOPS
- The "best" kernel is often only marginally better than the rest
- The oracle performance itself is very low, so any prediction error tanks efficiency
- Many kernels fail outright (tile_m=128 with M=1 wastes 127/128 of the tile)
The bottom shapes in our evaluation are all M=1, with efficiencies in the
63-70% range. These shapes have such low absolute performance that the model's
noise floor exceeds the performance difference between kernels.
**Mitigation**: Log-transform helps significantly (tiny_m improved from 84% to
96%). For production use with M=1, consider a dedicated fallback (e.g.,
hardcoded kernel selection for M < 4 based on known-good configs).
## 3. IHEM (Hard Example Mining) Hurts When Scale is the Issue
We tried Iterative Hard Example Mining with sample reweighting (2x-5x weight
on hard shapes). Result: it made things **worse**, degrading mean efficiency
from 94.31% to 92.90% over 3 iterations.
**Why**: The hard shapes are hard because of scale mismatch, not because the
model lacks capacity. Reweighting amplifies the small-TFLOPS rows, which
distorts the learned relationship between features and performance for the
majority of shapes. The log-transform was the correct fix -- it addresses the
root cause (scale) rather than the symptom (bad predictions on tiny shapes).
**Lesson**: IHEM is useful when the model has capacity gaps (e.g., certain
pipeline types are underrepresented). It is counterproductive when the issue
is target-variable scale. Always try target transforms before reweighting.
## 4. GroupKFold Key = (M, N, K) Forces Generalization
The validation uses `GroupKFold` where the group key is `(M, N, K)` -- all
kernels for the same shape go to the same fold. This means:
- The model is always evaluated on shapes it has **never seen** during training
- Layout is excluded from the key, forcing the model to generalize across layouts
- Since models are per-arch, `arch` is implicit (constant within one training run)
This is much stricter than random row splitting, where the model would see some
kernels for each shape during training. Our efficiency numbers are conservative
estimates of real-world performance on unseen shapes.
## 5. Model Size vs Accuracy Tradeoff
| Config | Trees | Leaves | LR | Mean Eff | P10 Eff | Train Time |
| ------------------ | -------- | ------- | -------- | ---------- | ---------- | ------------- |
| Small (default v1) | 500 | 127 | 0.05 | 96.92% | 94.34% | ~20s |
| **Big (current)** | **2000** | **255** | **0.02** | **97.51%** | **93.89%** | **~25s/fold** |
The bigger model improved mean efficiency by 0.6% but P10 didn't improve
(actually slightly worse). The extra capacity helps on medium shapes but
doesn't crack the tiny-M floor. This suggests the feature set, not model
capacity, is the limiting factor for the hardest shapes.
For C++ deployment, the bigger model (2000 trees, 255 leaves) is still fast
enough -- LightGBM inference is O(trees * log(leaves)) per sample, which is
~microseconds even at 2000 trees.
## 6. N=1 and K=1 Shapes are Degenerate
We generated benchmark data for 546 edge-case shapes (N=1, K=1, small N/K).
Result: **zero valid kernel results** across 94 shapes. All 4608 kernels either
fail or produce 0 TFLOPS for these degenerate dimensions.
This means:
- The tile engine kernels have hard minimum dimension requirements
- N=1 / K=1 shapes cannot be handled by the current kernel set
- These shapes need dedicated kernels (e.g., BLAS-1/BLAS-2 fallbacks)
- The ML model should not be expected to handle them -- they should be filtered
out before reaching the heuristic
## 7. Feature Engineering Insights
From LightGBM feature importances on the log-target model:
**Top features** (by split count):
- `M, N, K` -- raw dimensions are always the most important
- `tile_m, tile_n, tile_k` -- the tile shape is the primary kernel differentiator
- `overall_tile_efficiency` -- how well the shape fits the tile (the interaction)
- `num_tiles_m, total_output_tiles` -- work decomposition
- `arithmetic_intensity` -- compute vs memory bound regime
- `pipeline` -- pipeline type (compv3 vs compv4 vs mem) significantly affects perf
**Low-importance features**:
- Hardware constants (CUs, clock, caches) -- they're constant within one arch
model, so they provide no discriminative signal. They'll become important when
training cross-arch models.
- `split_k` -- always 1 in current data
- `persistent` -- rarely True in current kernel set
## 8. Warm-Start Works for Incremental Updates
LightGBM's `init_model` parameter successfully continues training from an
existing model. New trees are added on top of existing ones. Key considerations:
- Feature schema must match exactly (enforced by `check_feature_compatibility`)
- Use fewer new trees (200-500) since we're refining, not starting fresh
- The `train_manifest.json` tracks the full lineage (total trees, data sizes)
- Quality should be at least as good as the base model (tested)
## 9. Data Volume Matters More Than Model Complexity
| Dataset | Shapes | Rows | Mean Eff (log, 500 trees) |
| --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------ | ---- | ----------------------------- |
| Original (DeepSeek only) | 108 | 418K | 98.28% (on seen distribution) |
| + Wide coverage M=1 distribution. Adding 60 diverse shapes (many M=1) exposed the model's weakness on tiny shapes. More diverse training data is always better than a bigger model on narrow data.Summary of DefaultsBased on these findings, the current defaults in `train.py` are:- **Target transform**: `log1p` for TFLOPS and bandwidth (scale normalization)- **Model**: 2000 trees, 255 leaves, max depth 15, LR 0.02- **Validation**: 5-fold GroupKFold, key = (M, N, K)- **Early stopping**: patience 100 (let trees fully converge)- **Warm start**: 500 new trees (was 200, increased for bigger base model) | 168 | 626K | 96.92% (harder distribution) |
The original 108-shape model looked great (98.28%) but was overfitting to the
DeepSeek LLM inference