mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
CK Tile Dispatcher Examples
Comprehensive examples for GEMM and Grouped Convolution operations with GPU execution.
Quick Start
Step 1: Build
cd /path/to/composable_kernel/dispatcher
mkdir -p build && cd build
cmake .. \
-DCMAKE_PREFIX_PATH=/opt/rocm \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_BUILD_TYPE=Release \
-DGPU_TARGETS="gfx942" \
-DBUILD_DISPATCHER_EXAMPLES=ON
# Build everything (C++ examples + Python libraries)
make -j$(nproc)
# Or build ONLY Python libraries (faster)
make python_libs -j$(nproc)
Step 2: Run C++ Examples
cd build/examples
# GEMM
./gemm_01_basic
./gemm_02_multi_size
./gemm_03_benchmark_validation
./gemm_04_heuristics
./gemm_05_json_export
./gemm_06_multi_registry
Step 3: Run Python Examples
cd /path/to/composable_kernel/dispatcher
# GEMM
python3 examples/gemm/python/01_basic_gemm.py
python3 examples/gemm/python/04_validation.py
python3 examples/gemm/python/07_stress_test.py
python3 examples/gemm/python/08_heuristics.py
Directory Structure
examples/
|---- gemm/
| |---- cpp/ # 6 C++ GEMM examples
| +---- python/ # 11 Python GEMM examples
|
+---- README.md
GEMM Examples
C++ Examples
| # | Example | Description |
|---|---|---|
| 01 | gemm_01_basic |
Basic GEMM with declarative API, autofill, autocorrect |
| 02 | gemm_02_multi_size |
Wildcard expansion for multiple configurations |
| 03 | gemm_03_benchmark_validation |
Performance benchmarking with CPU/GPU validation |
| 04 | gemm_04_heuristics |
Heuristic-based kernel selection |
| 05 | gemm_05_json_export |
Registry JSON export for external tools |
| 06 | gemm_06_multi_registry |
Multiple registries with named kernel sets |
Details: gemm/cpp/README.md
Python Examples
| # | Example | Description |
|---|---|---|
| 01 | 01_basic_gemm.py |
Basic GEMM with multi-kernel support |
| 02 | 02_batch_gemm.py |
Batched GEMM operations |
| 03 | 03_benchmark.py |
Performance benchmarking |
| 04 | 04_validation.py |
CPU reference validation |
| 05 | 05_numpy_integration.py |
NumPy array integration |
| 06 | 06_json_export.py |
Registry JSON export |
| 07 | 07_stress_test.py |
Multi-kernel stress testing (48 configs) |
| 08 | 08_heuristics.py |
Heuristic-based kernel selection (24 configs) |
| 09 | 09_multi_registry.py |
Multiple registries |
| 10 | 10_advanced_benchmark.py |
Advanced benchmark with full control |
| 11 | 11_json_import.py |
Import kernels from JSON |
Details: gemm/python/README.md
Key Features
Declarative Kernel API
Both C++ and Python examples use a declarative approach:
C++ (DECL_KERNEL_SET macro):
DECL_KERNEL_SET(my_kernels,
.add(
Signature().dtype("fp16").layout("rcr"),
Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16)
.pipeline("compv4").scheduler("intrawave"),
"gfx942"
)
);
Python (KernelConfig):
config = KernelConfig(
tile_m=256, tile_n=256, tile_k=32,
wave_m=2, wave_n=2, wave_k=1,
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
pipeline="compv4", scheduler="intrawave"
)
Autofill and Autocorrect
The build system automatically:
- Autofills missing parameters with sensible defaults
- Autocorrects invalid parameters based on architecture constraints
- Expands wildcards (
*,-1,ANY_INT) to all valid configurations
Architecture Filtering
Kernel configurations are validated against GPU architecture constraints:
- Tile divisibility requirements
- Warp tile constraints
- Pipeline compatibility
Invalid configurations are automatically pruned during code generation.
Validation Examples
C++ Validation
./gemm_03_benchmark_validation --verify 1 # GEMM with CPU reference
./gemm_03_benchmark_validation --verify 2 # GEMM with GPU reference
Python Validation
python3 examples/gemm/python/04_validation.py
python3 examples/gemm/python/07_stress_test.py # Multi-kernel validation
Troubleshooting
Python: Library not found
# Run from dispatcher directory
cd /path/to/composable_kernel/dispatcher
python3 examples/gemm/python/01_basic_gemm.py
C++: Executables not found
# Build with examples enabled
cmake .. -DBUILD_DISPATCHER_EXAMPLES=ON
make -j$(nproc)
# Run from build/examples
cd build/examples
./gemm_01_basic
GPU not detected
rocminfo | grep "Name:"
# Should show: gfx942, gfx90a, etc.
Grouped Convolution
Grouped convolution support has been re-introduced with a unified infrastructure shared with GEMM.
Infrastructure
The grouped convolution code generation, utilities, and build scripts are available:
| Component | Location |
|---|---|
| C++ Headers | include/ck_tile/dispatcher/grouped_conv_*.hpp |
| Python Codegen | codegen/unified_grouped_conv_codegen.py |
| Python Utils | python/grouped_conv_utils.py |
| Build Script | scripts/compile_grouped_conv_examples.py |
Building Grouped Conv Kernels
# Generate grouped conv kernels
python3 codegen/unified_grouped_conv_codegen.py \
--output-dir build/generated_kernels \
--datatype fp16 --variant forward --ndim-spatial 2
# Compile a grouped conv example
python3 scripts/compile_grouped_conv_examples.py my_grouped_conv_example.cpp
See the main README for more details.