mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
62 lines
869 B
Plaintext
62 lines
869 B
Plaintext
# Python bytecode and caches
|
|
__pycache__/
|
|
*.pyc
|
|
*.pyo
|
|
*.pyd
|
|
.Python
|
|
|
|
# Jupyter notebooks
|
|
*.ipynb
|
|
.ipynb_checkpoints/
|
|
|
|
# Virtual environments
|
|
.venv/
|
|
venv/
|
|
ENV/
|
|
|
|
# IDE and editor files
|
|
.vscode/
|
|
.idea/
|
|
*.swp
|
|
*.swo
|
|
*~
|
|
|
|
# Test output and logs
|
|
*.log
|
|
test_output.log
|
|
custom_shapes_gpu_test.log
|
|
|
|
# Benchmark and analysis output files
|
|
*.csv
|
|
*.json
|
|
!models/*/feature_spec.json
|
|
!models/*/train_manifest.json
|
|
|
|
# Data files (parquet, arrow)
|
|
*.parquet
|
|
*.arrow
|
|
|
|
# Temporary and NFS files
|
|
.nfs*
|
|
*.tmp
|
|
*.bak
|
|
|
|
# Decompressed model files (compressed .lgbm.gz versions are tracked)
|
|
models/**/*.lgbm
|
|
|
|
# User-specific test and analysis scripts
|
|
test_*.py
|
|
!tests/test_*.py
|
|
find_*.py
|
|
oracle_*.json
|
|
validation_results_*.csv
|
|
custom_shapes_*.csv
|
|
fp16_bf16_*.csv
|
|
|
|
# Ignore all markdown files except tracked documentation
|
|
*.md
|
|
!DATA_GENERATION.md
|
|
!LEARNINGS.md
|
|
!LEARNINGS_GROUPED_CONV.md
|
|
!README.md
|