mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
CK Tile Dispatcher Python Utilities
This directory contains Python utilities used by the dispatcher examples.
Contents
Shared Utilities (used by both GEMM and Grouped Conv)
dispatcher_common.py- Shared dispatcher infrastructure- Path helpers (
get_dispatcher_root,get_build_dir, etc.) ValidationResultBase- Structured validation feedbackvalidate_wave_config,validate_warp_tile_config,validate_trait_comboauto_correct_wave,auto_correct_trait- Auto-correction helpersColors- Cross-platform ANSI color supportprint_phase,print_success,print_error,print_info- Phased outputcleanup_generated_kernels- Cleanup helper
- Path helpers (
GEMM Utilities
ctypes_utils.py- Core ctypes utilities for GEMM Python examplesKernelConfig- Kernel configuration dataclasssetup_gemm_dispatcher()- Setup dispatcher with auto-correctioncleanup_gemm()- Cleanup dispatcher resourcesGemmRunner- GPU execution helper- Auto-correction and validation utilities
Grouped Convolution Utilities
grouped_conv_utils.py- Utilities for grouped convolutionGroupedConvValidationResult- Validation result (extendsValidationResultBase)validate_grouped_conv_config- Validate a grouped conv configauto_correct_grouped_conv_config- Auto-correct invalid configsget_grouped_conv_default_config- Get default config for a variantGroupedConvDataType- Data type enum (FP16, BF16, FP32, FP8, BF8, INT8)format_grouped_conv_summary- Human-readable config summary
Usage
GEMM Examples
The GEMM Python examples in dispatcher/examples/gemm/python/ import:
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
GemmRunner,
)
Grouped Conv Usage
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from grouped_conv_utils import (
validate_grouped_conv_config,
auto_correct_grouped_conv_config,
get_grouped_conv_default_config,
GroupedConvDataType,
)
# Get a default config
config = get_grouped_conv_default_config(variant="forward", arch="gfx942")
# Validate
result = validate_grouped_conv_config(config)
print(f"Valid: {result.is_valid}")
Requirements
- Python 3.8+
- NumPy
- HIP runtime (for GPU execution)