Files
composable_kernel/dispatcher/python/grouped_conv_utils.py
Vidyasagar Ananthan 920acd2c12 [rocm-libraries] ROCm/rocm-libraries#5168 (commit 8b5afcb)
[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher

## Motivation

This PR adds CK Tile group convolution (forward, backward-data,
backward-weight) support to the kernel dispatcher, matching and unifying
with the existing dispatcher GEMM infrastructure in architecture and
usability. The dispatcher provides a unified kernel dispatch system with
both C++ and Python frontends, and until now only supported GEMM
operations. This PR enables framework integrators to use the same
declarative kernel workflow for convolutions as they do for GEMM:
declare kernels, build a registry JIT, select kernels within the
registry at runtime, and dispatch to GPU. Future PRs will include
runtime kernel selection heuristics for autotuning of kernel parameters
based on (problem, hardware arch).

## Technical Details

Grouped convolution support has been added to the CK Tile Dispatcher
with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out,
problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime
heuristic kernel selection, and GroupedConvKernelKey with full
ConvConfigBase fields. Python side adds parallel JIT via
registry.build(max_workers) and heuristic registry.select(). Includes 7
C++ and 6 Python examples covering all directions with CPU reference
validation, and shared infrastructure improvements (BaseRegistry CRTP,
structured exceptions). As a sanity check, JIT compile times for a
single kernel remains the same and for multiple kernels there is better
parallelism:
Kernels | 1 worker | 8 workers
1 | 7.7 s | 7.7 s
2 | 15.9 s | 8.2 s
4 | 33.4 s | 9.7 s
6 | 52.3 s | 10.2 s

## Test Plan

145 ephemeral unit tests have been added to test basic functionality.
All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7
C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference
validation for forward, backward-data, and backward-weight (2D) in both
C++ and Python examples pass.

## Test Result

30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56),
53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002
for all directions (fp16 vs fp32 reference).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-09 17:39:35 +00:00

1807 lines
61 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Grouped Convolution Dispatcher Utilities
Typed Python API for grouped convolution kernels, matching the patterns from
the old conv_utils.py and the GEMM ctypes_utils.py.
Classes:
GroupedConvKernelConfig - Kernel configuration (tile, wave, pipeline, arch)
GroupedConvProblem - Runtime problem specification (N,C,K,H,W,etc.)
GroupedConvProblemC - ctypes struct matching C++ ConvProblemC
GroupedConvDispatcherLib - Wrapper for libdispatcher_conv_lib.so
GpuGroupedConvRunner - High-level GPU execution runner
GroupedConvResult - Result of GPU execution (output, time, tflops)
GroupedConvRegistry - Collection of kernel configs with JSON export
Usage:
from grouped_conv_utils import (
GroupedConvKernelConfig,
GroupedConvProblem,
GpuGroupedConvRunner,
)
config = GroupedConvKernelConfig(variant="forward", ndim_spatial=2)
problem = GroupedConvProblem(N=1, C=64, K=128, Hi=28, Wi=28, Y=3, X=3,
stride_h=1, pad_h=1, direction="forward")
runner = GpuGroupedConvRunner()
if runner.is_available():
result = runner.run(input_np, weight_np, problem)
print(f"Time: {result.time_ms:.4f} ms, TFLOPS: {result.tflops:.2f}")
"""
import ctypes
import json
import copy
import subprocess
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from dispatcher_common import (
ValidationResultBase,
auto_correct_trait,
auto_correct_wave,
get_arch_filter_data,
validate_trait_combo,
validate_wave_config,
validate_warp_tile_config,
)
# =============================================================================
# Constants
# =============================================================================
VALID_VARIANTS = ("forward", "bwd_data", "bwd_weight")
VALID_NDIM_SPATIAL = (1, 2, 3)
BACKWARD_VARIANTS = ("bwd_data", "bwd_weight")
BACKWARD_PIPELINES = ("compv3", "mem")
VARIANT_ALIASES = {
"2d_fwd": "forward",
"2d_bwdd": "bwd_data",
"2d_bwdw": "bwd_weight",
"fwd": "forward",
"bwdd": "bwd_data",
"bwdw": "bwd_weight",
}
DIRECTION_MAP = {"forward": 0, "bwd_data": 1, "bwd_weight": 2}
def _resolve_variant(v: str) -> str:
return VARIANT_ALIASES.get(v, v)
# =============================================================================
# GroupedConvDataType
# =============================================================================
class GroupedConvDataType(Enum):
FP16 = "fp16"
BF16 = "bf16"
FP32 = "fp32"
FP8 = "fp8"
BF8 = "bf8"
INT8 = "int8"
# =============================================================================
# GroupedConvKernelConfig
# =============================================================================
@dataclass
class GroupedConvKernelConfig:
"""Complete kernel configuration for grouped convolution.
Captures all parameters needed to identify and run a specific kernel.
Mirrors the C++ GroupedConvSignature + GroupedConvAlgorithm.
"""
# What: signature
variant: str = "forward"
ndim_spatial: int = 2
dtype: str = "fp16"
layout: str = "nhwgc"
arch: str = "gfx942"
# How: algorithm - tile shape
tile_m: int = 1
tile_n: int = 128
tile_k: int = 128
# How: wave config
wave_m: int = 2
wave_n: int = 2
wave_k: int = 1
# How: warp tile
warp_tile_m: int = 32
warp_tile_n: int = 32
warp_tile_k: int = 16
# How: pipeline traits
pipeline: str = "compv4"
epilogue: str = "cshuffle"
scheduler: str = "intrawave"
# ConvConfigBase parity fields
vector_size_a: int = 4
vector_size_b: int = 8
vector_size_c: int = 8
block_per_cu: int = 1
num_wave_groups: int = 1
num_groups_to_merge: int = 1
# Padding (enables arbitrary problem sizes)
pad_m: bool = True
pad_n: bool = True
pad_k: bool = True
def __post_init__(self):
self.variant = _resolve_variant(self.variant)
if (
self.variant in BACKWARD_VARIANTS
and self.pipeline not in BACKWARD_PIPELINES
):
self.pipeline = "compv3"
@property
def tile_str(self) -> str:
return f"{self.tile_m}x{self.tile_n}x{self.tile_k}"
@property
def wave_str(self) -> str:
return f"{self.wave_m}x{self.wave_n}x{self.wave_k}"
@property
def warp_str(self) -> str:
return f"{self.warp_tile_m}x{self.warp_tile_n}x{self.warp_tile_k}"
@property
def vec_str(self) -> str:
return f"{self.vector_size_a}x{self.vector_size_b}x{self.vector_size_c}"
@property
def name(self) -> str:
return (
f"grouped_conv_{self.variant}_{self.dtype}_{self.ndim_spatial}d_"
f"{self.tile_str}_{self.pipeline}"
)
def to_dict(self) -> dict:
"""Convert to legacy dict format for codegen compatibility."""
return {
"tile_config": {
"tile_m": [self.tile_m],
"tile_n": [self.tile_n],
"tile_k": [self.tile_k],
"wave_m": [self.wave_m],
"wave_n": [self.wave_n],
"wave_k": [self.wave_k],
"warp_tile_m": [self.warp_tile_m],
"warp_tile_n": [self.warp_tile_n],
"warp_tile_k": [self.warp_tile_k],
},
"trait_config": {
"pipeline": [self.pipeline],
"epilogue": [self.epilogue],
"scheduler": [self.scheduler],
"pad_m": [self.pad_m],
"pad_n": [self.pad_n],
"pad_k": [self.pad_k],
"vector_size_a": [self.vector_size_a],
"vector_size_b": [self.vector_size_b],
"vector_size_c": [self.vector_size_c],
"block_per_cu": [self.block_per_cu],
"num_wave_groups": [self.num_wave_groups],
"num_groups_to_merge": [self.num_groups_to_merge],
},
"variant": self.variant,
"ndim_spatial": self.ndim_spatial,
"arch": self.arch,
"layout": self.layout,
"dtype": self.dtype,
}
def to_json_obj(self) -> dict:
"""Serializable dict for JSON export."""
return {
"name": self.name,
"signature": {
"variant": self.variant,
"dtype": self.dtype,
"ndim_spatial": self.ndim_spatial,
"layout": self.layout,
},
"algorithm": {
"tile_m": self.tile_m,
"tile_n": self.tile_n,
"tile_k": self.tile_k,
"wave": self.wave_str,
"warp": self.warp_str,
"pipeline": self.pipeline,
"epilogue": self.epilogue,
"scheduler": self.scheduler,
"vector_sizes": [
self.vector_size_a,
self.vector_size_b,
self.vector_size_c,
],
"block_per_cu": self.block_per_cu,
"num_wave_groups": self.num_wave_groups,
"num_groups_to_merge": self.num_groups_to_merge,
},
"arch": self.arch,
}
def print_config(self, indent: str = " "):
print(f"{indent}GroupedConvKernelConfig:")
print(f"{indent} Variant: {self.variant} {self.ndim_spatial}D")
print(f"{indent} Dtype: {self.dtype}")
print(f"{indent} Layout: {self.layout}")
print(f"{indent} Arch: {self.arch}")
print(f"{indent} Tile: {self.tile_str}")
print(f"{indent} Wave: {self.wave_str}")
print(f"{indent} Warp: {self.warp_str}")
print(f"{indent} Pipeline: {self.pipeline}/{self.scheduler}/{self.epilogue}")
print(f"{indent} VecSizes: {self.vec_str}")
print(
f"{indent} BlockCU: {self.block_per_cu} WaveGroups: {self.num_wave_groups} MergeGroups: {self.num_groups_to_merge}"
)
# =============================================================================
# GroupedConvProblem
# =============================================================================
@dataclass
class GroupedConvProblem:
"""Runtime convolution problem specification.
Describes the actual sizes of a convolution to be computed.
Matches the old ConvProblem from conv_utils.py.
"""
N: int = 1
C: int = 64
K: int = 128
G: int = 1
Hi: int = 28
Wi: int = 28
Di: int = 1
Y: int = 3
X: int = 3
Z: int = 1
stride_h: int = 1
stride_w: int = 1
stride_d: int = 1
pad_h: int = 0
pad_w: int = 0
pad_d: int = 0
dilation_h: int = 1
dilation_w: int = 1
dilation_d: int = 1
direction: str = "forward"
split_k: int = 1
@property
def Ho(self) -> int:
eff_y = (self.Y - 1) * self.dilation_h + 1
return (self.Hi + 2 * self.pad_h - eff_y) // self.stride_h + 1
@property
def Wo(self) -> int:
eff_x = (self.X - 1) * self.dilation_w + 1
return (self.Wi + 2 * self.pad_w - eff_x) // self.stride_w + 1
@property
def Do(self) -> int:
eff_z = (self.Z - 1) * self.dilation_d + 1
return (self.Di + 2 * self.pad_d - eff_z) // self.stride_d + 1
@property
def is_3d(self) -> bool:
return self.Di > 1 or self.Z > 1 or self.pad_d > 0
@property
def ndim_spatial(self) -> int:
return 3 if self.is_3d else 2
@property
def flops(self) -> float:
"""Total FLOPs for this convolution (any direction, same count)."""
c_per_group = self.C // self.G
if self.is_3d:
return (
2.0
* self.N
* self.K
* self.Do
* self.Ho
* self.Wo
* c_per_group
* self.Z
* self.Y
* self.X
)
return 2.0 * self.N * self.K * self.Ho * self.Wo * c_per_group * self.Y * self.X
@property
def gflops(self) -> float:
return self.flops / 1e9
def input_shape(self) -> tuple:
"""NHWGC or NDHWGC layout."""
c_per_g = self.C // self.G
if self.is_3d:
return (self.N, self.Di, self.Hi, self.Wi, self.G, c_per_g)
return (self.N, self.Hi, self.Wi, self.G, c_per_g)
def weight_shape(self) -> tuple:
"""GKYXC or GKZYXC layout."""
c_per_g = self.C // self.G
k_per_g = self.K // self.G
if self.is_3d:
return (self.G, k_per_g, self.Z, self.Y, self.X, c_per_g)
return (self.G, k_per_g, self.Y, self.X, c_per_g)
def output_shape(self) -> tuple:
"""NHWGK or NDHWGK layout."""
k_per_g = self.K // self.G
if self.is_3d:
return (self.N, self.Do, self.Ho, self.Wo, self.G, k_per_g)
return (self.N, self.Ho, self.Wo, self.G, k_per_g)
def print_problem(self, indent: str = " "):
dim_str = "3D" if self.is_3d else "2D"
print(f"{indent}GroupedConvProblem ({dim_str} {self.direction}):")
print(f"{indent} Batch: N={self.N}, G={self.G}")
print(f"{indent} Channels: C={self.C}, K={self.K}")
if self.is_3d:
print(f"{indent} Input: Di={self.Di}, Hi={self.Hi}, Wi={self.Wi}")
print(f"{indent} Filter: Z={self.Z}, Y={self.Y}, X={self.X}")
print(f"{indent} Output: Do={self.Do}, Ho={self.Ho}, Wo={self.Wo}")
else:
print(f"{indent} Input: Hi={self.Hi}, Wi={self.Wi}")
print(f"{indent} Filter: Y={self.Y}, X={self.X}")
print(f"{indent} Output: Ho={self.Ho}, Wo={self.Wo}")
print(f"{indent} GFLOPs: {self.gflops:.2f}")
# =============================================================================
# GroupedConvProblemC (ctypes struct matching C++)
# =============================================================================
class GroupedConvProblemC(ctypes.Structure):
"""C structure matching ConvProblemC in conv_ctypes_lib.cpp."""
_fields_ = [
("N", ctypes.c_int),
("G", ctypes.c_int),
("C", ctypes.c_int),
("K", ctypes.c_int),
("input_d", ctypes.c_int),
("input_h", ctypes.c_int),
("input_w", ctypes.c_int),
("filter_z", ctypes.c_int),
("filter_y", ctypes.c_int),
("filter_x", ctypes.c_int),
("stride_d", ctypes.c_int),
("stride_h", ctypes.c_int),
("stride_w", ctypes.c_int),
("pad_d", ctypes.c_int),
("pad_h", ctypes.c_int),
("pad_w", ctypes.c_int),
("dilation_d", ctypes.c_int),
("dilation_h", ctypes.c_int),
("dilation_w", ctypes.c_int),
("direction", ctypes.c_int),
("split_k", ctypes.c_int),
]
@classmethod
def from_problem(cls, p: GroupedConvProblem) -> "GroupedConvProblemC":
c = cls()
c.N, c.G, c.C, c.K = p.N, p.G, p.C, p.K
c.input_d, c.input_h, c.input_w = p.Di, p.Hi, p.Wi
c.filter_z, c.filter_y, c.filter_x = p.Z, p.Y, p.X
c.stride_d, c.stride_h, c.stride_w = p.stride_d, p.stride_h, p.stride_w
c.pad_d, c.pad_h, c.pad_w = p.pad_d, p.pad_h, p.pad_w
c.dilation_d, c.dilation_h, c.dilation_w = (
p.dilation_d,
p.dilation_h,
p.dilation_w,
)
c.direction = DIRECTION_MAP.get(p.direction, 0)
c.split_k = getattr(p, "split_k", 1)
return c
# =============================================================================
# GroupedConvResult
# =============================================================================
@dataclass
class GroupedConvResult:
"""Result of GPU convolution execution."""
success: bool = False
time_ms: float = 0.0
tflops: float = 0.0
output: Optional[np.ndarray] = None
error: str = ""
# =============================================================================
# GroupedConvDispatcherLib
# =============================================================================
class GroupedConvDispatcherLib:
"""Wrapper for the compiled convolution dispatcher library.
Provides Python interface to the C API in conv_ctypes_lib.cpp.
"""
SEARCH_PATHS = [
"build/examples/libdispatcher_conv_lib.so",
"build/bindings/libdispatcher_conv_lib.so",
"build/lib/libdispatcher_conv_lib.so",
]
def __init__(self, lib: ctypes.CDLL, path: Path):
self._lib = lib
self._path = path
self._setup_functions()
def _setup_functions(self):
self._lib.conv_dispatcher_init.argtypes = []
self._lib.conv_dispatcher_init.restype = ctypes.c_int
self._lib.conv_dispatcher_cleanup.argtypes = []
self._lib.conv_dispatcher_cleanup.restype = ctypes.c_int
self._lib.conv_dispatcher_version.argtypes = []
self._lib.conv_dispatcher_version.restype = ctypes.c_char_p
self._lib.conv_dispatcher_has_kernels.argtypes = []
self._lib.conv_dispatcher_has_kernels.restype = ctypes.c_int
self._lib.conv_dispatcher_has_bwd_data.argtypes = []
self._lib.conv_dispatcher_has_bwd_data.restype = ctypes.c_int
self._lib.conv_dispatcher_has_bwd_weight.argtypes = []
self._lib.conv_dispatcher_has_bwd_weight.restype = ctypes.c_int
self._lib.conv_dispatcher_get_kernel_count.argtypes = []
self._lib.conv_dispatcher_get_kernel_count.restype = ctypes.c_int
self._lib.conv_dispatcher_get_kernel_name.argtypes = [
ctypes.c_int,
ctypes.c_char_p,
ctypes.c_int,
]
self._lib.conv_dispatcher_get_kernel_name.restype = ctypes.c_int
self._lib.conv_dispatcher_is_supported.argtypes = [
ctypes.POINTER(GroupedConvProblemC),
]
self._lib.conv_dispatcher_is_supported.restype = ctypes.c_int
self._lib.conv_dispatcher_run.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.POINTER(GroupedConvProblemC),
ctypes.c_void_p,
]
self._lib.conv_dispatcher_run.restype = ctypes.c_float
@classmethod
def find(cls) -> Optional["GroupedConvDispatcherLib"]:
"""Search standard paths for the conv library."""
root = Path(__file__).parent.parent
for rel in cls.SEARCH_PATHS:
path = root / rel
if path.exists():
try:
lib = ctypes.CDLL(str(path))
return cls(lib, path)
except OSError:
continue
return None
@property
def path(self) -> Path:
return self._path
def initialize(self):
self._lib.conv_dispatcher_init()
def cleanup(self):
self._lib.conv_dispatcher_cleanup()
def version(self) -> str:
return self._lib.conv_dispatcher_version().decode()
def has_forward(self) -> bool:
return self._lib.conv_dispatcher_has_kernels() != 0
def has_bwd_data(self) -> bool:
return self._lib.conv_dispatcher_has_bwd_data() != 0
def has_bwd_weight(self) -> bool:
return self._lib.conv_dispatcher_has_bwd_weight() != 0
def kernel_count(self) -> int:
return self._lib.conv_dispatcher_get_kernel_count()
def kernel_names(self) -> List[str]:
names = []
for i in range(self.kernel_count()):
buf = ctypes.create_string_buffer(256)
if self._lib.conv_dispatcher_get_kernel_name(i, buf, 256) == 0:
names.append(buf.value.decode())
return names
def is_supported(self, problem: GroupedConvProblem) -> bool:
pc = GroupedConvProblemC.from_problem(problem)
return self._lib.conv_dispatcher_is_supported(ctypes.byref(pc)) != 0
def run(
self, a_ptr: int, b_ptr: int, c_ptr: int, problem: GroupedConvProblem
) -> float:
"""Run convolution. Returns time_ms (>0 success, <0 error)."""
pc = GroupedConvProblemC.from_problem(problem)
return self._lib.conv_dispatcher_run(
a_ptr, b_ptr, c_ptr, ctypes.byref(pc), None
)
# =============================================================================
# GpuGroupedConvRunner
# =============================================================================
class GpuGroupedConvRunner:
"""High-level GPU convolution runner.
Handles library loading, HIP memory management, and kernel execution.
Follows the same pattern as the old GpuConvRunner from conv_utils.py.
Usage:
runner = GpuGroupedConvRunner()
if runner.is_available():
result = runner.run(input_np, weight_np, problem)
print(f"Time: {result.time_ms:.4f} ms, TFLOPS: {result.tflops:.2f}")
"""
HIP_MEMCPY_H2D = 1
HIP_MEMCPY_D2H = 2
def __init__(self, lib_path: Optional[str] = None):
self._dispatch_lib: Optional[GroupedConvDispatcherLib] = None
self._hip = None
self._initialized = False
try:
if lib_path:
lib = ctypes.CDLL(lib_path)
self._dispatch_lib = GroupedConvDispatcherLib(lib, Path(lib_path))
else:
self._dispatch_lib = GroupedConvDispatcherLib.find()
if self._dispatch_lib is None:
return
self._hip = ctypes.CDLL("libamdhip64.so")
self._hip.hipMalloc.argtypes = [
ctypes.POINTER(ctypes.c_void_p),
ctypes.c_size_t,
]
self._hip.hipMalloc.restype = ctypes.c_int
self._hip.hipFree.argtypes = [ctypes.c_void_p]
self._hip.hipFree.restype = ctypes.c_int
self._hip.hipMemcpy.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
ctypes.c_int,
]
self._hip.hipMemcpy.restype = ctypes.c_int
self._hip.hipDeviceSynchronize.argtypes = []
self._hip.hipDeviceSynchronize.restype = ctypes.c_int
self._dispatch_lib.initialize()
self._initialized = True
except Exception:
self._initialized = False
def is_available(self) -> bool:
return self._initialized and self._dispatch_lib is not None
@property
def library_path(self) -> Optional[str]:
if self._dispatch_lib:
return str(self._dispatch_lib.path)
return None
@property
def lib(self) -> Optional[GroupedConvDispatcherLib]:
return self._dispatch_lib
def run(
self,
input_np: np.ndarray,
weight_np: np.ndarray,
problem: GroupedConvProblem,
output_np: Optional[np.ndarray] = None,
) -> GroupedConvResult:
"""Run convolution on GPU.
Args:
input_np: For forward: X (NHWGC). For bwd_data: dY. For bwd_weight: X.
weight_np: For forward: W (GKYXC). For bwd_data: W. For bwd_weight: dY.
problem: Problem specification.
output_np: Optional pre-allocated output buffer.
Returns:
GroupedConvResult with success, time_ms, tflops, output.
"""
if not self.is_available():
return GroupedConvResult(error="GPU not available")
try:
# Determine output shape based on direction
d = problem.direction
if d == "bwd_data":
out_shape = problem.input_shape()
elif d == "bwd_weight":
out_shape = problem.weight_shape()
else:
out_shape = problem.output_shape()
if output_np is None:
output_np = np.zeros(out_shape, dtype=input_np.dtype)
output_size = output_np.nbytes
# Allocate GPU memory
d_a, d_b, d_c = ctypes.c_void_p(), ctypes.c_void_p(), ctypes.c_void_p()
self._hip.hipMalloc(ctypes.byref(d_a), input_np.nbytes)
self._hip.hipMalloc(ctypes.byref(d_b), weight_np.nbytes)
self._hip.hipMalloc(ctypes.byref(d_c), output_size)
# Host to device
self._hip.hipMemcpy(
d_a, input_np.ctypes.data, input_np.nbytes, self.HIP_MEMCPY_H2D
)
self._hip.hipMemcpy(
d_b, weight_np.ctypes.data, weight_np.nbytes, self.HIP_MEMCPY_H2D
)
self._hip.hipDeviceSynchronize()
# Launch kernel
time_ms = self._dispatch_lib.run(d_a.value, d_b.value, d_c.value, problem)
self._hip.hipDeviceSynchronize()
result = GroupedConvResult()
if time_ms > 0:
# Device to host
self._hip.hipMemcpy(
output_np.ctypes.data, d_c, output_size, self.HIP_MEMCPY_D2H
)
self._hip.hipDeviceSynchronize()
result.success = True
result.time_ms = time_ms
result.tflops = problem.flops / (time_ms * 1e9)
result.output = output_np
else:
result.error = (
"unsupported"
if time_ms == -3.0
else "no kernel"
if time_ms == -2.0
else f"error (code {time_ms})"
)
# Free GPU memory
self._hip.hipFree(d_a)
self._hip.hipFree(d_b)
self._hip.hipFree(d_c)
return result
except Exception as e:
return GroupedConvResult(error=str(e))
def cleanup(self):
if self._dispatch_lib:
try:
self._dispatch_lib.cleanup()
except Exception:
pass
# =============================================================================
# GroupedConvRegistry
# =============================================================================
class GroupedConvRegistry:
"""Collection of grouped conv kernel configs with JSON export/import."""
def __init__(self, name: str = "default"):
self.name = name
self._kernels: List[GroupedConvKernelConfig] = []
def add(self, config: GroupedConvKernelConfig):
self._kernels.append(config)
@property
def kernels(self) -> List[GroupedConvKernelConfig]:
return list(self._kernels)
def __len__(self) -> int:
return len(self._kernels)
def select(
self, problem: "GroupedConvProblem", heuristic=None
) -> Optional[GroupedConvKernelConfig]:
"""Select the best kernel for a problem.
Args:
problem: The convolution problem.
heuristic: Optional callable(problem) -> List[str] returning
ranked kernel name substrings. The registry tries
each in order; falls back to first matching kernel.
Returns:
The best matching GroupedConvKernelConfig, or None.
"""
matching = [k for k in self._kernels if k.variant == problem.direction]
if not matching:
return None
if heuristic is not None:
ranked = heuristic(problem)
for hint in ranked:
for k in matching:
if hint in k.name:
return k
return matching[0] if matching else None
def filter_by_variant(self, variant: str) -> "GroupedConvRegistry":
variant = _resolve_variant(variant)
reg = GroupedConvRegistry(f"{self.name}_{variant}")
for k in self._kernels:
if k.variant == variant:
reg.add(k)
return reg
def filter_by_arch(self, arch: str) -> "GroupedConvRegistry":
reg = GroupedConvRegistry(f"{self.name}_{arch}")
for k in self._kernels:
if k.arch == arch:
reg.add(k)
return reg
def to_json(self, indent: int = 2) -> str:
return json.dumps(
{
"name": self.name,
"kernels": [k.to_json_obj() for k in self._kernels],
},
indent=indent,
)
@classmethod
def from_json(cls, json_str: str) -> "GroupedConvRegistry":
data = json.loads(json_str)
reg = cls(data.get("name", "imported"))
for kd in data.get("kernels", []):
sig = kd.get("signature", {})
algo = kd.get("algorithm", {})
wave = algo.get("wave", "2x2x1").split("x")
warp = algo.get("warp", "32x32x16").split("x")
vec = algo.get("vector_sizes", [4, 8, 8])
reg.add(
GroupedConvKernelConfig(
variant=sig.get("variant", "forward"),
ndim_spatial=sig.get("ndim_spatial", 2),
dtype=sig.get("dtype", "fp16"),
layout=sig.get("layout", "nhwgc"),
arch=kd.get("arch", "gfx942"),
tile_m=algo.get("tile_m", 1),
tile_n=algo.get("tile_n", 128),
tile_k=algo.get("tile_k", 128),
wave_m=int(wave[0]),
wave_n=int(wave[1]),
wave_k=int(wave[2]),
warp_tile_m=int(warp[0]),
warp_tile_n=int(warp[1]),
warp_tile_k=int(warp[2]),
pipeline=algo.get("pipeline", "compv3"),
epilogue=algo.get("epilogue", "cshuffle"),
scheduler=algo.get("scheduler", "intrawave"),
vector_size_a=vec[0] if len(vec) > 0 else 4,
vector_size_b=vec[1] if len(vec) > 1 else 8,
vector_size_c=vec[2] if len(vec) > 2 else 8,
block_per_cu=algo.get("block_per_cu", 1),
num_wave_groups=algo.get("num_wave_groups", 1),
num_groups_to_merge=algo.get("num_groups_to_merge", 1),
)
)
return reg
def build(
self,
verbose: bool = False,
max_workers: Optional[int] = None,
) -> Dict[Tuple[str, int], "GpuGroupedConvRunner"]:
"""Parallel JIT compile all kernels in this registry.
Args:
verbose: Print progress during build.
max_workers: Max parallel codegen/compile processes (default: cpu_count capped at 8).
Returns a dict mapping (variant, ndim_spatial) to a ready-to-use
GpuGroupedConvRunner.
"""
if not self._kernels:
return {}
libs = setup_multiple_grouped_conv_dispatchers(
self._kernels,
verbose=verbose,
max_workers=max_workers,
)
runners: Dict[Tuple[str, int], GpuGroupedConvRunner] = {}
for cfg, lib in zip(self._kernels, libs):
if lib is None:
continue
key = (cfg.variant, cfg.ndim_spatial)
if key in runners:
continue
runner = GpuGroupedConvRunner(lib_path=str(lib.path))
if runner.is_available():
runners[key] = runner
return runners
def print_registry(self, indent: str = " "):
print(f"{indent}Registry '{self.name}': {len(self)} kernels")
for i, k in enumerate(self._kernels):
print(
f"{indent} [{i}] {k.name} (valid={validate_grouped_conv_config(k.to_dict()).is_valid})"
)
# =============================================================================
# GroupedConvValidationResult
# =============================================================================
@dataclass
class GroupedConvValidationResult(ValidationResultBase):
"""Result of grouped conv kernel config validation."""
variant: str = "forward"
def __init__(
self,
is_valid=True,
errors=None,
warnings=None,
suggested_fixes=None,
variant="forward",
):
super().__init__(
is_valid=is_valid,
errors=errors or [],
warnings=warnings or [],
suggested_fixes=suggested_fixes or {},
)
self.variant = variant
# =============================================================================
# Validation helpers (extracted from the original config extraction code)
# =============================================================================
def _first(val):
if isinstance(val, list) and len(val) > 0:
return val[0]
return val
def _get_tile_config(config: dict) -> dict:
return config.get("tile_config") or {}
def _get_trait_config(config: dict) -> dict:
return config.get("trait_config") or {}
def _extract_wave_config(tile_config: dict) -> List[int]:
wm = tile_config.get("wave_m") or tile_config.get("warp_m")
wn = tile_config.get("wave_n") or tile_config.get("warp_n")
wk = tile_config.get("wave_k") or tile_config.get("warp_k")
if wm is not None and wn is not None and wk is not None:
return [_first(wm), _first(wn), _first(wk)]
return [2, 2, 1]
def _extract_warp_tile_config(tile_config: dict) -> List[int]:
wtm = tile_config.get("warp_tile_m") or tile_config.get("warp_m")
wtn = tile_config.get("warp_tile_n") or tile_config.get("warp_n")
wtk = tile_config.get("warp_tile_k") or tile_config.get("warp_k")
if wtm is not None and wtn is not None and wtk is not None:
return [_first(wtm), _first(wtn), _first(wtk)]
return [32, 32, 16]
def _extract_trait_values(trait_config: dict) -> Tuple[str, str, str]:
p = _first(trait_config.get("pipeline", "compv4"))
e = _first(trait_config.get("epilogue", "cshuffle"))
s = _first(trait_config.get("scheduler", "intrawave"))
if isinstance(p, list):
p = p[0] if p else "compv4"
if isinstance(e, list):
e = e[0] if e else "cshuffle"
if isinstance(s, list):
s = s[0] if s else "intrawave"
return (str(p), str(e), str(s))
# =============================================================================
# validate_grouped_conv_config / auto_correct_grouped_conv_config
# =============================================================================
def validate_grouped_conv_config(config: dict) -> GroupedConvValidationResult:
"""Validate a grouped conv kernel config dict.
Accepts either a raw dict (legacy) or GroupedConvKernelConfig.to_dict() output.
"""
errors: List[str] = []
warnings: List[str] = []
suggested_fixes: Dict[str, Any] = {}
required = (
"tile_config",
"trait_config",
"variant",
"ndim_spatial",
"arch",
"layout",
)
for key in required:
if key not in config:
errors.append(f"Missing required key: {key}")
if errors:
return GroupedConvValidationResult(
is_valid=False,
errors=errors,
warnings=warnings,
suggested_fixes=suggested_fixes,
variant=config.get("variant", "forward"),
)
tile_config = _get_tile_config(config)
trait_config = _get_trait_config(config)
variant = _first(config.get("variant", "forward"))
if isinstance(variant, list):
variant = variant[0] if variant else "forward"
variant = _resolve_variant(str(variant))
ndim_spatial = config.get("ndim_spatial")
arch = config.get("arch", "gfx942")
dtype = config.get("dtype", "fp16")
if variant not in VALID_VARIANTS:
errors.append(f"Invalid variant: {variant}. Valid: {', '.join(VALID_VARIANTS)}")
suggested_fixes["variant"] = "forward"
if ndim_spatial is not None:
ndim = ndim_spatial
if isinstance(ndim, list):
ndim = ndim[0] if ndim else 2
if ndim not in VALID_NDIM_SPATIAL:
errors.append(
f"Invalid ndim_spatial: {ndim}. Valid: {', '.join(map(str, VALID_NDIM_SPATIAL))}"
)
suggested_fixes["ndim_spatial"] = 2
pipeline, epilogue, scheduler = _extract_trait_values(trait_config)
if variant in BACKWARD_VARIANTS and pipeline not in BACKWARD_PIPELINES:
errors.append(
f"Backward variant '{variant}' requires pipeline compv3 or mem, got {pipeline}"
)
suggested_fixes["pipeline"] = "compv3"
ok, msg = validate_trait_combo(pipeline, epilogue, scheduler)
if not ok:
errors.append(msg)
suggested_fixes["scheduler"] = "intrawave"
wave_cfg = _extract_wave_config(tile_config)
ok, msg = validate_wave_config(wave_cfg, arch)
if not ok:
errors.append(msg)
arch_data = get_arch_filter_data()
valid_waves = arch_data["warp_combos"].get(arch, [[2, 2, 1]])
if valid_waves:
suggested_fixes["wave_m"] = valid_waves[0][0]
suggested_fixes["wave_n"] = valid_waves[0][1]
suggested_fixes["wave_k"] = valid_waves[0][2]
warp_cfg = _extract_warp_tile_config(tile_config)
ok, msg = validate_warp_tile_config(warp_cfg, arch, dtype)
if not ok:
errors.append(msg)
arch_data = get_arch_filter_data()
acc = "int32" if dtype == "int8" else "fp32"
dtype_key = f"{dtype}_{dtype}_{acc}"
valid_tiles = (
arch_data["warp_tile_combos"]
.get(arch, {})
.get(dtype_key, [[32, 32, 16], [16, 16, 16]])
)
if valid_tiles:
suggested_fixes["warp_tile_m"] = valid_tiles[0][0]
suggested_fixes["warp_tile_n"] = valid_tiles[0][1]
suggested_fixes["warp_tile_k"] = valid_tiles[0][2]
arch_data = get_arch_filter_data()
if arch not in arch_data["supported_archs"]:
errors.append(
f"Unsupported architecture: {arch}. Supported: {', '.join(arch_data['supported_archs'])}"
)
return GroupedConvValidationResult(
is_valid=len(errors) == 0,
errors=errors,
warnings=warnings,
suggested_fixes=suggested_fixes,
variant=variant,
)
def auto_correct_grouped_conv_config(
config: dict,
) -> Tuple[dict, GroupedConvValidationResult]:
"""Auto-correct invalid grouped conv config. Returns (corrected, result)."""
result = validate_grouped_conv_config(config)
corrected = copy.deepcopy(config)
if result.is_valid:
return corrected, result
tile_config = corrected.setdefault("tile_config", {})
trait_config = corrected.setdefault("trait_config", {})
wave_cfg = _extract_wave_config(tile_config)
arch = config.get("arch", "gfx942")
fixed_wave = auto_correct_wave(wave_cfg, arch)
tile_config["wave_m"] = fixed_wave[0]
tile_config["wave_n"] = fixed_wave[1]
tile_config["wave_k"] = fixed_wave[2]
pipeline, epilogue, scheduler = _extract_trait_values(trait_config)
fixed_pipeline, fixed_scheduler = auto_correct_trait(pipeline, scheduler)
trait_config["pipeline"] = fixed_pipeline
trait_config["scheduler"] = fixed_scheduler
variant = _first(config.get("variant", "forward"))
if isinstance(variant, list):
variant = variant[0] if variant else "forward"
variant = _resolve_variant(str(variant))
if variant in BACKWARD_VARIANTS and fixed_pipeline not in BACKWARD_PIPELINES:
trait_config["pipeline"] = "compv3"
if "warp_tile_m" in result.suggested_fixes:
tile_config["warp_tile_m"] = result.suggested_fixes["warp_tile_m"]
tile_config["warp_tile_n"] = result.suggested_fixes["warp_tile_n"]
tile_config["warp_tile_k"] = result.suggested_fixes["warp_tile_k"]
result = validate_grouped_conv_config(corrected)
return corrected, result
def _run_hipcc_subprocess(args: dict) -> Tuple[bool, Optional[Path], str]:
"""Run one hipcc compile+link job in a subprocess worker."""
import subprocess
from pathlib import Path
compile_cmd = args["compile_cmd"]
link_cmd = args["link_cmd"]
lib_path = Path(args["lib_path"])
try:
res_c = subprocess.run(compile_cmd, capture_output=True, text=True, timeout=300)
if res_c.returncode != 0:
return False, None, f"Compile failed: {res_c.stderr[:400]}"
res_l = subprocess.run(link_cmd, capture_output=True, text=True, timeout=300)
if res_l.returncode != 0:
return False, None, f"Link failed: {res_l.stderr[:400]}"
return True, lib_path, ""
except subprocess.TimeoutExpired:
return False, None, "Timeout"
except Exception as e:
return False, None, f"Error: {e}"
def _run_conv_codegen_subprocess(args: dict) -> Tuple[bool, Optional[str], str]:
"""Run grouped-conv codegen once and return generated kernel header path."""
import subprocess
from pathlib import Path
out_dir = Path(args["output_dir"])
out_dir.mkdir(parents=True, exist_ok=True)
# Remove stale kernels so header discovery is exact for this invocation.
for stale in out_dir.glob("grouped_conv_*.hpp"):
stale.unlink(missing_ok=True)
for stale in out_dir.glob("include_all_grouped_conv_*.hpp"):
stale.unlink(missing_ok=True)
try:
res = subprocess.run(args["cmd"], capture_output=True, text=True, timeout=300)
if res.returncode != 0:
err = (res.stderr or res.stdout or "").strip()[:500]
return False, None, f"Codegen failed: {err}"
generated = sorted(
out_dir.glob("grouped_conv_*.hpp"),
key=lambda p: p.stat().st_mtime,
reverse=True,
)
if not generated:
return False, None, "Codegen produced no grouped_conv_*.hpp header"
return True, str(generated[0]), ""
except subprocess.TimeoutExpired:
return False, None, "Codegen timed out"
except Exception as e:
return False, None, f"Codegen error: {e}"
def _config_key(c: GroupedConvKernelConfig) -> Tuple[Any, ...]:
return (
c.variant,
c.ndim_spatial,
c.dtype,
c.layout,
c.arch,
c.tile_m,
c.tile_n,
c.tile_k,
c.wave_m,
c.wave_n,
c.wave_k,
c.warp_tile_m,
c.warp_tile_n,
c.warp_tile_k,
c.pipeline,
c.epilogue,
c.scheduler,
)
def _parse_triplet(value: str) -> Tuple[int, int, int]:
parts = value.split("x")
if len(parts) != 3:
raise ValueError(f"Invalid triplet: {value}")
return int(parts[0]), int(parts[1]), int(parts[2])
def _list_arch_valid_grouped_conv_configs(
codegen_script: Path,
arch: str,
dtype: str,
variant: str,
ndim_spatial: int,
) -> List[GroupedConvKernelConfig]:
"""Query codegen defaults for this (arch, dtype, variant, ndim) tuple."""
import re
import sys
cmd = [
sys.executable,
str(codegen_script),
"--list-configs",
"--arch",
arch,
"--datatype",
dtype,
"--variant",
variant,
"--ndim",
str(ndim_spatial),
]
res = subprocess.run(cmd, capture_output=True, text=True, timeout=180)
if res.returncode != 0:
return []
# Example:
# grouped_conv_fwd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x32_2x2x1_32x32x16
name_re = re.compile(
r"^grouped_conv_(fwd|bwd_data|bwd_weight|bwdd|bwdw)_([a-z0-9]+)_([a-z0-9]+)_([123])d_"
r"([a-z0-9]+)_([a-z0-9]+)_([a-z0-9]+)_"
r"([0-9]+x[0-9]+x[0-9]+)_([0-9]+x[0-9]+x[0-9]+)_([0-9]+x[0-9]+x[0-9]+)"
r"(?:_.*)?$"
)
short_to_variant = {
"fwd": "forward",
"bwd_data": "bwd_data",
"bwd_weight": "bwd_weight",
"bwdd": "bwd_data",
"bwdw": "bwd_weight",
}
out: List[GroupedConvKernelConfig] = []
seen = set()
for raw in res.stdout.splitlines():
line = raw.strip()
if not line.startswith("- grouped_conv_"):
continue
name = line[2:].strip()
m = name_re.match(name)
if not m:
continue
v_short, dt, layout, ndim, pipe, epi, sched, tile_s, wave_s, warp_s = m.groups()
tm, tn, tk = _parse_triplet(tile_s)
wm, wn, wk = _parse_triplet(wave_s)
wtm, wtn, wtk = _parse_triplet(warp_s)
cfg = GroupedConvKernelConfig(
variant=short_to_variant[v_short],
ndim_spatial=int(ndim),
dtype=dt,
layout=layout,
arch=arch,
tile_m=tm,
tile_n=tn,
tile_k=tk,
wave_m=wm,
wave_n=wn,
wave_k=wk,
warp_tile_m=wtm,
warp_tile_n=wtn,
warp_tile_k=wtk,
pipeline=pipe,
epilogue=epi,
scheduler=sched,
)
key = _config_key(cfg)
if key not in seen:
out.append(cfg)
seen.add(key)
return out
def _select_best_arch_valid_conv_config(
requested: GroupedConvKernelConfig,
candidates: List[GroupedConvKernelConfig],
) -> GroupedConvKernelConfig:
"""Pick nearest arch-valid config while preferring trait exact matches."""
def score(c: GroupedConvKernelConfig) -> Tuple[int, int, int, int, int, int]:
tile_delta = (
abs(c.tile_m - requested.tile_m)
+ abs(c.tile_n - requested.tile_n)
+ abs(c.tile_k - requested.tile_k)
)
wave_delta = (
abs(c.wave_m - requested.wave_m)
+ abs(c.wave_n - requested.wave_n)
+ abs(c.wave_k - requested.wave_k)
)
warp_tile_delta = (
abs(c.warp_tile_m - requested.warp_tile_m)
+ abs(c.warp_tile_n - requested.warp_tile_n)
+ abs(c.warp_tile_k - requested.warp_tile_k)
)
return (
0 if c.pipeline == requested.pipeline else 1,
0 if c.scheduler == requested.scheduler else 1,
0 if c.epilogue == requested.epilogue else 1,
tile_delta,
wave_delta,
warp_tile_delta,
)
best = min(candidates, key=score)
selected = copy.deepcopy(best)
selected.arch = requested.arch
return selected
def _write_single_conv_dispatch_header(
config: GroupedConvKernelConfig,
kernel_header: Path,
dispatch_header: Path,
) -> None:
"""Create a tiny dispatch header consumed by conv_ctypes_lib.cpp."""
macros: List[str] = []
aliases: List[str] = []
if config.variant == "forward":
kernel_name_symbol = "CONV_FWD_KERNEL_NAME"
if config.ndim_spatial == 3:
macros.append("#define CONV_FWD_3D_AVAILABLE 1")
aliases.append("using ConvFwd3dLauncher = SelectedConvKernelLauncher;")
else:
macros.append("#define CONV_FWD_2D_AVAILABLE 1")
elif config.variant == "bwd_data":
kernel_name_symbol = "CONV_BWD_DATA_KERNEL_NAME"
if config.ndim_spatial == 3:
macros.append("#define CONV_BWD_DATA_3D_AVAILABLE 1")
aliases.append("using ConvBwdData3dLauncher = SelectedConvBwdDataLauncher;")
else:
macros.append("#define CONV_BWD_DATA_2D_AVAILABLE 1")
else:
kernel_name_symbol = "CONV_BWD_WEIGHT_KERNEL_NAME"
if config.ndim_spatial == 3:
macros.append("#define CONV_BWD_WEIGHT_3D_AVAILABLE 1")
aliases.append(
"using ConvBwdWeight3dLauncher = SelectedConvBwdWeightLauncher;"
)
else:
macros.append("#define CONV_BWD_WEIGHT_2D_AVAILABLE 1")
content = (
"// Auto-generated single-kernel dispatch header for Python JIT\n"
"#pragma once\n\n"
f'#include "{kernel_header.name}"\n\n'
+ "\n".join(macros)
+ "\n\n"
+ "\n".join(aliases)
+ "\n\n"
+ f"static const char* CONV_KERNEL_NAMES[] = {{{kernel_name_symbol}}};\n"
+ "static constexpr int CONV_KERNEL_COUNT = 1;\n"
)
dispatch_header.write_text(content)
class GroupedConvCodegenRunner:
"""Generate and compile grouped-conv JIT libraries in parallel."""
def __init__(self, max_workers: Optional[int] = None):
import multiprocessing
self.max_workers = max_workers or min(multiprocessing.cpu_count(), 8)
self.root = Path(__file__).parent.parent
self.build_dir = self.root / "build"
self.codegen_script = self.root / "codegen" / "unified_grouped_conv_codegen.py"
def generate_and_compile_parallel(
self,
configs: List[GroupedConvKernelConfig],
verbose: bool = True,
) -> List[Optional[Path]]:
import sys
from concurrent.futures import ProcessPoolExecutor, as_completed
if not configs:
return []
if not self.build_dir.exists():
self.build_dir.mkdir(parents=True, exist_ok=True)
ctypes_source = self.root / "bindings" / "ctypes" / "conv_ctypes_lib.cpp"
static_lib = self.build_dir / "libck_tile_dispatcher.a"
jit_root = self.build_dir / "generated_kernels" / "python_jit"
jit_root.mkdir(parents=True, exist_ok=True)
(self.build_dir / "examples").mkdir(parents=True, exist_ok=True)
if not self.codegen_script.exists():
if verbose:
print(f"Codegen script missing: {self.codegen_script}")
return [None] * len(configs)
if not ctypes_source.exists() or not static_lib.exists():
if verbose:
print("Missing conv ctypes source or static dispatcher library")
return [None] * len(configs)
if verbose:
print(
f"Generating {len(configs)} grouped-conv kernels in parallel "
f"(workers={self.max_workers})..."
)
gen_jobs: List[Dict[str, Any]] = []
job_dirs: List[Path] = []
for i, c in enumerate(configs):
cfg_dir = jit_root / f"cfg_{i}"
cfg_dir.mkdir(parents=True, exist_ok=True)
job_dirs.append(cfg_dir)
cmd = [
sys.executable,
str(self.codegen_script),
"--output",
str(cfg_dir),
"--datatype",
c.dtype,
"--variant",
c.variant,
"--ndim",
str(c.ndim_spatial),
"--arch",
c.arch,
"--tile-m",
str(c.tile_m),
"--tile-n",
str(c.tile_n),
"--tile-k",
str(c.tile_k),
"--warp-m",
str(c.wave_m),
"--warp-n",
str(c.wave_n),
"--warp-k",
str(c.wave_k),
"--warp-tile-m",
str(c.warp_tile_m),
"--warp-tile-n",
str(c.warp_tile_n),
"--warp-tile-k",
str(c.warp_tile_k),
"--pipeline",
c.pipeline,
"--scheduler",
c.scheduler,
"--epilogue",
c.epilogue,
]
gen_jobs.append({"cmd": cmd, "output_dir": str(cfg_dir)})
generated_headers: List[Optional[Path]] = [None] * len(configs)
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
futures = {
executor.submit(_run_conv_codegen_subprocess, job): idx
for idx, job in enumerate(gen_jobs)
}
for future in as_completed(futures):
idx = futures[future]
ok, header_path, err = future.result()
if ok and header_path:
generated_headers[idx] = Path(header_path)
if verbose:
print(f" OK [{idx}] codegen: {Path(header_path).name}")
else:
if verbose:
print(f" FAIL [{idx}] codegen: {err}")
if verbose:
compile_count = sum(1 for h in generated_headers if h is not None)
print(
f"Compiling {compile_count} grouped-conv libraries in parallel "
f"(workers={self.max_workers})..."
)
compile_jobs: List[Dict[str, Any]] = []
compile_to_input_index: Dict[int, int] = {}
for i, c in enumerate(configs):
hdr_path = generated_headers[i]
if hdr_path is None:
continue
cfg_dir = job_dirs[i]
dispatch_header = cfg_dir / "conv_python_dispatch.hpp"
_write_single_conv_dispatch_header(c, hdr_path, dispatch_header)
lib_name = (
f"libdispatcher_conv_{c.variant}_{c.ndim_spatial}d_{c.dtype}_"
f"{c.tile_str}_{c.wave_str}_{c.warp_str}_{c.pipeline}_{c.scheduler}.so"
)
lib_path = self.build_dir / "examples" / lib_name
obj_file = lib_path.with_suffix(".o")
compile_cmd = [
"/opt/rocm/bin/hipcc",
"-c",
"-fPIC",
"-O3",
f"-I{self.root / 'include'}",
f"-I{self.root.parent / 'include'}",
f"-I{self.root.parent}",
f"-I{cfg_dir}",
"-DCK_TILE_SINGLE_KERNEL_INCLUDE",
f"-include{dispatch_header}",
"-D__HIP_PLATFORM_AMD__",
f"--offload-arch={c.arch}",
f'-DGFX_ARCH="{c.arch}"',
"-mllvm",
"-enable-noalias-to-md-conversion=0",
"-Wno-undefined-func-template",
"-Wno-float-equal",
str(ctypes_source),
"-o",
str(obj_file),
]
link_cmd = [
"/opt/rocm/bin/hipcc",
"-shared",
"-fPIC",
f"--offload-arch={c.arch}",
"--hip-link",
str(obj_file),
str(static_lib),
"-o",
str(lib_path),
]
compile_to_input_index[len(compile_jobs)] = i
compile_jobs.append(
{
"compile_cmd": compile_cmd,
"link_cmd": link_cmd,
"lib_path": str(lib_path),
"config_name": c.name,
}
)
results_map: Dict[int, Optional[Path]] = {i: None for i in range(len(configs))}
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
futures = {
executor.submit(_run_hipcc_subprocess, job): j
for j, job in enumerate(compile_jobs)
}
for future in as_completed(futures):
job_idx = futures[future]
idx = compile_to_input_index[job_idx]
success, lib_path, err = future.result()
if success and lib_path:
results_map[idx] = Path(lib_path)
if verbose:
status = "OK" if success else f"FAIL ({err})"
name = (
Path(lib_path).name
if success and lib_path
else compile_jobs[job_idx]["config_name"]
)
print(f" {status} {name}")
return [results_map.get(i) for i in range(len(configs))]
# =============================================================================
# Convenience functions
# =============================================================================
def get_grouped_conv_default_config(
variant: str = "forward",
ndim_spatial: int = 2,
arch: str = "gfx942",
dtype: str = "fp16",
) -> GroupedConvKernelConfig:
"""Return a valid default GroupedConvKernelConfig."""
return GroupedConvKernelConfig(
variant=variant,
ndim_spatial=ndim_spatial,
arch=arch,
dtype=dtype,
)
def format_grouped_conv_summary(config) -> str:
"""Format a config (dict or GroupedConvKernelConfig) into a human-readable string."""
if isinstance(config, GroupedConvKernelConfig):
lines = [
f"Grouped Conv Config: {config.variant} {config.ndim_spatial}D",
f" Arch: {config.arch}",
f" Layout: {config.layout}",
f" Dtype: {config.dtype}",
f" Tile: {config.tile_str}",
f" Wave: {config.wave_str}",
f" Warp: {config.warp_str}",
f" Traits: pipeline={config.pipeline} epilogue={config.epilogue} scheduler={config.scheduler}",
]
return "\n".join(lines)
# Legacy dict support
tile_config = _get_tile_config(config) if isinstance(config, dict) else {}
trait_config = _get_trait_config(config) if isinstance(config, dict) else {}
variant = config.get("variant", "?") if isinstance(config, dict) else "?"
ndim = config.get("ndim_spatial", "?") if isinstance(config, dict) else "?"
arch = config.get("arch", "?") if isinstance(config, dict) else "?"
layout = config.get("layout", "?") if isinstance(config, dict) else "?"
dtype = config.get("dtype", "fp16") if isinstance(config, dict) else "fp16"
lines = [f"Grouped Conv Config: {variant} {ndim}D"]
lines.append(f" Arch: {arch}")
lines.append(f" Layout: {layout}")
lines.append(f" Dtype: {dtype}")
if tile_config:
wave = _extract_wave_config(tile_config)
warp = _extract_warp_tile_config(tile_config)
lines.append(
f" Tile: M={_first(tile_config.get('tile_m', 1))} N={_first(tile_config.get('tile_n', 128))} K={_first(tile_config.get('tile_k', 128))}"
)
lines.append(f" Wave: {wave[0]}x{wave[1]}x{wave[2]}")
lines.append(f" Warp: {warp[0]}x{warp[1]}x{warp[2]}")
if trait_config:
pipeline = _first(trait_config.get("pipeline", "?"))
epilogue = _first(trait_config.get("epilogue", "?"))
scheduler = _first(trait_config.get("scheduler", "?"))
lines.append(
f" Traits: pipeline={pipeline} epilogue={epilogue} scheduler={scheduler}"
)
return "\n".join(lines) if lines else "(empty config)"
def setup_multiple_grouped_conv_dispatchers(
configs: List[GroupedConvKernelConfig],
verbose: bool = True,
max_workers: Optional[int] = None,
) -> List[Optional[GroupedConvDispatcherLib]]:
"""
Setup multiple grouped-conv dispatchers in parallel.
This keeps architecture filtering strict:
1. Validate + auto-correct each requested config
2. Query codegen's arch-valid config set for each (arch, dtype, variant, ndim)
3. Map each request to nearest valid config
4. Parallel codegen + parallel compile
"""
if not configs:
return []
codegen_script = (
Path(__file__).parent.parent / "codegen" / "unified_grouped_conv_codegen.py"
)
arch_valid_cache: Dict[
Tuple[str, str, str, int], List[GroupedConvKernelConfig]
] = {}
selected_configs: List[Optional[GroupedConvKernelConfig]] = []
for i, original in enumerate(configs):
c = copy.deepcopy(original)
val = validate_grouped_conv_config(c.to_dict())
if not val.is_valid:
corrected, corrected_result = auto_correct_grouped_conv_config(c.to_dict())
if not corrected_result.is_valid:
if verbose:
print(f" FAIL [{i}] config remains invalid after auto-correct")
selected_configs.append(None)
continue
tile_cfg = corrected.get("tile_config", {})
trait_cfg = corrected.get("trait_config", {})
c.variant = _resolve_variant(
str(_first(corrected.get("variant", c.variant)))
)
c.ndim_spatial = int(_first(corrected.get("ndim_spatial", c.ndim_spatial)))
c.arch = str(corrected.get("arch", c.arch))
c.layout = str(corrected.get("layout", c.layout))
c.dtype = str(corrected.get("dtype", c.dtype))
c.tile_m = int(_first(tile_cfg.get("tile_m", c.tile_m)))
c.tile_n = int(_first(tile_cfg.get("tile_n", c.tile_n)))
c.tile_k = int(_first(tile_cfg.get("tile_k", c.tile_k)))
c.wave_m = int(_first(tile_cfg.get("wave_m", c.wave_m)))
c.wave_n = int(_first(tile_cfg.get("wave_n", c.wave_n)))
c.wave_k = int(_first(tile_cfg.get("wave_k", c.wave_k)))
c.warp_tile_m = int(_first(tile_cfg.get("warp_tile_m", c.warp_tile_m)))
c.warp_tile_n = int(_first(tile_cfg.get("warp_tile_n", c.warp_tile_n)))
c.warp_tile_k = int(_first(tile_cfg.get("warp_tile_k", c.warp_tile_k)))
c.pipeline = str(_first(trait_cfg.get("pipeline", c.pipeline)))
c.scheduler = str(_first(trait_cfg.get("scheduler", c.scheduler)))
c.epilogue = str(_first(trait_cfg.get("epilogue", c.epilogue)))
cache_key = (c.arch, c.dtype, c.variant, c.ndim_spatial)
if cache_key not in arch_valid_cache:
arch_valid_cache[cache_key] = _list_arch_valid_grouped_conv_configs(
codegen_script=codegen_script,
arch=c.arch,
dtype=c.dtype,
variant=c.variant,
ndim_spatial=c.ndim_spatial,
)
if verbose and not arch_valid_cache[cache_key]:
print(
f" FAIL [{i}] no arch-valid configs listed for "
f"{c.arch}/{c.dtype}/{c.variant}/{c.ndim_spatial}d"
)
candidates = arch_valid_cache[cache_key]
if not candidates:
selected_configs.append(None)
continue
selected = _select_best_arch_valid_conv_config(c, candidates)
if verbose and _config_key(selected) != _config_key(c):
print(
f" INFO [{i}] mapped to arch-valid config: "
f"{selected.tile_str} {selected.wave_str} {selected.warp_str} "
f"{selected.pipeline}/{selected.scheduler}/{selected.epilogue}"
)
selected_configs.append(selected)
unique_configs: List[GroupedConvKernelConfig] = []
unique_index_by_key: Dict[Tuple[Any, ...], int] = {}
input_to_unique: List[Optional[int]] = []
for cfg in selected_configs:
if cfg is None:
input_to_unique.append(None)
continue
key = _config_key(cfg)
if key not in unique_index_by_key:
unique_index_by_key[key] = len(unique_configs)
unique_configs.append(cfg)
input_to_unique.append(unique_index_by_key[key])
runner = GroupedConvCodegenRunner(max_workers=max_workers)
unique_lib_paths = runner.generate_and_compile_parallel(
unique_configs, verbose=verbose
)
libs: List[Optional[GroupedConvDispatcherLib]] = []
loaded_cache: Dict[int, Optional[GroupedConvDispatcherLib]] = {}
for input_idx, unique_idx in enumerate(input_to_unique):
if unique_idx is None:
libs.append(None)
continue
if unique_idx in loaded_cache:
libs.append(loaded_cache[unique_idx])
continue
path = (
unique_lib_paths[unique_idx] if unique_idx < len(unique_lib_paths) else None
)
disp: Optional[GroupedConvDispatcherLib] = None
if path and path.exists():
try:
lib = ctypes.CDLL(str(path))
disp = GroupedConvDispatcherLib(lib, path)
disp.initialize()
except Exception as e:
if verbose:
print(f" FAIL [{input_idx}] failed to load {path}: {e}")
loaded_cache[unique_idx] = disp
libs.append(disp)
return libs
def detect_gpu_arch() -> str:
"""Detect GPU architecture using rocminfo."""
try:
out = subprocess.check_output(
["rocminfo"], stderr=subprocess.DEVNULL, text=True
)
for line in out.split("\n"):
if "gfx" in line.lower() and "name:" in line.lower():
for part in line.split():
if part.startswith("gfx"):
return part
except Exception:
pass
return "gfx942"