Files
composable_kernel/dispatcher/scripts/compile_grouped_conv_examples.py
Vidyasagar Ananthan 920acd2c12 [rocm-libraries] ROCm/rocm-libraries#5168 (commit 8b5afcb)
[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher

## Motivation

This PR adds CK Tile group convolution (forward, backward-data,
backward-weight) support to the kernel dispatcher, matching and unifying
with the existing dispatcher GEMM infrastructure in architecture and
usability. The dispatcher provides a unified kernel dispatch system with
both C++ and Python frontends, and until now only supported GEMM
operations. This PR enables framework integrators to use the same
declarative kernel workflow for convolutions as they do for GEMM:
declare kernels, build a registry JIT, select kernels within the
registry at runtime, and dispatch to GPU. Future PRs will include
runtime kernel selection heuristics for autotuning of kernel parameters
based on (problem, hardware arch).

## Technical Details

Grouped convolution support has been added to the CK Tile Dispatcher
with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out,
problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime
heuristic kernel selection, and GroupedConvKernelKey with full
ConvConfigBase fields. Python side adds parallel JIT via
registry.build(max_workers) and heuristic registry.select(). Includes 7
C++ and 6 Python examples covering all directions with CPU reference
validation, and shared infrastructure improvements (BaseRegistry CRTP,
structured exceptions). As a sanity check, JIT compile times for a
single kernel remains the same and for multiple kernels there is better
parallelism:
Kernels | 1 worker | 8 workers
1 | 7.7 s | 7.7 s
2 | 15.9 s | 8.2 s
4 | 33.4 s | 9.7 s
6 | 52.3 s | 10.2 s

## Test Plan

145 ephemeral unit tests have been added to test basic functionality.
All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7
C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference
validation for forward, backward-data, and backward-weight (2D) in both
C++ and Python examples pass.

## Test Result

30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56),
53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002
for all directions (fp16 vs fp32 reference).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-09 17:39:35 +00:00

883 lines
31 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Self-contained build script for C++ grouped convolution examples.
Parses DECL_GROUPED_CONV_KERNEL_SET declarations from source files,
generates the needed kernels, and compiles the example.
Includes validation and auto-correction via wildcard expansion.
Usage:
python3 compile_grouped_conv_examples.py examples/grouped_conv/cpp/02_grouped_conv_forward.cpp
python3 compile_grouped_conv_examples.py examples/grouped_conv/cpp/03_grouped_conv_validation.cpp --no-compile
"""
import argparse
import os
import re
import subprocess
import sys
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Optional
# Setup paths
SCRIPT_DIR = Path(__file__).parent.resolve()
DISPATCHER_DIR = SCRIPT_DIR.parent
CK_ROOT = DISPATCHER_DIR.parent
sys.path.insert(0, str(DISPATCHER_DIR / "python"))
sys.path.insert(0, str(DISPATCHER_DIR / "codegen"))
from dispatcher_common import ( # noqa: E402
print_phase,
print_success,
print_error,
print_info,
find_hipcc,
get_arch_filter_data,
get_build_dir,
get_ck_root,
get_dispatcher_root,
get_generated_kernels_dir,
)
def extract_grouped_conv_declarations(source_file: Path) -> list:
"""Extract DECL_GROUPED_CONV_KERNEL_SET declarations from C++ source."""
content = source_file.read_text()
declarations = []
# Pattern: DECL_GROUPED_CONV_KERNEL_SET(name, .add(...).add(...))
# Find all DECL_GROUPED_CONV_KERNEL_SET blocks by matching parentheses
pattern_start = r"DECL_GROUPED_CONV_KERNEL_SET\s*\(\s*(\w+)\s*,"
for match in re.finditer(pattern_start, content):
set_name = match.group(1)
start_pos = match.end()
# Find matching closing paren by counting parens
paren_count = 1 # We're already inside the first paren
end_pos = start_pos
for i, c in enumerate(content[start_pos:]):
if c == "(":
paren_count += 1
elif c == ")":
paren_count -= 1
if paren_count == 0:
end_pos = start_pos + i
break
set_body = content[start_pos:end_pos]
# Pattern 1: Simple add("dtype", "layout", "conv_type", tile_k, tile_c)
simple_add = (
r'\.add\s*\(\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)'
)
for add_match in re.finditer(simple_add, set_body):
conv_type = add_match.group(3)
default_pipeline = (
"compv3" if conv_type in ("bwd_data", "bwd_weight") else "compv4"
)
declarations.append(
{
"set": set_name,
"dtype": add_match.group(1),
"layout": add_match.group(2),
"conv_type": conv_type,
"tile_k": int(add_match.group(4)),
"tile_c": int(add_match.group(5)),
"num_dims": 2,
"pipeline": default_pipeline,
"scheduler": "intrawave",
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"arch": "gfx942",
}
)
# Pattern 2: Full ConvSig()/ConvAlgo() specification
# Find all .add( positions that start with ConvSig()
full_add = r"\.add\s*\(\s*ConvSig\(\)"
add_positions = [m.start() for m in re.finditer(full_add, set_body)]
for pos in add_positions:
# Find matching closing paren by counting parens
paren_count = 0
in_add = False
end = pos
for i, c in enumerate(set_body[pos:]):
if c == "(":
paren_count += 1
in_add = True
elif c == ")":
paren_count -= 1
if in_add and paren_count == 0:
end = pos + i + 1
break
add_str = set_body[pos:end]
# Extract signature part (between ConvSig() and ConvAlgo())
sig_match = re.search(r"ConvSig\(\)(.*?)ConvAlgo\(\)", add_str, re.DOTALL)
if not sig_match:
continue
sig_str = sig_match.group(1)
# Extract algorithm part (between ConvAlgo() and arch string)
algo_match = re.search(
r'ConvAlgo\(\)(.*?),\s*"(\w+)"\s*\)', add_str, re.DOTALL
)
if not algo_match:
continue
algo_str = algo_match.group(1)
arch = algo_match.group(2)
# Parse signature
dtype = "fp16"
dtype_match = re.search(r'\.dtype\s*\(\s*"(\w+)"', sig_str)
if dtype_match:
dtype = dtype_match.group(1)
layout = "nhwgc"
layout_match = re.search(r'\.layout\s*\(\s*"(\w+)"', sig_str)
if layout_match:
layout = layout_match.group(1)
conv_type = "forward"
conv_type_match = re.search(r'\.conv_type\s*\(\s*"(\w+)"', sig_str)
if conv_type_match:
conv_type = conv_type_match.group(1)
num_dims = 2
dims_match = re.search(r"\.dims\s*\(\s*(\d+)", sig_str)
if dims_match:
num_dims = int(dims_match.group(1))
# Parse algorithm
tile_k, tile_c = 128, 128
tile_match = re.search(
r"\.tile\s*\(\s*\d+\s*,\s*(\d+)\s*,\s*(\d+)", algo_str
)
if tile_match:
tile_k = int(tile_match.group(1))
tile_c = int(tile_match.group(2))
wave_m, wave_n, wave_k = 2, 2, 1
wave_match = re.search(
r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?", algo_str
)
if wave_match:
wave_m = int(wave_match.group(1))
wave_n = int(wave_match.group(2))
wave_k = int(wave_match.group(3) or 1)
warp_m, warp_n, warp_k = 32, 32, 16
warp_match = re.search(
r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?", algo_str
)
if warp_match:
warp_m = int(warp_match.group(1))
warp_n = int(warp_match.group(2))
warp_k = int(warp_match.group(3) or 16)
pipeline = "compv4"
pipeline_match = re.search(r'\.pipeline\s*\(\s*"(\w+)"', algo_str)
if pipeline_match:
pipeline = pipeline_match.group(1)
scheduler = "intrawave"
scheduler_match = re.search(r'\.scheduler\s*\(\s*"(\w+)"', algo_str)
if scheduler_match:
scheduler = scheduler_match.group(1)
# Parse additional parameters
vector_a, vector_b, vector_c = 4, 8, 8
vector_match = re.search(
r"\.vector_sizes\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)", algo_str
)
if vector_match:
vector_a = int(vector_match.group(1))
vector_b = int(vector_match.group(2))
vector_c = int(vector_match.group(3))
block_per_cu = 1
block_per_cu_match = re.search(r"\.block_per_cu\s*\(\s*(\d+)", algo_str)
if block_per_cu_match:
block_per_cu = int(block_per_cu_match.group(1))
memory_op = "set"
memory_op_match = re.search(r'\.memory_op\s*\(\s*"(\w+)"', algo_str)
if memory_op_match:
memory_op = memory_op_match.group(1)
epilogue = "cshuffle"
epilogue_match = re.search(r'\.epilogue\s*\(\s*"(\w+)"', algo_str)
if epilogue_match:
epilogue = epilogue_match.group(1)
# Parse num_wave_groups (for V5 pipeline)
num_wave_groups = 1
nwg_match = re.search(r"\.num_wave_groups\s*\(\s*(\d+)", algo_str)
if nwg_match:
num_wave_groups = int(nwg_match.group(1))
# Parse num_groups_to_merge (for merged group grouped convolution)
num_groups_to_merge = 1
ngm_match = re.search(r"\.num_groups_to_merge\s*\(\s*(\d+)", algo_str)
if ngm_match:
num_groups_to_merge = int(ngm_match.group(1))
# Parse double_smem_buffer (for V4 pipeline)
double_smem_buffer = False
dsb_match = re.search(
r"\.double_smem_buffer\s*\(\s*(true|false)", algo_str, re.I
)
if dsb_match:
double_smem_buffer = dsb_match.group(1).lower() == "true"
# Parse padding flags
pad_m, pad_n, pad_k = True, True, True
padding_match = re.search(
r"\.padding\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)",
algo_str,
re.I,
)
if padding_match:
pad_m = padding_match.group(1).lower() == "true"
pad_n = padding_match.group(2).lower() == "true"
pad_k = padding_match.group(3).lower() == "true"
declarations.append(
{
"set": set_name,
"dtype": dtype,
"layout": layout,
"conv_type": conv_type,
"tile_k": tile_k,
"tile_c": tile_c,
"num_dims": num_dims,
"pipeline": pipeline,
"scheduler": scheduler,
"wave_m": wave_m,
"wave_n": wave_n,
"wave_k": wave_k,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_k": warp_k,
"vector_a": vector_a,
"vector_b": vector_b,
"vector_c": vector_c,
"block_per_cu": block_per_cu,
"memory_op": memory_op,
"epilogue": epilogue,
"num_wave_groups": num_wave_groups,
"num_groups_to_merge": num_groups_to_merge,
"double_smem_buffer": double_smem_buffer,
"pad_m": pad_m,
"pad_n": pad_n,
"pad_k": pad_k,
"arch": arch,
}
)
return declarations
# =============================================================================
# VALIDATION AND AUTO-CORRECTION
# =============================================================================
def is_grouped_conv_wildcard_declaration(decl: dict) -> bool:
"""Check if a declaration uses wildcards (-1 or '*')."""
wildcard_fields = ["wave_m", "wave_n", "warp_m", "warp_n", "pipeline", "scheduler"]
for field in wildcard_fields:
val = decl.get(field)
if val == -1 or val == "*":
return True
return False
def validate_grouped_conv_kernel_config(decl: dict, arch: str = "gfx942") -> tuple:
"""Validate a grouped conv kernel configuration against known supported combinations.
Returns: (is_valid, error_message)
"""
# Skip validation for wildcards - expansion will filter invalid combos
if is_grouped_conv_wildcard_declaration(decl):
return (True, None)
arch_data = get_arch_filter_data()
pipeline = decl.get("pipeline", "compv4")
scheduler = decl.get("scheduler", "intrawave")
dtype = decl.get("dtype", "fp16")
wave_m = decl.get("wave_m", 2)
wave_n = decl.get("wave_n", 2)
wave_k = decl.get("wave_k", 1)
warp_m = decl.get("warp_m", 32)
warp_n = decl.get("warp_n", 32)
warp_k = decl.get("warp_k", 16)
errors = []
# Check trait combination (pipeline, epilogue, scheduler)
combo = (pipeline, "cshuffle", scheduler)
if combo in arch_data["trait_unsupported"]:
errors.append(
f"Unsupported trait combination: pipeline={pipeline}, scheduler={scheduler}\n"
f" Valid schedulers for {pipeline}: intrawave"
)
# Check wave configuration for this arch
warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]])
wave_cfg = [wave_m, wave_n, wave_k]
if wave_cfg not in warp_combos:
valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos)
errors.append(
f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}\n"
f" Valid wave configs: {valid_str}"
)
# Check warp tile configuration for this arch and dtype
acc_dtype = "int32" if dtype == "int8" else "fp32"
dtype_key = f"{dtype}_{dtype}_{acc_dtype}"
warp_tile_combos = (
arch_data["warp_tile_combos"]
.get(arch, {})
.get(dtype_key, [[32, 32, 16], [16, 16, 16], [16, 16, 32]])
)
warp_cfg = [warp_m, warp_n, warp_k]
if warp_cfg not in warp_tile_combos:
valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5])
errors.append(
f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}\n"
f" Valid warp tiles: {valid_str}"
)
# Check arch is supported
if arch not in arch_data["supported_archs"]:
errors.append(
f"Unsupported architecture: {arch}\n"
f" Supported: {', '.join(arch_data['supported_archs'])}"
)
if errors:
return (False, "\n".join(errors))
return (True, None)
def expand_grouped_conv_declaration_with_arch_filter(
decl: dict, arch: str = "gfx942"
) -> list:
"""Expand a grouped conv declaration with wildcards into valid configurations.
Wildcards:
- wave_m/wave_n = -1: Try all valid wave configs for this arch
- warp_m/warp_n = -1: Try all valid warp tiles for this arch/dtype
- pipeline/scheduler = "*": Try all valid combinations
Returns a list of fully-specified declarations.
"""
arch_data = get_arch_filter_data()
dtype = decl.get("dtype", "fp16")
# Get valid combinations for this arch
valid_wave_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]])
acc_dtype = "int32" if dtype == "int8" else "fp32"
dtype_key = f"{dtype}_{dtype}_{acc_dtype}"
valid_warp_tiles = (
arch_data["warp_tile_combos"]
.get(arch, {})
.get(dtype_key, [[32, 32, 16], [16, 16, 16]])
)
# Valid pipelines and schedulers
valid_pipelines = ["compv3", "compv4"]
valid_schedulers = ["intrawave"] # interwave often unsupported
# Determine which fields need expansion
expand_wave = decl.get("wave_m", 2) == -1 or decl.get("wave_n", 2) == -1
expand_warp = decl.get("warp_m", 32) == -1 or decl.get("warp_n", 32) == -1
expand_pipeline = decl.get("pipeline", "compv4") == "*"
expand_scheduler = decl.get("scheduler", "intrawave") == "*"
# Build combinations
wave_options = (
valid_wave_combos
if expand_wave
else [[decl.get("wave_m", 2), decl.get("wave_n", 2), decl.get("wave_k", 1)]]
)
warp_options = (
valid_warp_tiles
if expand_warp
else [[decl.get("warp_m", 32), decl.get("warp_n", 32), decl.get("warp_k", 16)]]
)
pipeline_options = (
valid_pipelines if expand_pipeline else [decl.get("pipeline", "compv4")]
)
scheduler_options = (
valid_schedulers if expand_scheduler else [decl.get("scheduler", "intrawave")]
)
expanded = []
for wave in wave_options:
for warp in warp_options:
for pipeline in pipeline_options:
for scheduler in scheduler_options:
# Skip known invalid combinations
if (pipeline, "cshuffle", scheduler) in arch_data[
"trait_unsupported"
]:
continue
new_decl = decl.copy()
new_decl["wave_m"] = wave[0]
new_decl["wave_n"] = wave[1]
new_decl["wave_k"] = wave[2]
new_decl["warp_m"] = warp[0]
new_decl["warp_n"] = warp[1]
new_decl["warp_k"] = warp[2]
new_decl["pipeline"] = pipeline
new_decl["scheduler"] = scheduler
expanded.append(new_decl)
# If no valid expansions, return original (will fail validation later)
if not expanded:
return [decl]
# Return first valid config (or all if needed)
return expanded[:1] # Just use first valid config for grouped conv
def validate_and_expand_grouped_conv_declarations(
declarations: list, arch: str, verbose: bool = False
) -> list:
"""Validate declarations and auto-correct invalid ones via wildcard expansion."""
print(f"\n Validating against {arch} arch filter...")
wildcard_count = 0
invalid_count = 0
auto_corrections = []
for decl in declarations:
decl_arch = decl.get("arch", arch)
decl_name = (
f"{decl['dtype']}_{decl['conv_type']}_{decl['tile_k']}x{decl['tile_c']}"
)
# Check for wildcards
if is_grouped_conv_wildcard_declaration(decl):
wildcard_count += 1
continue
is_valid, error_msg = validate_grouped_conv_kernel_config(decl, decl_arch)
if not is_valid:
print(f"\n WARNING Invalid grouped conv configuration: {decl_name}")
# Parse the error and show specific auto-corrections
corrections = []
original_values = {}
if "wave configuration" in error_msg.lower():
original_values["wave"] = (
f"[{decl.get('wave_m', 2)}, {decl.get('wave_n', 2)}, {decl.get('wave_k', 1)}]"
)
decl["wave_m"] = -1
decl["wave_n"] = -1
corrections.append(
f"wave: {original_values['wave']} -> [wildcard expansion]"
)
if "warp tile" in error_msg.lower():
original_values["warp"] = (
f"[{decl.get('warp_m', 32)}, {decl.get('warp_n', 32)}, {decl.get('warp_k', 16)}]"
)
decl["warp_m"] = -1
decl["warp_n"] = -1
corrections.append(
f"warp_tile: {original_values['warp']} -> [wildcard expansion]"
)
if "trait combination" in error_msg.lower():
original_values["pipeline"] = decl.get("pipeline", "compv4")
original_values["scheduler"] = decl.get("scheduler", "intrawave")
decl["pipeline"] = "*"
decl["scheduler"] = "*"
corrections.append(
f"pipeline: {original_values['pipeline']} -> [wildcard expansion]"
)
corrections.append(
f"scheduler: {original_values['scheduler']} -> [wildcard expansion]"
)
# Print the auto-corrections
print(" AUTO-CORRECTION:")
for corr in corrections:
print(f" - {corr}")
auto_corrections.append((decl_name, corrections))
invalid_count += 1
wildcard_count += 1
if invalid_count > 0:
print(
f"\n WARNING {invalid_count} invalid config(s) auto-corrected via wildcard expansion"
)
if wildcard_count > 0:
print(
f" OK {len(declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)"
)
else:
print(f" OK All {len(declarations)} configurations valid")
# Expand wildcards
print("\n Expanding wildcards to valid configurations...")
expanded_declarations = []
for decl in declarations:
decl_arch = decl.get("arch", arch)
decl_name = (
f"{decl['dtype']}_{decl['conv_type']}_{decl['tile_k']}x{decl['tile_c']}"
)
expanded = expand_grouped_conv_declaration_with_arch_filter(decl, decl_arch)
expanded_declarations.extend(expanded)
if len(expanded) > 1:
print(
f" {decl_name}: expanded to {len(expanded)} valid configurations"
)
for exp in expanded[:3]:
wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]"
warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]"
print(
f" -> wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}"
)
if len(expanded) > 3:
print(f" ... and {len(expanded) - 3} more")
elif is_grouped_conv_wildcard_declaration(decl) and len(expanded) == 1:
exp = expanded[0]
wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]"
warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]"
print(f" {decl_name}: -> wave={wave_str}, warp={warp_str}")
if len(expanded_declarations) != len(declarations):
print(
f"\n Total: {len(declarations)} declarations -> {len(expanded_declarations)} configurations"
)
return expanded_declarations
def _generate_single_grouped_conv_kernel(args: tuple) -> tuple:
"""Generate one grouped conv kernel (picklable for ProcessPoolExecutor).
Args: (decl, output_dir_str, gpu_target)
Returns: (idx, filepath_str or None, error_str or None)
"""
decl, output_dir_str, gpu_target = args
output_dir = Path(output_dir_str)
idx = decl.get("_idx", 0)
try:
from codegen_common import TileConfig
from unified_grouped_conv_codegen import (
GroupedConvKernelConfig,
GroupedConvTraitConfig,
GroupedConvVariant,
UnifiedGroupedConvCodegen,
)
# Map conv_type to variant
variant = GroupedConvVariant.FORWARD
if decl["conv_type"] == "bwd_data":
variant = GroupedConvVariant.BACKWARD_DATA
elif decl["conv_type"] == "bwd_weight":
variant = GroupedConvVariant.BACKWARD_WEIGHT
pipeline = decl.get("pipeline", "compv4")
adj_tile_k = 64 * 2 if pipeline == "compv4" else 64
# Create tile config (tile_m=tile_k, tile_n=tile_c for conv GEMM view)
tile = TileConfig(
tile_m=decl["tile_k"],
tile_n=decl["tile_c"],
tile_k=adj_tile_k,
warp_m=decl["wave_m"],
warp_n=decl["wave_n"],
warp_k=decl.get("wave_k", 1),
warp_tile_m=decl["warp_m"],
warp_tile_n=decl["warp_n"],
warp_tile_k=decl["warp_k"],
)
trait = GroupedConvTraitConfig(
pipeline=pipeline,
scheduler=decl["scheduler"],
epilogue=decl.get("epilogue", "cshuffle"),
double_smem_buffer=decl.get("double_smem_buffer", False),
pad_m=decl.get("pad_m", True),
pad_n=decl.get("pad_n", True),
pad_k=decl.get("pad_k", True),
num_groups_to_merge=decl.get("num_groups_to_merge", 1),
)
config = GroupedConvKernelConfig(
tile=tile,
trait=trait,
variant=variant,
ndim_spatial=decl["num_dims"],
arch=decl.get("arch", gpu_target),
vector_size_a=decl.get("vector_a", 4),
vector_size_b=decl.get("vector_b", 8),
vector_size_c=decl.get("vector_c", 8),
block_per_cu=decl.get("block_per_cu", 1),
num_wave_groups=decl.get("num_wave_groups", 1),
num_groups_to_merge=decl.get("num_groups_to_merge", 1),
double_smem_buffer=decl.get("double_smem_buffer", False),
)
codegen = UnifiedGroupedConvCodegen(output_dir, gpu_target=gpu_target)
kernel_path, _ = codegen.generate_kernel(config, decl["dtype"], variant)
return (idx, str(kernel_path), None)
except Exception as e:
return (idx, None, str(e))
def generate_grouped_conv_kernels(
declarations: list,
output_dir: Path,
gpu_target: str = "gfx942",
max_workers: Optional[int] = None,
) -> list:
"""Generate grouped convolution kernels using unified_grouped_conv_codegen.
Uses ProcessPoolExecutor for parallel kernel generation.
"""
output_dir.mkdir(parents=True, exist_ok=True)
# Prepare work items (add _idx for ordering)
work_items = []
for idx, decl in enumerate(declarations):
decl_copy = decl.copy()
decl_copy["_idx"] = idx
work_items.append((decl_copy, str(output_dir), gpu_target))
max_workers = max_workers or min(len(work_items), os.cpu_count() or 4)
generated = []
failed = []
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(_generate_single_grouped_conv_kernel, w): w[0]["_idx"]
for w in work_items
}
for future in as_completed(futures):
idx, path, err = future.result()
if path:
generated.append(Path(path))
print_info(f" Generated: {Path(path).name}")
else:
failed.append((idx, err))
print_error(f" Failed kernel {idx + 1}: {err}")
if failed:
for idx, err in failed[:3]:
print_error(f" Kernel {idx + 1}: {err[:200]}")
if len(failed) > 3:
print_error(f" ... and {len(failed) - 3} more failures")
return generated
def compile_grouped_conv_example(
source_file: Path,
output_bin: Path,
kernel_headers: list,
hipcc: str,
gpu_target: str,
) -> bool:
"""Compile the C++ example with generated kernels."""
kernel_dir = get_generated_kernels_dir()
ck_root = get_ck_root()
dispatcher_dir = get_dispatcher_root()
includes = [
f"-I{ck_root / 'include'}",
f"-I{dispatcher_dir / 'include'}",
f"-I{kernel_dir}",
]
# Build include flags for generated kernels
kernel_includes = []
for header in kernel_headers:
kernel_includes.extend(["-include", str(header)])
# Add define to indicate kernels are available
defines = ["-DGROUPED_CONV_KERNEL_AVAILABLE=1"]
cmd = [
hipcc,
"-std=c++20",
"-O2",
f"--offload-arch={gpu_target}",
*includes,
*defines,
*kernel_includes,
"-o",
str(output_bin),
str(source_file),
]
print_info(f" Compiling: {source_file.name}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
if result.stderr:
lines = result.stderr.split("\n")
errors = [line for line in lines if "error:" in line.lower()][:5]
for err_line in errors:
print_error(f" {err_line}")
return False
return True
def main():
parser = argparse.ArgumentParser(
description="Build C++ grouped convolution example with self-contained kernel generation"
)
parser.add_argument("source", help="Source file (.cpp)")
parser.add_argument("--output", "-o", help="Output binary name")
parser.add_argument("--gpu-target", default="gfx942", help="GPU target")
parser.add_argument(
"--no-compile", action="store_true", help="Only generate kernels, don't compile"
)
parser.add_argument("--verbose", "-v", action="store_true")
parser.add_argument(
"--jobs",
"-j",
type=int,
default=None,
help="Parallel jobs for kernel generation (default: cpu_count)",
)
args = parser.parse_args()
# Resolve source file
source_file = Path(args.source)
if not source_file.is_absolute():
candidates = [
get_dispatcher_root() / args.source,
Path.cwd() / args.source,
]
for c in candidates:
if c.exists():
source_file = c
break
if not source_file.exists():
print_error(f"Source file not found: {source_file}")
return 1
build_dir = get_build_dir()
kernel_dir = get_generated_kernels_dir()
output_name = args.output or source_file.stem
output_bin = build_dir / output_name
print_success("=== Grouped Conv Example Builder (Self-Contained) ===")
# Phase 1: Extract declarations
print_phase(1, "Scanning for DECL_GROUPED_CONV_KERNEL_SET...")
declarations = extract_grouped_conv_declarations(source_file)
if not declarations:
print_error(" No DECL_GROUPED_CONV_KERNEL_SET declarations found!")
return 1
print(f" Found {len(declarations)} kernel declaration(s):")
for decl in declarations:
name = f"{decl['dtype']}_{decl['conv_type']}_{decl['num_dims']}d_{decl['tile_k']}x{decl['tile_c']}"
print(f" [{decl['set']}] {name}")
# Phase 2: Validate and expand
print_phase(2, "Validating and expanding declarations...")
declarations = validate_and_expand_grouped_conv_declarations(
declarations, args.gpu_target, args.verbose
)
print()
# Phase 3: Generate kernels
print_phase(3, "Generating kernels...")
generated = generate_grouped_conv_kernels(
declarations, kernel_dir, args.gpu_target, max_workers=args.jobs
)
if not generated:
print_error(" No kernels generated!")
return 1
print(f" Generated {len(generated)} kernel file(s)")
print()
# Phase 4: Compile (optional)
if args.no_compile:
print_info("Skipping compilation (--no-compile)")
print()
print_success("=== Kernel Generation Complete ===")
print(f"Kernels in: {kernel_dir}")
return 0
print_phase(4, "Compiling example...")
hipcc_path = find_hipcc()
if not hipcc_path:
print_error(" hipcc not found. Install ROCm or set HIPCC env var.")
print(" To compile manually:")
ck_root = get_dispatcher_root().parent
print(
f" hipcc -std=c++20 -O2 -I{ck_root / 'include'} -I{get_dispatcher_root() / 'include'} \\"
)
print(f" -I{kernel_dir} \\")
for h in generated[:1]:
print(f" -include {h} \\")
print(" -DGROUPED_CONV_KERNEL_AVAILABLE=1 \\")
print(f" --offload-arch={args.gpu_target} \\")
print(f" {source_file} -o {output_bin}")
return 1
build_dir.mkdir(parents=True, exist_ok=True)
if not compile_grouped_conv_example(
source_file, output_bin, generated, hipcc_path, args.gpu_target
):
print_error(" Compilation failed!")
return 1
print_success(f" Output: {output_bin}")
print()
print_success("=== Build Complete ===")
print()
print("Run with:")
print(f" {output_bin}")
return 0
if __name__ == "__main__":
sys.exit(main())