mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
[rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b05040b919
commit
6989cf800c
@@ -76,16 +76,17 @@ def main():
|
||||
print("\n--- Step 1: Declare Forward Kernels ---")
|
||||
reg = GroupedConvRegistry("forward_conv")
|
||||
|
||||
# Forward 2D: compv4, 128x128 tile, wave 2x2x1, warp 32x32x16
|
||||
# Forward 2D: compv4, 64x128x64 tile (LDS 24 KiB <= 32 KiB), wave 2x2x1, warp 32x32x16
|
||||
# Constraint: tile_m == wave_m * warp_tile_m (small M handled by kPadM=True)
|
||||
reg.add(
|
||||
GroupedConvKernelConfig(
|
||||
variant="forward",
|
||||
ndim_spatial=2,
|
||||
arch=arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=64, # = wave_m(2) * warp_tile_m(32)
|
||||
tile_n=128,
|
||||
tile_k=128,
|
||||
tile_k=64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
@@ -99,18 +100,19 @@ def main():
|
||||
vector_size_b=8,
|
||||
vector_size_c=8,
|
||||
block_per_cu=1,
|
||||
double_smem_buffer=True, # required by compv4 pipeline
|
||||
)
|
||||
)
|
||||
# Forward 3D: compv3, 64x64 tile, wave 1x4x1, warp 16x16x32
|
||||
# Forward 3D: compv3, 16x64x128 tile, wave 1x4x1, warp 16x16x32
|
||||
reg.add(
|
||||
GroupedConvKernelConfig(
|
||||
variant="forward",
|
||||
ndim_spatial=3,
|
||||
arch=arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=16, # = wave_m(1) * warp_tile_m(16)
|
||||
tile_n=64,
|
||||
tile_k=64,
|
||||
tile_k=128,
|
||||
wave_m=1,
|
||||
wave_n=4,
|
||||
wave_k=1,
|
||||
|
||||
Reference in New Issue
Block a user