Files
composable_kernel/dispatcher/codegen/grouped_config_rules.py
Yaswanth Raparti 6989cf800c [rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for
 grouped conv (#6327)

## Motivation
The ML heuristic in dispatcher does not support grouped-conv operator
yet. In this PR, the support for fwd, bdw-data, and bwd-weight
grouped-conv kernels have been added. A tile_engine utility has also
been added to compile and run any selected kernel configuration through
dispatcher infrastructure.

## Technical Details

1. Tile engine utility is added to benchmark each shape with all the
possible kernel+tile_size combinations here -
[https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url)
2. New LGBM regressor models for grouped conv are added to models
directory. We have 3 separate models for fwd, bwd-data, and bwd-weights
[https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url)
3. Implemented lazy GPU initialization (dispatcher/python)
- **Issue**: ProcessPoolExecutor fork() + GPU context caused memory
access faults
- **Solution**: Mirror FMHA pattern - defer GPU initialization until
first run()
  - **Changes**:
- setup_multiple_grouped_conv_dispatchers() returns List[Path], not
loaded libs
    - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL
    - Added _ensure_initialized() method for lazy GPU loading
    - GPU context created only on first run() call
  - **Benefit**: Parallel compilation now works without GPU conflicts
4. Addressed few miscellaneous issues such as:
  - Fixed BF16->FP16 naming bug in the dispatcher wrapper
- Added new tile sizes, and comp_v5 pipeline to the arch spec to expand
the kernel selection
- Added automatic padding support for unsupported shapes in dispatcher
runner
- Created a single source of truth between tile_engine and dispatcher
about the architecture and tile_size details
- Build a validation scripts to compare oracle_best vs ml_heuristic
comparison

## Test Plan

1. Validated fwd, bwd-data, and bwd-weight kernels with both known and
unseen data sets with up to 300 problems.
2. Ensured that test cases are added in both dispatcher and tile_engine
to validate the heuristic.

## Test Result
Results on Unseen shapes validated on gfx950
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.05%**
  - Median Efficiency: **96.8%**
  - P10 Efficiency: **79.9%**

#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.8%**
  - Median Efficiency: **96.5%**
  - P10 Efficiency: **82.9%**

#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **96.1%**
  - Median Efficiency: **99.2%**
  - P10 Efficiency: **89.4%**

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-08 20:48:42 +00:00

311 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Single Source of Truth for Grouped Convolution Tile Configurations
This module defines all valid tile configurations for grouped convolution kernels.
Both codegen and instance_builder import from here to ensure consistency.
Architecture:
grouped_conv_tile_configs.py (SOURCE OF TRUTH)
├── Used by unified_grouped_conv_codegen.py
└── Used by grouped_conv_instance_builder.py
"""
from typing import Dict, List, Tuple
# =============================================================================
# Tile Configurations (Single Source of Truth)
# =============================================================================
# Common tile configurations used across variants
# Format: (tile_m, tile_n, tile_k)
# CRITICAL: tile_m MUST equal wave_m × warp_tile_m (TileGemmShape constraint)
# Only tiles that successfully compile are included
COMMON_TILES: List[Tuple[int, int, int]] = [
# Using warp_tile [16,16,16]: tile_m = wave_m × 16
(16, 64, 64), # 1 × 16 = 16, wave=(1,4,1)
(32, 64, 64), # 2 × 16 = 32, wave=(2,2,1)
(64, 64, 64), # 4 × 16 = 64, wave=(4,1,1)
# (128, 64, 64), # 8 × 16 = 128, wave=(8,2,1) - EXCLUDED: Compile error
# Using warp_tile [32,32,16]: tile_m = wave_m × 32
(32, 128, 64), # 1 × 32 = 32, wave=(1,4,1)
(64, 128, 64), # 2 × 32 = 64, wave=(2,2,1)
(128, 128, 64), # 4 × 32 = 128, wave=(4,4,1) - NEW!
# Note: 256x64x64 excluded - compilation issues
# Using warp_tile [16,16,32]: tile_m = wave_m × 16
(16, 64, 128), # 1 × 16 = 16, wave=(1,4,1)
(32, 64, 128), # 2 × 16 = 32, wave=(2,2,1)
(64, 64, 128), # 4 × 16 = 64, wave=(4,1,1)
(128, 64, 128), # 8 × 16 = 128, wave=(8,2,1) - NEW!
# Note: Excluded tiles:
# - 128x64x64: wave=8x2x1, warp=16x16x16 - compile error
# - 32x128x128, 64x128x128, 128x128x128, 256x128x128 (warp_tile 32x32x32) - compv4 issues
# - 256x64x64, 256x128x128 - arch filter rejection
]
# Wave configurations per tile
# Key: (tile_m, tile_n, tile_k) -> (wave_m, wave_n, wave_k)
# Constraint: tile_m == wave_m × warp_tile_m
# Only use approved wave configs from arch_specs.json: [1,4,1], [2,2,1], [4,1,1], [8,2,1], [4,4,1]
TILE_TO_WAVE: Dict[Tuple[int, int, int], Tuple[int, int, int]] = {
# warp_tile [16,16,16]
(16, 64, 64): (1, 4, 1),
(32, 64, 64): (2, 2, 1),
(64, 64, 64): (4, 1, 1),
# warp_tile [32,32,16]
(32, 128, 64): (1, 4, 1),
(64, 128, 64): (2, 2, 1),
(128, 128, 64): (4, 4, 1), # NEW - balanced 4x4 wave
# warp_tile [16,16,32]
(16, 64, 128): (1, 4, 1),
(32, 64, 128): (2, 2, 1),
(64, 64, 128): (4, 1, 1),
(128, 64, 128): (8, 2, 1), # NEW
}
# Warp tile configurations (must match arch_specs.json gfx950 bf16 approved list)
# Key: (tile_m, tile_n, tile_k) -> (warp_m, warp_n, warp_k)
TILE_TO_WARP: Dict[Tuple[int, int, int], Tuple[int, int, int]] = {
# warp_tile [16,16,16]
(16, 64, 64): (16, 16, 16),
(32, 64, 64): (16, 16, 16),
(64, 64, 64): (16, 16, 16),
# warp_tile [32,32,16]
(32, 128, 64): (32, 32, 16),
(64, 128, 64): (32, 32, 16),
(128, 128, 64): (32, 32, 16), # NEW
# warp_tile [16,16,32]
(16, 64, 128): (16, 16, 32),
(32, 64, 128): (16, 16, 32),
(64, 64, 128): (16, 16, 32),
(128, 64, 128): (16, 16, 32), # NEW
}
# Vector sizes per tile (for memory operations)
# Key: (tile_m, tile_n, tile_k) -> (vec_a, vec_b, vec_c)
TILE_TO_VECTOR: Dict[Tuple[int, int, int], Tuple[int, int, int]] = {
(16, 64, 64): (4, 8, 8),
(32, 64, 64): (4, 8, 8),
(64, 64, 64): (4, 8, 8),
(32, 128, 64): (4, 8, 8),
(64, 128, 64): (4, 8, 8),
(128, 128, 64): (4, 8, 8),
(16, 64, 128): (4, 8, 8),
(32, 64, 128): (4, 8, 8),
(64, 64, 128): (4, 8, 8),
(128, 64, 128): (4, 8, 8),
}
# =============================================================================
# Pipeline Variant Suffixes (single source of truth)
# =============================================================================
# Empirically verified valid (pipeline, wave_mode, has_dsb, has_si) combinations
# observed in the 2D and 3D bf16 gfx950 benchmark CSVs. 30 entries total per ndim.
# Each tuple: (pipeline, wave_mode, has_dsb, has_si)
# wave_mode: "intrawave" | "interwave"
# has_dsb: 1 if "_dsb" suffix present (double smem buffer), else 0
# has_si: 1 if "_si" suffix present (store immediate), else 0
PIPELINE_VARIANTS: List[Tuple[str, str, int, int]] = [
# basic_v1: both intra/inter × {∅, dsb, si, dsb_si} = 8 combos
("basic_v1", "intrawave", 0, 0),
("basic_v1", "intrawave", 1, 0),
("basic_v1", "intrawave", 0, 1),
("basic_v1", "intrawave", 1, 1),
("basic_v1", "interwave", 0, 0),
("basic_v1", "interwave", 1, 0),
("basic_v1", "interwave", 0, 1),
("basic_v1", "interwave", 1, 1),
# compv3: intrawave × {∅, dsb, si, dsb_si} = 4 combos
("compv3", "intrawave", 0, 0),
("compv3", "intrawave", 1, 0),
("compv3", "intrawave", 0, 1),
("compv3", "intrawave", 1, 1),
# compv4: intrawave × {dsb, dsb_si} only = 2 combos
("compv4", "intrawave", 1, 0),
("compv4", "intrawave", 1, 1),
# compv5: intrawave × {∅, dsb, si, dsb_si} = 4 combos
("compv5", "intrawave", 0, 0),
("compv5", "intrawave", 1, 0),
("compv5", "intrawave", 0, 1),
("compv5", "intrawave", 1, 1),
# compv6: intrawave × {∅, dsb, si, dsb_si} = 4 combos
("compv6", "intrawave", 0, 0),
("compv6", "intrawave", 1, 0),
("compv6", "intrawave", 0, 1),
("compv6", "intrawave", 1, 1),
# mem: both intra/inter × {∅, dsb, si, dsb_si} = 8 combos
("mem", "intrawave", 0, 0),
("mem", "intrawave", 1, 0),
("mem", "intrawave", 0, 1),
("mem", "intrawave", 1, 1),
("mem", "interwave", 0, 0),
("mem", "interwave", 1, 0),
("mem", "interwave", 0, 1),
("mem", "interwave", 1, 1),
]
def iter_pipeline_variants(pipelines: List[str] = None):
"""Iterate (pipeline, wave_mode, has_dsb, has_si) tuples, optionally filtered.
Args:
pipelines: optional list of pipeline names to keep. If None, yield all.
"""
if pipelines is None:
for entry in PIPELINE_VARIANTS:
yield entry
return
keep = set(pipelines)
for entry in PIPELINE_VARIANTS:
if entry[0] in keep:
yield entry
# Valid pipelines per variant
# All 8 pipelines (basic_v1, mem, compv3-6, comp_async, basic_async_v1) successfully
# build and run for all variants in both 2D and 3D (verified via 10_test_all_pipelines.py)
VARIANT_PIPELINES: Dict[str, List[str]] = {
"forward": [
"basic_v1",
"mem",
"compv3",
"compv4",
"compv5",
"compv6",
"comp_async",
"basic_async_v1",
],
"bwd_data": [
"basic_v1",
"mem",
"compv3",
"compv4",
"compv5",
"compv6",
"comp_async",
"basic_async_v1",
],
"bwd_weight": [
"basic_v1",
"mem",
"compv3",
"compv4",
"compv5",
"compv6",
"comp_async",
"basic_async_v1",
],
}
# Tiles that support compv4 pipeline
# compv4 has stricter requirements due to double buffering and LDS constraints
# Pattern: only warp_tile [16,16,16] or [16,16,32] work with compv4
# Large warp_tile [32,32,16] and wave [8,2,1] fail arch validation for compv4
COMPV4_COMPATIBLE_TILES: List[Tuple[int, int, int]] = [
# warp_tile [16,16,16] - all work with compv4
(16, 64, 64),
(32, 64, 64),
(64, 64, 64),
# (128, 64, 64), # Excluded: wave=8x2x1 fails for compv4
# warp_tile [16,16,32] - all work with compv4
(16, 64, 128),
(32, 64, 128),
(64, 64, 128),
# (128, 64, 128), # Excluded: wave=8x2x1 fails for compv4
]
# Backward weight tiles (very restricted due to transpose_tile2d constraints)
# Testing all tiles to verify which ones actually work
BWD_WEIGHT_TILES: List[Tuple[int, int, int]] = [
# warp_tile [16,16,16]
(16, 64, 64), # Known working config
(32, 64, 64), # Test
(64, 64, 64), # Test
# warp_tile [32,32,16]
(32, 128, 64), # Test
(64, 128, 64), # Test
(128, 128, 64), # Test
# warp_tile [16,16,32]
(16, 64, 128), # Test
(32, 64, 128), # Test
(64, 64, 128), # Test
(128, 64, 128), # Test
]
# =============================================================================
# Validation
# =============================================================================
def validate_tile_config(tile_m: int, tile_n: int, tile_k: int) -> bool:
"""Check if a tile configuration is valid and registered."""
tile_key = (tile_m, tile_n, tile_k)
return (
tile_key in TILE_TO_WAVE
and tile_key in TILE_TO_WARP
and tile_key in TILE_TO_VECTOR
)
def get_tile_full_config(tile_m: int, tile_n: int, tile_k: int) -> dict:
"""Get complete configuration for a tile size.
Returns:
dict with keys: wave_m, wave_n, wave_k, warp_m, warp_n, warp_k, vec_a, vec_b, vec_c
or None if tile not found
"""
tile_key = (tile_m, tile_n, tile_k)
if not validate_tile_config(tile_m, tile_n, tile_k):
return None
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
vec_a, vec_b, vec_c = TILE_TO_VECTOR[tile_key]
return {
"tile_m": tile_m,
"tile_n": tile_n,
"tile_k": tile_k,
"wave_m": wave_m,
"wave_n": wave_n,
"wave_k": wave_k,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_k": warp_k,
"vec_a": vec_a,
"vec_b": vec_b,
"vec_c": vec_c,
}
# =============================================================================
# Summary Statistics
# =============================================================================
def print_summary():
"""Print summary of available tile configurations."""
print("=" * 80)
print("Grouped Convolution Tile Configurations (Single Source of Truth)")
print("=" * 80)
print(f"Total tiles: {len(COMMON_TILES)}")
print(f"Backward weight tiles: {len(BWD_WEIGHT_TILES)}")
print()
print("Tile sizes (M×N×K):")
for tile in COMMON_TILES:
m, n, k = tile
wave = TILE_TO_WAVE[tile]
warp = TILE_TO_WARP[tile]
print(
f" {m:3}×{n:3}×{k:3} wave={wave[0]}×{wave[1]}×{wave[2]} warp={warp[0]}×{warp[1]}×{warp[2]}"
)
print("=" * 80)
if __name__ == "__main__":
print_summary()