mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
311 lines
10 KiB
Python
311 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
|
||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||
# SPDX-License-Identifier: MIT
|
||
|
||
"""
|
||
Single Source of Truth for Grouped Convolution Tile Configurations
|
||
|
||
This module defines all valid tile configurations for grouped convolution kernels.
|
||
Both codegen and instance_builder import from here to ensure consistency.
|
||
|
||
Architecture:
|
||
grouped_conv_tile_configs.py (SOURCE OF TRUTH)
|
||
├── Used by unified_grouped_conv_codegen.py
|
||
└── Used by grouped_conv_instance_builder.py
|
||
"""
|
||
|
||
from typing import Dict, List, Tuple
|
||
|
||
# =============================================================================
|
||
# Tile Configurations (Single Source of Truth)
|
||
# =============================================================================
|
||
|
||
# Common tile configurations used across variants
|
||
# Format: (tile_m, tile_n, tile_k)
|
||
# CRITICAL: tile_m MUST equal wave_m × warp_tile_m (TileGemmShape constraint)
|
||
# Only tiles that successfully compile are included
|
||
COMMON_TILES: List[Tuple[int, int, int]] = [
|
||
# Using warp_tile [16,16,16]: tile_m = wave_m × 16
|
||
(16, 64, 64), # 1 × 16 = 16, wave=(1,4,1)
|
||
(32, 64, 64), # 2 × 16 = 32, wave=(2,2,1)
|
||
(64, 64, 64), # 4 × 16 = 64, wave=(4,1,1)
|
||
# (128, 64, 64), # 8 × 16 = 128, wave=(8,2,1) - EXCLUDED: Compile error
|
||
# Using warp_tile [32,32,16]: tile_m = wave_m × 32
|
||
(32, 128, 64), # 1 × 32 = 32, wave=(1,4,1)
|
||
(64, 128, 64), # 2 × 32 = 64, wave=(2,2,1)
|
||
(128, 128, 64), # 4 × 32 = 128, wave=(4,4,1) - NEW!
|
||
# Note: 256x64x64 excluded - compilation issues
|
||
# Using warp_tile [16,16,32]: tile_m = wave_m × 16
|
||
(16, 64, 128), # 1 × 16 = 16, wave=(1,4,1)
|
||
(32, 64, 128), # 2 × 16 = 32, wave=(2,2,1)
|
||
(64, 64, 128), # 4 × 16 = 64, wave=(4,1,1)
|
||
(128, 64, 128), # 8 × 16 = 128, wave=(8,2,1) - NEW!
|
||
# Note: Excluded tiles:
|
||
# - 128x64x64: wave=8x2x1, warp=16x16x16 - compile error
|
||
# - 32x128x128, 64x128x128, 128x128x128, 256x128x128 (warp_tile 32x32x32) - compv4 issues
|
||
# - 256x64x64, 256x128x128 - arch filter rejection
|
||
]
|
||
|
||
# Wave configurations per tile
|
||
# Key: (tile_m, tile_n, tile_k) -> (wave_m, wave_n, wave_k)
|
||
# Constraint: tile_m == wave_m × warp_tile_m
|
||
# Only use approved wave configs from arch_specs.json: [1,4,1], [2,2,1], [4,1,1], [8,2,1], [4,4,1]
|
||
TILE_TO_WAVE: Dict[Tuple[int, int, int], Tuple[int, int, int]] = {
|
||
# warp_tile [16,16,16]
|
||
(16, 64, 64): (1, 4, 1),
|
||
(32, 64, 64): (2, 2, 1),
|
||
(64, 64, 64): (4, 1, 1),
|
||
# warp_tile [32,32,16]
|
||
(32, 128, 64): (1, 4, 1),
|
||
(64, 128, 64): (2, 2, 1),
|
||
(128, 128, 64): (4, 4, 1), # NEW - balanced 4x4 wave
|
||
# warp_tile [16,16,32]
|
||
(16, 64, 128): (1, 4, 1),
|
||
(32, 64, 128): (2, 2, 1),
|
||
(64, 64, 128): (4, 1, 1),
|
||
(128, 64, 128): (8, 2, 1), # NEW
|
||
}
|
||
|
||
# Warp tile configurations (must match arch_specs.json gfx950 bf16 approved list)
|
||
# Key: (tile_m, tile_n, tile_k) -> (warp_m, warp_n, warp_k)
|
||
TILE_TO_WARP: Dict[Tuple[int, int, int], Tuple[int, int, int]] = {
|
||
# warp_tile [16,16,16]
|
||
(16, 64, 64): (16, 16, 16),
|
||
(32, 64, 64): (16, 16, 16),
|
||
(64, 64, 64): (16, 16, 16),
|
||
# warp_tile [32,32,16]
|
||
(32, 128, 64): (32, 32, 16),
|
||
(64, 128, 64): (32, 32, 16),
|
||
(128, 128, 64): (32, 32, 16), # NEW
|
||
# warp_tile [16,16,32]
|
||
(16, 64, 128): (16, 16, 32),
|
||
(32, 64, 128): (16, 16, 32),
|
||
(64, 64, 128): (16, 16, 32),
|
||
(128, 64, 128): (16, 16, 32), # NEW
|
||
}
|
||
|
||
# Vector sizes per tile (for memory operations)
|
||
# Key: (tile_m, tile_n, tile_k) -> (vec_a, vec_b, vec_c)
|
||
TILE_TO_VECTOR: Dict[Tuple[int, int, int], Tuple[int, int, int]] = {
|
||
(16, 64, 64): (4, 8, 8),
|
||
(32, 64, 64): (4, 8, 8),
|
||
(64, 64, 64): (4, 8, 8),
|
||
(32, 128, 64): (4, 8, 8),
|
||
(64, 128, 64): (4, 8, 8),
|
||
(128, 128, 64): (4, 8, 8),
|
||
(16, 64, 128): (4, 8, 8),
|
||
(32, 64, 128): (4, 8, 8),
|
||
(64, 64, 128): (4, 8, 8),
|
||
(128, 64, 128): (4, 8, 8),
|
||
}
|
||
|
||
# =============================================================================
|
||
# Pipeline Variant Suffixes (single source of truth)
|
||
# =============================================================================
|
||
# Empirically verified valid (pipeline, wave_mode, has_dsb, has_si) combinations
|
||
# observed in the 2D and 3D bf16 gfx950 benchmark CSVs. 30 entries total per ndim.
|
||
# Each tuple: (pipeline, wave_mode, has_dsb, has_si)
|
||
# wave_mode: "intrawave" | "interwave"
|
||
# has_dsb: 1 if "_dsb" suffix present (double smem buffer), else 0
|
||
# has_si: 1 if "_si" suffix present (store immediate), else 0
|
||
PIPELINE_VARIANTS: List[Tuple[str, str, int, int]] = [
|
||
# basic_v1: both intra/inter × {∅, dsb, si, dsb_si} = 8 combos
|
||
("basic_v1", "intrawave", 0, 0),
|
||
("basic_v1", "intrawave", 1, 0),
|
||
("basic_v1", "intrawave", 0, 1),
|
||
("basic_v1", "intrawave", 1, 1),
|
||
("basic_v1", "interwave", 0, 0),
|
||
("basic_v1", "interwave", 1, 0),
|
||
("basic_v1", "interwave", 0, 1),
|
||
("basic_v1", "interwave", 1, 1),
|
||
# compv3: intrawave × {∅, dsb, si, dsb_si} = 4 combos
|
||
("compv3", "intrawave", 0, 0),
|
||
("compv3", "intrawave", 1, 0),
|
||
("compv3", "intrawave", 0, 1),
|
||
("compv3", "intrawave", 1, 1),
|
||
# compv4: intrawave × {dsb, dsb_si} only = 2 combos
|
||
("compv4", "intrawave", 1, 0),
|
||
("compv4", "intrawave", 1, 1),
|
||
# compv5: intrawave × {∅, dsb, si, dsb_si} = 4 combos
|
||
("compv5", "intrawave", 0, 0),
|
||
("compv5", "intrawave", 1, 0),
|
||
("compv5", "intrawave", 0, 1),
|
||
("compv5", "intrawave", 1, 1),
|
||
# compv6: intrawave × {∅, dsb, si, dsb_si} = 4 combos
|
||
("compv6", "intrawave", 0, 0),
|
||
("compv6", "intrawave", 1, 0),
|
||
("compv6", "intrawave", 0, 1),
|
||
("compv6", "intrawave", 1, 1),
|
||
# mem: both intra/inter × {∅, dsb, si, dsb_si} = 8 combos
|
||
("mem", "intrawave", 0, 0),
|
||
("mem", "intrawave", 1, 0),
|
||
("mem", "intrawave", 0, 1),
|
||
("mem", "intrawave", 1, 1),
|
||
("mem", "interwave", 0, 0),
|
||
("mem", "interwave", 1, 0),
|
||
("mem", "interwave", 0, 1),
|
||
("mem", "interwave", 1, 1),
|
||
]
|
||
|
||
|
||
def iter_pipeline_variants(pipelines: List[str] = None):
|
||
"""Iterate (pipeline, wave_mode, has_dsb, has_si) tuples, optionally filtered.
|
||
|
||
Args:
|
||
pipelines: optional list of pipeline names to keep. If None, yield all.
|
||
"""
|
||
if pipelines is None:
|
||
for entry in PIPELINE_VARIANTS:
|
||
yield entry
|
||
return
|
||
keep = set(pipelines)
|
||
for entry in PIPELINE_VARIANTS:
|
||
if entry[0] in keep:
|
||
yield entry
|
||
|
||
|
||
# Valid pipelines per variant
|
||
# All 8 pipelines (basic_v1, mem, compv3-6, comp_async, basic_async_v1) successfully
|
||
# build and run for all variants in both 2D and 3D (verified via 10_test_all_pipelines.py)
|
||
VARIANT_PIPELINES: Dict[str, List[str]] = {
|
||
"forward": [
|
||
"basic_v1",
|
||
"mem",
|
||
"compv3",
|
||
"compv4",
|
||
"compv5",
|
||
"compv6",
|
||
"comp_async",
|
||
"basic_async_v1",
|
||
],
|
||
"bwd_data": [
|
||
"basic_v1",
|
||
"mem",
|
||
"compv3",
|
||
"compv4",
|
||
"compv5",
|
||
"compv6",
|
||
"comp_async",
|
||
"basic_async_v1",
|
||
],
|
||
"bwd_weight": [
|
||
"basic_v1",
|
||
"mem",
|
||
"compv3",
|
||
"compv4",
|
||
"compv5",
|
||
"compv6",
|
||
"comp_async",
|
||
"basic_async_v1",
|
||
],
|
||
}
|
||
|
||
# Tiles that support compv4 pipeline
|
||
# compv4 has stricter requirements due to double buffering and LDS constraints
|
||
# Pattern: only warp_tile [16,16,16] or [16,16,32] work with compv4
|
||
# Large warp_tile [32,32,16] and wave [8,2,1] fail arch validation for compv4
|
||
COMPV4_COMPATIBLE_TILES: List[Tuple[int, int, int]] = [
|
||
# warp_tile [16,16,16] - all work with compv4
|
||
(16, 64, 64),
|
||
(32, 64, 64),
|
||
(64, 64, 64),
|
||
# (128, 64, 64), # Excluded: wave=8x2x1 fails for compv4
|
||
# warp_tile [16,16,32] - all work with compv4
|
||
(16, 64, 128),
|
||
(32, 64, 128),
|
||
(64, 64, 128),
|
||
# (128, 64, 128), # Excluded: wave=8x2x1 fails for compv4
|
||
]
|
||
|
||
# Backward weight tiles (very restricted due to transpose_tile2d constraints)
|
||
# Testing all tiles to verify which ones actually work
|
||
BWD_WEIGHT_TILES: List[Tuple[int, int, int]] = [
|
||
# warp_tile [16,16,16]
|
||
(16, 64, 64), # Known working config
|
||
(32, 64, 64), # Test
|
||
(64, 64, 64), # Test
|
||
# warp_tile [32,32,16]
|
||
(32, 128, 64), # Test
|
||
(64, 128, 64), # Test
|
||
(128, 128, 64), # Test
|
||
# warp_tile [16,16,32]
|
||
(16, 64, 128), # Test
|
||
(32, 64, 128), # Test
|
||
(64, 64, 128), # Test
|
||
(128, 64, 128), # Test
|
||
]
|
||
|
||
# =============================================================================
|
||
# Validation
|
||
# =============================================================================
|
||
|
||
|
||
def validate_tile_config(tile_m: int, tile_n: int, tile_k: int) -> bool:
|
||
"""Check if a tile configuration is valid and registered."""
|
||
tile_key = (tile_m, tile_n, tile_k)
|
||
return (
|
||
tile_key in TILE_TO_WAVE
|
||
and tile_key in TILE_TO_WARP
|
||
and tile_key in TILE_TO_VECTOR
|
||
)
|
||
|
||
|
||
def get_tile_full_config(tile_m: int, tile_n: int, tile_k: int) -> dict:
|
||
"""Get complete configuration for a tile size.
|
||
|
||
Returns:
|
||
dict with keys: wave_m, wave_n, wave_k, warp_m, warp_n, warp_k, vec_a, vec_b, vec_c
|
||
or None if tile not found
|
||
"""
|
||
tile_key = (tile_m, tile_n, tile_k)
|
||
if not validate_tile_config(tile_m, tile_n, tile_k):
|
||
return None
|
||
|
||
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
|
||
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
|
||
vec_a, vec_b, vec_c = TILE_TO_VECTOR[tile_key]
|
||
|
||
return {
|
||
"tile_m": tile_m,
|
||
"tile_n": tile_n,
|
||
"tile_k": tile_k,
|
||
"wave_m": wave_m,
|
||
"wave_n": wave_n,
|
||
"wave_k": wave_k,
|
||
"warp_m": warp_m,
|
||
"warp_n": warp_n,
|
||
"warp_k": warp_k,
|
||
"vec_a": vec_a,
|
||
"vec_b": vec_b,
|
||
"vec_c": vec_c,
|
||
}
|
||
|
||
|
||
# =============================================================================
|
||
# Summary Statistics
|
||
# =============================================================================
|
||
|
||
|
||
def print_summary():
|
||
"""Print summary of available tile configurations."""
|
||
print("=" * 80)
|
||
print("Grouped Convolution Tile Configurations (Single Source of Truth)")
|
||
print("=" * 80)
|
||
print(f"Total tiles: {len(COMMON_TILES)}")
|
||
print(f"Backward weight tiles: {len(BWD_WEIGHT_TILES)}")
|
||
print()
|
||
print("Tile sizes (M×N×K):")
|
||
for tile in COMMON_TILES:
|
||
m, n, k = tile
|
||
wave = TILE_TO_WAVE[tile]
|
||
warp = TILE_TO_WARP[tile]
|
||
print(
|
||
f" {m:3}×{n:3}×{k:3} wave={wave[0]}×{wave[1]}×{wave[2]} warp={warp[0]}×{warp[1]}×{warp[2]}"
|
||
)
|
||
print("=" * 80)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print_summary()
|