Files
composable_kernel/dispatcher/python/grouped_conv_utils.py
Yaswanth Raparti 6989cf800c [rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for
 grouped conv (#6327)

## Motivation
The ML heuristic in dispatcher does not support grouped-conv operator
yet. In this PR, the support for fwd, bdw-data, and bwd-weight
grouped-conv kernels have been added. A tile_engine utility has also
been added to compile and run any selected kernel configuration through
dispatcher infrastructure.

## Technical Details

1. Tile engine utility is added to benchmark each shape with all the
possible kernel+tile_size combinations here -
[https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url)
2. New LGBM regressor models for grouped conv are added to models
directory. We have 3 separate models for fwd, bwd-data, and bwd-weights
[https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url)
3. Implemented lazy GPU initialization (dispatcher/python)
- **Issue**: ProcessPoolExecutor fork() + GPU context caused memory
access faults
- **Solution**: Mirror FMHA pattern - defer GPU initialization until
first run()
  - **Changes**:
- setup_multiple_grouped_conv_dispatchers() returns List[Path], not
loaded libs
    - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL
    - Added _ensure_initialized() method for lazy GPU loading
    - GPU context created only on first run() call
  - **Benefit**: Parallel compilation now works without GPU conflicts
4. Addressed few miscellaneous issues such as:
  - Fixed BF16->FP16 naming bug in the dispatcher wrapper
- Added new tile sizes, and comp_v5 pipeline to the arch spec to expand
the kernel selection
- Added automatic padding support for unsupported shapes in dispatcher
runner
- Created a single source of truth between tile_engine and dispatcher
about the architecture and tile_size details
- Build a validation scripts to compare oracle_best vs ml_heuristic
comparison

## Test Plan

1. Validated fwd, bwd-data, and bwd-weight kernels with both known and
unseen data sets with up to 300 problems.
2. Ensured that test cases are added in both dispatcher and tile_engine
to validate the heuristic.

## Test Result
Results on Unseen shapes validated on gfx950
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.05%**
  - Median Efficiency: **96.8%**
  - P10 Efficiency: **79.9%**

#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.8%**
  - Median Efficiency: **96.5%**
  - P10 Efficiency: **82.9%**

#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **96.1%**
  - Median Efficiency: **99.2%**
  - P10 Efficiency: **89.4%**

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-08 20:48:42 +00:00

1952 lines
68 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Grouped Convolution Dispatcher Utilities
Typed Python API for grouped convolution kernels, matching the patterns from
the old conv_utils.py and the GEMM ctypes_utils.py.
Classes:
GroupedConvKernelConfig - Kernel configuration (tile, wave, pipeline, arch)
GroupedConvProblem - Runtime problem specification (N,C,K,H,W,etc.)
GroupedConvProblemC - ctypes struct matching C++ ConvProblemC
GroupedConvDispatcherLib - Wrapper for libdispatcher_conv_lib.so
GpuGroupedConvRunner - High-level GPU execution runner
GroupedConvResult - Result of GPU execution (output, time, tflops)
GroupedConvRegistry - Collection of kernel configs with JSON export
Usage:
from grouped_conv_utils import (
GroupedConvKernelConfig,
GroupedConvProblem,
GpuGroupedConvRunner,
)
config = GroupedConvKernelConfig(variant="forward", ndim_spatial=2)
problem = GroupedConvProblem(N=1, C=64, K=128, Hi=28, Wi=28, Y=3, X=3,
stride_h=1, pad_h=1, direction="forward")
runner = GpuGroupedConvRunner()
if runner.is_available():
result = runner.run(input_np, weight_np, problem)
print(f"Time: {result.time_ms:.4f} ms, TFLOPS: {result.tflops:.2f}")
"""
import ctypes
import json
import copy
import subprocess
import threading
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from dispatcher_common import (
ValidationResultBase,
auto_correct_trait,
auto_correct_wave,
get_arch_filter_data,
validate_trait_combo,
validate_wave_config,
validate_warp_tile_config,
)
# =============================================================================
# Constants
# =============================================================================
VALID_VARIANTS = ("forward", "bwd_data", "bwd_weight")
VALID_NDIM_SPATIAL = (1, 2, 3)
BACKWARD_VARIANTS = ("bwd_data", "bwd_weight")
BACKWARD_PIPELINES = ("compv3", "mem")
VARIANT_ALIASES = {
"2d_fwd": "forward",
"2d_bwdd": "bwd_data",
"2d_bwdw": "bwd_weight",
"fwd": "forward",
"bwdd": "bwd_data",
"bwdw": "bwd_weight",
}
DIRECTION_MAP = {"forward": 0, "bwd_data": 1, "bwd_weight": 2}
def _resolve_variant(v: str) -> str:
return VARIANT_ALIASES.get(v, v)
# =============================================================================
# GroupedConvDataType
# =============================================================================
class GroupedConvDataType(Enum):
FP16 = "fp16"
BF16 = "bf16"
FP32 = "fp32"
FP8 = "fp8"
BF8 = "bf8"
INT8 = "int8"
# =============================================================================
# GroupedConvKernelConfig
# =============================================================================
@dataclass
class GroupedConvKernelConfig:
"""Complete kernel configuration for grouped convolution.
Captures all parameters needed to identify and run a specific kernel.
Mirrors the C++ GroupedConvSignature + GroupedConvAlgorithm.
"""
# What: signature
variant: str = "forward"
ndim_spatial: int = 2
dtype: str = "fp16"
layout: str = "nhwgc"
arch: str = "gfx942"
# How: algorithm - tile shape
tile_m: int = 1
tile_n: int = 128
tile_k: int = 128
# How: wave config
wave_m: int = 2
wave_n: int = 2
wave_k: int = 1
# How: warp tile
warp_tile_m: int = 32
warp_tile_n: int = 32
warp_tile_k: int = 16
# How: pipeline traits
pipeline: str = "compv4"
epilogue: str = "cshuffle"
scheduler: str = "intrawave"
# ConvConfigBase parity fields
vector_size_a: int = 4
vector_size_b: int = 8
vector_size_c: int = 8
block_per_cu: int = 1
num_wave_groups: int = 1
num_groups_to_merge: int = 1
# Padding (enables arbitrary problem sizes)
pad_m: bool = True
pad_n: bool = True
pad_k: bool = True
# Additional trait config options
double_smem_buffer: bool = False
split_image: bool = False
explicit_gemm: bool = False
two_stage: bool = False
def __post_init__(self):
self.variant = _resolve_variant(self.variant)
if (
self.variant in BACKWARD_VARIANTS
and self.pipeline not in BACKWARD_PIPELINES
):
self.pipeline = "compv3"
@property
def tile_str(self) -> str:
return f"{self.tile_m}x{self.tile_n}x{self.tile_k}"
@property
def wave_str(self) -> str:
return f"{self.wave_m}x{self.wave_n}x{self.wave_k}"
@property
def warp_str(self) -> str:
return f"{self.warp_tile_m}x{self.warp_tile_n}x{self.warp_tile_k}"
@property
def vec_str(self) -> str:
return f"{self.vector_size_a}x{self.vector_size_b}x{self.vector_size_c}"
@property
def name(self) -> str:
parts = [
f"grouped_conv_{self.variant}_{self.dtype}_{self.ndim_spatial}d",
self.tile_str,
self.pipeline,
self.scheduler, # NEW: Include scheduler
]
if self.num_groups_to_merge != 1:
parts.append(f"gm{self.num_groups_to_merge}") # NEW: Group merge
if self.double_smem_buffer:
parts.append("dsb") # NEW: Double SMEM buffer
if self.split_image:
parts.append("si") # NEW: Split image
if self.two_stage:
parts.append("2stage") # NEW: Two-stage
return "_".join(parts)
def to_dict(self) -> dict:
"""Convert to legacy dict format for codegen compatibility."""
return {
"tile_config": {
"tile_m": [self.tile_m],
"tile_n": [self.tile_n],
"tile_k": [self.tile_k],
"wave_m": [self.wave_m],
"wave_n": [self.wave_n],
"wave_k": [self.wave_k],
"warp_tile_m": [self.warp_tile_m],
"warp_tile_n": [self.warp_tile_n],
"warp_tile_k": [self.warp_tile_k],
},
"trait_config": {
"pipeline": [self.pipeline],
"epilogue": [self.epilogue],
"scheduler": [self.scheduler],
"pad_m": [self.pad_m],
"pad_n": [self.pad_n],
"pad_k": [self.pad_k],
"vector_size_a": [self.vector_size_a],
"vector_size_b": [self.vector_size_b],
"vector_size_c": [self.vector_size_c],
"block_per_cu": [self.block_per_cu],
"num_wave_groups": [self.num_wave_groups],
"num_groups_to_merge": [self.num_groups_to_merge],
"double_smem_buffer": [self.double_smem_buffer],
"split_image": [self.split_image],
"explicit_gemm": [self.explicit_gemm],
"two_stage": [self.two_stage],
},
"variant": self.variant,
"ndim_spatial": self.ndim_spatial,
"arch": self.arch,
"layout": self.layout,
"dtype": self.dtype,
}
def to_json_obj(self) -> dict:
"""Serializable dict for JSON export."""
return {
"name": self.name,
"signature": {
"variant": self.variant,
"dtype": self.dtype,
"ndim_spatial": self.ndim_spatial,
"layout": self.layout,
},
"algorithm": {
"tile_m": self.tile_m,
"tile_n": self.tile_n,
"tile_k": self.tile_k,
"wave": self.wave_str,
"warp": self.warp_str,
"pipeline": self.pipeline,
"epilogue": self.epilogue,
"scheduler": self.scheduler,
"vector_sizes": [
self.vector_size_a,
self.vector_size_b,
self.vector_size_c,
],
"block_per_cu": self.block_per_cu,
"num_wave_groups": self.num_wave_groups,
"num_groups_to_merge": self.num_groups_to_merge,
},
"arch": self.arch,
}
def print_config(self, indent: str = " "):
print(f"{indent}GroupedConvKernelConfig:")
print(f"{indent} Variant: {self.variant} {self.ndim_spatial}D")
print(f"{indent} Dtype: {self.dtype}")
print(f"{indent} Layout: {self.layout}")
print(f"{indent} Arch: {self.arch}")
print(f"{indent} Tile: {self.tile_str}")
print(f"{indent} Wave: {self.wave_str}")
print(f"{indent} Warp: {self.warp_str}")
print(f"{indent} Pipeline: {self.pipeline}/{self.scheduler}/{self.epilogue}")
print(f"{indent} VecSizes: {self.vec_str}")
print(
f"{indent} BlockCU: {self.block_per_cu} WaveGroups: {self.num_wave_groups} MergeGroups: {self.num_groups_to_merge}"
)
# =============================================================================
# GroupedConvProblem
# =============================================================================
@dataclass
class GroupedConvProblem:
"""Runtime convolution problem specification.
Describes the actual sizes of a convolution to be computed.
Matches the old ConvProblem from conv_utils.py.
"""
N: int = 1
C: int = 64
K: int = 128
G: int = 1
Hi: int = 28
Wi: int = 28
Di: int = 1
Y: int = 3
X: int = 3
Z: int = 1
stride_h: int = 1
stride_w: int = 1
stride_d: int = 1
pad_h: int = 0
pad_w: int = 0
pad_d: int = 0
dilation_h: int = 1
dilation_w: int = 1
dilation_d: int = 1
direction: str = "forward"
split_k: int = 1
def __post_init__(self):
"""Validate grouped convolution constraints."""
if self.C % self.G != 0:
raise ValueError(
f"C must be divisible by G for grouped convolution: C={self.C}, G={self.G}"
)
if self.K % self.G != 0:
raise ValueError(
f"K must be divisible by G for grouped convolution: K={self.K}, G={self.G}"
)
@property
def Ho(self) -> int:
eff_y = (self.Y - 1) * self.dilation_h + 1
return (self.Hi + 2 * self.pad_h - eff_y) // self.stride_h + 1
@property
def Wo(self) -> int:
eff_x = (self.X - 1) * self.dilation_w + 1
return (self.Wi + 2 * self.pad_w - eff_x) // self.stride_w + 1
@property
def Do(self) -> int:
eff_z = (self.Z - 1) * self.dilation_d + 1
return (self.Di + 2 * self.pad_d - eff_z) // self.stride_d + 1
@property
def is_3d(self) -> bool:
return self.Di > 1 or self.Z > 1 or self.pad_d > 0
@property
def ndim_spatial(self) -> int:
return 3 if self.is_3d else 2
@property
def flops(self) -> float:
"""Total FLOPs for this convolution (any direction, same count).
Uses float division C/G to match canonical formula (validated C % G == 0 in __post_init__).
"""
c_per_group = self.C / self.G # Float division (validated C % G == 0)
if self.is_3d:
return (
2.0
* self.N
* self.K
* self.Do
* self.Ho
* self.Wo
* c_per_group
* self.Z
* self.Y
* self.X
)
return 2.0 * self.N * self.K * self.Ho * self.Wo * c_per_group * self.Y * self.X
@property
def gflops(self) -> float:
return self.flops / 1e9
def input_shape(self) -> tuple:
"""NHWGC or NDHWGC layout."""
c_per_g = self.C // self.G
if self.is_3d:
return (self.N, self.Di, self.Hi, self.Wi, self.G, c_per_g)
return (self.N, self.Hi, self.Wi, self.G, c_per_g)
def weight_shape(self) -> tuple:
"""GKYXC or GKZYXC layout."""
c_per_g = self.C // self.G
k_per_g = self.K // self.G
if self.is_3d:
return (self.G, k_per_g, self.Z, self.Y, self.X, c_per_g)
return (self.G, k_per_g, self.Y, self.X, c_per_g)
def output_shape(self) -> tuple:
"""NHWGK or NDHWGK layout."""
k_per_g = self.K // self.G
if self.is_3d:
return (self.N, self.Do, self.Ho, self.Wo, self.G, k_per_g)
return (self.N, self.Ho, self.Wo, self.G, k_per_g)
def print_problem(self, indent: str = " "):
dim_str = "3D" if self.is_3d else "2D"
print(f"{indent}GroupedConvProblem ({dim_str} {self.direction}):")
print(f"{indent} Batch: N={self.N}, G={self.G}")
print(f"{indent} Channels: C={self.C}, K={self.K}")
if self.is_3d:
print(f"{indent} Input: Di={self.Di}, Hi={self.Hi}, Wi={self.Wi}")
print(f"{indent} Filter: Z={self.Z}, Y={self.Y}, X={self.X}")
print(f"{indent} Output: Do={self.Do}, Ho={self.Ho}, Wo={self.Wo}")
else:
print(f"{indent} Input: Hi={self.Hi}, Wi={self.Wi}")
print(f"{indent} Filter: Y={self.Y}, X={self.X}")
print(f"{indent} Output: Ho={self.Ho}, Wo={self.Wo}")
print(f"{indent} GFLOPs: {self.gflops:.2f}")
# =============================================================================
# GroupedConvProblemC (ctypes struct matching C++)
# =============================================================================
class GroupedConvProblemC(ctypes.Structure):
"""C structure matching ConvProblemC in conv_ctypes_lib.cpp."""
_fields_ = [
("N", ctypes.c_int),
("G", ctypes.c_int),
("C", ctypes.c_int),
("K", ctypes.c_int),
("input_d", ctypes.c_int),
("input_h", ctypes.c_int),
("input_w", ctypes.c_int),
("filter_z", ctypes.c_int),
("filter_y", ctypes.c_int),
("filter_x", ctypes.c_int),
("stride_d", ctypes.c_int),
("stride_h", ctypes.c_int),
("stride_w", ctypes.c_int),
("pad_d", ctypes.c_int),
("pad_h", ctypes.c_int),
("pad_w", ctypes.c_int),
("dilation_d", ctypes.c_int),
("dilation_h", ctypes.c_int),
("dilation_w", ctypes.c_int),
("direction", ctypes.c_int),
("split_k", ctypes.c_int),
]
@classmethod
def from_problem(cls, p: GroupedConvProblem) -> "GroupedConvProblemC":
c = cls()
c.N, c.G, c.C, c.K = p.N, p.G, p.C, p.K
c.input_d, c.input_h, c.input_w = p.Di, p.Hi, p.Wi
c.filter_z, c.filter_y, c.filter_x = p.Z, p.Y, p.X
c.stride_d, c.stride_h, c.stride_w = p.stride_d, p.stride_h, p.stride_w
c.pad_d, c.pad_h, c.pad_w = p.pad_d, p.pad_h, p.pad_w
c.dilation_d, c.dilation_h, c.dilation_w = (
p.dilation_d,
p.dilation_h,
p.dilation_w,
)
c.direction = DIRECTION_MAP.get(p.direction, 0)
c.split_k = getattr(p, "split_k", 1)
return c
# =============================================================================
# GroupedConvResult
# =============================================================================
@dataclass
class GroupedConvResult:
"""Result of GPU convolution execution."""
success: bool = False
time_ms: float = 0.0
tflops: float = 0.0
output: Optional[np.ndarray] = None
error: str = ""
# =============================================================================
# GroupedConvDispatcherLib
# =============================================================================
class GroupedConvDispatcherLib:
"""Wrapper for the compiled convolution dispatcher library.
Provides Python interface to the C API in conv_ctypes_lib.cpp.
"""
SEARCH_PATHS = [
"build/examples/libdispatcher_conv_lib.so",
"build/bindings/libdispatcher_conv_lib.so",
"build/lib/libdispatcher_conv_lib.so",
]
def __init__(self, lib: ctypes.CDLL, path: Path):
self._lib = lib
self._path = path
self._setup_functions()
def _setup_functions(self):
self._lib.conv_dispatcher_init.argtypes = []
self._lib.conv_dispatcher_init.restype = ctypes.c_int
self._lib.conv_dispatcher_cleanup.argtypes = []
self._lib.conv_dispatcher_cleanup.restype = ctypes.c_int
self._lib.conv_dispatcher_version.argtypes = []
self._lib.conv_dispatcher_version.restype = ctypes.c_char_p
self._lib.conv_dispatcher_has_kernels.argtypes = []
self._lib.conv_dispatcher_has_kernels.restype = ctypes.c_int
self._lib.conv_dispatcher_has_bwd_data.argtypes = []
self._lib.conv_dispatcher_has_bwd_data.restype = ctypes.c_int
self._lib.conv_dispatcher_has_bwd_weight.argtypes = []
self._lib.conv_dispatcher_has_bwd_weight.restype = ctypes.c_int
self._lib.conv_dispatcher_get_kernel_count.argtypes = []
self._lib.conv_dispatcher_get_kernel_count.restype = ctypes.c_int
self._lib.conv_dispatcher_get_kernel_name.argtypes = [
ctypes.c_int,
ctypes.c_char_p,
ctypes.c_int,
]
self._lib.conv_dispatcher_get_kernel_name.restype = ctypes.c_int
self._lib.conv_dispatcher_is_supported.argtypes = [
ctypes.POINTER(GroupedConvProblemC),
]
self._lib.conv_dispatcher_is_supported.restype = ctypes.c_int
self._lib.conv_dispatcher_run.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.POINTER(GroupedConvProblemC),
ctypes.c_void_p,
]
self._lib.conv_dispatcher_run.restype = ctypes.c_float
@classmethod
def find(cls) -> Optional["GroupedConvDispatcherLib"]:
"""Search standard paths for the conv library."""
root = Path(__file__).parent.parent
for rel in cls.SEARCH_PATHS:
path = root / rel
if path.exists():
try:
lib = ctypes.CDLL(str(path))
return cls(lib, path)
except OSError:
continue
return None
@property
def path(self) -> Path:
return self._path
def initialize(self):
self._lib.conv_dispatcher_init()
def cleanup(self):
self._lib.conv_dispatcher_cleanup()
def version(self) -> str:
return self._lib.conv_dispatcher_version().decode()
def has_forward(self) -> bool:
return self._lib.conv_dispatcher_has_kernels() != 0
def has_bwd_data(self) -> bool:
return self._lib.conv_dispatcher_has_bwd_data() != 0
def has_bwd_weight(self) -> bool:
return self._lib.conv_dispatcher_has_bwd_weight() != 0
def kernel_count(self) -> int:
return self._lib.conv_dispatcher_get_kernel_count()
def kernel_names(self) -> List[str]:
names = []
for i in range(self.kernel_count()):
buf = ctypes.create_string_buffer(256)
if self._lib.conv_dispatcher_get_kernel_name(i, buf, 256) == 0:
names.append(buf.value.decode())
return names
def is_supported(self, problem: GroupedConvProblem) -> bool:
pc = GroupedConvProblemC.from_problem(problem)
return self._lib.conv_dispatcher_is_supported(ctypes.byref(pc)) != 0
def run(
self, a_ptr: int, b_ptr: int, c_ptr: int, problem: GroupedConvProblem
) -> float:
"""Run convolution. Returns time_ms (>0 success, <0 error)."""
pc = GroupedConvProblemC.from_problem(problem)
return self._lib.conv_dispatcher_run(
a_ptr, b_ptr, c_ptr, ctypes.byref(pc), None
)
# =============================================================================
# GpuGroupedConvRunner
# =============================================================================
class GpuGroupedConvRunner:
"""High-level GPU convolution runner.
Handles library loading, HIP memory management, and kernel execution.
Follows the same pattern as the old GpuConvRunner from conv_utils.py.
Usage:
runner = GpuGroupedConvRunner()
if runner.is_available():
result = runner.run(input_np, weight_np, problem)
print(f"Time: {result.time_ms:.4f} ms, TFLOPS: {result.tflops:.2f}")
"""
HIP_MEMCPY_H2D = 1
HIP_MEMCPY_D2H = 2
def __init__(self, lib_path: Optional[str] = None):
"""Initialize runner WITHOUT loading GPU libraries.
GPU context is created lazily on first run() call, avoiding fork() issues
during parallel compilation. This mirrors FMHA design.
Args:
lib_path: Path to dispatcher .so file (or None to auto-detect)
"""
self._lib_path = lib_path
self._dispatch_lib: Optional[GroupedConvDispatcherLib] = None
self._hip = None
self._initialized = False
self._init_error = None
self._init_traceback = None
def _ensure_initialized(self):
"""Lazy initialization - only load GPU libraries when actually needed."""
if self._initialized:
return
try:
# Load dispatcher library
if self._lib_path:
lib = ctypes.CDLL(self._lib_path)
self._dispatch_lib = GroupedConvDispatcherLib(lib, Path(self._lib_path))
else:
self._dispatch_lib = GroupedConvDispatcherLib.find()
if self._dispatch_lib is None:
return
# Load HIP library - THIS creates GPU context
self._hip = ctypes.CDLL("libamdhip64.so")
self._hip.hipMalloc.argtypes = [
ctypes.POINTER(ctypes.c_void_p),
ctypes.c_size_t,
]
self._hip.hipMalloc.restype = ctypes.c_int
self._hip.hipFree.argtypes = [ctypes.c_void_p]
self._hip.hipFree.restype = ctypes.c_int
self._hip.hipMemcpy.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
ctypes.c_int,
]
self._hip.hipMemcpy.restype = ctypes.c_int
self._hip.hipDeviceSynchronize.argtypes = []
self._hip.hipDeviceSynchronize.restype = ctypes.c_int
# Initialize dispatcher
self._dispatch_lib.initialize()
self._initialized = True
except Exception as e:
self._initialized = False
self._init_error = str(e)
self._init_traceback = traceback.format_exc()
def is_available(self) -> bool:
return self._initialized and self._dispatch_lib is not None
def get_init_error(self) -> Optional[str]:
"""Get initialization error message if initialization failed."""
return self._init_error
def get_init_traceback(self) -> Optional[str]:
"""Get full initialization traceback for debugging."""
return self._init_traceback
@property
def library_path(self) -> Optional[str]:
if self._dispatch_lib:
return str(self._dispatch_lib.path)
return None
@property
def lib(self) -> Optional[GroupedConvDispatcherLib]:
return self._dispatch_lib
def run(
self,
input_np: np.ndarray,
weight_np: np.ndarray,
problem: GroupedConvProblem,
output_np: Optional[np.ndarray] = None,
verbose: bool = False,
) -> GroupedConvResult:
"""Run convolution on GPU.
Args:
input_np: For forward: X (NHWGC). For bwd_data: dY. For bwd_weight: X.
weight_np: For forward: W (GKYXC). For bwd_data: W. For bwd_weight: dY.
problem: Problem specification.
output_np: Optional pre-allocated output buffer.
verbose: If True, print full traceback on initialization failure.
Returns:
GroupedConvResult with success, time_ms, tflops, output.
"""
# Lazy initialization - load GPU libraries only on first run
self._ensure_initialized()
if not self.is_available():
# Surface the actual initialization error for diagnosability
if self._init_error:
error_msg = f"GPU initialization failed: {self._init_error}"
if verbose and self._init_traceback:
print("=" * 80)
print("GPU Initialization Traceback:")
print("=" * 80)
print(self._init_traceback)
print("=" * 80)
else:
error_msg = "GPU not available"
return GroupedConvResult(error=error_msg)
try:
# Determine output shape based on direction
d = problem.direction
if d == "bwd_data":
out_shape = problem.input_shape()
elif d == "bwd_weight":
out_shape = problem.weight_shape()
else:
out_shape = problem.output_shape()
if output_np is None:
output_np = np.zeros(out_shape, dtype=input_np.dtype)
output_size = output_np.nbytes
# Allocate GPU memory with error checking
d_a = ctypes.c_void_p()
d_b = ctypes.c_void_p()
d_c = ctypes.c_void_p()
allocated_ptrs = [] # Track successfully allocated pointers
try:
# Allocate input
ret = self._hip.hipMalloc(ctypes.byref(d_a), input_np.nbytes)
if ret != 0:
raise RuntimeError(
f"hipMalloc failed for input (code {ret}, size {input_np.nbytes})"
)
allocated_ptrs.append(d_a)
# Allocate weight
ret = self._hip.hipMalloc(ctypes.byref(d_b), weight_np.nbytes)
if ret != 0:
raise RuntimeError(
f"hipMalloc failed for weight (code {ret}, size {weight_np.nbytes})"
)
allocated_ptrs.append(d_b)
# Allocate output
ret = self._hip.hipMalloc(ctypes.byref(d_c), output_size)
if ret != 0:
raise RuntimeError(
f"hipMalloc failed for output (code {ret}, size {output_size})"
)
allocated_ptrs.append(d_c)
# Host to device
ret = self._hip.hipMemcpy(
d_a, input_np.ctypes.data, input_np.nbytes, self.HIP_MEMCPY_H2D
)
if ret != 0:
raise RuntimeError(f"hipMemcpy H2D failed for input (code {ret})")
ret = self._hip.hipMemcpy(
d_b, weight_np.ctypes.data, weight_np.nbytes, self.HIP_MEMCPY_H2D
)
if ret != 0:
raise RuntimeError(f"hipMemcpy H2D failed for weight (code {ret})")
self._hip.hipDeviceSynchronize()
# Launch kernel
time_ms = self._dispatch_lib.run(
d_a.value, d_b.value, d_c.value, problem
)
self._hip.hipDeviceSynchronize()
result = GroupedConvResult()
if time_ms > 0:
# Device to host
ret = self._hip.hipMemcpy(
output_np.ctypes.data, d_c, output_size, self.HIP_MEMCPY_D2H
)
if ret != 0:
raise RuntimeError(
f"hipMemcpy D2H failed for output (code {ret})"
)
self._hip.hipDeviceSynchronize()
result.success = True
result.time_ms = time_ms
result.tflops = problem.flops / (time_ms * 1e9)
result.output = output_np
else:
result.error = (
"unsupported"
if time_ms == -3.0
else "no kernel"
if time_ms == -2.0
else f"error (code {time_ms})"
)
return result
finally:
# CRITICAL: Only free successfully allocated pointers
for ptr in allocated_ptrs:
if ptr.value: # Only free non-null pointers
self._hip.hipFree(ptr)
except Exception as e:
return GroupedConvResult(error=str(e))
def cleanup(self):
if self._dispatch_lib:
try:
self._dispatch_lib.cleanup()
except Exception:
pass
# =============================================================================
# GroupedConvRegistry
# =============================================================================
class GroupedConvRegistry:
"""Collection of grouped conv kernel configs with JSON export/import."""
def __init__(self, name: str = "default"):
self.name = name
self._kernels: List[GroupedConvKernelConfig] = []
def add(self, config: GroupedConvKernelConfig):
self._kernels.append(config)
@property
def kernels(self) -> List[GroupedConvKernelConfig]:
return list(self._kernels)
def __len__(self) -> int:
return len(self._kernels)
def select(
self, problem: "GroupedConvProblem", heuristic=None
) -> Optional[GroupedConvKernelConfig]:
"""Select the best kernel for a problem.
Args:
problem: The convolution problem.
heuristic: Optional callable(problem) -> List[str] returning
ranked kernel name substrings. The registry tries
each in order; falls back to first matching kernel.
Returns:
The best matching GroupedConvKernelConfig, or None.
"""
matching = [k for k in self._kernels if k.variant == problem.direction]
if not matching:
return None
if heuristic is not None:
ranked = heuristic(problem)
for hint in ranked:
for k in matching:
if hint in k.name:
return k
return matching[0] if matching else None
def filter_by_variant(self, variant: str) -> "GroupedConvRegistry":
variant = _resolve_variant(variant)
reg = GroupedConvRegistry(f"{self.name}_{variant}")
for k in self._kernels:
if k.variant == variant:
reg.add(k)
return reg
def filter_by_arch(self, arch: str) -> "GroupedConvRegistry":
reg = GroupedConvRegistry(f"{self.name}_{arch}")
for k in self._kernels:
if k.arch == arch:
reg.add(k)
return reg
def to_json(self, indent: int = 2) -> str:
return json.dumps(
{
"name": self.name,
"kernels": [k.to_json_obj() for k in self._kernels],
},
indent=indent,
)
@classmethod
def from_json(cls, json_str: str) -> "GroupedConvRegistry":
data = json.loads(json_str)
reg = cls(data.get("name", "imported"))
for kd in data.get("kernels", []):
sig = kd.get("signature", {})
algo = kd.get("algorithm", {})
wave = algo.get("wave", "2x2x1").split("x")
warp = algo.get("warp", "32x32x16").split("x")
vec = algo.get("vector_sizes", [4, 8, 8])
reg.add(
GroupedConvKernelConfig(
variant=sig.get("variant", "forward"),
ndim_spatial=sig.get("ndim_spatial", 2),
dtype=sig.get("dtype", "fp16"),
layout=sig.get("layout", "nhwgc"),
arch=kd.get("arch", "gfx942"),
tile_m=algo.get("tile_m", 1),
tile_n=algo.get("tile_n", 128),
tile_k=algo.get("tile_k", 128),
wave_m=int(wave[0]),
wave_n=int(wave[1]),
wave_k=int(wave[2]),
warp_tile_m=int(warp[0]),
warp_tile_n=int(warp[1]),
warp_tile_k=int(warp[2]),
pipeline=algo.get("pipeline", "compv3"),
epilogue=algo.get("epilogue", "cshuffle"),
scheduler=algo.get("scheduler", "intrawave"),
vector_size_a=vec[0] if len(vec) > 0 else 4,
vector_size_b=vec[1] if len(vec) > 1 else 8,
vector_size_c=vec[2] if len(vec) > 2 else 8,
block_per_cu=algo.get("block_per_cu", 1),
num_wave_groups=algo.get("num_wave_groups", 1),
num_groups_to_merge=algo.get("num_groups_to_merge", 1),
)
)
return reg
def build(
self,
verbose: bool = False,
max_workers: Optional[int] = None,
) -> Dict[Tuple[str, int], "GpuGroupedConvRunner"]:
"""Parallel JIT compile all kernels in this registry.
Args:
verbose: Print progress during build.
max_workers: Max parallel codegen/compile processes (default: cpu_count capped at 8).
Returns a dict mapping (variant, ndim_spatial) to a ready-to-use
GpuGroupedConvRunner.
"""
if not self._kernels:
return {}
libs = setup_multiple_grouped_conv_dispatchers(
self._kernels,
verbose=verbose,
max_workers=max_workers,
)
runners: Dict[Tuple[str, int], GpuGroupedConvRunner] = {}
for cfg, lib in zip(self._kernels, libs):
if lib is None:
continue
key = (cfg.variant, cfg.ndim_spatial)
if key in runners:
continue
runner = GpuGroupedConvRunner(lib_path=str(lib))
runner._ensure_initialized()
if runner.is_available():
runners[key] = runner
return runners
def print_registry(self, indent: str = " "):
print(f"{indent}Registry '{self.name}': {len(self)} kernels")
for i, k in enumerate(self._kernels):
print(
f"{indent} [{i}] {k.name} (valid={validate_grouped_conv_config(k.to_dict()).is_valid})"
)
# =============================================================================
# GroupedConvValidationResult
# =============================================================================
@dataclass
class GroupedConvValidationResult(ValidationResultBase):
"""Result of grouped conv kernel config validation."""
variant: str = "forward"
def __init__(
self,
is_valid=True,
errors=None,
warnings=None,
suggested_fixes=None,
variant="forward",
):
super().__init__(
is_valid=is_valid,
errors=errors or [],
warnings=warnings or [],
suggested_fixes=suggested_fixes or {},
)
self.variant = variant
# =============================================================================
# Validation helpers (extracted from the original config extraction code)
# =============================================================================
def _first(val):
if isinstance(val, list) and len(val) > 0:
return val[0]
return val
def _get_tile_config(config: dict) -> dict:
return config.get("tile_config") or {}
def _get_trait_config(config: dict) -> dict:
return config.get("trait_config") or {}
def _extract_wave_config(tile_config: dict) -> List[int]:
wm = tile_config.get("wave_m") or tile_config.get("warp_m")
wn = tile_config.get("wave_n") or tile_config.get("warp_n")
wk = tile_config.get("wave_k") or tile_config.get("warp_k")
if wm is not None and wn is not None and wk is not None:
return [_first(wm), _first(wn), _first(wk)]
return [2, 2, 1]
def _extract_warp_tile_config(tile_config: dict) -> List[int]:
wtm = tile_config.get("warp_tile_m") or tile_config.get("warp_m")
wtn = tile_config.get("warp_tile_n") or tile_config.get("warp_n")
wtk = tile_config.get("warp_tile_k") or tile_config.get("warp_k")
if wtm is not None and wtn is not None and wtk is not None:
return [_first(wtm), _first(wtn), _first(wtk)]
return [32, 32, 16]
def _extract_trait_values(trait_config: dict) -> Tuple[str, str, str]:
p = _first(trait_config.get("pipeline", "compv4"))
e = _first(trait_config.get("epilogue", "cshuffle"))
s = _first(trait_config.get("scheduler", "intrawave"))
if isinstance(p, list):
p = p[0] if p else "compv4"
if isinstance(e, list):
e = e[0] if e else "cshuffle"
if isinstance(s, list):
s = s[0] if s else "intrawave"
return (str(p), str(e), str(s))
# =============================================================================
# validate_grouped_conv_config / auto_correct_grouped_conv_config
# =============================================================================
def validate_grouped_conv_config(config: dict) -> GroupedConvValidationResult:
"""Validate a grouped conv kernel config dict.
Accepts either a raw dict (legacy) or GroupedConvKernelConfig.to_dict() output.
"""
errors: List[str] = []
warnings: List[str] = []
suggested_fixes: Dict[str, Any] = {}
required = (
"tile_config",
"trait_config",
"variant",
"ndim_spatial",
"arch",
"layout",
)
for key in required:
if key not in config:
errors.append(f"Missing required key: {key}")
if errors:
return GroupedConvValidationResult(
is_valid=False,
errors=errors,
warnings=warnings,
suggested_fixes=suggested_fixes,
variant=config.get("variant", "forward"),
)
tile_config = _get_tile_config(config)
trait_config = _get_trait_config(config)
variant = _first(config.get("variant", "forward"))
if isinstance(variant, list):
variant = variant[0] if variant else "forward"
variant = _resolve_variant(str(variant))
ndim_spatial = config.get("ndim_spatial")
arch = config.get("arch", "gfx942")
dtype = config.get("dtype", "fp16")
if variant not in VALID_VARIANTS:
errors.append(f"Invalid variant: {variant}. Valid: {', '.join(VALID_VARIANTS)}")
suggested_fixes["variant"] = "forward"
if ndim_spatial is not None:
ndim = ndim_spatial
if isinstance(ndim, list):
ndim = ndim[0] if ndim else 2
if ndim not in VALID_NDIM_SPATIAL:
errors.append(
f"Invalid ndim_spatial: {ndim}. Valid: {', '.join(map(str, VALID_NDIM_SPATIAL))}"
)
suggested_fixes["ndim_spatial"] = 2
pipeline, epilogue, scheduler = _extract_trait_values(trait_config)
if variant in BACKWARD_VARIANTS and pipeline not in BACKWARD_PIPELINES:
errors.append(
f"Backward variant '{variant}' requires pipeline compv3 or mem, got {pipeline}"
)
suggested_fixes["pipeline"] = "compv3"
ok, msg = validate_trait_combo(pipeline, epilogue, scheduler)
if not ok:
errors.append(msg)
suggested_fixes["scheduler"] = "intrawave"
wave_cfg = _extract_wave_config(tile_config)
ok, msg = validate_wave_config(wave_cfg, arch)
if not ok:
errors.append(msg)
arch_data = get_arch_filter_data()
valid_waves = arch_data["warp_combos"].get(arch, [[2, 2, 1]])
if valid_waves:
suggested_fixes["wave_m"] = valid_waves[0][0]
suggested_fixes["wave_n"] = valid_waves[0][1]
suggested_fixes["wave_k"] = valid_waves[0][2]
warp_cfg = _extract_warp_tile_config(tile_config)
ok, msg = validate_warp_tile_config(warp_cfg, arch, dtype)
if not ok:
errors.append(msg)
arch_data = get_arch_filter_data()
acc = "int32" if dtype == "int8" else "fp32"
dtype_key = f"{dtype}_{dtype}_{acc}"
valid_tiles = (
arch_data["warp_tile_combos"]
.get(arch, {})
.get(dtype_key, [[32, 32, 16], [16, 16, 16]])
)
if valid_tiles:
suggested_fixes["warp_tile_m"] = valid_tiles[0][0]
suggested_fixes["warp_tile_n"] = valid_tiles[0][1]
suggested_fixes["warp_tile_k"] = valid_tiles[0][2]
arch_data = get_arch_filter_data()
if arch not in arch_data["supported_archs"]:
errors.append(
f"Unsupported architecture: {arch}. Supported: {', '.join(arch_data['supported_archs'])}"
)
return GroupedConvValidationResult(
is_valid=len(errors) == 0,
errors=errors,
warnings=warnings,
suggested_fixes=suggested_fixes,
variant=variant,
)
def auto_correct_grouped_conv_config(
config: dict,
) -> Tuple[dict, GroupedConvValidationResult]:
"""Auto-correct invalid grouped conv config. Returns (corrected, result)."""
result = validate_grouped_conv_config(config)
corrected = copy.deepcopy(config)
if result.is_valid:
return corrected, result
tile_config = corrected.setdefault("tile_config", {})
trait_config = corrected.setdefault("trait_config", {})
wave_cfg = _extract_wave_config(tile_config)
arch = config.get("arch", "gfx942")
fixed_wave = auto_correct_wave(wave_cfg, arch)
tile_config["wave_m"] = fixed_wave[0]
tile_config["wave_n"] = fixed_wave[1]
tile_config["wave_k"] = fixed_wave[2]
pipeline, epilogue, scheduler = _extract_trait_values(trait_config)
fixed_pipeline, fixed_scheduler = auto_correct_trait(pipeline, scheduler)
trait_config["pipeline"] = fixed_pipeline
trait_config["scheduler"] = fixed_scheduler
variant = _first(config.get("variant", "forward"))
if isinstance(variant, list):
variant = variant[0] if variant else "forward"
variant = _resolve_variant(str(variant))
if variant in BACKWARD_VARIANTS and fixed_pipeline not in BACKWARD_PIPELINES:
trait_config["pipeline"] = "compv3"
if "warp_tile_m" in result.suggested_fixes:
tile_config["warp_tile_m"] = result.suggested_fixes["warp_tile_m"]
tile_config["warp_tile_n"] = result.suggested_fixes["warp_tile_n"]
tile_config["warp_tile_k"] = result.suggested_fixes["warp_tile_k"]
result = validate_grouped_conv_config(corrected)
return corrected, result
def _run_hipcc_subprocess(args: dict) -> Tuple[bool, Optional[Path], str]:
"""Run one hipcc compile+link job in a subprocess worker."""
import subprocess
from pathlib import Path
compile_cmd = args["compile_cmd"]
link_cmd = args["link_cmd"]
lib_path = Path(args["lib_path"])
try:
res_c = subprocess.run(compile_cmd, capture_output=True, text=True, timeout=300)
if res_c.returncode != 0:
err = (res_c.stderr or res_c.stdout or "").rstrip()
return False, None, f"Compile failed (rc={res_c.returncode}):\n{err}"
res_l = subprocess.run(link_cmd, capture_output=True, text=True, timeout=300)
if res_l.returncode != 0:
err = (res_l.stderr or res_l.stdout or "").rstrip()
return False, None, f"Link failed (rc={res_l.returncode}):\n{err}"
return True, lib_path, ""
except subprocess.TimeoutExpired:
return False, None, "Timeout"
except Exception as e:
return False, None, f"Error: {e}"
def _run_conv_codegen_subprocess(args: dict) -> Tuple[bool, Optional[str], str]:
"""Run grouped-conv codegen once and return generated kernel header path."""
import subprocess
from pathlib import Path
out_dir = Path(args["output_dir"])
out_dir.mkdir(parents=True, exist_ok=True)
# Remove stale kernels so header discovery is exact for this invocation.
for stale in out_dir.glob("grouped_conv_*.hpp"):
stale.unlink(missing_ok=True)
for stale in out_dir.glob("include_all_grouped_conv_*.hpp"):
stale.unlink(missing_ok=True)
try:
res = subprocess.run(args["cmd"], capture_output=True, text=True, timeout=300)
if res.returncode != 0:
err = (res.stderr or res.stdout or "").rstrip()
return False, None, f"Codegen failed (rc={res.returncode}):\n{err}"
generated = sorted(
out_dir.glob("grouped_conv_*.hpp"),
key=lambda p: p.stat().st_mtime,
reverse=True,
)
if not generated:
return False, None, "Codegen produced no grouped_conv_*.hpp header"
return True, str(generated[0]), ""
except subprocess.TimeoutExpired:
return False, None, "Codegen timed out"
except Exception as e:
return False, None, f"Codegen error: {e}"
def _config_key(c: GroupedConvKernelConfig) -> Tuple[Any, ...]:
return (
c.variant,
c.ndim_spatial,
c.dtype,
c.layout,
c.arch,
c.tile_m,
c.tile_n,
c.tile_k,
c.wave_m,
c.wave_n,
c.wave_k,
c.warp_tile_m,
c.warp_tile_n,
c.warp_tile_k,
c.pipeline,
c.epilogue,
c.scheduler,
c.num_groups_to_merge,
c.double_smem_buffer,
c.split_image,
c.two_stage,
)
def _parse_triplet(value: str) -> Tuple[int, int, int]:
parts = value.split("x")
if len(parts) != 3:
raise ValueError(f"Invalid triplet: {value}")
return int(parts[0]), int(parts[1]), int(parts[2])
def _list_arch_valid_grouped_conv_configs(
codegen_script: Path,
arch: str,
dtype: str,
variant: str,
ndim_spatial: int,
) -> List[GroupedConvKernelConfig]:
"""Query codegen defaults for this (arch, dtype, variant, ndim) tuple."""
import re
import sys
cmd = [
sys.executable,
str(codegen_script),
"--list-configs",
"--arch",
arch,
"--datatype",
dtype,
"--variant",
variant,
"--ndim",
str(ndim_spatial),
]
res = subprocess.run(cmd, capture_output=True, text=True, timeout=180)
if res.returncode != 0:
return []
# Example:
# grouped_conv_fwd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x32_2x2x1_32x32x16
name_re = re.compile(
r"^grouped_conv_(fwd|bwd_data|bwd_weight|bwdd|bwdw)_([a-z0-9]+)_([a-z0-9]+)_([123])d_"
r"([a-z0-9]+)_([a-z0-9]+)_([a-z0-9]+)_"
r"([0-9]+x[0-9]+x[0-9]+)_([0-9]+x[0-9]+x[0-9]+)_([0-9]+x[0-9]+x[0-9]+)"
r"(?:_.*)?$"
)
short_to_variant = {
"fwd": "forward",
"bwd_data": "bwd_data",
"bwd_weight": "bwd_weight",
"bwdd": "bwd_data",
"bwdw": "bwd_weight",
}
out: List[GroupedConvKernelConfig] = []
seen = set()
for raw in res.stdout.splitlines():
line = raw.strip()
if not line.startswith("- grouped_conv_"):
continue
name = line[2:].strip()
m = name_re.match(name)
if not m:
continue
v_short, dt, layout, ndim, pipe, epi, sched, tile_s, wave_s, warp_s = m.groups()
tm, tn, tk = _parse_triplet(tile_s)
wm, wn, wk = _parse_triplet(wave_s)
wtm, wtn, wtk = _parse_triplet(warp_s)
cfg = GroupedConvKernelConfig(
variant=short_to_variant[v_short],
ndim_spatial=int(ndim),
dtype=dt,
layout=layout,
arch=arch,
tile_m=tm,
tile_n=tn,
tile_k=tk,
wave_m=wm,
wave_n=wn,
wave_k=wk,
warp_tile_m=wtm,
warp_tile_n=wtn,
warp_tile_k=wtk,
pipeline=pipe,
epilogue=epi,
scheduler=sched,
)
key = _config_key(cfg)
if key not in seen:
out.append(cfg)
seen.add(key)
return out
def _select_best_arch_valid_conv_config(
requested: GroupedConvKernelConfig,
candidates: List[GroupedConvKernelConfig],
) -> GroupedConvKernelConfig:
"""Pick nearest arch-valid config while preferring trait exact matches."""
def score(c: GroupedConvKernelConfig) -> Tuple[int, int, int, int, int, int]:
tile_delta = (
abs(c.tile_m - requested.tile_m)
+ abs(c.tile_n - requested.tile_n)
+ abs(c.tile_k - requested.tile_k)
)
wave_delta = (
abs(c.wave_m - requested.wave_m)
+ abs(c.wave_n - requested.wave_n)
+ abs(c.wave_k - requested.wave_k)
)
warp_tile_delta = (
abs(c.warp_tile_m - requested.warp_tile_m)
+ abs(c.warp_tile_n - requested.warp_tile_n)
+ abs(c.warp_tile_k - requested.warp_tile_k)
)
return (
0 if c.pipeline == requested.pipeline else 1,
0 if c.scheduler == requested.scheduler else 1,
0 if c.epilogue == requested.epilogue else 1,
tile_delta,
wave_delta,
warp_tile_delta,
)
best = min(candidates, key=score)
selected = copy.deepcopy(best)
selected.arch = requested.arch
return selected
def _write_single_conv_dispatch_header(
config: GroupedConvKernelConfig,
kernel_header: Path,
dispatch_header: Path,
) -> None:
"""Create a tiny dispatch header consumed by conv_ctypes_lib.cpp."""
macros: List[str] = []
aliases: List[str] = []
if config.variant == "forward":
kernel_name_symbol = "CONV_FWD_KERNEL_NAME"
if config.ndim_spatial == 3:
macros.append("#define CONV_FWD_3D_AVAILABLE 1")
aliases.append("using ConvFwd3dLauncher = SelectedConvKernelLauncher;")
else:
macros.append("#define CONV_FWD_2D_AVAILABLE 1")
elif config.variant == "bwd_data":
kernel_name_symbol = "CONV_BWD_DATA_KERNEL_NAME"
if config.ndim_spatial == 3:
macros.append("#define CONV_BWD_DATA_3D_AVAILABLE 1")
aliases.append("using ConvBwdData3dLauncher = SelectedConvBwdDataLauncher;")
else:
macros.append("#define CONV_BWD_DATA_2D_AVAILABLE 1")
else:
kernel_name_symbol = "CONV_BWD_WEIGHT_KERNEL_NAME"
if config.ndim_spatial == 3:
macros.append("#define CONV_BWD_WEIGHT_3D_AVAILABLE 1")
aliases.append(
"using ConvBwdWeight3dLauncher = SelectedConvBwdWeightLauncher;"
)
else:
macros.append("#define CONV_BWD_WEIGHT_2D_AVAILABLE 1")
content = (
"// Auto-generated single-kernel dispatch header for Python JIT\n"
"#pragma once\n\n"
f'#include "{kernel_header.name}"\n\n'
+ "\n".join(macros)
+ "\n\n"
+ "\n".join(aliases)
+ "\n\n"
+ f"static const char* CONV_KERNEL_NAMES[] = {{{kernel_name_symbol}}};\n"
+ "static constexpr int CONV_KERNEL_COUNT = 1;\n"
)
dispatch_header.write_text(content)
class GroupedConvCodegenRunner:
"""Generate and compile grouped-conv JIT libraries in parallel."""
def __init__(self, max_workers: Optional[int] = None):
import multiprocessing
self.max_workers = max_workers or min(multiprocessing.cpu_count(), 8)
self.root = Path(__file__).parent.parent
self.build_dir = self.root / "build"
self.codegen_script = self.root / "codegen" / "unified_grouped_conv_codegen.py"
def generate_and_compile_parallel(
self,
configs: List[GroupedConvKernelConfig],
verbose: bool = True,
) -> List[Optional[Path]]:
import sys
if not configs:
return []
if not self.build_dir.exists():
self.build_dir.mkdir(parents=True, exist_ok=True)
ctypes_source = self.root / "bindings" / "ctypes" / "conv_ctypes_lib.cpp"
static_lib = self.build_dir / "libck_tile_dispatcher.a"
jit_root = self.build_dir / "generated_kernels" / "python_jit"
jit_root.mkdir(parents=True, exist_ok=True)
(self.build_dir / "examples").mkdir(parents=True, exist_ok=True)
if not self.codegen_script.exists():
if verbose:
print(f"Codegen script missing: {self.codegen_script}")
return [None] * len(configs)
if not ctypes_source.exists() or not static_lib.exists():
if verbose:
print("Missing conv ctypes source or static dispatcher library")
return [None] * len(configs)
if verbose:
print(
f"Generating {len(configs)} grouped-conv kernels with "
f"{self.max_workers} threads (out-of-order)..."
)
gen_jobs: List[Dict[str, Any]] = []
job_dirs: List[Path] = []
for i, c in enumerate(configs):
cfg_dir = jit_root / f"cfg_{i}"
cfg_dir.mkdir(parents=True, exist_ok=True)
job_dirs.append(cfg_dir)
cmd = [
sys.executable,
str(self.codegen_script),
"--output",
str(cfg_dir),
"--datatype",
c.dtype,
"--variant",
c.variant,
"--ndim",
str(c.ndim_spatial),
"--arch",
c.arch,
"--tile-m",
str(c.tile_m),
"--tile-n",
str(c.tile_n),
"--tile-k",
str(c.tile_k),
"--warp-m",
str(c.wave_m),
"--warp-n",
str(c.wave_n),
"--warp-k",
str(c.wave_k),
"--warp-tile-m",
str(c.warp_tile_m),
"--warp-tile-n",
str(c.warp_tile_n),
"--warp-tile-k",
str(c.warp_tile_k),
"--pipeline",
c.pipeline,
"--scheduler",
c.scheduler,
"--epilogue",
c.epilogue,
"--num-groups-to-merge",
str(c.num_groups_to_merge),
"--double-smem-buffer",
"true" if c.double_smem_buffer else "false",
]
if c.split_image:
cmd.append("--split-image")
if c.two_stage:
cmd.append("--two-stage")
gen_jobs.append({"cmd": cmd, "output_dir": str(cfg_dir)})
generated_headers: List[Optional[Path]] = [None] * len(configs)
# Phase 1 codegen: each worker just calls subprocess.run() to invoke the
# codegen script. The wait releases the GIL, so threads give true parallelism
# without the fork-after-HIP risk of ProcessPoolExecutor.
print_lock = threading.Lock()
with ThreadPoolExecutor(max_workers=self.max_workers) as ex:
futures = {
ex.submit(_run_conv_codegen_subprocess, job): idx
for idx, job in enumerate(gen_jobs)
}
for fut in as_completed(futures):
idx = futures[fut]
ok, header_path, err = fut.result()
if ok and header_path:
generated_headers[idx] = Path(header_path)
if verbose:
with print_lock:
print(f" OK [{idx}] codegen: {Path(header_path).name}")
else:
if verbose:
with print_lock:
print(f" FAIL [{idx}] codegen: {err}")
if verbose:
compile_count = sum(1 for h in generated_headers if h is not None)
print(
f"Compiling {compile_count} grouped-conv libraries with "
f"{self.max_workers} threads (out-of-order)..."
)
compile_jobs: List[Dict[str, Any]] = []
compile_to_input_index: Dict[int, int] = {}
for i, c in enumerate(configs):
hdr_path = generated_headers[i]
if hdr_path is None:
continue
cfg_dir = job_dirs[i]
dispatch_header = cfg_dir / "conv_python_dispatch.hpp"
_write_single_conv_dispatch_header(c, hdr_path, dispatch_header)
# Build suffix with all distinguishing config options
suffix = ""
if c.num_groups_to_merge != 1:
suffix += f"_gm{c.num_groups_to_merge}"
if c.double_smem_buffer:
suffix += "_dsb"
if c.split_image:
suffix += "_si"
if c.two_stage:
suffix += "_2stage"
lib_name = (
f"libdispatcher_conv_{c.variant}_{c.ndim_spatial}d_{c.dtype}_"
f"{c.tile_str}_{c.wave_str}_{c.warp_str}_{c.pipeline}_{c.scheduler}{suffix}.so"
)
lib_path = self.build_dir / "examples" / lib_name
obj_file = lib_path.with_suffix(".o")
compile_cmd = [
"/opt/rocm/bin/hipcc",
"-c",
"-fPIC",
"-O3",
f"-I{self.root / 'include'}",
f"-I{self.root.parent / 'include'}",
f"-I{self.root.parent}",
f"-I{cfg_dir}",
"-DCK_TILE_SINGLE_KERNEL_INCLUDE",
f"-include{dispatch_header}",
"-D__HIP_PLATFORM_AMD__",
f"--offload-arch={c.arch}",
f'-DGFX_ARCH="{c.arch}"',
"-mllvm",
"-enable-noalias-to-md-conversion=0",
"-Wno-undefined-func-template",
"-Wno-float-equal",
str(ctypes_source),
"-o",
str(obj_file),
]
link_cmd = [
"/opt/rocm/bin/hipcc",
"-shared",
"-fPIC",
f"--offload-arch={c.arch}",
"--hip-link",
str(obj_file),
str(static_lib),
"-o",
str(lib_path),
]
compile_to_input_index[len(compile_jobs)] = i
compile_jobs.append(
{
"compile_cmd": compile_cmd,
"link_cmd": link_cmd,
"lib_path": str(lib_path),
"config_name": c.name,
}
)
results_map: Dict[int, Optional[Path]] = {i: None for i in range(len(configs))}
# Phase 1 compile: workers shell out to hipcc, releasing the GIL while
# waiting. Threads give true parallelism here; ProcessPool would risk
# fork() corrupting any HIP state the parent might have loaded.
with ThreadPoolExecutor(max_workers=self.max_workers) as ex:
futures = {
ex.submit(_run_hipcc_subprocess, job): j
for j, job in enumerate(compile_jobs)
}
for fut in as_completed(futures):
j = futures[fut]
idx = compile_to_input_index[j]
success, lib_path, err = fut.result()
if success and lib_path:
results_map[idx] = Path(lib_path)
if verbose:
name = (
Path(lib_path).name
if success and lib_path
else compile_jobs[j]["config_name"]
)
with print_lock:
if success:
print(f" OK {name}")
else:
# Print the full multi-line error indented for readability
# so users don't have to monkey-patch to see real compile output.
print(f" FAIL {name}")
for line in (err or "").splitlines() or [""]:
print(f" {line}")
return [results_map.get(i) for i in range(len(configs))]
# =============================================================================
# Convenience functions
# =============================================================================
def get_grouped_conv_default_config(
variant: str = "forward",
ndim_spatial: int = 2,
arch: str = "gfx942",
dtype: str = "fp16",
) -> GroupedConvKernelConfig:
"""Return a valid default GroupedConvKernelConfig."""
return GroupedConvKernelConfig(
variant=variant,
ndim_spatial=ndim_spatial,
arch=arch,
dtype=dtype,
)
def format_grouped_conv_summary(config) -> str:
"""Format a config (dict or GroupedConvKernelConfig) into a human-readable string."""
if isinstance(config, GroupedConvKernelConfig):
lines = [
f"Grouped Conv Config: {config.variant} {config.ndim_spatial}D",
f" Arch: {config.arch}",
f" Layout: {config.layout}",
f" Dtype: {config.dtype}",
f" Tile: {config.tile_str}",
f" Wave: {config.wave_str}",
f" Warp: {config.warp_str}",
f" Traits: pipeline={config.pipeline} epilogue={config.epilogue} scheduler={config.scheduler}",
]
return "\n".join(lines)
# Legacy dict support
tile_config = _get_tile_config(config) if isinstance(config, dict) else {}
trait_config = _get_trait_config(config) if isinstance(config, dict) else {}
variant = config.get("variant", "?") if isinstance(config, dict) else "?"
ndim = config.get("ndim_spatial", "?") if isinstance(config, dict) else "?"
arch = config.get("arch", "?") if isinstance(config, dict) else "?"
layout = config.get("layout", "?") if isinstance(config, dict) else "?"
dtype = config.get("dtype", "fp16") if isinstance(config, dict) else "fp16"
lines = [f"Grouped Conv Config: {variant} {ndim}D"]
lines.append(f" Arch: {arch}")
lines.append(f" Layout: {layout}")
lines.append(f" Dtype: {dtype}")
if tile_config:
wave = _extract_wave_config(tile_config)
warp = _extract_warp_tile_config(tile_config)
lines.append(
f" Tile: M={_first(tile_config.get('tile_m', 1))} N={_first(tile_config.get('tile_n', 128))} K={_first(tile_config.get('tile_k', 128))}"
)
lines.append(f" Wave: {wave[0]}x{wave[1]}x{wave[2]}")
lines.append(f" Warp: {warp[0]}x{warp[1]}x{warp[2]}")
if trait_config:
pipeline = _first(trait_config.get("pipeline", "?"))
epilogue = _first(trait_config.get("epilogue", "?"))
scheduler = _first(trait_config.get("scheduler", "?"))
lines.append(
f" Traits: pipeline={pipeline} epilogue={epilogue} scheduler={scheduler}"
)
return "\n".join(lines) if lines else "(empty config)"
def setup_multiple_grouped_conv_dispatchers(
configs: List[GroupedConvKernelConfig],
verbose: bool = True,
max_workers: Optional[int] = None,
) -> List[Optional[Path]]:
"""
Setup multiple grouped-conv dispatchers.
Returns library paths WITHOUT loading them, to avoid GPU context during compilation.
This mirrors FMHA design: keep GPU context out of JIT phase entirely.
Architecture filtering workflow:
1. Validate each requested config via validate_grouped_conv_config; if invalid,
attempt auto_correct_grouped_conv_config. Drop configs that remain invalid.
2. Trust the (possibly auto-corrected) config as-is. Knobs such as scheduler,
num_groups_to_merge, double_smem_buffer, split_image, two_stage are preserved
exactly as requested -- no remap to a hardcoded "default" set.
3. Threaded codegen + threaded compile (workers shell out via subprocess,
which releases the GIL; threads avoid the fork-after-HIP risk that
ProcessPoolExecutor would have).
4. Return paths (NOT loaded libraries).
Returns:
List of paths to compiled .so files (or None for failed configs)
"""
if not configs:
return []
selected_configs: List[Optional[GroupedConvKernelConfig]] = []
for i, original in enumerate(configs):
c = copy.deepcopy(original)
val = validate_grouped_conv_config(c.to_dict())
if not val.is_valid:
corrected, corrected_result = auto_correct_grouped_conv_config(c.to_dict())
if not corrected_result.is_valid:
if verbose:
print(f" FAIL [{i}] config remains invalid after auto-correct")
selected_configs.append(None)
continue
tile_cfg = corrected.get("tile_config", {})
trait_cfg = corrected.get("trait_config", {})
c.variant = _resolve_variant(
str(_first(corrected.get("variant", c.variant)))
)
c.ndim_spatial = int(_first(corrected.get("ndim_spatial", c.ndim_spatial)))
c.arch = str(corrected.get("arch", c.arch))
c.layout = str(corrected.get("layout", c.layout))
c.dtype = str(corrected.get("dtype", c.dtype))
c.tile_m = int(_first(tile_cfg.get("tile_m", c.tile_m)))
c.tile_n = int(_first(tile_cfg.get("tile_n", c.tile_n)))
c.tile_k = int(_first(tile_cfg.get("tile_k", c.tile_k)))
c.wave_m = int(_first(tile_cfg.get("wave_m", c.wave_m)))
c.wave_n = int(_first(tile_cfg.get("wave_n", c.wave_n)))
c.wave_k = int(_first(tile_cfg.get("wave_k", c.wave_k)))
c.warp_tile_m = int(_first(tile_cfg.get("warp_tile_m", c.warp_tile_m)))
c.warp_tile_n = int(_first(tile_cfg.get("warp_tile_n", c.warp_tile_n)))
c.warp_tile_k = int(_first(tile_cfg.get("warp_tile_k", c.warp_tile_k)))
c.pipeline = str(_first(trait_cfg.get("pipeline", c.pipeline)))
c.scheduler = str(_first(trait_cfg.get("scheduler", c.scheduler)))
c.epilogue = str(_first(trait_cfg.get("epilogue", c.epilogue)))
# Trust the validated config -- no remap to a hardcoded arch-valid set.
# Knobs (num_groups_to_merge, double_smem_buffer, split_image, two_stage)
# and scheduler choice are preserved exactly as requested.
selected_configs.append(c)
unique_configs: List[GroupedConvKernelConfig] = []
unique_index_by_key: Dict[Tuple[Any, ...], int] = {}
input_to_unique: List[Optional[int]] = []
for cfg in selected_configs:
if cfg is None:
input_to_unique.append(None)
continue
key = _config_key(cfg)
if key not in unique_index_by_key:
unique_index_by_key[key] = len(unique_configs)
unique_configs.append(cfg)
input_to_unique.append(unique_index_by_key[key])
runner = GroupedConvCodegenRunner(max_workers=max_workers)
unique_lib_paths = runner.generate_and_compile_parallel(
unique_configs, verbose=verbose
)
# Map unique lib paths back to input order
# DO NOT load libraries here - just return paths
lib_paths: List[Optional[Path]] = []
path_cache: Dict[int, Optional[Path]] = {}
for input_idx, unique_idx in enumerate(input_to_unique):
if unique_idx is None:
lib_paths.append(None)
continue
if unique_idx in path_cache:
lib_paths.append(path_cache[unique_idx])
continue
path = (
unique_lib_paths[unique_idx] if unique_idx < len(unique_lib_paths) else None
)
# Validate path exists but don't load it
if path and not path.exists():
if verbose:
print(f" FAIL [{input_idx}] library not found: {path}")
path = None
path_cache[unique_idx] = path
lib_paths.append(path)
return lib_paths
def detect_gpu_arch() -> str:
"""Detect GPU architecture using rocminfo."""
try:
out = subprocess.check_output(
["rocminfo"], stderr=subprocess.DEVNULL, text=True
)
for line in out.split("\n"):
if "gfx" in line.lower() and "name:" in line.lower():
for part in line.split():
if part.startswith("gfx"):
return part
except Exception:
pass
return "gfx942"