[rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)

[CK][CK TILE] Dispatcher kernel selection heuristic for
 grouped conv (#6327)

## Motivation
The ML heuristic in dispatcher does not support grouped-conv operator
yet. In this PR, the support for fwd, bdw-data, and bwd-weight
grouped-conv kernels have been added. A tile_engine utility has also
been added to compile and run any selected kernel configuration through
dispatcher infrastructure.

## Technical Details

1. Tile engine utility is added to benchmark each shape with all the
possible kernel+tile_size combinations here -
[https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url)
2. New LGBM regressor models for grouped conv are added to models
directory. We have 3 separate models for fwd, bwd-data, and bwd-weights
[https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url)
3. Implemented lazy GPU initialization (dispatcher/python)
- **Issue**: ProcessPoolExecutor fork() + GPU context caused memory
access faults
- **Solution**: Mirror FMHA pattern - defer GPU initialization until
first run()
  - **Changes**:
- setup_multiple_grouped_conv_dispatchers() returns List[Path], not
loaded libs
    - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL
    - Added _ensure_initialized() method for lazy GPU loading
    - GPU context created only on first run() call
  - **Benefit**: Parallel compilation now works without GPU conflicts
4. Addressed few miscellaneous issues such as:
  - Fixed BF16->FP16 naming bug in the dispatcher wrapper
- Added new tile sizes, and comp_v5 pipeline to the arch spec to expand
the kernel selection
- Added automatic padding support for unsupported shapes in dispatcher
runner
- Created a single source of truth between tile_engine and dispatcher
about the architecture and tile_size details
- Build a validation scripts to compare oracle_best vs ml_heuristic
comparison

## Test Plan

1. Validated fwd, bwd-data, and bwd-weight kernels with both known and
unseen data sets with up to 300 problems.
2. Ensured that test cases are added in both dispatcher and tile_engine
to validate the heuristic.

## Test Result
Results on Unseen shapes validated on gfx950
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.05%**
  - Median Efficiency: **96.8%**
  - P10 Efficiency: **79.9%**

#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.8%**
  - Median Efficiency: **96.5%**
  - P10 Efficiency: **82.9%**

#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **96.1%**
  - Median Efficiency: **99.2%**
  - P10 Efficiency: **89.4%**

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Yaswanth Raparti
2026-05-08 20:48:42 +00:00
committed by assistant-librarian[bot]
parent b05040b919
commit 6989cf800c
65 changed files with 13206 additions and 389 deletions

View File

@@ -38,6 +38,9 @@ import ctypes
import json
import copy
import subprocess
import threading
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
@@ -148,6 +151,12 @@ class GroupedConvKernelConfig:
pad_n: bool = True
pad_k: bool = True
# Additional trait config options
double_smem_buffer: bool = False
split_image: bool = False
explicit_gemm: bool = False
two_stage: bool = False
def __post_init__(self):
self.variant = _resolve_variant(self.variant)
if (
@@ -174,10 +183,21 @@ class GroupedConvKernelConfig:
@property
def name(self) -> str:
return (
f"grouped_conv_{self.variant}_{self.dtype}_{self.ndim_spatial}d_"
f"{self.tile_str}_{self.pipeline}"
)
parts = [
f"grouped_conv_{self.variant}_{self.dtype}_{self.ndim_spatial}d",
self.tile_str,
self.pipeline,
self.scheduler, # NEW: Include scheduler
]
if self.num_groups_to_merge != 1:
parts.append(f"gm{self.num_groups_to_merge}") # NEW: Group merge
if self.double_smem_buffer:
parts.append("dsb") # NEW: Double SMEM buffer
if self.split_image:
parts.append("si") # NEW: Split image
if self.two_stage:
parts.append("2stage") # NEW: Two-stage
return "_".join(parts)
def to_dict(self) -> dict:
"""Convert to legacy dict format for codegen compatibility."""
@@ -206,6 +226,10 @@ class GroupedConvKernelConfig:
"block_per_cu": [self.block_per_cu],
"num_wave_groups": [self.num_wave_groups],
"num_groups_to_merge": [self.num_groups_to_merge],
"double_smem_buffer": [self.double_smem_buffer],
"split_image": [self.split_image],
"explicit_gemm": [self.explicit_gemm],
"two_stage": [self.two_stage],
},
"variant": self.variant,
"ndim_spatial": self.ndim_spatial,
@@ -302,6 +326,17 @@ class GroupedConvProblem:
direction: str = "forward"
split_k: int = 1
def __post_init__(self):
"""Validate grouped convolution constraints."""
if self.C % self.G != 0:
raise ValueError(
f"C must be divisible by G for grouped convolution: C={self.C}, G={self.G}"
)
if self.K % self.G != 0:
raise ValueError(
f"K must be divisible by G for grouped convolution: K={self.K}, G={self.G}"
)
@property
def Ho(self) -> int:
eff_y = (self.Y - 1) * self.dilation_h + 1
@@ -327,8 +362,11 @@ class GroupedConvProblem:
@property
def flops(self) -> float:
"""Total FLOPs for this convolution (any direction, same count)."""
c_per_group = self.C // self.G
"""Total FLOPs for this convolution (any direction, same count).
Uses float division C/G to match canonical formula (validated C % G == 0 in __post_init__).
"""
c_per_group = self.C / self.G # Float division (validated C % G == 0)
if self.is_3d:
return (
2.0
@@ -591,20 +629,38 @@ class GpuGroupedConvRunner:
HIP_MEMCPY_D2H = 2
def __init__(self, lib_path: Optional[str] = None):
"""Initialize runner WITHOUT loading GPU libraries.
GPU context is created lazily on first run() call, avoiding fork() issues
during parallel compilation. This mirrors FMHA design.
Args:
lib_path: Path to dispatcher .so file (or None to auto-detect)
"""
self._lib_path = lib_path
self._dispatch_lib: Optional[GroupedConvDispatcherLib] = None
self._hip = None
self._initialized = False
self._init_error = None
self._init_traceback = None
def _ensure_initialized(self):
"""Lazy initialization - only load GPU libraries when actually needed."""
if self._initialized:
return
try:
if lib_path:
lib = ctypes.CDLL(lib_path)
self._dispatch_lib = GroupedConvDispatcherLib(lib, Path(lib_path))
# Load dispatcher library
if self._lib_path:
lib = ctypes.CDLL(self._lib_path)
self._dispatch_lib = GroupedConvDispatcherLib(lib, Path(self._lib_path))
else:
self._dispatch_lib = GroupedConvDispatcherLib.find()
if self._dispatch_lib is None:
return
# Load HIP library - THIS creates GPU context
self._hip = ctypes.CDLL("libamdhip64.so")
self._hip.hipMalloc.argtypes = [
ctypes.POINTER(ctypes.c_void_p),
@@ -623,14 +679,25 @@ class GpuGroupedConvRunner:
self._hip.hipDeviceSynchronize.argtypes = []
self._hip.hipDeviceSynchronize.restype = ctypes.c_int
# Initialize dispatcher
self._dispatch_lib.initialize()
self._initialized = True
except Exception:
except Exception as e:
self._initialized = False
self._init_error = str(e)
self._init_traceback = traceback.format_exc()
def is_available(self) -> bool:
return self._initialized and self._dispatch_lib is not None
def get_init_error(self) -> Optional[str]:
"""Get initialization error message if initialization failed."""
return self._init_error
def get_init_traceback(self) -> Optional[str]:
"""Get full initialization traceback for debugging."""
return self._init_traceback
@property
def library_path(self) -> Optional[str]:
if self._dispatch_lib:
@@ -647,6 +714,7 @@ class GpuGroupedConvRunner:
weight_np: np.ndarray,
problem: GroupedConvProblem,
output_np: Optional[np.ndarray] = None,
verbose: bool = False,
) -> GroupedConvResult:
"""Run convolution on GPU.
@@ -655,12 +723,27 @@ class GpuGroupedConvRunner:
weight_np: For forward: W (GKYXC). For bwd_data: W. For bwd_weight: dY.
problem: Problem specification.
output_np: Optional pre-allocated output buffer.
verbose: If True, print full traceback on initialization failure.
Returns:
GroupedConvResult with success, time_ms, tflops, output.
"""
# Lazy initialization - load GPU libraries only on first run
self._ensure_initialized()
if not self.is_available():
return GroupedConvResult(error="GPU not available")
# Surface the actual initialization error for diagnosability
if self._init_error:
error_msg = f"GPU initialization failed: {self._init_error}"
if verbose and self._init_traceback:
print("=" * 80)
print("GPU Initialization Traceback:")
print("=" * 80)
print(self._init_traceback)
print("=" * 80)
else:
error_msg = "GPU not available"
return GroupedConvResult(error=error_msg)
try:
# Determine output shape based on direction
@@ -677,52 +760,91 @@ class GpuGroupedConvRunner:
output_size = output_np.nbytes
# Allocate GPU memory
d_a, d_b, d_c = ctypes.c_void_p(), ctypes.c_void_p(), ctypes.c_void_p()
self._hip.hipMalloc(ctypes.byref(d_a), input_np.nbytes)
self._hip.hipMalloc(ctypes.byref(d_b), weight_np.nbytes)
self._hip.hipMalloc(ctypes.byref(d_c), output_size)
# Allocate GPU memory with error checking
d_a = ctypes.c_void_p()
d_b = ctypes.c_void_p()
d_c = ctypes.c_void_p()
allocated_ptrs = [] # Track successfully allocated pointers
# Host to device
self._hip.hipMemcpy(
d_a, input_np.ctypes.data, input_np.nbytes, self.HIP_MEMCPY_H2D
)
self._hip.hipMemcpy(
d_b, weight_np.ctypes.data, weight_np.nbytes, self.HIP_MEMCPY_H2D
)
self._hip.hipDeviceSynchronize()
try:
# Allocate input
ret = self._hip.hipMalloc(ctypes.byref(d_a), input_np.nbytes)
if ret != 0:
raise RuntimeError(
f"hipMalloc failed for input (code {ret}, size {input_np.nbytes})"
)
allocated_ptrs.append(d_a)
# Launch kernel
time_ms = self._dispatch_lib.run(d_a.value, d_b.value, d_c.value, problem)
self._hip.hipDeviceSynchronize()
# Allocate weight
ret = self._hip.hipMalloc(ctypes.byref(d_b), weight_np.nbytes)
if ret != 0:
raise RuntimeError(
f"hipMalloc failed for weight (code {ret}, size {weight_np.nbytes})"
)
allocated_ptrs.append(d_b)
result = GroupedConvResult()
# Allocate output
ret = self._hip.hipMalloc(ctypes.byref(d_c), output_size)
if ret != 0:
raise RuntimeError(
f"hipMalloc failed for output (code {ret}, size {output_size})"
)
allocated_ptrs.append(d_c)
if time_ms > 0:
# Device to host
self._hip.hipMemcpy(
output_np.ctypes.data, d_c, output_size, self.HIP_MEMCPY_D2H
# Host to device
ret = self._hip.hipMemcpy(
d_a, input_np.ctypes.data, input_np.nbytes, self.HIP_MEMCPY_H2D
)
if ret != 0:
raise RuntimeError(f"hipMemcpy H2D failed for input (code {ret})")
ret = self._hip.hipMemcpy(
d_b, weight_np.ctypes.data, weight_np.nbytes, self.HIP_MEMCPY_H2D
)
if ret != 0:
raise RuntimeError(f"hipMemcpy H2D failed for weight (code {ret})")
self._hip.hipDeviceSynchronize()
# Launch kernel
time_ms = self._dispatch_lib.run(
d_a.value, d_b.value, d_c.value, problem
)
self._hip.hipDeviceSynchronize()
result.success = True
result.time_ms = time_ms
result.tflops = problem.flops / (time_ms * 1e9)
result.output = output_np
else:
result.error = (
"unsupported"
if time_ms == -3.0
else "no kernel"
if time_ms == -2.0
else f"error (code {time_ms})"
)
# Free GPU memory
self._hip.hipFree(d_a)
self._hip.hipFree(d_b)
self._hip.hipFree(d_c)
result = GroupedConvResult()
return result
if time_ms > 0:
# Device to host
ret = self._hip.hipMemcpy(
output_np.ctypes.data, d_c, output_size, self.HIP_MEMCPY_D2H
)
if ret != 0:
raise RuntimeError(
f"hipMemcpy D2H failed for output (code {ret})"
)
self._hip.hipDeviceSynchronize()
result.success = True
result.time_ms = time_ms
result.tflops = problem.flops / (time_ms * 1e9)
result.output = output_np
else:
result.error = (
"unsupported"
if time_ms == -3.0
else "no kernel"
if time_ms == -2.0
else f"error (code {time_ms})"
)
return result
finally:
# CRITICAL: Only free successfully allocated pointers
for ptr in allocated_ptrs:
if ptr.value: # Only free non-null pointers
self._hip.hipFree(ptr)
except Exception as e:
return GroupedConvResult(error=str(e))
@@ -877,7 +999,8 @@ class GroupedConvRegistry:
key = (cfg.variant, cfg.ndim_spatial)
if key in runners:
continue
runner = GpuGroupedConvRunner(lib_path=str(lib.path))
runner = GpuGroupedConvRunner(lib_path=str(lib))
runner._ensure_initialized()
if runner.is_available():
runners[key] = runner
return runners
@@ -1135,11 +1258,13 @@ def _run_hipcc_subprocess(args: dict) -> Tuple[bool, Optional[Path], str]:
try:
res_c = subprocess.run(compile_cmd, capture_output=True, text=True, timeout=300)
if res_c.returncode != 0:
return False, None, f"Compile failed: {res_c.stderr[:400]}"
err = (res_c.stderr or res_c.stdout or "").rstrip()
return False, None, f"Compile failed (rc={res_c.returncode}):\n{err}"
res_l = subprocess.run(link_cmd, capture_output=True, text=True, timeout=300)
if res_l.returncode != 0:
return False, None, f"Link failed: {res_l.stderr[:400]}"
err = (res_l.stderr or res_l.stdout or "").rstrip()
return False, None, f"Link failed (rc={res_l.returncode}):\n{err}"
return True, lib_path, ""
except subprocess.TimeoutExpired:
@@ -1165,8 +1290,8 @@ def _run_conv_codegen_subprocess(args: dict) -> Tuple[bool, Optional[str], str]:
try:
res = subprocess.run(args["cmd"], capture_output=True, text=True, timeout=300)
if res.returncode != 0:
err = (res.stderr or res.stdout or "").strip()[:500]
return False, None, f"Codegen failed: {err}"
err = (res.stderr or res.stdout or "").rstrip()
return False, None, f"Codegen failed (rc={res.returncode}):\n{err}"
generated = sorted(
out_dir.glob("grouped_conv_*.hpp"),
@@ -1202,6 +1327,10 @@ def _config_key(c: GroupedConvKernelConfig) -> Tuple[Any, ...]:
c.pipeline,
c.epilogue,
c.scheduler,
c.num_groups_to_merge,
c.double_smem_buffer,
c.split_image,
c.two_stage,
)
@@ -1400,7 +1529,6 @@ class GroupedConvCodegenRunner:
verbose: bool = True,
) -> List[Optional[Path]]:
import sys
from concurrent.futures import ProcessPoolExecutor, as_completed
if not configs:
return []
@@ -1425,8 +1553,8 @@ class GroupedConvCodegenRunner:
if verbose:
print(
f"Generating {len(configs)} grouped-conv kernels in parallel "
f"(workers={self.max_workers})..."
f"Generating {len(configs)} grouped-conv kernels with "
f"{self.max_workers} threads (out-of-order)..."
)
gen_jobs: List[Dict[str, Any]] = []
@@ -1473,31 +1601,47 @@ class GroupedConvCodegenRunner:
c.scheduler,
"--epilogue",
c.epilogue,
"--num-groups-to-merge",
str(c.num_groups_to_merge),
"--double-smem-buffer",
"true" if c.double_smem_buffer else "false",
]
if c.split_image:
cmd.append("--split-image")
if c.two_stage:
cmd.append("--two-stage")
gen_jobs.append({"cmd": cmd, "output_dir": str(cfg_dir)})
generated_headers: List[Optional[Path]] = [None] * len(configs)
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
# Phase 1 codegen: each worker just calls subprocess.run() to invoke the
# codegen script. The wait releases the GIL, so threads give true parallelism
# without the fork-after-HIP risk of ProcessPoolExecutor.
print_lock = threading.Lock()
with ThreadPoolExecutor(max_workers=self.max_workers) as ex:
futures = {
executor.submit(_run_conv_codegen_subprocess, job): idx
ex.submit(_run_conv_codegen_subprocess, job): idx
for idx, job in enumerate(gen_jobs)
}
for future in as_completed(futures):
idx = futures[future]
ok, header_path, err = future.result()
for fut in as_completed(futures):
idx = futures[fut]
ok, header_path, err = fut.result()
if ok and header_path:
generated_headers[idx] = Path(header_path)
if verbose:
print(f" OK [{idx}] codegen: {Path(header_path).name}")
with print_lock:
print(f" OK [{idx}] codegen: {Path(header_path).name}")
else:
if verbose:
print(f" FAIL [{idx}] codegen: {err}")
with print_lock:
print(f" FAIL [{idx}] codegen: {err}")
if verbose:
compile_count = sum(1 for h in generated_headers if h is not None)
print(
f"Compiling {compile_count} grouped-conv libraries in parallel "
f"(workers={self.max_workers})..."
f"Compiling {compile_count} grouped-conv libraries with "
f"{self.max_workers} threads (out-of-order)..."
)
compile_jobs: List[Dict[str, Any]] = []
@@ -1511,9 +1655,20 @@ class GroupedConvCodegenRunner:
dispatch_header = cfg_dir / "conv_python_dispatch.hpp"
_write_single_conv_dispatch_header(c, hdr_path, dispatch_header)
# Build suffix with all distinguishing config options
suffix = ""
if c.num_groups_to_merge != 1:
suffix += f"_gm{c.num_groups_to_merge}"
if c.double_smem_buffer:
suffix += "_dsb"
if c.split_image:
suffix += "_si"
if c.two_stage:
suffix += "_2stage"
lib_name = (
f"libdispatcher_conv_{c.variant}_{c.ndim_spatial}d_{c.dtype}_"
f"{c.tile_str}_{c.wave_str}_{c.warp_str}_{c.pipeline}_{c.scheduler}.so"
f"{c.tile_str}_{c.wave_str}_{c.warp_str}_{c.pipeline}_{c.scheduler}{suffix}.so"
)
lib_path = self.build_dir / "examples" / lib_name
obj_file = lib_path.with_suffix(".o")
@@ -1563,25 +1718,36 @@ class GroupedConvCodegenRunner:
)
results_map: Dict[int, Optional[Path]] = {i: None for i in range(len(configs))}
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
# Phase 1 compile: workers shell out to hipcc, releasing the GIL while
# waiting. Threads give true parallelism here; ProcessPool would risk
# fork() corrupting any HIP state the parent might have loaded.
with ThreadPoolExecutor(max_workers=self.max_workers) as ex:
futures = {
executor.submit(_run_hipcc_subprocess, job): j
ex.submit(_run_hipcc_subprocess, job): j
for j, job in enumerate(compile_jobs)
}
for future in as_completed(futures):
job_idx = futures[future]
idx = compile_to_input_index[job_idx]
success, lib_path, err = future.result()
for fut in as_completed(futures):
j = futures[fut]
idx = compile_to_input_index[j]
success, lib_path, err = fut.result()
if success and lib_path:
results_map[idx] = Path(lib_path)
if verbose:
status = "OK" if success else f"FAIL ({err})"
name = (
Path(lib_path).name
if success and lib_path
else compile_jobs[job_idx]["config_name"]
else compile_jobs[j]["config_name"]
)
print(f" {status} {name}")
with print_lock:
if success:
print(f" OK {name}")
else:
# Print the full multi-line error indented for readability
# so users don't have to monkey-patch to see real compile output.
print(f" FAIL {name}")
for line in (err or "").splitlines() or [""]:
print(f" {line}")
return [results_map.get(i) for i in range(len(configs))]
@@ -1659,26 +1825,30 @@ def setup_multiple_grouped_conv_dispatchers(
configs: List[GroupedConvKernelConfig],
verbose: bool = True,
max_workers: Optional[int] = None,
) -> List[Optional[GroupedConvDispatcherLib]]:
) -> List[Optional[Path]]:
"""
Setup multiple grouped-conv dispatchers in parallel.
Setup multiple grouped-conv dispatchers.
This keeps architecture filtering strict:
1. Validate + auto-correct each requested config
2. Query codegen's arch-valid config set for each (arch, dtype, variant, ndim)
3. Map each request to nearest valid config
4. Parallel codegen + parallel compile
Returns library paths WITHOUT loading them, to avoid GPU context during compilation.
This mirrors FMHA design: keep GPU context out of JIT phase entirely.
Architecture filtering workflow:
1. Validate each requested config via validate_grouped_conv_config; if invalid,
attempt auto_correct_grouped_conv_config. Drop configs that remain invalid.
2. Trust the (possibly auto-corrected) config as-is. Knobs such as scheduler,
num_groups_to_merge, double_smem_buffer, split_image, two_stage are preserved
exactly as requested -- no remap to a hardcoded "default" set.
3. Threaded codegen + threaded compile (workers shell out via subprocess,
which releases the GIL; threads avoid the fork-after-HIP risk that
ProcessPoolExecutor would have).
4. Return paths (NOT loaded libraries).
Returns:
List of paths to compiled .so files (or None for failed configs)
"""
if not configs:
return []
codegen_script = (
Path(__file__).parent.parent / "codegen" / "unified_grouped_conv_codegen.py"
)
arch_valid_cache: Dict[
Tuple[str, str, str, int], List[GroupedConvKernelConfig]
] = {}
selected_configs: List[Optional[GroupedConvKernelConfig]] = []
for i, original in enumerate(configs):
c = copy.deepcopy(original)
@@ -1714,34 +1884,10 @@ def setup_multiple_grouped_conv_dispatchers(
c.scheduler = str(_first(trait_cfg.get("scheduler", c.scheduler)))
c.epilogue = str(_first(trait_cfg.get("epilogue", c.epilogue)))
cache_key = (c.arch, c.dtype, c.variant, c.ndim_spatial)
if cache_key not in arch_valid_cache:
arch_valid_cache[cache_key] = _list_arch_valid_grouped_conv_configs(
codegen_script=codegen_script,
arch=c.arch,
dtype=c.dtype,
variant=c.variant,
ndim_spatial=c.ndim_spatial,
)
if verbose and not arch_valid_cache[cache_key]:
print(
f" FAIL [{i}] no arch-valid configs listed for "
f"{c.arch}/{c.dtype}/{c.variant}/{c.ndim_spatial}d"
)
candidates = arch_valid_cache[cache_key]
if not candidates:
selected_configs.append(None)
continue
selected = _select_best_arch_valid_conv_config(c, candidates)
if verbose and _config_key(selected) != _config_key(c):
print(
f" INFO [{i}] mapped to arch-valid config: "
f"{selected.tile_str} {selected.wave_str} {selected.warp_str} "
f"{selected.pipeline}/{selected.scheduler}/{selected.epilogue}"
)
selected_configs.append(selected)
# Trust the validated config -- no remap to a hardcoded arch-valid set.
# Knobs (num_groups_to_merge, double_smem_buffer, split_image, two_stage)
# and scheduler choice are preserved exactly as requested.
selected_configs.append(c)
unique_configs: List[GroupedConvKernelConfig] = []
unique_index_by_key: Dict[Tuple[Any, ...], int] = {}
@@ -1761,33 +1907,32 @@ def setup_multiple_grouped_conv_dispatchers(
unique_configs, verbose=verbose
)
libs: List[Optional[GroupedConvDispatcherLib]] = []
loaded_cache: Dict[int, Optional[GroupedConvDispatcherLib]] = {}
# Map unique lib paths back to input order
# DO NOT load libraries here - just return paths
lib_paths: List[Optional[Path]] = []
path_cache: Dict[int, Optional[Path]] = {}
for input_idx, unique_idx in enumerate(input_to_unique):
if unique_idx is None:
libs.append(None)
lib_paths.append(None)
continue
if unique_idx in loaded_cache:
libs.append(loaded_cache[unique_idx])
if unique_idx in path_cache:
lib_paths.append(path_cache[unique_idx])
continue
path = (
unique_lib_paths[unique_idx] if unique_idx < len(unique_lib_paths) else None
)
disp: Optional[GroupedConvDispatcherLib] = None
if path and path.exists():
try:
lib = ctypes.CDLL(str(path))
disp = GroupedConvDispatcherLib(lib, path)
disp.initialize()
except Exception as e:
if verbose:
print(f" FAIL [{input_idx}] failed to load {path}: {e}")
loaded_cache[unique_idx] = disp
libs.append(disp)
# Validate path exists but don't load it
if path and not path.exists():
if verbose:
print(f" FAIL [{input_idx}] library not found: {path}")
path = None
return libs
path_cache[unique_idx] = path
lib_paths.append(path)
return lib_paths
def detect_gpu_arch() -> str: