mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
CK Tile Unified Code Generators
Single source of truth for GEMM and Grouped Convolution kernel generation.
See also: Main Dispatcher README for installation and core concepts.
Shared Infrastructure
Both GEMM and Grouped Conv generators share common code via codegen_common.py:
TileConfig- Dataclass for tile dimensionsTraitConfigBase- Base for kernel trait configurations with arch-aware validationCommonTypeMappings- Dtype-to-C++ type mappingsparallel_generate()- Parallel kernel generation with per-kernel progress logging- Arch-aware expansion helpers (
valid_wave_configs,valid_warp_configs, etc.)
Quick Start
GEMM
cd dispatcher/codegen
# Generate standard FP16 kernels
python3 unified_gemm_codegen.py \
--output-dir ../build/generated_kernels \
--datatype fp16 \
--layout rcr \
--variants standard
# Generate all variants
python3 unified_gemm_codegen.py \
--output-dir ../build/generated_kernels \
--variants standard preshuffle multi_d
Grouped Convolution
cd dispatcher/codegen
# Generate forward FP16 grouped conv kernels
python3 unified_grouped_conv_codegen.py \
--output-dir ../build/generated_kernels \
--datatype fp16 \
--variant forward \
--ndim-spatial 2
# Generate backward data kernels
python3 unified_grouped_conv_codegen.py \
--output-dir ../build/generated_kernels \
--variant backward_data \
--ndim-spatial 2
Using from Python
from ctypes_utils import CodegenRunner, KernelConfig
# Generate from specific config
config = KernelConfig(tile_m=256, tile_n=256, tile_k=64)
codegen = CodegenRunner()
result = codegen.generate_from_config(config)
# Generate variant
result = codegen.generate("preshuffle")
# Generate all
results = codegen.generate_all()
Command Line Options
| Option | Values | Description |
|---|---|---|
--output-dir |
path | Output directory |
--datatype |
fp16, bf16, fp32, int8 |
Data type |
--layout |
rcr, rrr, crr, ccr |
Matrix layouts |
--gpu-target |
gfx942, gfx90a, gfx950 |
Target GPU |
--variants |
standard, preshuffle, multi_d |
Kernel variants |
--preselected |
fp16_rcr_essential, etc. |
Predefined kernel set |
Layout Notation
R= Row-major,C= Column-major- Order: A, B, C (e.g.,
rcr= A row, B col, C row)
Variants
Standard
Basic GEMM: C = A x B
PreShuffle
Optimized weight access with LDS pre-shuffling. Best for large matrices.
Multi-D
Element-wise fusion: C = op(A x B + D0 + D1 + ...)
Supported ops: PassThrough, MultiDAdd, Relu, Gelu, Sigmoid, Tanh
Output Structure
generated_kernels/
|---- gemm_fp16_rcr_compv4_..._128x128x32_....hpp # GEMM kernels
|---- gemm_fp16_rcr_compv4_..._preshuffle.hpp
|---- gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp
|---- grouped_conv_fwd_fp16_nhwgc_..._128x128x32_....hpp # Grouped conv kernels
+---- ...
Configuration Files
arch_specs.json
GPU architecture specifications (single source of truth):
{
"architectures": {
"gfx942": {
"family": "cdna3",
"warp_size": 64,
"warp_configs": [[2, 2, 1], [4, 4, 1]],
...
}
}
}
preselected_kernels.py
Curated kernel sets for common use cases.
Adding New GPU Support
See ADDING_NEW_GPU.md for complete guide.
Quick steps:
- Edit
arch_specs.json - Run
python generate_arch_specs.py - Rebuild
Troubleshooting
| Issue | Solution |
|---|---|
| "Arguments not supported" | Check tile config validity |
| Missing element-wise op | Check elementwise_ops.hpp |
| Compilation errors | Verify C++17, include paths |
More info: See ../README.md for full documentation.