mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Vidyasagar Ananthan <vidyasagar.ananthan@amd.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
152 lines
9.2 KiB
Markdown
152 lines
9.2 KiB
Markdown
# Learnings and Design Decisions
|
||
|
||
Empirical findings from building the CK Tile kernel performance prediction system.
|
||
These inform the current defaults and explain why certain approaches were chosen.
|
||
|
||
## 1. Log-Transform is Essential for Cross-Scale Accuracy
|
||
|
||
**Problem**: GEMM TFLOPS spans 5 orders of magnitude across different problem
|
||
sizes. When training on raw TFLOPS, the regression loss (RMSE) is dominated by
|
||
large shapes where absolute errors are biggest. The model learns to predict
|
||
large shapes accurately but ignores tiny shapes where the TFLOPS values are
|
||
much lower.
|
||
|
||
**Evidence** (168 shapes, 626K rows, 5-fold GroupKFold CV):
|
||
|
||
|
||
| Model | Mean Eff | P10 Eff | tiny_m Eff | Min Eff |
|
||
| ----------------------------- | ---------- | ---------- | ---------- | ---------- |
|
||
| Raw TFLOPS (500 trees) | 92.73% | 80.24% | 84.55% | 4.26% |
|
||
| **log1p(TFLOPS)** (500 trees) | **96.92%** | **94.34%** | **94.89%** | **60.27%** |
|
||
| log1p(TFLOPS) (2000 trees) | 97.51% | 93.89% | 96.04% | 63.56% |
|
||
|
||
|
||
**Solution**: Train on `log1p(measured_tflops)` and apply `expm1()` to
|
||
predictions. This is now the default in `train.py`. Pass `--no_log_transform`
|
||
to revert to raw regression (not recommended).
|
||
|
||
**Why log1p, not log**: `log1p(x) = log(1 + x)` handles zero and near-zero
|
||
TFLOPS gracefully, whereas `log(x)` produces -inf for x=0.
|
||
|
||
## 2. Tiny-M Shapes are the Hardest Case
|
||
|
||
M=1 (single-token inference) shapes are fundamentally different from batch shapes:
|
||
|
||
- Most kernel configurations produce very low TFLOPS
|
||
- The "best" kernel is often only marginally better than the rest
|
||
- The oracle performance itself is very low, so any prediction error tanks efficiency
|
||
- Many kernels fail outright (tile_m=128 with M=1 wastes 127/128 of the tile)
|
||
|
||
The bottom shapes in our evaluation are all M=1, with efficiencies in the
|
||
63-70% range. These shapes have such low absolute performance that the model's
|
||
noise floor exceeds the performance difference between kernels.
|
||
|
||
**Mitigation**: Log-transform helps significantly (tiny_m improved from 84% to
|
||
96%). For production use with M=1, consider a dedicated fallback (e.g.,
|
||
hardcoded kernel selection for M < 4 based on known-good configs).
|
||
|
||
## 3. IHEM (Hard Example Mining) Hurts When Scale is the Issue
|
||
|
||
We tried Iterative Hard Example Mining with sample reweighting (2x-5x weight
|
||
on hard shapes). Result: it made things **worse**, degrading mean efficiency
|
||
from 94.31% to 92.90% over 3 iterations.
|
||
|
||
**Why**: The hard shapes are hard because of scale mismatch, not because the
|
||
model lacks capacity. Reweighting amplifies the small-TFLOPS rows, which
|
||
distorts the learned relationship between features and performance for the
|
||
majority of shapes. The log-transform was the correct fix -- it addresses the
|
||
root cause (scale) rather than the symptom (bad predictions on tiny shapes).
|
||
|
||
**Lesson**: IHEM is useful when the model has capacity gaps (e.g., certain
|
||
pipeline types are underrepresented). It is counterproductive when the issue
|
||
is target-variable scale. Always try target transforms before reweighting.
|
||
|
||
## 4. GroupKFold Key = (M, N, K) Forces Generalization
|
||
|
||
The validation uses `GroupKFold` where the group key is `(M, N, K)` -- all
|
||
kernels for the same shape go to the same fold. This means:
|
||
|
||
- The model is always evaluated on shapes it has **never seen** during training
|
||
- Layout is excluded from the key, forcing the model to generalize across layouts
|
||
- Since models are per-arch, `arch` is implicit (constant within one training run)
|
||
|
||
This is much stricter than random row splitting, where the model would see some
|
||
kernels for each shape during training. Our efficiency numbers are conservative
|
||
estimates of real-world performance on unseen shapes.
|
||
|
||
## 5. Model Size vs Accuracy Tradeoff
|
||
|
||
|
||
| Config | Trees | Leaves | LR | Mean Eff | P10 Eff | Train Time |
|
||
| ------------------ | -------- | ------- | -------- | ---------- | ---------- | ------------- |
|
||
| Small (default v1) | 500 | 127 | 0.05 | 96.92% | 94.34% | ~20s |
|
||
| **Big (current)** | **2000** | **255** | **0.02** | **97.51%** | **93.89%** | **~25s/fold** |
|
||
|
||
|
||
The bigger model improved mean efficiency by 0.6% but P10 didn't improve
|
||
(actually slightly worse). The extra capacity helps on medium shapes but
|
||
doesn't crack the tiny-M floor. This suggests the feature set, not model
|
||
capacity, is the limiting factor for the hardest shapes.
|
||
|
||
For C++ deployment, the bigger model (2000 trees, 255 leaves) is still fast
|
||
enough -- LightGBM inference is O(trees * log(leaves)) per sample, which is
|
||
~microseconds even at 2000 trees.
|
||
|
||
## 6. N=1 and K=1 Shapes are Degenerate
|
||
|
||
We generated benchmark data for 546 edge-case shapes (N=1, K=1, small N/K).
|
||
Result: **zero valid kernel results** across 94 shapes. All 4608 kernels either
|
||
fail or produce 0 TFLOPS for these degenerate dimensions.
|
||
|
||
This means:
|
||
|
||
- The tile engine kernels have hard minimum dimension requirements
|
||
- N=1 / K=1 shapes cannot be handled by the current kernel set
|
||
- These shapes need dedicated kernels (e.g., BLAS-1/BLAS-2 fallbacks)
|
||
- The ML model should not be expected to handle them -- they should be filtered
|
||
out before reaching the heuristic
|
||
|
||
## 7. Feature Engineering Insights
|
||
|
||
From LightGBM feature importances on the log-target model:
|
||
|
||
**Top features** (by split count):
|
||
|
||
- `M, N, K` -- raw dimensions are always the most important
|
||
- `tile_m, tile_n, tile_k` -- the tile shape is the primary kernel differentiator
|
||
- `overall_tile_efficiency` -- how well the shape fits the tile (the interaction)
|
||
- `num_tiles_m, total_output_tiles` -- work decomposition
|
||
- `arithmetic_intensity` -- compute vs memory bound regime
|
||
- `pipeline` -- pipeline type (compv3 vs compv4 vs mem) significantly affects perf
|
||
|
||
**Low-importance features**:
|
||
|
||
- Hardware constants (CUs, clock, caches) -- they're constant within one arch
|
||
model, so they provide no discriminative signal. They'll become important when
|
||
training cross-arch models.
|
||
- `split_k` -- always 1 in current data
|
||
- `persistent` -- rarely True in current kernel set
|
||
|
||
## 8. Warm-Start Works for Incremental Updates
|
||
|
||
LightGBM's `init_model` parameter successfully continues training from an
|
||
existing model. New trees are added on top of existing ones. Key considerations:
|
||
|
||
- Feature schema must match exactly (enforced by `check_feature_compatibility`)
|
||
- Use fewer new trees (200-500) since we're refining, not starting fresh
|
||
- The `train_manifest.json` tracks the full lineage (total trees, data sizes)
|
||
- Quality should be at least as good as the base model (tested)
|
||
|
||
## 9. Data Volume Matters More Than Model Complexity
|
||
|
||
|
||
| Dataset | Shapes | Rows | Mean Eff (log, 500 trees) |
|
||
| --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------ | ---- | ----------------------------- |
|
||
| Original (DeepSeek only) | 108 | 418K | 98.28% (on seen distribution) |
|
||
| + Wide coverage M=1 distribution. Adding 60 diverse shapes (many M=1) exposed the model's weakness on tiny shapes. More diverse training data is always better than a bigger model on narrow data.Summary of DefaultsBased on these findings, the current defaults in `train.py` are:- **Target transform**: `log1p` for TFLOPS and bandwidth (scale normalization)- **Model**: 2000 trees, 255 leaves, max depth 15, LR 0.02- **Validation**: 5-fold GroupKFold, key = (M, N, K)- **Early stopping**: patience 100 (let trees fully converge)- **Warm start**: 500 new trees (was 200, increased for bigger base model) | 168 | 626K | 96.92% (harder distribution) |
|
||
|
||
|
||
The original 108-shape model looked great (98.28%) but was overfitting to the
|
||
DeepSeek LLM inference
|
||
|