mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1952 lines
68 KiB
Python
1952 lines
68 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Grouped Convolution Dispatcher Utilities
|
|
|
|
Typed Python API for grouped convolution kernels, matching the patterns from
|
|
the old conv_utils.py and the GEMM ctypes_utils.py.
|
|
|
|
Classes:
|
|
GroupedConvKernelConfig - Kernel configuration (tile, wave, pipeline, arch)
|
|
GroupedConvProblem - Runtime problem specification (N,C,K,H,W,etc.)
|
|
GroupedConvProblemC - ctypes struct matching C++ ConvProblemC
|
|
GroupedConvDispatcherLib - Wrapper for libdispatcher_conv_lib.so
|
|
GpuGroupedConvRunner - High-level GPU execution runner
|
|
GroupedConvResult - Result of GPU execution (output, time, tflops)
|
|
GroupedConvRegistry - Collection of kernel configs with JSON export
|
|
|
|
Usage:
|
|
from grouped_conv_utils import (
|
|
GroupedConvKernelConfig,
|
|
GroupedConvProblem,
|
|
GpuGroupedConvRunner,
|
|
)
|
|
|
|
config = GroupedConvKernelConfig(variant="forward", ndim_spatial=2)
|
|
problem = GroupedConvProblem(N=1, C=64, K=128, Hi=28, Wi=28, Y=3, X=3,
|
|
stride_h=1, pad_h=1, direction="forward")
|
|
runner = GpuGroupedConvRunner()
|
|
if runner.is_available():
|
|
result = runner.run(input_np, weight_np, problem)
|
|
print(f"Time: {result.time_ms:.4f} ms, TFLOPS: {result.tflops:.2f}")
|
|
"""
|
|
|
|
import ctypes
|
|
import json
|
|
import copy
|
|
import subprocess
|
|
import threading
|
|
import traceback
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from dispatcher_common import (
|
|
ValidationResultBase,
|
|
auto_correct_trait,
|
|
auto_correct_wave,
|
|
get_arch_filter_data,
|
|
validate_trait_combo,
|
|
validate_wave_config,
|
|
validate_warp_tile_config,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Constants
|
|
# =============================================================================
|
|
|
|
VALID_VARIANTS = ("forward", "bwd_data", "bwd_weight")
|
|
VALID_NDIM_SPATIAL = (1, 2, 3)
|
|
BACKWARD_VARIANTS = ("bwd_data", "bwd_weight")
|
|
BACKWARD_PIPELINES = ("compv3", "mem")
|
|
|
|
VARIANT_ALIASES = {
|
|
"2d_fwd": "forward",
|
|
"2d_bwdd": "bwd_data",
|
|
"2d_bwdw": "bwd_weight",
|
|
"fwd": "forward",
|
|
"bwdd": "bwd_data",
|
|
"bwdw": "bwd_weight",
|
|
}
|
|
|
|
DIRECTION_MAP = {"forward": 0, "bwd_data": 1, "bwd_weight": 2}
|
|
|
|
|
|
def _resolve_variant(v: str) -> str:
|
|
return VARIANT_ALIASES.get(v, v)
|
|
|
|
|
|
# =============================================================================
|
|
# GroupedConvDataType
|
|
# =============================================================================
|
|
|
|
|
|
class GroupedConvDataType(Enum):
|
|
FP16 = "fp16"
|
|
BF16 = "bf16"
|
|
FP32 = "fp32"
|
|
FP8 = "fp8"
|
|
BF8 = "bf8"
|
|
INT8 = "int8"
|
|
|
|
|
|
# =============================================================================
|
|
# GroupedConvKernelConfig
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class GroupedConvKernelConfig:
|
|
"""Complete kernel configuration for grouped convolution.
|
|
|
|
Captures all parameters needed to identify and run a specific kernel.
|
|
Mirrors the C++ GroupedConvSignature + GroupedConvAlgorithm.
|
|
"""
|
|
|
|
# What: signature
|
|
variant: str = "forward"
|
|
ndim_spatial: int = 2
|
|
dtype: str = "fp16"
|
|
layout: str = "nhwgc"
|
|
arch: str = "gfx942"
|
|
|
|
# How: algorithm - tile shape
|
|
tile_m: int = 1
|
|
tile_n: int = 128
|
|
tile_k: int = 128
|
|
|
|
# How: wave config
|
|
wave_m: int = 2
|
|
wave_n: int = 2
|
|
wave_k: int = 1
|
|
|
|
# How: warp tile
|
|
warp_tile_m: int = 32
|
|
warp_tile_n: int = 32
|
|
warp_tile_k: int = 16
|
|
|
|
# How: pipeline traits
|
|
pipeline: str = "compv4"
|
|
epilogue: str = "cshuffle"
|
|
scheduler: str = "intrawave"
|
|
|
|
# ConvConfigBase parity fields
|
|
vector_size_a: int = 4
|
|
vector_size_b: int = 8
|
|
vector_size_c: int = 8
|
|
block_per_cu: int = 1
|
|
num_wave_groups: int = 1
|
|
num_groups_to_merge: int = 1
|
|
|
|
# Padding (enables arbitrary problem sizes)
|
|
pad_m: bool = True
|
|
pad_n: bool = True
|
|
pad_k: bool = True
|
|
|
|
# Additional trait config options
|
|
double_smem_buffer: bool = False
|
|
split_image: bool = False
|
|
explicit_gemm: bool = False
|
|
two_stage: bool = False
|
|
|
|
def __post_init__(self):
|
|
self.variant = _resolve_variant(self.variant)
|
|
if (
|
|
self.variant in BACKWARD_VARIANTS
|
|
and self.pipeline not in BACKWARD_PIPELINES
|
|
):
|
|
self.pipeline = "compv3"
|
|
|
|
@property
|
|
def tile_str(self) -> str:
|
|
return f"{self.tile_m}x{self.tile_n}x{self.tile_k}"
|
|
|
|
@property
|
|
def wave_str(self) -> str:
|
|
return f"{self.wave_m}x{self.wave_n}x{self.wave_k}"
|
|
|
|
@property
|
|
def warp_str(self) -> str:
|
|
return f"{self.warp_tile_m}x{self.warp_tile_n}x{self.warp_tile_k}"
|
|
|
|
@property
|
|
def vec_str(self) -> str:
|
|
return f"{self.vector_size_a}x{self.vector_size_b}x{self.vector_size_c}"
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
parts = [
|
|
f"grouped_conv_{self.variant}_{self.dtype}_{self.ndim_spatial}d",
|
|
self.tile_str,
|
|
self.pipeline,
|
|
self.scheduler, # NEW: Include scheduler
|
|
]
|
|
if self.num_groups_to_merge != 1:
|
|
parts.append(f"gm{self.num_groups_to_merge}") # NEW: Group merge
|
|
if self.double_smem_buffer:
|
|
parts.append("dsb") # NEW: Double SMEM buffer
|
|
if self.split_image:
|
|
parts.append("si") # NEW: Split image
|
|
if self.two_stage:
|
|
parts.append("2stage") # NEW: Two-stage
|
|
return "_".join(parts)
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert to legacy dict format for codegen compatibility."""
|
|
return {
|
|
"tile_config": {
|
|
"tile_m": [self.tile_m],
|
|
"tile_n": [self.tile_n],
|
|
"tile_k": [self.tile_k],
|
|
"wave_m": [self.wave_m],
|
|
"wave_n": [self.wave_n],
|
|
"wave_k": [self.wave_k],
|
|
"warp_tile_m": [self.warp_tile_m],
|
|
"warp_tile_n": [self.warp_tile_n],
|
|
"warp_tile_k": [self.warp_tile_k],
|
|
},
|
|
"trait_config": {
|
|
"pipeline": [self.pipeline],
|
|
"epilogue": [self.epilogue],
|
|
"scheduler": [self.scheduler],
|
|
"pad_m": [self.pad_m],
|
|
"pad_n": [self.pad_n],
|
|
"pad_k": [self.pad_k],
|
|
"vector_size_a": [self.vector_size_a],
|
|
"vector_size_b": [self.vector_size_b],
|
|
"vector_size_c": [self.vector_size_c],
|
|
"block_per_cu": [self.block_per_cu],
|
|
"num_wave_groups": [self.num_wave_groups],
|
|
"num_groups_to_merge": [self.num_groups_to_merge],
|
|
"double_smem_buffer": [self.double_smem_buffer],
|
|
"split_image": [self.split_image],
|
|
"explicit_gemm": [self.explicit_gemm],
|
|
"two_stage": [self.two_stage],
|
|
},
|
|
"variant": self.variant,
|
|
"ndim_spatial": self.ndim_spatial,
|
|
"arch": self.arch,
|
|
"layout": self.layout,
|
|
"dtype": self.dtype,
|
|
}
|
|
|
|
def to_json_obj(self) -> dict:
|
|
"""Serializable dict for JSON export."""
|
|
return {
|
|
"name": self.name,
|
|
"signature": {
|
|
"variant": self.variant,
|
|
"dtype": self.dtype,
|
|
"ndim_spatial": self.ndim_spatial,
|
|
"layout": self.layout,
|
|
},
|
|
"algorithm": {
|
|
"tile_m": self.tile_m,
|
|
"tile_n": self.tile_n,
|
|
"tile_k": self.tile_k,
|
|
"wave": self.wave_str,
|
|
"warp": self.warp_str,
|
|
"pipeline": self.pipeline,
|
|
"epilogue": self.epilogue,
|
|
"scheduler": self.scheduler,
|
|
"vector_sizes": [
|
|
self.vector_size_a,
|
|
self.vector_size_b,
|
|
self.vector_size_c,
|
|
],
|
|
"block_per_cu": self.block_per_cu,
|
|
"num_wave_groups": self.num_wave_groups,
|
|
"num_groups_to_merge": self.num_groups_to_merge,
|
|
},
|
|
"arch": self.arch,
|
|
}
|
|
|
|
def print_config(self, indent: str = " "):
|
|
print(f"{indent}GroupedConvKernelConfig:")
|
|
print(f"{indent} Variant: {self.variant} {self.ndim_spatial}D")
|
|
print(f"{indent} Dtype: {self.dtype}")
|
|
print(f"{indent} Layout: {self.layout}")
|
|
print(f"{indent} Arch: {self.arch}")
|
|
print(f"{indent} Tile: {self.tile_str}")
|
|
print(f"{indent} Wave: {self.wave_str}")
|
|
print(f"{indent} Warp: {self.warp_str}")
|
|
print(f"{indent} Pipeline: {self.pipeline}/{self.scheduler}/{self.epilogue}")
|
|
print(f"{indent} VecSizes: {self.vec_str}")
|
|
print(
|
|
f"{indent} BlockCU: {self.block_per_cu} WaveGroups: {self.num_wave_groups} MergeGroups: {self.num_groups_to_merge}"
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# GroupedConvProblem
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class GroupedConvProblem:
|
|
"""Runtime convolution problem specification.
|
|
|
|
Describes the actual sizes of a convolution to be computed.
|
|
Matches the old ConvProblem from conv_utils.py.
|
|
"""
|
|
|
|
N: int = 1
|
|
C: int = 64
|
|
K: int = 128
|
|
G: int = 1
|
|
|
|
Hi: int = 28
|
|
Wi: int = 28
|
|
Di: int = 1
|
|
|
|
Y: int = 3
|
|
X: int = 3
|
|
Z: int = 1
|
|
|
|
stride_h: int = 1
|
|
stride_w: int = 1
|
|
stride_d: int = 1
|
|
|
|
pad_h: int = 0
|
|
pad_w: int = 0
|
|
pad_d: int = 0
|
|
|
|
dilation_h: int = 1
|
|
dilation_w: int = 1
|
|
dilation_d: int = 1
|
|
|
|
direction: str = "forward"
|
|
split_k: int = 1
|
|
|
|
def __post_init__(self):
|
|
"""Validate grouped convolution constraints."""
|
|
if self.C % self.G != 0:
|
|
raise ValueError(
|
|
f"C must be divisible by G for grouped convolution: C={self.C}, G={self.G}"
|
|
)
|
|
if self.K % self.G != 0:
|
|
raise ValueError(
|
|
f"K must be divisible by G for grouped convolution: K={self.K}, G={self.G}"
|
|
)
|
|
|
|
@property
|
|
def Ho(self) -> int:
|
|
eff_y = (self.Y - 1) * self.dilation_h + 1
|
|
return (self.Hi + 2 * self.pad_h - eff_y) // self.stride_h + 1
|
|
|
|
@property
|
|
def Wo(self) -> int:
|
|
eff_x = (self.X - 1) * self.dilation_w + 1
|
|
return (self.Wi + 2 * self.pad_w - eff_x) // self.stride_w + 1
|
|
|
|
@property
|
|
def Do(self) -> int:
|
|
eff_z = (self.Z - 1) * self.dilation_d + 1
|
|
return (self.Di + 2 * self.pad_d - eff_z) // self.stride_d + 1
|
|
|
|
@property
|
|
def is_3d(self) -> bool:
|
|
return self.Di > 1 or self.Z > 1 or self.pad_d > 0
|
|
|
|
@property
|
|
def ndim_spatial(self) -> int:
|
|
return 3 if self.is_3d else 2
|
|
|
|
@property
|
|
def flops(self) -> float:
|
|
"""Total FLOPs for this convolution (any direction, same count).
|
|
|
|
Uses float division C/G to match canonical formula (validated C % G == 0 in __post_init__).
|
|
"""
|
|
c_per_group = self.C / self.G # Float division (validated C % G == 0)
|
|
if self.is_3d:
|
|
return (
|
|
2.0
|
|
* self.N
|
|
* self.K
|
|
* self.Do
|
|
* self.Ho
|
|
* self.Wo
|
|
* c_per_group
|
|
* self.Z
|
|
* self.Y
|
|
* self.X
|
|
)
|
|
return 2.0 * self.N * self.K * self.Ho * self.Wo * c_per_group * self.Y * self.X
|
|
|
|
@property
|
|
def gflops(self) -> float:
|
|
return self.flops / 1e9
|
|
|
|
def input_shape(self) -> tuple:
|
|
"""NHWGC or NDHWGC layout."""
|
|
c_per_g = self.C // self.G
|
|
if self.is_3d:
|
|
return (self.N, self.Di, self.Hi, self.Wi, self.G, c_per_g)
|
|
return (self.N, self.Hi, self.Wi, self.G, c_per_g)
|
|
|
|
def weight_shape(self) -> tuple:
|
|
"""GKYXC or GKZYXC layout."""
|
|
c_per_g = self.C // self.G
|
|
k_per_g = self.K // self.G
|
|
if self.is_3d:
|
|
return (self.G, k_per_g, self.Z, self.Y, self.X, c_per_g)
|
|
return (self.G, k_per_g, self.Y, self.X, c_per_g)
|
|
|
|
def output_shape(self) -> tuple:
|
|
"""NHWGK or NDHWGK layout."""
|
|
k_per_g = self.K // self.G
|
|
if self.is_3d:
|
|
return (self.N, self.Do, self.Ho, self.Wo, self.G, k_per_g)
|
|
return (self.N, self.Ho, self.Wo, self.G, k_per_g)
|
|
|
|
def print_problem(self, indent: str = " "):
|
|
dim_str = "3D" if self.is_3d else "2D"
|
|
print(f"{indent}GroupedConvProblem ({dim_str} {self.direction}):")
|
|
print(f"{indent} Batch: N={self.N}, G={self.G}")
|
|
print(f"{indent} Channels: C={self.C}, K={self.K}")
|
|
if self.is_3d:
|
|
print(f"{indent} Input: Di={self.Di}, Hi={self.Hi}, Wi={self.Wi}")
|
|
print(f"{indent} Filter: Z={self.Z}, Y={self.Y}, X={self.X}")
|
|
print(f"{indent} Output: Do={self.Do}, Ho={self.Ho}, Wo={self.Wo}")
|
|
else:
|
|
print(f"{indent} Input: Hi={self.Hi}, Wi={self.Wi}")
|
|
print(f"{indent} Filter: Y={self.Y}, X={self.X}")
|
|
print(f"{indent} Output: Ho={self.Ho}, Wo={self.Wo}")
|
|
print(f"{indent} GFLOPs: {self.gflops:.2f}")
|
|
|
|
|
|
# =============================================================================
|
|
# GroupedConvProblemC (ctypes struct matching C++)
|
|
# =============================================================================
|
|
|
|
|
|
class GroupedConvProblemC(ctypes.Structure):
|
|
"""C structure matching ConvProblemC in conv_ctypes_lib.cpp."""
|
|
|
|
_fields_ = [
|
|
("N", ctypes.c_int),
|
|
("G", ctypes.c_int),
|
|
("C", ctypes.c_int),
|
|
("K", ctypes.c_int),
|
|
("input_d", ctypes.c_int),
|
|
("input_h", ctypes.c_int),
|
|
("input_w", ctypes.c_int),
|
|
("filter_z", ctypes.c_int),
|
|
("filter_y", ctypes.c_int),
|
|
("filter_x", ctypes.c_int),
|
|
("stride_d", ctypes.c_int),
|
|
("stride_h", ctypes.c_int),
|
|
("stride_w", ctypes.c_int),
|
|
("pad_d", ctypes.c_int),
|
|
("pad_h", ctypes.c_int),
|
|
("pad_w", ctypes.c_int),
|
|
("dilation_d", ctypes.c_int),
|
|
("dilation_h", ctypes.c_int),
|
|
("dilation_w", ctypes.c_int),
|
|
("direction", ctypes.c_int),
|
|
("split_k", ctypes.c_int),
|
|
]
|
|
|
|
@classmethod
|
|
def from_problem(cls, p: GroupedConvProblem) -> "GroupedConvProblemC":
|
|
c = cls()
|
|
c.N, c.G, c.C, c.K = p.N, p.G, p.C, p.K
|
|
c.input_d, c.input_h, c.input_w = p.Di, p.Hi, p.Wi
|
|
c.filter_z, c.filter_y, c.filter_x = p.Z, p.Y, p.X
|
|
c.stride_d, c.stride_h, c.stride_w = p.stride_d, p.stride_h, p.stride_w
|
|
c.pad_d, c.pad_h, c.pad_w = p.pad_d, p.pad_h, p.pad_w
|
|
c.dilation_d, c.dilation_h, c.dilation_w = (
|
|
p.dilation_d,
|
|
p.dilation_h,
|
|
p.dilation_w,
|
|
)
|
|
c.direction = DIRECTION_MAP.get(p.direction, 0)
|
|
c.split_k = getattr(p, "split_k", 1)
|
|
return c
|
|
|
|
|
|
# =============================================================================
|
|
# GroupedConvResult
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class GroupedConvResult:
|
|
"""Result of GPU convolution execution."""
|
|
|
|
success: bool = False
|
|
time_ms: float = 0.0
|
|
tflops: float = 0.0
|
|
output: Optional[np.ndarray] = None
|
|
error: str = ""
|
|
|
|
|
|
# =============================================================================
|
|
# GroupedConvDispatcherLib
|
|
# =============================================================================
|
|
|
|
|
|
class GroupedConvDispatcherLib:
|
|
"""Wrapper for the compiled convolution dispatcher library.
|
|
|
|
Provides Python interface to the C API in conv_ctypes_lib.cpp.
|
|
"""
|
|
|
|
SEARCH_PATHS = [
|
|
"build/examples/libdispatcher_conv_lib.so",
|
|
"build/bindings/libdispatcher_conv_lib.so",
|
|
"build/lib/libdispatcher_conv_lib.so",
|
|
]
|
|
|
|
def __init__(self, lib: ctypes.CDLL, path: Path):
|
|
self._lib = lib
|
|
self._path = path
|
|
self._setup_functions()
|
|
|
|
def _setup_functions(self):
|
|
self._lib.conv_dispatcher_init.argtypes = []
|
|
self._lib.conv_dispatcher_init.restype = ctypes.c_int
|
|
self._lib.conv_dispatcher_cleanup.argtypes = []
|
|
self._lib.conv_dispatcher_cleanup.restype = ctypes.c_int
|
|
self._lib.conv_dispatcher_version.argtypes = []
|
|
self._lib.conv_dispatcher_version.restype = ctypes.c_char_p
|
|
self._lib.conv_dispatcher_has_kernels.argtypes = []
|
|
self._lib.conv_dispatcher_has_kernels.restype = ctypes.c_int
|
|
self._lib.conv_dispatcher_has_bwd_data.argtypes = []
|
|
self._lib.conv_dispatcher_has_bwd_data.restype = ctypes.c_int
|
|
self._lib.conv_dispatcher_has_bwd_weight.argtypes = []
|
|
self._lib.conv_dispatcher_has_bwd_weight.restype = ctypes.c_int
|
|
self._lib.conv_dispatcher_get_kernel_count.argtypes = []
|
|
self._lib.conv_dispatcher_get_kernel_count.restype = ctypes.c_int
|
|
self._lib.conv_dispatcher_get_kernel_name.argtypes = [
|
|
ctypes.c_int,
|
|
ctypes.c_char_p,
|
|
ctypes.c_int,
|
|
]
|
|
self._lib.conv_dispatcher_get_kernel_name.restype = ctypes.c_int
|
|
self._lib.conv_dispatcher_is_supported.argtypes = [
|
|
ctypes.POINTER(GroupedConvProblemC),
|
|
]
|
|
self._lib.conv_dispatcher_is_supported.restype = ctypes.c_int
|
|
self._lib.conv_dispatcher_run.argtypes = [
|
|
ctypes.c_void_p,
|
|
ctypes.c_void_p,
|
|
ctypes.c_void_p,
|
|
ctypes.POINTER(GroupedConvProblemC),
|
|
ctypes.c_void_p,
|
|
]
|
|
self._lib.conv_dispatcher_run.restype = ctypes.c_float
|
|
|
|
@classmethod
|
|
def find(cls) -> Optional["GroupedConvDispatcherLib"]:
|
|
"""Search standard paths for the conv library."""
|
|
root = Path(__file__).parent.parent
|
|
for rel in cls.SEARCH_PATHS:
|
|
path = root / rel
|
|
if path.exists():
|
|
try:
|
|
lib = ctypes.CDLL(str(path))
|
|
return cls(lib, path)
|
|
except OSError:
|
|
continue
|
|
return None
|
|
|
|
@property
|
|
def path(self) -> Path:
|
|
return self._path
|
|
|
|
def initialize(self):
|
|
self._lib.conv_dispatcher_init()
|
|
|
|
def cleanup(self):
|
|
self._lib.conv_dispatcher_cleanup()
|
|
|
|
def version(self) -> str:
|
|
return self._lib.conv_dispatcher_version().decode()
|
|
|
|
def has_forward(self) -> bool:
|
|
return self._lib.conv_dispatcher_has_kernels() != 0
|
|
|
|
def has_bwd_data(self) -> bool:
|
|
return self._lib.conv_dispatcher_has_bwd_data() != 0
|
|
|
|
def has_bwd_weight(self) -> bool:
|
|
return self._lib.conv_dispatcher_has_bwd_weight() != 0
|
|
|
|
def kernel_count(self) -> int:
|
|
return self._lib.conv_dispatcher_get_kernel_count()
|
|
|
|
def kernel_names(self) -> List[str]:
|
|
names = []
|
|
for i in range(self.kernel_count()):
|
|
buf = ctypes.create_string_buffer(256)
|
|
if self._lib.conv_dispatcher_get_kernel_name(i, buf, 256) == 0:
|
|
names.append(buf.value.decode())
|
|
return names
|
|
|
|
def is_supported(self, problem: GroupedConvProblem) -> bool:
|
|
pc = GroupedConvProblemC.from_problem(problem)
|
|
return self._lib.conv_dispatcher_is_supported(ctypes.byref(pc)) != 0
|
|
|
|
def run(
|
|
self, a_ptr: int, b_ptr: int, c_ptr: int, problem: GroupedConvProblem
|
|
) -> float:
|
|
"""Run convolution. Returns time_ms (>0 success, <0 error)."""
|
|
pc = GroupedConvProblemC.from_problem(problem)
|
|
return self._lib.conv_dispatcher_run(
|
|
a_ptr, b_ptr, c_ptr, ctypes.byref(pc), None
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# GpuGroupedConvRunner
|
|
# =============================================================================
|
|
|
|
|
|
class GpuGroupedConvRunner:
|
|
"""High-level GPU convolution runner.
|
|
|
|
Handles library loading, HIP memory management, and kernel execution.
|
|
Follows the same pattern as the old GpuConvRunner from conv_utils.py.
|
|
|
|
Usage:
|
|
runner = GpuGroupedConvRunner()
|
|
if runner.is_available():
|
|
result = runner.run(input_np, weight_np, problem)
|
|
print(f"Time: {result.time_ms:.4f} ms, TFLOPS: {result.tflops:.2f}")
|
|
"""
|
|
|
|
HIP_MEMCPY_H2D = 1
|
|
HIP_MEMCPY_D2H = 2
|
|
|
|
def __init__(self, lib_path: Optional[str] = None):
|
|
"""Initialize runner WITHOUT loading GPU libraries.
|
|
|
|
GPU context is created lazily on first run() call, avoiding fork() issues
|
|
during parallel compilation. This mirrors FMHA design.
|
|
|
|
Args:
|
|
lib_path: Path to dispatcher .so file (or None to auto-detect)
|
|
"""
|
|
self._lib_path = lib_path
|
|
self._dispatch_lib: Optional[GroupedConvDispatcherLib] = None
|
|
self._hip = None
|
|
self._initialized = False
|
|
self._init_error = None
|
|
self._init_traceback = None
|
|
|
|
def _ensure_initialized(self):
|
|
"""Lazy initialization - only load GPU libraries when actually needed."""
|
|
if self._initialized:
|
|
return
|
|
|
|
try:
|
|
# Load dispatcher library
|
|
if self._lib_path:
|
|
lib = ctypes.CDLL(self._lib_path)
|
|
self._dispatch_lib = GroupedConvDispatcherLib(lib, Path(self._lib_path))
|
|
else:
|
|
self._dispatch_lib = GroupedConvDispatcherLib.find()
|
|
|
|
if self._dispatch_lib is None:
|
|
return
|
|
|
|
# Load HIP library - THIS creates GPU context
|
|
self._hip = ctypes.CDLL("libamdhip64.so")
|
|
self._hip.hipMalloc.argtypes = [
|
|
ctypes.POINTER(ctypes.c_void_p),
|
|
ctypes.c_size_t,
|
|
]
|
|
self._hip.hipMalloc.restype = ctypes.c_int
|
|
self._hip.hipFree.argtypes = [ctypes.c_void_p]
|
|
self._hip.hipFree.restype = ctypes.c_int
|
|
self._hip.hipMemcpy.argtypes = [
|
|
ctypes.c_void_p,
|
|
ctypes.c_void_p,
|
|
ctypes.c_size_t,
|
|
ctypes.c_int,
|
|
]
|
|
self._hip.hipMemcpy.restype = ctypes.c_int
|
|
self._hip.hipDeviceSynchronize.argtypes = []
|
|
self._hip.hipDeviceSynchronize.restype = ctypes.c_int
|
|
|
|
# Initialize dispatcher
|
|
self._dispatch_lib.initialize()
|
|
self._initialized = True
|
|
except Exception as e:
|
|
self._initialized = False
|
|
self._init_error = str(e)
|
|
self._init_traceback = traceback.format_exc()
|
|
|
|
def is_available(self) -> bool:
|
|
return self._initialized and self._dispatch_lib is not None
|
|
|
|
def get_init_error(self) -> Optional[str]:
|
|
"""Get initialization error message if initialization failed."""
|
|
return self._init_error
|
|
|
|
def get_init_traceback(self) -> Optional[str]:
|
|
"""Get full initialization traceback for debugging."""
|
|
return self._init_traceback
|
|
|
|
@property
|
|
def library_path(self) -> Optional[str]:
|
|
if self._dispatch_lib:
|
|
return str(self._dispatch_lib.path)
|
|
return None
|
|
|
|
@property
|
|
def lib(self) -> Optional[GroupedConvDispatcherLib]:
|
|
return self._dispatch_lib
|
|
|
|
def run(
|
|
self,
|
|
input_np: np.ndarray,
|
|
weight_np: np.ndarray,
|
|
problem: GroupedConvProblem,
|
|
output_np: Optional[np.ndarray] = None,
|
|
verbose: bool = False,
|
|
) -> GroupedConvResult:
|
|
"""Run convolution on GPU.
|
|
|
|
Args:
|
|
input_np: For forward: X (NHWGC). For bwd_data: dY. For bwd_weight: X.
|
|
weight_np: For forward: W (GKYXC). For bwd_data: W. For bwd_weight: dY.
|
|
problem: Problem specification.
|
|
output_np: Optional pre-allocated output buffer.
|
|
verbose: If True, print full traceback on initialization failure.
|
|
|
|
Returns:
|
|
GroupedConvResult with success, time_ms, tflops, output.
|
|
"""
|
|
# Lazy initialization - load GPU libraries only on first run
|
|
self._ensure_initialized()
|
|
|
|
if not self.is_available():
|
|
# Surface the actual initialization error for diagnosability
|
|
if self._init_error:
|
|
error_msg = f"GPU initialization failed: {self._init_error}"
|
|
if verbose and self._init_traceback:
|
|
print("=" * 80)
|
|
print("GPU Initialization Traceback:")
|
|
print("=" * 80)
|
|
print(self._init_traceback)
|
|
print("=" * 80)
|
|
else:
|
|
error_msg = "GPU not available"
|
|
return GroupedConvResult(error=error_msg)
|
|
|
|
try:
|
|
# Determine output shape based on direction
|
|
d = problem.direction
|
|
if d == "bwd_data":
|
|
out_shape = problem.input_shape()
|
|
elif d == "bwd_weight":
|
|
out_shape = problem.weight_shape()
|
|
else:
|
|
out_shape = problem.output_shape()
|
|
|
|
if output_np is None:
|
|
output_np = np.zeros(out_shape, dtype=input_np.dtype)
|
|
|
|
output_size = output_np.nbytes
|
|
|
|
# Allocate GPU memory with error checking
|
|
d_a = ctypes.c_void_p()
|
|
d_b = ctypes.c_void_p()
|
|
d_c = ctypes.c_void_p()
|
|
allocated_ptrs = [] # Track successfully allocated pointers
|
|
|
|
try:
|
|
# Allocate input
|
|
ret = self._hip.hipMalloc(ctypes.byref(d_a), input_np.nbytes)
|
|
if ret != 0:
|
|
raise RuntimeError(
|
|
f"hipMalloc failed for input (code {ret}, size {input_np.nbytes})"
|
|
)
|
|
allocated_ptrs.append(d_a)
|
|
|
|
# Allocate weight
|
|
ret = self._hip.hipMalloc(ctypes.byref(d_b), weight_np.nbytes)
|
|
if ret != 0:
|
|
raise RuntimeError(
|
|
f"hipMalloc failed for weight (code {ret}, size {weight_np.nbytes})"
|
|
)
|
|
allocated_ptrs.append(d_b)
|
|
|
|
# Allocate output
|
|
ret = self._hip.hipMalloc(ctypes.byref(d_c), output_size)
|
|
if ret != 0:
|
|
raise RuntimeError(
|
|
f"hipMalloc failed for output (code {ret}, size {output_size})"
|
|
)
|
|
allocated_ptrs.append(d_c)
|
|
|
|
# Host to device
|
|
ret = self._hip.hipMemcpy(
|
|
d_a, input_np.ctypes.data, input_np.nbytes, self.HIP_MEMCPY_H2D
|
|
)
|
|
if ret != 0:
|
|
raise RuntimeError(f"hipMemcpy H2D failed for input (code {ret})")
|
|
|
|
ret = self._hip.hipMemcpy(
|
|
d_b, weight_np.ctypes.data, weight_np.nbytes, self.HIP_MEMCPY_H2D
|
|
)
|
|
if ret != 0:
|
|
raise RuntimeError(f"hipMemcpy H2D failed for weight (code {ret})")
|
|
|
|
self._hip.hipDeviceSynchronize()
|
|
|
|
# Launch kernel
|
|
time_ms = self._dispatch_lib.run(
|
|
d_a.value, d_b.value, d_c.value, problem
|
|
)
|
|
self._hip.hipDeviceSynchronize()
|
|
|
|
result = GroupedConvResult()
|
|
|
|
if time_ms > 0:
|
|
# Device to host
|
|
ret = self._hip.hipMemcpy(
|
|
output_np.ctypes.data, d_c, output_size, self.HIP_MEMCPY_D2H
|
|
)
|
|
if ret != 0:
|
|
raise RuntimeError(
|
|
f"hipMemcpy D2H failed for output (code {ret})"
|
|
)
|
|
|
|
self._hip.hipDeviceSynchronize()
|
|
result.success = True
|
|
result.time_ms = time_ms
|
|
result.tflops = problem.flops / (time_ms * 1e9)
|
|
result.output = output_np
|
|
else:
|
|
result.error = (
|
|
"unsupported"
|
|
if time_ms == -3.0
|
|
else "no kernel"
|
|
if time_ms == -2.0
|
|
else f"error (code {time_ms})"
|
|
)
|
|
|
|
return result
|
|
|
|
finally:
|
|
# CRITICAL: Only free successfully allocated pointers
|
|
for ptr in allocated_ptrs:
|
|
if ptr.value: # Only free non-null pointers
|
|
self._hip.hipFree(ptr)
|
|
|
|
except Exception as e:
|
|
return GroupedConvResult(error=str(e))
|
|
|
|
def cleanup(self):
|
|
if self._dispatch_lib:
|
|
try:
|
|
self._dispatch_lib.cleanup()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
# =============================================================================
|
|
# GroupedConvRegistry
|
|
# =============================================================================
|
|
|
|
|
|
class GroupedConvRegistry:
|
|
"""Collection of grouped conv kernel configs with JSON export/import."""
|
|
|
|
def __init__(self, name: str = "default"):
|
|
self.name = name
|
|
self._kernels: List[GroupedConvKernelConfig] = []
|
|
|
|
def add(self, config: GroupedConvKernelConfig):
|
|
self._kernels.append(config)
|
|
|
|
@property
|
|
def kernels(self) -> List[GroupedConvKernelConfig]:
|
|
return list(self._kernels)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._kernels)
|
|
|
|
def select(
|
|
self, problem: "GroupedConvProblem", heuristic=None
|
|
) -> Optional[GroupedConvKernelConfig]:
|
|
"""Select the best kernel for a problem.
|
|
|
|
Args:
|
|
problem: The convolution problem.
|
|
heuristic: Optional callable(problem) -> List[str] returning
|
|
ranked kernel name substrings. The registry tries
|
|
each in order; falls back to first matching kernel.
|
|
|
|
Returns:
|
|
The best matching GroupedConvKernelConfig, or None.
|
|
"""
|
|
matching = [k for k in self._kernels if k.variant == problem.direction]
|
|
if not matching:
|
|
return None
|
|
|
|
if heuristic is not None:
|
|
ranked = heuristic(problem)
|
|
for hint in ranked:
|
|
for k in matching:
|
|
if hint in k.name:
|
|
return k
|
|
|
|
return matching[0] if matching else None
|
|
|
|
def filter_by_variant(self, variant: str) -> "GroupedConvRegistry":
|
|
variant = _resolve_variant(variant)
|
|
reg = GroupedConvRegistry(f"{self.name}_{variant}")
|
|
for k in self._kernels:
|
|
if k.variant == variant:
|
|
reg.add(k)
|
|
return reg
|
|
|
|
def filter_by_arch(self, arch: str) -> "GroupedConvRegistry":
|
|
reg = GroupedConvRegistry(f"{self.name}_{arch}")
|
|
for k in self._kernels:
|
|
if k.arch == arch:
|
|
reg.add(k)
|
|
return reg
|
|
|
|
def to_json(self, indent: int = 2) -> str:
|
|
return json.dumps(
|
|
{
|
|
"name": self.name,
|
|
"kernels": [k.to_json_obj() for k in self._kernels],
|
|
},
|
|
indent=indent,
|
|
)
|
|
|
|
@classmethod
|
|
def from_json(cls, json_str: str) -> "GroupedConvRegistry":
|
|
data = json.loads(json_str)
|
|
reg = cls(data.get("name", "imported"))
|
|
for kd in data.get("kernels", []):
|
|
sig = kd.get("signature", {})
|
|
algo = kd.get("algorithm", {})
|
|
wave = algo.get("wave", "2x2x1").split("x")
|
|
warp = algo.get("warp", "32x32x16").split("x")
|
|
vec = algo.get("vector_sizes", [4, 8, 8])
|
|
reg.add(
|
|
GroupedConvKernelConfig(
|
|
variant=sig.get("variant", "forward"),
|
|
ndim_spatial=sig.get("ndim_spatial", 2),
|
|
dtype=sig.get("dtype", "fp16"),
|
|
layout=sig.get("layout", "nhwgc"),
|
|
arch=kd.get("arch", "gfx942"),
|
|
tile_m=algo.get("tile_m", 1),
|
|
tile_n=algo.get("tile_n", 128),
|
|
tile_k=algo.get("tile_k", 128),
|
|
wave_m=int(wave[0]),
|
|
wave_n=int(wave[1]),
|
|
wave_k=int(wave[2]),
|
|
warp_tile_m=int(warp[0]),
|
|
warp_tile_n=int(warp[1]),
|
|
warp_tile_k=int(warp[2]),
|
|
pipeline=algo.get("pipeline", "compv3"),
|
|
epilogue=algo.get("epilogue", "cshuffle"),
|
|
scheduler=algo.get("scheduler", "intrawave"),
|
|
vector_size_a=vec[0] if len(vec) > 0 else 4,
|
|
vector_size_b=vec[1] if len(vec) > 1 else 8,
|
|
vector_size_c=vec[2] if len(vec) > 2 else 8,
|
|
block_per_cu=algo.get("block_per_cu", 1),
|
|
num_wave_groups=algo.get("num_wave_groups", 1),
|
|
num_groups_to_merge=algo.get("num_groups_to_merge", 1),
|
|
)
|
|
)
|
|
return reg
|
|
|
|
def build(
|
|
self,
|
|
verbose: bool = False,
|
|
max_workers: Optional[int] = None,
|
|
) -> Dict[Tuple[str, int], "GpuGroupedConvRunner"]:
|
|
"""Parallel JIT compile all kernels in this registry.
|
|
|
|
Args:
|
|
verbose: Print progress during build.
|
|
max_workers: Max parallel codegen/compile processes (default: cpu_count capped at 8).
|
|
|
|
Returns a dict mapping (variant, ndim_spatial) to a ready-to-use
|
|
GpuGroupedConvRunner.
|
|
"""
|
|
if not self._kernels:
|
|
return {}
|
|
|
|
libs = setup_multiple_grouped_conv_dispatchers(
|
|
self._kernels,
|
|
verbose=verbose,
|
|
max_workers=max_workers,
|
|
)
|
|
|
|
runners: Dict[Tuple[str, int], GpuGroupedConvRunner] = {}
|
|
for cfg, lib in zip(self._kernels, libs):
|
|
if lib is None:
|
|
continue
|
|
key = (cfg.variant, cfg.ndim_spatial)
|
|
if key in runners:
|
|
continue
|
|
runner = GpuGroupedConvRunner(lib_path=str(lib))
|
|
runner._ensure_initialized()
|
|
if runner.is_available():
|
|
runners[key] = runner
|
|
return runners
|
|
|
|
def print_registry(self, indent: str = " "):
|
|
print(f"{indent}Registry '{self.name}': {len(self)} kernels")
|
|
for i, k in enumerate(self._kernels):
|
|
print(
|
|
f"{indent} [{i}] {k.name} (valid={validate_grouped_conv_config(k.to_dict()).is_valid})"
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# GroupedConvValidationResult
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class GroupedConvValidationResult(ValidationResultBase):
|
|
"""Result of grouped conv kernel config validation."""
|
|
|
|
variant: str = "forward"
|
|
|
|
def __init__(
|
|
self,
|
|
is_valid=True,
|
|
errors=None,
|
|
warnings=None,
|
|
suggested_fixes=None,
|
|
variant="forward",
|
|
):
|
|
super().__init__(
|
|
is_valid=is_valid,
|
|
errors=errors or [],
|
|
warnings=warnings or [],
|
|
suggested_fixes=suggested_fixes or {},
|
|
)
|
|
self.variant = variant
|
|
|
|
|
|
# =============================================================================
|
|
# Validation helpers (extracted from the original config extraction code)
|
|
# =============================================================================
|
|
|
|
|
|
def _first(val):
|
|
if isinstance(val, list) and len(val) > 0:
|
|
return val[0]
|
|
return val
|
|
|
|
|
|
def _get_tile_config(config: dict) -> dict:
|
|
return config.get("tile_config") or {}
|
|
|
|
|
|
def _get_trait_config(config: dict) -> dict:
|
|
return config.get("trait_config") or {}
|
|
|
|
|
|
def _extract_wave_config(tile_config: dict) -> List[int]:
|
|
wm = tile_config.get("wave_m") or tile_config.get("warp_m")
|
|
wn = tile_config.get("wave_n") or tile_config.get("warp_n")
|
|
wk = tile_config.get("wave_k") or tile_config.get("warp_k")
|
|
if wm is not None and wn is not None and wk is not None:
|
|
return [_first(wm), _first(wn), _first(wk)]
|
|
return [2, 2, 1]
|
|
|
|
|
|
def _extract_warp_tile_config(tile_config: dict) -> List[int]:
|
|
wtm = tile_config.get("warp_tile_m") or tile_config.get("warp_m")
|
|
wtn = tile_config.get("warp_tile_n") or tile_config.get("warp_n")
|
|
wtk = tile_config.get("warp_tile_k") or tile_config.get("warp_k")
|
|
if wtm is not None and wtn is not None and wtk is not None:
|
|
return [_first(wtm), _first(wtn), _first(wtk)]
|
|
return [32, 32, 16]
|
|
|
|
|
|
def _extract_trait_values(trait_config: dict) -> Tuple[str, str, str]:
|
|
p = _first(trait_config.get("pipeline", "compv4"))
|
|
e = _first(trait_config.get("epilogue", "cshuffle"))
|
|
s = _first(trait_config.get("scheduler", "intrawave"))
|
|
if isinstance(p, list):
|
|
p = p[0] if p else "compv4"
|
|
if isinstance(e, list):
|
|
e = e[0] if e else "cshuffle"
|
|
if isinstance(s, list):
|
|
s = s[0] if s else "intrawave"
|
|
return (str(p), str(e), str(s))
|
|
|
|
|
|
# =============================================================================
|
|
# validate_grouped_conv_config / auto_correct_grouped_conv_config
|
|
# =============================================================================
|
|
|
|
|
|
def validate_grouped_conv_config(config: dict) -> GroupedConvValidationResult:
|
|
"""Validate a grouped conv kernel config dict.
|
|
|
|
Accepts either a raw dict (legacy) or GroupedConvKernelConfig.to_dict() output.
|
|
"""
|
|
errors: List[str] = []
|
|
warnings: List[str] = []
|
|
suggested_fixes: Dict[str, Any] = {}
|
|
|
|
required = (
|
|
"tile_config",
|
|
"trait_config",
|
|
"variant",
|
|
"ndim_spatial",
|
|
"arch",
|
|
"layout",
|
|
)
|
|
for key in required:
|
|
if key not in config:
|
|
errors.append(f"Missing required key: {key}")
|
|
if errors:
|
|
return GroupedConvValidationResult(
|
|
is_valid=False,
|
|
errors=errors,
|
|
warnings=warnings,
|
|
suggested_fixes=suggested_fixes,
|
|
variant=config.get("variant", "forward"),
|
|
)
|
|
|
|
tile_config = _get_tile_config(config)
|
|
trait_config = _get_trait_config(config)
|
|
variant = _first(config.get("variant", "forward"))
|
|
if isinstance(variant, list):
|
|
variant = variant[0] if variant else "forward"
|
|
variant = _resolve_variant(str(variant))
|
|
|
|
ndim_spatial = config.get("ndim_spatial")
|
|
arch = config.get("arch", "gfx942")
|
|
dtype = config.get("dtype", "fp16")
|
|
|
|
if variant not in VALID_VARIANTS:
|
|
errors.append(f"Invalid variant: {variant}. Valid: {', '.join(VALID_VARIANTS)}")
|
|
suggested_fixes["variant"] = "forward"
|
|
|
|
if ndim_spatial is not None:
|
|
ndim = ndim_spatial
|
|
if isinstance(ndim, list):
|
|
ndim = ndim[0] if ndim else 2
|
|
if ndim not in VALID_NDIM_SPATIAL:
|
|
errors.append(
|
|
f"Invalid ndim_spatial: {ndim}. Valid: {', '.join(map(str, VALID_NDIM_SPATIAL))}"
|
|
)
|
|
suggested_fixes["ndim_spatial"] = 2
|
|
|
|
pipeline, epilogue, scheduler = _extract_trait_values(trait_config)
|
|
if variant in BACKWARD_VARIANTS and pipeline not in BACKWARD_PIPELINES:
|
|
errors.append(
|
|
f"Backward variant '{variant}' requires pipeline compv3 or mem, got {pipeline}"
|
|
)
|
|
suggested_fixes["pipeline"] = "compv3"
|
|
|
|
ok, msg = validate_trait_combo(pipeline, epilogue, scheduler)
|
|
if not ok:
|
|
errors.append(msg)
|
|
suggested_fixes["scheduler"] = "intrawave"
|
|
|
|
wave_cfg = _extract_wave_config(tile_config)
|
|
ok, msg = validate_wave_config(wave_cfg, arch)
|
|
if not ok:
|
|
errors.append(msg)
|
|
arch_data = get_arch_filter_data()
|
|
valid_waves = arch_data["warp_combos"].get(arch, [[2, 2, 1]])
|
|
if valid_waves:
|
|
suggested_fixes["wave_m"] = valid_waves[0][0]
|
|
suggested_fixes["wave_n"] = valid_waves[0][1]
|
|
suggested_fixes["wave_k"] = valid_waves[0][2]
|
|
|
|
warp_cfg = _extract_warp_tile_config(tile_config)
|
|
ok, msg = validate_warp_tile_config(warp_cfg, arch, dtype)
|
|
if not ok:
|
|
errors.append(msg)
|
|
arch_data = get_arch_filter_data()
|
|
acc = "int32" if dtype == "int8" else "fp32"
|
|
dtype_key = f"{dtype}_{dtype}_{acc}"
|
|
valid_tiles = (
|
|
arch_data["warp_tile_combos"]
|
|
.get(arch, {})
|
|
.get(dtype_key, [[32, 32, 16], [16, 16, 16]])
|
|
)
|
|
if valid_tiles:
|
|
suggested_fixes["warp_tile_m"] = valid_tiles[0][0]
|
|
suggested_fixes["warp_tile_n"] = valid_tiles[0][1]
|
|
suggested_fixes["warp_tile_k"] = valid_tiles[0][2]
|
|
|
|
arch_data = get_arch_filter_data()
|
|
if arch not in arch_data["supported_archs"]:
|
|
errors.append(
|
|
f"Unsupported architecture: {arch}. Supported: {', '.join(arch_data['supported_archs'])}"
|
|
)
|
|
|
|
return GroupedConvValidationResult(
|
|
is_valid=len(errors) == 0,
|
|
errors=errors,
|
|
warnings=warnings,
|
|
suggested_fixes=suggested_fixes,
|
|
variant=variant,
|
|
)
|
|
|
|
|
|
def auto_correct_grouped_conv_config(
|
|
config: dict,
|
|
) -> Tuple[dict, GroupedConvValidationResult]:
|
|
"""Auto-correct invalid grouped conv config. Returns (corrected, result)."""
|
|
result = validate_grouped_conv_config(config)
|
|
corrected = copy.deepcopy(config)
|
|
|
|
if result.is_valid:
|
|
return corrected, result
|
|
|
|
tile_config = corrected.setdefault("tile_config", {})
|
|
trait_config = corrected.setdefault("trait_config", {})
|
|
|
|
wave_cfg = _extract_wave_config(tile_config)
|
|
arch = config.get("arch", "gfx942")
|
|
fixed_wave = auto_correct_wave(wave_cfg, arch)
|
|
tile_config["wave_m"] = fixed_wave[0]
|
|
tile_config["wave_n"] = fixed_wave[1]
|
|
tile_config["wave_k"] = fixed_wave[2]
|
|
|
|
pipeline, epilogue, scheduler = _extract_trait_values(trait_config)
|
|
fixed_pipeline, fixed_scheduler = auto_correct_trait(pipeline, scheduler)
|
|
trait_config["pipeline"] = fixed_pipeline
|
|
trait_config["scheduler"] = fixed_scheduler
|
|
|
|
variant = _first(config.get("variant", "forward"))
|
|
if isinstance(variant, list):
|
|
variant = variant[0] if variant else "forward"
|
|
variant = _resolve_variant(str(variant))
|
|
if variant in BACKWARD_VARIANTS and fixed_pipeline not in BACKWARD_PIPELINES:
|
|
trait_config["pipeline"] = "compv3"
|
|
|
|
if "warp_tile_m" in result.suggested_fixes:
|
|
tile_config["warp_tile_m"] = result.suggested_fixes["warp_tile_m"]
|
|
tile_config["warp_tile_n"] = result.suggested_fixes["warp_tile_n"]
|
|
tile_config["warp_tile_k"] = result.suggested_fixes["warp_tile_k"]
|
|
|
|
result = validate_grouped_conv_config(corrected)
|
|
return corrected, result
|
|
|
|
|
|
def _run_hipcc_subprocess(args: dict) -> Tuple[bool, Optional[Path], str]:
|
|
"""Run one hipcc compile+link job in a subprocess worker."""
|
|
import subprocess
|
|
from pathlib import Path
|
|
|
|
compile_cmd = args["compile_cmd"]
|
|
link_cmd = args["link_cmd"]
|
|
lib_path = Path(args["lib_path"])
|
|
|
|
try:
|
|
res_c = subprocess.run(compile_cmd, capture_output=True, text=True, timeout=300)
|
|
if res_c.returncode != 0:
|
|
err = (res_c.stderr or res_c.stdout or "").rstrip()
|
|
return False, None, f"Compile failed (rc={res_c.returncode}):\n{err}"
|
|
|
|
res_l = subprocess.run(link_cmd, capture_output=True, text=True, timeout=300)
|
|
if res_l.returncode != 0:
|
|
err = (res_l.stderr or res_l.stdout or "").rstrip()
|
|
return False, None, f"Link failed (rc={res_l.returncode}):\n{err}"
|
|
|
|
return True, lib_path, ""
|
|
except subprocess.TimeoutExpired:
|
|
return False, None, "Timeout"
|
|
except Exception as e:
|
|
return False, None, f"Error: {e}"
|
|
|
|
|
|
def _run_conv_codegen_subprocess(args: dict) -> Tuple[bool, Optional[str], str]:
|
|
"""Run grouped-conv codegen once and return generated kernel header path."""
|
|
import subprocess
|
|
from pathlib import Path
|
|
|
|
out_dir = Path(args["output_dir"])
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Remove stale kernels so header discovery is exact for this invocation.
|
|
for stale in out_dir.glob("grouped_conv_*.hpp"):
|
|
stale.unlink(missing_ok=True)
|
|
for stale in out_dir.glob("include_all_grouped_conv_*.hpp"):
|
|
stale.unlink(missing_ok=True)
|
|
|
|
try:
|
|
res = subprocess.run(args["cmd"], capture_output=True, text=True, timeout=300)
|
|
if res.returncode != 0:
|
|
err = (res.stderr or res.stdout or "").rstrip()
|
|
return False, None, f"Codegen failed (rc={res.returncode}):\n{err}"
|
|
|
|
generated = sorted(
|
|
out_dir.glob("grouped_conv_*.hpp"),
|
|
key=lambda p: p.stat().st_mtime,
|
|
reverse=True,
|
|
)
|
|
if not generated:
|
|
return False, None, "Codegen produced no grouped_conv_*.hpp header"
|
|
|
|
return True, str(generated[0]), ""
|
|
except subprocess.TimeoutExpired:
|
|
return False, None, "Codegen timed out"
|
|
except Exception as e:
|
|
return False, None, f"Codegen error: {e}"
|
|
|
|
|
|
def _config_key(c: GroupedConvKernelConfig) -> Tuple[Any, ...]:
|
|
return (
|
|
c.variant,
|
|
c.ndim_spatial,
|
|
c.dtype,
|
|
c.layout,
|
|
c.arch,
|
|
c.tile_m,
|
|
c.tile_n,
|
|
c.tile_k,
|
|
c.wave_m,
|
|
c.wave_n,
|
|
c.wave_k,
|
|
c.warp_tile_m,
|
|
c.warp_tile_n,
|
|
c.warp_tile_k,
|
|
c.pipeline,
|
|
c.epilogue,
|
|
c.scheduler,
|
|
c.num_groups_to_merge,
|
|
c.double_smem_buffer,
|
|
c.split_image,
|
|
c.two_stage,
|
|
)
|
|
|
|
|
|
def _parse_triplet(value: str) -> Tuple[int, int, int]:
|
|
parts = value.split("x")
|
|
if len(parts) != 3:
|
|
raise ValueError(f"Invalid triplet: {value}")
|
|
return int(parts[0]), int(parts[1]), int(parts[2])
|
|
|
|
|
|
def _list_arch_valid_grouped_conv_configs(
|
|
codegen_script: Path,
|
|
arch: str,
|
|
dtype: str,
|
|
variant: str,
|
|
ndim_spatial: int,
|
|
) -> List[GroupedConvKernelConfig]:
|
|
"""Query codegen defaults for this (arch, dtype, variant, ndim) tuple."""
|
|
import re
|
|
import sys
|
|
|
|
cmd = [
|
|
sys.executable,
|
|
str(codegen_script),
|
|
"--list-configs",
|
|
"--arch",
|
|
arch,
|
|
"--datatype",
|
|
dtype,
|
|
"--variant",
|
|
variant,
|
|
"--ndim",
|
|
str(ndim_spatial),
|
|
]
|
|
res = subprocess.run(cmd, capture_output=True, text=True, timeout=180)
|
|
if res.returncode != 0:
|
|
return []
|
|
|
|
# Example:
|
|
# grouped_conv_fwd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x32_2x2x1_32x32x16
|
|
name_re = re.compile(
|
|
r"^grouped_conv_(fwd|bwd_data|bwd_weight|bwdd|bwdw)_([a-z0-9]+)_([a-z0-9]+)_([123])d_"
|
|
r"([a-z0-9]+)_([a-z0-9]+)_([a-z0-9]+)_"
|
|
r"([0-9]+x[0-9]+x[0-9]+)_([0-9]+x[0-9]+x[0-9]+)_([0-9]+x[0-9]+x[0-9]+)"
|
|
r"(?:_.*)?$"
|
|
)
|
|
short_to_variant = {
|
|
"fwd": "forward",
|
|
"bwd_data": "bwd_data",
|
|
"bwd_weight": "bwd_weight",
|
|
"bwdd": "bwd_data",
|
|
"bwdw": "bwd_weight",
|
|
}
|
|
|
|
out: List[GroupedConvKernelConfig] = []
|
|
seen = set()
|
|
for raw in res.stdout.splitlines():
|
|
line = raw.strip()
|
|
if not line.startswith("- grouped_conv_"):
|
|
continue
|
|
name = line[2:].strip()
|
|
m = name_re.match(name)
|
|
if not m:
|
|
continue
|
|
|
|
v_short, dt, layout, ndim, pipe, epi, sched, tile_s, wave_s, warp_s = m.groups()
|
|
tm, tn, tk = _parse_triplet(tile_s)
|
|
wm, wn, wk = _parse_triplet(wave_s)
|
|
wtm, wtn, wtk = _parse_triplet(warp_s)
|
|
|
|
cfg = GroupedConvKernelConfig(
|
|
variant=short_to_variant[v_short],
|
|
ndim_spatial=int(ndim),
|
|
dtype=dt,
|
|
layout=layout,
|
|
arch=arch,
|
|
tile_m=tm,
|
|
tile_n=tn,
|
|
tile_k=tk,
|
|
wave_m=wm,
|
|
wave_n=wn,
|
|
wave_k=wk,
|
|
warp_tile_m=wtm,
|
|
warp_tile_n=wtn,
|
|
warp_tile_k=wtk,
|
|
pipeline=pipe,
|
|
epilogue=epi,
|
|
scheduler=sched,
|
|
)
|
|
key = _config_key(cfg)
|
|
if key not in seen:
|
|
out.append(cfg)
|
|
seen.add(key)
|
|
|
|
return out
|
|
|
|
|
|
def _select_best_arch_valid_conv_config(
|
|
requested: GroupedConvKernelConfig,
|
|
candidates: List[GroupedConvKernelConfig],
|
|
) -> GroupedConvKernelConfig:
|
|
"""Pick nearest arch-valid config while preferring trait exact matches."""
|
|
|
|
def score(c: GroupedConvKernelConfig) -> Tuple[int, int, int, int, int, int]:
|
|
tile_delta = (
|
|
abs(c.tile_m - requested.tile_m)
|
|
+ abs(c.tile_n - requested.tile_n)
|
|
+ abs(c.tile_k - requested.tile_k)
|
|
)
|
|
wave_delta = (
|
|
abs(c.wave_m - requested.wave_m)
|
|
+ abs(c.wave_n - requested.wave_n)
|
|
+ abs(c.wave_k - requested.wave_k)
|
|
)
|
|
warp_tile_delta = (
|
|
abs(c.warp_tile_m - requested.warp_tile_m)
|
|
+ abs(c.warp_tile_n - requested.warp_tile_n)
|
|
+ abs(c.warp_tile_k - requested.warp_tile_k)
|
|
)
|
|
return (
|
|
0 if c.pipeline == requested.pipeline else 1,
|
|
0 if c.scheduler == requested.scheduler else 1,
|
|
0 if c.epilogue == requested.epilogue else 1,
|
|
tile_delta,
|
|
wave_delta,
|
|
warp_tile_delta,
|
|
)
|
|
|
|
best = min(candidates, key=score)
|
|
selected = copy.deepcopy(best)
|
|
selected.arch = requested.arch
|
|
return selected
|
|
|
|
|
|
def _write_single_conv_dispatch_header(
|
|
config: GroupedConvKernelConfig,
|
|
kernel_header: Path,
|
|
dispatch_header: Path,
|
|
) -> None:
|
|
"""Create a tiny dispatch header consumed by conv_ctypes_lib.cpp."""
|
|
macros: List[str] = []
|
|
aliases: List[str] = []
|
|
|
|
if config.variant == "forward":
|
|
kernel_name_symbol = "CONV_FWD_KERNEL_NAME"
|
|
if config.ndim_spatial == 3:
|
|
macros.append("#define CONV_FWD_3D_AVAILABLE 1")
|
|
aliases.append("using ConvFwd3dLauncher = SelectedConvKernelLauncher;")
|
|
else:
|
|
macros.append("#define CONV_FWD_2D_AVAILABLE 1")
|
|
elif config.variant == "bwd_data":
|
|
kernel_name_symbol = "CONV_BWD_DATA_KERNEL_NAME"
|
|
if config.ndim_spatial == 3:
|
|
macros.append("#define CONV_BWD_DATA_3D_AVAILABLE 1")
|
|
aliases.append("using ConvBwdData3dLauncher = SelectedConvBwdDataLauncher;")
|
|
else:
|
|
macros.append("#define CONV_BWD_DATA_2D_AVAILABLE 1")
|
|
else:
|
|
kernel_name_symbol = "CONV_BWD_WEIGHT_KERNEL_NAME"
|
|
if config.ndim_spatial == 3:
|
|
macros.append("#define CONV_BWD_WEIGHT_3D_AVAILABLE 1")
|
|
aliases.append(
|
|
"using ConvBwdWeight3dLauncher = SelectedConvBwdWeightLauncher;"
|
|
)
|
|
else:
|
|
macros.append("#define CONV_BWD_WEIGHT_2D_AVAILABLE 1")
|
|
|
|
content = (
|
|
"// Auto-generated single-kernel dispatch header for Python JIT\n"
|
|
"#pragma once\n\n"
|
|
f'#include "{kernel_header.name}"\n\n'
|
|
+ "\n".join(macros)
|
|
+ "\n\n"
|
|
+ "\n".join(aliases)
|
|
+ "\n\n"
|
|
+ f"static const char* CONV_KERNEL_NAMES[] = {{{kernel_name_symbol}}};\n"
|
|
+ "static constexpr int CONV_KERNEL_COUNT = 1;\n"
|
|
)
|
|
dispatch_header.write_text(content)
|
|
|
|
|
|
class GroupedConvCodegenRunner:
|
|
"""Generate and compile grouped-conv JIT libraries in parallel."""
|
|
|
|
def __init__(self, max_workers: Optional[int] = None):
|
|
import multiprocessing
|
|
|
|
self.max_workers = max_workers or min(multiprocessing.cpu_count(), 8)
|
|
self.root = Path(__file__).parent.parent
|
|
self.build_dir = self.root / "build"
|
|
self.codegen_script = self.root / "codegen" / "unified_grouped_conv_codegen.py"
|
|
|
|
def generate_and_compile_parallel(
|
|
self,
|
|
configs: List[GroupedConvKernelConfig],
|
|
verbose: bool = True,
|
|
) -> List[Optional[Path]]:
|
|
import sys
|
|
|
|
if not configs:
|
|
return []
|
|
|
|
if not self.build_dir.exists():
|
|
self.build_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
ctypes_source = self.root / "bindings" / "ctypes" / "conv_ctypes_lib.cpp"
|
|
static_lib = self.build_dir / "libck_tile_dispatcher.a"
|
|
jit_root = self.build_dir / "generated_kernels" / "python_jit"
|
|
jit_root.mkdir(parents=True, exist_ok=True)
|
|
(self.build_dir / "examples").mkdir(parents=True, exist_ok=True)
|
|
|
|
if not self.codegen_script.exists():
|
|
if verbose:
|
|
print(f"Codegen script missing: {self.codegen_script}")
|
|
return [None] * len(configs)
|
|
if not ctypes_source.exists() or not static_lib.exists():
|
|
if verbose:
|
|
print("Missing conv ctypes source or static dispatcher library")
|
|
return [None] * len(configs)
|
|
|
|
if verbose:
|
|
print(
|
|
f"Generating {len(configs)} grouped-conv kernels with "
|
|
f"{self.max_workers} threads (out-of-order)..."
|
|
)
|
|
|
|
gen_jobs: List[Dict[str, Any]] = []
|
|
job_dirs: List[Path] = []
|
|
for i, c in enumerate(configs):
|
|
cfg_dir = jit_root / f"cfg_{i}"
|
|
cfg_dir.mkdir(parents=True, exist_ok=True)
|
|
job_dirs.append(cfg_dir)
|
|
|
|
cmd = [
|
|
sys.executable,
|
|
str(self.codegen_script),
|
|
"--output",
|
|
str(cfg_dir),
|
|
"--datatype",
|
|
c.dtype,
|
|
"--variant",
|
|
c.variant,
|
|
"--ndim",
|
|
str(c.ndim_spatial),
|
|
"--arch",
|
|
c.arch,
|
|
"--tile-m",
|
|
str(c.tile_m),
|
|
"--tile-n",
|
|
str(c.tile_n),
|
|
"--tile-k",
|
|
str(c.tile_k),
|
|
"--warp-m",
|
|
str(c.wave_m),
|
|
"--warp-n",
|
|
str(c.wave_n),
|
|
"--warp-k",
|
|
str(c.wave_k),
|
|
"--warp-tile-m",
|
|
str(c.warp_tile_m),
|
|
"--warp-tile-n",
|
|
str(c.warp_tile_n),
|
|
"--warp-tile-k",
|
|
str(c.warp_tile_k),
|
|
"--pipeline",
|
|
c.pipeline,
|
|
"--scheduler",
|
|
c.scheduler,
|
|
"--epilogue",
|
|
c.epilogue,
|
|
"--num-groups-to-merge",
|
|
str(c.num_groups_to_merge),
|
|
"--double-smem-buffer",
|
|
"true" if c.double_smem_buffer else "false",
|
|
]
|
|
if c.split_image:
|
|
cmd.append("--split-image")
|
|
if c.two_stage:
|
|
cmd.append("--two-stage")
|
|
gen_jobs.append({"cmd": cmd, "output_dir": str(cfg_dir)})
|
|
|
|
generated_headers: List[Optional[Path]] = [None] * len(configs)
|
|
|
|
# Phase 1 codegen: each worker just calls subprocess.run() to invoke the
|
|
# codegen script. The wait releases the GIL, so threads give true parallelism
|
|
# without the fork-after-HIP risk of ProcessPoolExecutor.
|
|
print_lock = threading.Lock()
|
|
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as ex:
|
|
futures = {
|
|
ex.submit(_run_conv_codegen_subprocess, job): idx
|
|
for idx, job in enumerate(gen_jobs)
|
|
}
|
|
for fut in as_completed(futures):
|
|
idx = futures[fut]
|
|
ok, header_path, err = fut.result()
|
|
if ok and header_path:
|
|
generated_headers[idx] = Path(header_path)
|
|
if verbose:
|
|
with print_lock:
|
|
print(f" OK [{idx}] codegen: {Path(header_path).name}")
|
|
else:
|
|
if verbose:
|
|
with print_lock:
|
|
print(f" FAIL [{idx}] codegen: {err}")
|
|
|
|
if verbose:
|
|
compile_count = sum(1 for h in generated_headers if h is not None)
|
|
print(
|
|
f"Compiling {compile_count} grouped-conv libraries with "
|
|
f"{self.max_workers} threads (out-of-order)..."
|
|
)
|
|
|
|
compile_jobs: List[Dict[str, Any]] = []
|
|
compile_to_input_index: Dict[int, int] = {}
|
|
for i, c in enumerate(configs):
|
|
hdr_path = generated_headers[i]
|
|
if hdr_path is None:
|
|
continue
|
|
|
|
cfg_dir = job_dirs[i]
|
|
dispatch_header = cfg_dir / "conv_python_dispatch.hpp"
|
|
_write_single_conv_dispatch_header(c, hdr_path, dispatch_header)
|
|
|
|
# Build suffix with all distinguishing config options
|
|
suffix = ""
|
|
if c.num_groups_to_merge != 1:
|
|
suffix += f"_gm{c.num_groups_to_merge}"
|
|
if c.double_smem_buffer:
|
|
suffix += "_dsb"
|
|
if c.split_image:
|
|
suffix += "_si"
|
|
if c.two_stage:
|
|
suffix += "_2stage"
|
|
|
|
lib_name = (
|
|
f"libdispatcher_conv_{c.variant}_{c.ndim_spatial}d_{c.dtype}_"
|
|
f"{c.tile_str}_{c.wave_str}_{c.warp_str}_{c.pipeline}_{c.scheduler}{suffix}.so"
|
|
)
|
|
lib_path = self.build_dir / "examples" / lib_name
|
|
obj_file = lib_path.with_suffix(".o")
|
|
|
|
compile_cmd = [
|
|
"/opt/rocm/bin/hipcc",
|
|
"-c",
|
|
"-fPIC",
|
|
"-O3",
|
|
f"-I{self.root / 'include'}",
|
|
f"-I{self.root.parent / 'include'}",
|
|
f"-I{self.root.parent}",
|
|
f"-I{cfg_dir}",
|
|
"-DCK_TILE_SINGLE_KERNEL_INCLUDE",
|
|
f"-include{dispatch_header}",
|
|
"-D__HIP_PLATFORM_AMD__",
|
|
f"--offload-arch={c.arch}",
|
|
f'-DGFX_ARCH="{c.arch}"',
|
|
"-mllvm",
|
|
"-enable-noalias-to-md-conversion=0",
|
|
"-Wno-undefined-func-template",
|
|
"-Wno-float-equal",
|
|
str(ctypes_source),
|
|
"-o",
|
|
str(obj_file),
|
|
]
|
|
link_cmd = [
|
|
"/opt/rocm/bin/hipcc",
|
|
"-shared",
|
|
"-fPIC",
|
|
f"--offload-arch={c.arch}",
|
|
"--hip-link",
|
|
str(obj_file),
|
|
str(static_lib),
|
|
"-o",
|
|
str(lib_path),
|
|
]
|
|
|
|
compile_to_input_index[len(compile_jobs)] = i
|
|
compile_jobs.append(
|
|
{
|
|
"compile_cmd": compile_cmd,
|
|
"link_cmd": link_cmd,
|
|
"lib_path": str(lib_path),
|
|
"config_name": c.name,
|
|
}
|
|
)
|
|
|
|
results_map: Dict[int, Optional[Path]] = {i: None for i in range(len(configs))}
|
|
|
|
# Phase 1 compile: workers shell out to hipcc, releasing the GIL while
|
|
# waiting. Threads give true parallelism here; ProcessPool would risk
|
|
# fork() corrupting any HIP state the parent might have loaded.
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as ex:
|
|
futures = {
|
|
ex.submit(_run_hipcc_subprocess, job): j
|
|
for j, job in enumerate(compile_jobs)
|
|
}
|
|
for fut in as_completed(futures):
|
|
j = futures[fut]
|
|
idx = compile_to_input_index[j]
|
|
success, lib_path, err = fut.result()
|
|
if success and lib_path:
|
|
results_map[idx] = Path(lib_path)
|
|
if verbose:
|
|
name = (
|
|
Path(lib_path).name
|
|
if success and lib_path
|
|
else compile_jobs[j]["config_name"]
|
|
)
|
|
with print_lock:
|
|
if success:
|
|
print(f" OK {name}")
|
|
else:
|
|
# Print the full multi-line error indented for readability
|
|
# so users don't have to monkey-patch to see real compile output.
|
|
print(f" FAIL {name}")
|
|
for line in (err or "").splitlines() or [""]:
|
|
print(f" {line}")
|
|
|
|
return [results_map.get(i) for i in range(len(configs))]
|
|
|
|
|
|
# =============================================================================
|
|
# Convenience functions
|
|
# =============================================================================
|
|
|
|
|
|
def get_grouped_conv_default_config(
|
|
variant: str = "forward",
|
|
ndim_spatial: int = 2,
|
|
arch: str = "gfx942",
|
|
dtype: str = "fp16",
|
|
) -> GroupedConvKernelConfig:
|
|
"""Return a valid default GroupedConvKernelConfig."""
|
|
return GroupedConvKernelConfig(
|
|
variant=variant,
|
|
ndim_spatial=ndim_spatial,
|
|
arch=arch,
|
|
dtype=dtype,
|
|
)
|
|
|
|
|
|
def format_grouped_conv_summary(config) -> str:
|
|
"""Format a config (dict or GroupedConvKernelConfig) into a human-readable string."""
|
|
if isinstance(config, GroupedConvKernelConfig):
|
|
lines = [
|
|
f"Grouped Conv Config: {config.variant} {config.ndim_spatial}D",
|
|
f" Arch: {config.arch}",
|
|
f" Layout: {config.layout}",
|
|
f" Dtype: {config.dtype}",
|
|
f" Tile: {config.tile_str}",
|
|
f" Wave: {config.wave_str}",
|
|
f" Warp: {config.warp_str}",
|
|
f" Traits: pipeline={config.pipeline} epilogue={config.epilogue} scheduler={config.scheduler}",
|
|
]
|
|
return "\n".join(lines)
|
|
|
|
# Legacy dict support
|
|
tile_config = _get_tile_config(config) if isinstance(config, dict) else {}
|
|
trait_config = _get_trait_config(config) if isinstance(config, dict) else {}
|
|
variant = config.get("variant", "?") if isinstance(config, dict) else "?"
|
|
ndim = config.get("ndim_spatial", "?") if isinstance(config, dict) else "?"
|
|
arch = config.get("arch", "?") if isinstance(config, dict) else "?"
|
|
layout = config.get("layout", "?") if isinstance(config, dict) else "?"
|
|
dtype = config.get("dtype", "fp16") if isinstance(config, dict) else "fp16"
|
|
|
|
lines = [f"Grouped Conv Config: {variant} {ndim}D"]
|
|
lines.append(f" Arch: {arch}")
|
|
lines.append(f" Layout: {layout}")
|
|
lines.append(f" Dtype: {dtype}")
|
|
|
|
if tile_config:
|
|
wave = _extract_wave_config(tile_config)
|
|
warp = _extract_warp_tile_config(tile_config)
|
|
lines.append(
|
|
f" Tile: M={_first(tile_config.get('tile_m', 1))} N={_first(tile_config.get('tile_n', 128))} K={_first(tile_config.get('tile_k', 128))}"
|
|
)
|
|
lines.append(f" Wave: {wave[0]}x{wave[1]}x{wave[2]}")
|
|
lines.append(f" Warp: {warp[0]}x{warp[1]}x{warp[2]}")
|
|
|
|
if trait_config:
|
|
pipeline = _first(trait_config.get("pipeline", "?"))
|
|
epilogue = _first(trait_config.get("epilogue", "?"))
|
|
scheduler = _first(trait_config.get("scheduler", "?"))
|
|
lines.append(
|
|
f" Traits: pipeline={pipeline} epilogue={epilogue} scheduler={scheduler}"
|
|
)
|
|
|
|
return "\n".join(lines) if lines else "(empty config)"
|
|
|
|
|
|
def setup_multiple_grouped_conv_dispatchers(
|
|
configs: List[GroupedConvKernelConfig],
|
|
verbose: bool = True,
|
|
max_workers: Optional[int] = None,
|
|
) -> List[Optional[Path]]:
|
|
"""
|
|
Setup multiple grouped-conv dispatchers.
|
|
|
|
Returns library paths WITHOUT loading them, to avoid GPU context during compilation.
|
|
This mirrors FMHA design: keep GPU context out of JIT phase entirely.
|
|
|
|
Architecture filtering workflow:
|
|
1. Validate each requested config via validate_grouped_conv_config; if invalid,
|
|
attempt auto_correct_grouped_conv_config. Drop configs that remain invalid.
|
|
2. Trust the (possibly auto-corrected) config as-is. Knobs such as scheduler,
|
|
num_groups_to_merge, double_smem_buffer, split_image, two_stage are preserved
|
|
exactly as requested -- no remap to a hardcoded "default" set.
|
|
3. Threaded codegen + threaded compile (workers shell out via subprocess,
|
|
which releases the GIL; threads avoid the fork-after-HIP risk that
|
|
ProcessPoolExecutor would have).
|
|
4. Return paths (NOT loaded libraries).
|
|
|
|
Returns:
|
|
List of paths to compiled .so files (or None for failed configs)
|
|
"""
|
|
if not configs:
|
|
return []
|
|
|
|
selected_configs: List[Optional[GroupedConvKernelConfig]] = []
|
|
for i, original in enumerate(configs):
|
|
c = copy.deepcopy(original)
|
|
|
|
val = validate_grouped_conv_config(c.to_dict())
|
|
if not val.is_valid:
|
|
corrected, corrected_result = auto_correct_grouped_conv_config(c.to_dict())
|
|
if not corrected_result.is_valid:
|
|
if verbose:
|
|
print(f" FAIL [{i}] config remains invalid after auto-correct")
|
|
selected_configs.append(None)
|
|
continue
|
|
|
|
tile_cfg = corrected.get("tile_config", {})
|
|
trait_cfg = corrected.get("trait_config", {})
|
|
c.variant = _resolve_variant(
|
|
str(_first(corrected.get("variant", c.variant)))
|
|
)
|
|
c.ndim_spatial = int(_first(corrected.get("ndim_spatial", c.ndim_spatial)))
|
|
c.arch = str(corrected.get("arch", c.arch))
|
|
c.layout = str(corrected.get("layout", c.layout))
|
|
c.dtype = str(corrected.get("dtype", c.dtype))
|
|
c.tile_m = int(_first(tile_cfg.get("tile_m", c.tile_m)))
|
|
c.tile_n = int(_first(tile_cfg.get("tile_n", c.tile_n)))
|
|
c.tile_k = int(_first(tile_cfg.get("tile_k", c.tile_k)))
|
|
c.wave_m = int(_first(tile_cfg.get("wave_m", c.wave_m)))
|
|
c.wave_n = int(_first(tile_cfg.get("wave_n", c.wave_n)))
|
|
c.wave_k = int(_first(tile_cfg.get("wave_k", c.wave_k)))
|
|
c.warp_tile_m = int(_first(tile_cfg.get("warp_tile_m", c.warp_tile_m)))
|
|
c.warp_tile_n = int(_first(tile_cfg.get("warp_tile_n", c.warp_tile_n)))
|
|
c.warp_tile_k = int(_first(tile_cfg.get("warp_tile_k", c.warp_tile_k)))
|
|
c.pipeline = str(_first(trait_cfg.get("pipeline", c.pipeline)))
|
|
c.scheduler = str(_first(trait_cfg.get("scheduler", c.scheduler)))
|
|
c.epilogue = str(_first(trait_cfg.get("epilogue", c.epilogue)))
|
|
|
|
# Trust the validated config -- no remap to a hardcoded arch-valid set.
|
|
# Knobs (num_groups_to_merge, double_smem_buffer, split_image, two_stage)
|
|
# and scheduler choice are preserved exactly as requested.
|
|
selected_configs.append(c)
|
|
|
|
unique_configs: List[GroupedConvKernelConfig] = []
|
|
unique_index_by_key: Dict[Tuple[Any, ...], int] = {}
|
|
input_to_unique: List[Optional[int]] = []
|
|
for cfg in selected_configs:
|
|
if cfg is None:
|
|
input_to_unique.append(None)
|
|
continue
|
|
key = _config_key(cfg)
|
|
if key not in unique_index_by_key:
|
|
unique_index_by_key[key] = len(unique_configs)
|
|
unique_configs.append(cfg)
|
|
input_to_unique.append(unique_index_by_key[key])
|
|
|
|
runner = GroupedConvCodegenRunner(max_workers=max_workers)
|
|
unique_lib_paths = runner.generate_and_compile_parallel(
|
|
unique_configs, verbose=verbose
|
|
)
|
|
|
|
# Map unique lib paths back to input order
|
|
# DO NOT load libraries here - just return paths
|
|
lib_paths: List[Optional[Path]] = []
|
|
path_cache: Dict[int, Optional[Path]] = {}
|
|
for input_idx, unique_idx in enumerate(input_to_unique):
|
|
if unique_idx is None:
|
|
lib_paths.append(None)
|
|
continue
|
|
|
|
if unique_idx in path_cache:
|
|
lib_paths.append(path_cache[unique_idx])
|
|
continue
|
|
|
|
path = (
|
|
unique_lib_paths[unique_idx] if unique_idx < len(unique_lib_paths) else None
|
|
)
|
|
# Validate path exists but don't load it
|
|
if path and not path.exists():
|
|
if verbose:
|
|
print(f" FAIL [{input_idx}] library not found: {path}")
|
|
path = None
|
|
|
|
path_cache[unique_idx] = path
|
|
lib_paths.append(path)
|
|
|
|
return lib_paths
|
|
|
|
|
|
def detect_gpu_arch() -> str:
|
|
"""Detect GPU architecture using rocminfo."""
|
|
try:
|
|
out = subprocess.check_output(
|
|
["rocminfo"], stderr=subprocess.DEVNULL, text=True
|
|
)
|
|
for line in out.split("\n"):
|
|
if "gfx" in line.lower() and "name:" in line.lower():
|
|
for part in line.split():
|
|
if part.startswith("gfx"):
|
|
return part
|
|
except Exception:
|
|
pass
|
|
return "gfx942"
|