mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
CK Tile Dispatcher - Language Bindings
This directory contains language bindings for the CK Tile Dispatcher.
Structure
bindings/
|---- ctypes/ # Python ctypes bindings (C API)
| |---- gemm_ctypes_lib.cpp # GEMM dispatcher C API
| |---- conv_ctypes_lib.cpp # Grouped conv dispatcher C API (fwd + bwd_data)
| |---- conv_bwdw_ctypes_lib.cpp # Grouped conv backward weight C API (separate library)
| |---- gpu_helper.cpp # CLI helper for Python
| +---- CMakeLists.txt
+---- README.md
ctypes Bindings
The ctypes bindings provide a C API that Python can load via ctypes.CDLL().
Building
cd build
cmake .. -DCMAKE_PREFIX_PATH=/opt/rocm
make dispatcher_gemm_lib dispatcher_conv_lib gpu_helper
Usage from Python
import ctypes
# Load the library
lib = ctypes.CDLL("path/to/libdispatcher_gemm_lib.so")
# Initialize
lib.dispatcher_init()
# Check if problem is supported
is_supported = lib.dispatcher_is_supported(M, N, K)
# Run GEMM
time_ms = ctypes.c_float()
result = lib.dispatcher_run_gemm(
A_ptr, B_ptr, C_ptr,
M, N, K,
ctypes.byref(time_ms)
)
# Cleanup
lib.dispatcher_cleanup()
GEMM API
| Function | Description |
|---|---|
dispatcher_init() |
Initialize the dispatcher |
dispatcher_is_supported(M, N, K) |
Check if problem size is supported |
dispatcher_select_kernel(M, N, K, name_buf, buf_size) |
Get kernel name for problem |
dispatcher_run_gemm(A, B, C, M, N, K, time_ms) |
Execute GEMM |
dispatcher_get_kernel_count() |
Get number of registered kernels |
dispatcher_export_registry_json() |
Export registry as JSON |
dispatcher_cleanup() |
Release resources |
Grouped Convolution API
| Function | Description |
|---|---|
conv_dispatcher_init() |
Initialize the dispatcher |
conv_dispatcher_is_supported(prob) |
Check if problem is supported |
conv_dispatcher_select_kernel(prob, name_buf, buf_size) |
Get kernel name |
conv_dispatcher_run(input, weight, output, prob, stream) |
Execute convolution |
conv_dispatcher_get_kernel_count() |
Get number of registered kernels |
conv_dispatcher_cleanup() |
Release resources |
GPU Helper
The gpu_helper executable provides a CLI interface for Python:
./gpu_helper 1024 1024 1024 --validate
Output is JSON for easy parsing:
{
"problem": {"M": 1024, "N": 1024, "K": 1024},
"kernel": "gemm_fp16_rcr_...",
"execution": {
"time_ms": 0.5,
"tflops": 4.2
},
"validation": {
"accuracy": 100.0
},
"status": "success"
}
Examples
See the examples that use these bindings:
- GEMM:
dispatcher/examples/gemm/python/
Grouped Convolution
Grouped convolution C++ headers and Python utilities are in:
- C++ Headers:
dispatcher/include/ck_tile/dispatcher/grouped_conv_*.hpp - Python Utils:
dispatcher/python/grouped_conv_utils.py - Build Script:
dispatcher/scripts/compile_grouped_conv_examples.py