[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
7.6 KiB
Learnings — Grouped-Conv Heuristic (Forward, 2D + 3D)
Empirical findings from building the grouped-convolution kernel performance predictor for gfx950. Specific to the forward path (NHWGC × GKYXC → NHWGK); backward variants share the same architecture but have not been re-trained against the latest feature schema (see §6).
These notes inform the current defaults in feature_engine_grouped_conv.py,
predict.py, and train.py, and explain why certain approaches were chosen.
1. Kernel-Name Aliasing Was the Top-1 Accuracy Ceiling
Problem: Grouped-conv kernel names look like
grouped_conv_forward_bf16_2d_64x64x64_compv3_intrawave_dsb_si, but the
original parser in convert_csv_to_parquet.py matched only up to the
pipeline token and discarded the wave-mode / dsb / si suffix. Every
(tile, pipeline) bucket aliased to a single feature row, even though the
benchmark contained up to 8 distinct kernels per bucket
({intrawave, interwave} × {∅, dsb, si, dsb_si}). With the 2D vs 3D ndim
split, up to 16 physical kernels collapsed into one feature signature.
Evidence (forward 2D+3D holdout, ~80 unique physical problems):
| Model | Features | Mean Eff | Top-1 | Top-5 |
|---|---|---|---|---|
| Pre-suffix (aliased) | 91 | 88.0% | ~5–10% | ~30% |
| Suffix-aware (current) | 97 | 92.5% | 27.9% | 70.6% |
Solution: Three new kernel-side numeric flags (mirroring is_compv*):
is_intrawave, has_dsb, has_si. Plus three pipeline one-hots that were
missing (is_basic, is_compv6, is_mem). Total feature count went from
83 → 91 → 97 in two stages (3D + dilation in the 91-step; suffix-aware
flags in the 97-step). The 30 valid (pipeline, wave_mode, dsb, si)
combinations live in dispatcher/codegen/grouped_config_rules.py::PIPELINE_VARIANTS
as the single source of truth used by both the candidate-pool generator and
the codegen harness.
Why log-target alone wasn't enough: log-transform fixes scale, not
discrimination. With aliased kernels the model literally cannot rank the 8
intra/inter × dsb/si variants of one tile against each other, no matter
what loss you train against. Top-1 accuracy was bounded by 1/8 = 12.5%
even with a perfect regressor on the aliased schema.
2. Combined 2D+3D Beats Per-Dim Models
We trained three forward models in sequence:
| Model | Features | Training data | Status |
|---|---|---|---|
grouped_conv_forward_bf16_gfx950 |
83 | 2D only, no suffix | Legacy. Kept for back-compat. |
grouped_conv_forward_2d3d_bf16_gfx950 |
91 | 2D + 3D, no suffix | Pre-suffix baseline. |
grouped_conv_forward_2d3d_suffix_bf16_gfx950 |
97 | 2D + 3D + suffix | Current best. |
Finding: The combined-2D+3D model does not hurt 2D performance — both
share the same feature engine and the model learns to gate 3D features on
Di > 1. Don't bother training separate 2D-only and 3D-only models unless
you have a strong reason; the combined model wins on holdout.
Critical features for 3D: dilation_d/h/w in the 91/97-feature schemas
are essential for 3D shapes. Without them the model cannot distinguish
between shapes that share (N,C,K,Hi,Wi,Y,X) but differ in dilation, and
its predictions for dilated 3D problems are meaningless. Always include
dilation columns when re-converting CSVs that contain 3D shapes.
3. Model Coexistence via Version-Aware Predictor
After the 83 → 91 → 97 feature progression, all older models would have crashed on load with:
LightGBMError: The number of features in data (97) is not the same as
it was in training data (83/91)
We need to keep the old forward, bwd_data, and bwd_weight models
loadable because we don't have the benchmark data to re-train backward
variants from scratch.
Solution: predict.py::Predictor.__init__ reads
feature_spec.json["feature_names"] and builds an index map into the
engine's emit order, so old models pull only the columns they were trained
on. If the engine matches the spec exactly (e.g. the suffix model with the
current engine, or any GEMM model), the index map is None and the predict
path is a no-op fast path. If a model expects features the engine no longer
supplies (renamed or removed), __init__ raises with a clear error rather
than silently predicting garbage.
Constraint for future engine changes: the current engine must remain a superset of every deployed model's feature set, or you must retrain. Adding new features is safe; renaming or removing one is a breaking change.
4. What Did Not Matter as Much as Expected
- Hyperparameter tuning. Default LightGBM params got within ~1% of any tuned configuration we tried. The suffix-aware feature change was ~10x more impactful than any HP move.
- Number of CV folds.
n_splits=5andn_splits=10gave indistinguishable holdout numbers. use_logfor tflops target on grouped-conv. Marginal (~0.5%) improvement, in contrast to the dramatic effect on GEMM (seeLEARNINGS.md§1). Grouped-conv TFLOPS span a narrower range, so scale normalization helps less. Left on by default for stability of the warm-start path.
5. What Did Matter
- De-aliasing kernel names via the suffix-aware feature/parser change (§1) — by far the largest single improvement.
- Group-aware CV (
GroupKFoldkeyed on the dim tuple). Without it, the same physical problem with different kernels ends up in both train and val, and the CV metric is wildly optimistic. - Including dilation columns for 3D shapes (§2).
- Joining ML and oracle results by dimension tuple, not
problem_idx. Index columns in benchmark CSVs are an artifact of generation order and cannot be trusted across files; always re-key on the dim tuple.
6. Backward Variants Not Yet Upgraded
grouped_conv_bwd_data_bf16_gfx950 and grouped_conv_bwd_weight_bf16_gfx950
are still 83-feature, pre-suffix models. They load via the version-aware
Predictor but inherit the same aliasing problem the forward model used to
have. To upgrade:
- Re-benchmark (the existing CSVs do not encode wave_mode / dsb / si in the kernel names — verify before you start).
- Re-run
convert_csv_to_parquet.py(suffix-aware regex) to get parquets withwave_mode,has_dsb,has_sicolumns. - Train with
--op grouped_conv --targets tflops --n_splits 5.
Expect the same magnitude of top-1 accuracy jump that the forward model saw.
Summary of Defaults
Based on these findings, the current defaults for grouped-conv are:
- Feature engine:
GroupedConvFeatureEngineemits 97 features (38 problem + extended kernel block with suffix flags + 18 interaction + 12 hardware). - Pipeline variant set:
dispatcher/codegen/grouped_config_rules.PIPELINE_VARIANTSis the single source of truth for the 30 valid(pipeline, wave_mode, dsb, si)combinations used by both codegen and the candidate-pool generator. - Predictor loading: version-aware feature filtering in
predict.py::Predictorallows old (83/91-feature) models to coexist with the new (97-feature) suffix model under the same engine. - CV: 5-fold GroupKFold with the group key including all spatial dims and dilation.
- Target transform: log1p on tflops (consistent with GEMM defaults even though the marginal gain on grouped-conv is small).