[rocm-libraries] ROCm/rocm-libraries#5168 (commit 8b5afcb)

[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher

## Motivation

This PR adds CK Tile group convolution (forward, backward-data,
backward-weight) support to the kernel dispatcher, matching and unifying
with the existing dispatcher GEMM infrastructure in architecture and
usability. The dispatcher provides a unified kernel dispatch system with
both C++ and Python frontends, and until now only supported GEMM
operations. This PR enables framework integrators to use the same
declarative kernel workflow for convolutions as they do for GEMM:
declare kernels, build a registry JIT, select kernels within the
registry at runtime, and dispatch to GPU. Future PRs will include
runtime kernel selection heuristics for autotuning of kernel parameters
based on (problem, hardware arch).

## Technical Details

Grouped convolution support has been added to the CK Tile Dispatcher
with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out,
problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime
heuristic kernel selection, and GroupedConvKernelKey with full
ConvConfigBase fields. Python side adds parallel JIT via
registry.build(max_workers) and heuristic registry.select(). Includes 7
C++ and 6 Python examples covering all directions with CPU reference
validation, and shared infrastructure improvements (BaseRegistry CRTP,
structured exceptions). As a sanity check, JIT compile times for a
single kernel remains the same and for multiple kernels there is better
parallelism:
Kernels | 1 worker | 8 workers
1 | 7.7 s | 7.7 s
2 | 15.9 s | 8.2 s
4 | 33.4 s | 9.7 s
6 | 52.3 s | 10.2 s

## Test Plan

145 ephemeral unit tests have been added to test basic functionality.
All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7
C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference
validation for forward, backward-data, and backward-weight (2D) in both
C++ and Python examples pass.

## Test Result

30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56),
53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002
for all directions (fp16 vs fp32 reference).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Vidyasagar Ananthan
2026-04-09 17:39:35 +00:00
committed by assistant-librarian[bot]
parent 4c0e73ab12
commit 920acd2c12
86 changed files with 15538 additions and 1500 deletions

View File

@@ -55,10 +55,10 @@ def extract_balanced_parens(text: str, start_pos: int) -> str:
def parse_conv_declarations(content: str) -> List[Dict]:
"""Parse DECL_CONV_KERNEL_SET declarations with all parameters."""
"""Parse DECL_GROUPED_CONV_KERNEL_SET declarations with all parameters."""
kernels = []
for match in re.finditer(r"DECL_CONV_KERNEL_SET\s*\(", content):
for match in re.finditer(r"DECL_GROUPED_CONV_KERNEL_SET\s*\(", content):
body = extract_balanced_parens(content, match.end() - 1)
if not body:
continue
@@ -619,7 +619,7 @@ def strip_cpp_strings_and_comments(content: str) -> str:
n = len(content)
# Patterns that indicate a string is problematic and should be stripped
problematic_patterns = ["DECL_KERNEL_SET", "DECL_CONV_KERNEL_SET", ".add("]
problematic_patterns = ["DECL_KERNEL_SET", "DECL_GROUPED_CONV_KERNEL_SET", ".add("]
while i < n:
# Check for raw string literal: R"delimiter(...)delimiter"
@@ -697,7 +697,7 @@ def detect_and_parse(source_path: Path) -> Tuple[str, List[Dict]]:
content = source_path.read_text()
content = strip_cpp_strings_and_comments(content)
if "DECL_CONV_KERNEL_SET" in content:
if "DECL_GROUPED_CONV_KERNEL_SET" in content:
return "conv", parse_conv_declarations(content)
elif "DECL_KERNEL_SET" in content:
return "gemm", parse_gemm_declarations(content)
@@ -966,30 +966,128 @@ def generate_per_set_functions(source_stem: str) -> str:
def generate_conv_registration(
kernel_headers: List[Path], example_name: str, kernels: List[Dict]
) -> str:
"""Generate Conv kernel registration code for the dispatcher registry."""
"""Generate Conv kernel registration code for the dispatcher registry.
Creates real GroupedConvKernelInstance entries backed by the generated
launcher's launch() method via the conv backend RunFn factories.
"""
if not kernel_headers:
return " // No kernels to register"
lines = []
lines.append(
" (void)registry; (void)arch; // Conv uses direct launcher pattern for now"
)
# For conv, we provide direct access to kernel launchers
for i, h in enumerate(kernel_headers):
kernel_name = h.stem
lines.append(f" // Kernel {i + 1}: {kernel_name}")
kname = h.stem
ns = f"ns_{kname}"
launcher = f"{ns}::{kname}_Launcher"
# Determine direction and ndim from the kernel header name
if "_fwd_" in kname:
direction = "Forward"
run_fn_factory = "make_conv_fwd_run_fn"
elif "_bwd_data_" in kname or "_bwdd_" in kname:
direction = "BackwardData"
run_fn_factory = "make_conv_bwd_data_run_fn"
elif "_bwd_weight_" in kname or "_bwdw_" in kname:
direction = "BackwardWeight"
run_fn_factory = "make_conv_bwd_weight_run_fn"
else:
direction = "Forward"
run_fn_factory = "make_conv_fwd_run_fn"
ndim = 3 if "_3d_" in kname else 2
# Parse dtype from name (e.g. grouped_conv_fwd_fp16_...)
dtype = "fp16"
for dt in ["fp16", "bf16", "fp32"]:
if f"_{dt}_" in kname:
dtype = dt
break
# Parse tile, wave, warp from name.
# Format: ..._TILExTILExTILE_WAVExWAVExWAVE_WARPxWARPxWARP_...
import re as _re
tile_m, tile_n, tile_k = 1, 128, 128
wave_m, wave_n, wave_k = 2, 2, 1
warp_m, warp_n, warp_k = 32, 32, 16
triplets = _re.findall(r"_(\d+)x(\d+)x(\d+)", kname)
if len(triplets) >= 1:
tile_m, tile_n, tile_k = (
int(triplets[0][0]),
int(triplets[0][1]),
int(triplets[0][2]),
)
if len(triplets) >= 2:
wave_m, wave_n, wave_k = (
int(triplets[1][0]),
int(triplets[1][1]),
int(triplets[1][2]),
)
if len(triplets) >= 3:
warp_m, warp_n, warp_k = (
int(triplets[2][0]),
int(triplets[2][1]),
int(triplets[2][2]),
)
pipeline = "compv4" if "compv4" in kname else "compv3"
scheduler = "interwave" if "interwave" in kname else "intrawave"
epilogue = "cshuffle" if "cshuffle" in kname else "default"
# ConvConfigBase defaults
vec_a, vec_b, vec_c = 4, 8, 8
block_per_cu = 1
num_wave_groups = 1
num_groups_to_merge = 1
lines.append(f" // Kernel {i + 1}: {kname}")
lines.append(" {")
lines.append(f" ck_tile::dispatcher::GroupedConvKernelKey key_{i};")
lines.append(f' key_{i}.dtype_in = "{dtype}";')
lines.append(f' key_{i}.dtype_wei = "{dtype}";')
lines.append(f' key_{i}.dtype_out = "{dtype}";')
lines.append(f' key_{i}.layout = "nhwgc";')
lines.append(f" key_{i}.ndim_spatial = {ndim};")
lines.append(
f" key_{i}.op = ck_tile::dispatcher::GroupedConvOp::{direction};"
)
lines.append(f" key_{i}.tile_m = {tile_m};")
lines.append(f" key_{i}.tile_n = {tile_n};")
lines.append(f" key_{i}.tile_k = {tile_k};")
lines.append(f" key_{i}.wave_m = {wave_m};")
lines.append(f" key_{i}.wave_n = {wave_n};")
lines.append(f" key_{i}.wave_k = {wave_k};")
lines.append(f" key_{i}.warp_m = {warp_m};")
lines.append(f" key_{i}.warp_n = {warp_n};")
lines.append(f" key_{i}.warp_k = {warp_k};")
lines.append(f' key_{i}.pipeline = "{pipeline}";')
lines.append(f' key_{i}.scheduler = "{scheduler}";')
lines.append(f' key_{i}.epilogue = "{epilogue}";')
lines.append(f" key_{i}.vector_size_a = {vec_a};")
lines.append(f" key_{i}.vector_size_b = {vec_b};")
lines.append(f" key_{i}.vector_size_c = {vec_c};")
lines.append(f" key_{i}.block_per_cu = {block_per_cu};")
lines.append(f" key_{i}.num_wave_groups = {num_wave_groups};")
lines.append(f" key_{i}.num_groups_to_merge = {num_groups_to_merge};")
lines.append(f" key_{i}.arch = arch;")
lines.append(
f" auto run_fn_{i} = ck_tile::dispatcher::backends::{run_fn_factory}<{launcher}, {ndim}>();"
)
lines.append(
f' auto inst_{i} = std::make_shared<ck_tile::dispatcher::GroupedConvKernelInstance>(key_{i}, "{kname}", std::move(run_fn_{i}));'
)
lines.append(f" registry.register_kernel(key_{i}, inst_{i});")
lines.append(" }")
return "\n".join(lines)
def generate_conv_kernels(
kernels: List[Dict], output_dir: Path, codegen_dir: Path
) -> bool:
"""Generate Conv kernels for ALL declarations using unified codegen."""
if not kernels:
return False
def _build_conv_codegen_cmd(
idx: int, k: Dict, codegen_dir: Path, output_dir: Path
) -> Tuple[int, List[str], str]:
"""Build the command for a single conv kernel codegen invocation."""
variant_map = {
"forward": "forward",
"bwd_data": "bwd_data",
@@ -997,93 +1095,130 @@ def generate_conv_kernels(
"bwd_weight": "bwd_weight",
"backward_weight": "bwd_weight",
}
variant = variant_map.get(k.get("conv_type", "forward"), "forward")
cmd = [
sys.executable,
str(codegen_dir / "unified_grouped_conv_codegen.py"),
"--datatype",
k.get("dtype", "fp16"),
"--variant",
variant,
"--ndim",
str(k.get("ndim", 2)),
"--output",
str(output_dir),
]
if k.get("tile_m"):
cmd.extend(["--tile-m", str(k["tile_m"])])
if k.get("tile_n"):
cmd.extend(["--tile-n", str(k["tile_n"])])
if k.get("warp_m"):
cmd.extend(["--warp-m", str(k["warp_m"])])
if k.get("warp_n"):
cmd.extend(["--warp-n", str(k["warp_n"])])
if k.get("warp_k"):
cmd.extend(["--warp-k", str(k["warp_k"])])
if k.get("warp_tile_m"):
cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])])
if k.get("warp_tile_n"):
cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])])
if k.get("warp_tile_k"):
cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])])
if k.get("pipeline"):
cmd.extend(["--pipeline", k["pipeline"]])
if k.get("scheduler"):
cmd.extend(["--scheduler", k["scheduler"]])
if k.get("epilogue"):
cmd.extend(["--epilogue", k["epilogue"]])
if k.get("vector_a"):
cmd.extend(["--vector-a", str(k["vector_a"])])
if k.get("vector_b"):
cmd.extend(["--vector-b", str(k["vector_b"])])
if k.get("vector_c"):
cmd.extend(["--vector-c", str(k["vector_c"])])
if k.get("block_per_cu"):
cmd.extend(["--block-per-cu", str(k["block_per_cu"])])
if k.get("num_wave_groups"):
cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])])
if k.get("num_groups_to_merge"):
cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])])
if k.get("double_smem_buffer") is not None:
cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()])
if k.get("tile_k"):
cmd.extend(["--tile-k", str(k["tile_k"])])
return (idx, cmd, str(codegen_dir))
def _run_conv_codegen(args: Tuple) -> Tuple[int, bool, str]:
"""Run unified_grouped_conv_codegen.py for a single kernel config (picklable for ProcessPoolExecutor)."""
idx, cmd, cwd = args
result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd)
if result.returncode != 0:
return (idx, False, result.stderr[:300])
return (idx, True, "")
def generate_conv_kernels(
kernels: List[Dict], output_dir: Path, codegen_dir: Path
) -> bool:
"""Generate Conv kernels for ALL declarations using unified codegen.
Launches all codegen subprocesses in parallel via ProcessPoolExecutor
for significantly faster generation when multiple conv kernels are declared.
"""
if not kernels:
return False
work_items = [
_build_conv_codegen_cmd(idx, k, codegen_dir, output_dir)
for idx, k in enumerate(kernels)
]
success_count = 0
max_workers = min(len(work_items), os.cpu_count() or 4)
# Generate a kernel for EACH declaration
for idx, k in enumerate(kernels):
variant = variant_map.get(k.get("conv_type", "forward"), "forward")
cmd = [
sys.executable,
str(codegen_dir / "unified_conv_codegen.py"),
"--datatype",
k.get("dtype", "fp16"),
"--variant",
variant,
"--ndim",
str(k.get("ndim", 2)),
"--output",
str(output_dir),
]
# Add optional parameters if specified
if k.get("tile_m"):
cmd.extend(["--tile-m", str(k["tile_m"])])
if k.get("tile_n"):
cmd.extend(["--tile-n", str(k["tile_n"])])
if k.get("warp_m"):
cmd.extend(["--warp-m", str(k["warp_m"])])
if k.get("warp_n"):
cmd.extend(["--warp-n", str(k["warp_n"])])
if k.get("warp_k"):
cmd.extend(["--warp-k", str(k["warp_k"])])
if k.get("warp_tile_m"):
cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])])
if k.get("warp_tile_n"):
cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])])
if k.get("warp_tile_k"):
cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])])
if k.get("pipeline"):
cmd.extend(["--pipeline", k["pipeline"]])
if k.get("scheduler"):
cmd.extend(["--scheduler", k["scheduler"]])
if k.get("epilogue"):
cmd.extend(["--epilogue", k["epilogue"]])
if k.get("vector_a"):
cmd.extend(["--vector-a", str(k["vector_a"])])
if k.get("vector_b"):
cmd.extend(["--vector-b", str(k["vector_b"])])
if k.get("vector_c"):
cmd.extend(["--vector-c", str(k["vector_c"])])
if k.get("block_per_cu"):
cmd.extend(["--block-per-cu", str(k["block_per_cu"])])
if k.get("num_wave_groups"):
cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])])
if k.get("num_groups_to_merge"):
cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])])
if k.get("double_smem_buffer") is not None:
cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()])
if k.get("tile_k"):
cmd.extend(["--tile-k", str(k["tile_k"])])
result = subprocess.run(
cmd, capture_output=True, text=True, cwd=str(codegen_dir)
)
if result.returncode != 0:
print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}")
else:
success_count += 1
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(_run_conv_codegen, w): w[0] for w in work_items}
for future in as_completed(futures):
idx, ok, err = future.result()
if ok:
success_count += 1
else:
print(f" Codegen error for kernel {idx + 1}: {err}")
return success_count > 0
def _run_gemm_codegen(args: Tuple) -> Tuple[int, bool, str]:
"""Run unified_gemm_codegen.py for a single kernel config (picklable for ProcessPoolExecutor)."""
idx, cmd, cwd = args
result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd)
if result.returncode != 0:
return (idx, False, result.stderr[:300])
return (idx, True, "")
def generate_gemm_kernels(
kernels: List[Dict], output_dir: Path, codegen_dir: Path
) -> bool:
"""Generate GEMM kernels for ALL declarations using unified codegen."""
"""Generate GEMM kernels for ALL declarations using unified codegen.
Launches all codegen subprocesses in parallel via ProcessPoolExecutor
for significantly faster generation when multiple kernels are declared.
"""
import json
if not kernels:
return False
success_count = 0
# Generate a kernel for EACH declaration
# Build all commands upfront
work_items = []
for idx, k in enumerate(kernels):
variant = "multi_d" if k.get("elementwise_op") else "standard"
# Build tile config JSON for this specific kernel
tile_config = {
"tile_m": [k.get("tile_m", 128)],
"tile_n": [k.get("tile_n", 128)],
@@ -1125,13 +1260,20 @@ def generate_gemm_kernels(
config_json,
]
result = subprocess.run(
cmd, capture_output=True, text=True, cwd=str(codegen_dir)
)
if result.returncode != 0:
print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}")
else:
success_count += 1
work_items.append((idx, cmd, str(codegen_dir)))
# Run all codegen subprocesses in parallel
success_count = 0
max_workers = min(len(work_items), os.cpu_count() or 4)
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(_run_gemm_codegen, w): w[0] for w in work_items}
for future in as_completed(futures):
idx, ok, err = future.result()
if ok:
success_count += 1
else:
print(f" Codegen error for kernel {idx + 1}: {err}")
return success_count > 0
@@ -1229,15 +1371,17 @@ def main():
if example_type == "gemm":
kernel_headers = list(args.output_dir.glob("gemm_*.hpp"))
else:
k = kernels[0] if kernels else {}
variant = k.get("conv_type", "forward")
prefix_map = {
"forward": "conv_fwd",
"bwd_data": "conv_bwdd",
"bwd_weight": "conv_bwdw",
"forward": "grouped_conv_fwd",
"bwd_data": "grouped_conv_bwd_data",
"bwd_weight": "grouped_conv_bwd_weight",
}
prefix = prefix_map.get(variant, "conv_fwd")
kernel_headers = list(args.output_dir.glob(f"{prefix}_*.hpp"))
# Collect headers from ALL variants present in declarations
variants_used = set(k.get("conv_type", "forward") for k in kernels)
kernel_headers = []
for variant in variants_used:
prefix = prefix_map.get(variant, "grouped_conv_fwd")
kernel_headers.extend(args.output_dir.glob(f"{prefix}_*.hpp"))
if not kernel_headers:
print(f"[{target_name}] No kernel headers generated!")
@@ -1347,29 +1491,39 @@ def main():
)
if has_bwd_data:
bwdd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdd_")
if bwdd_kernel:
bwdd_ns = f"ns_{bwdd_kernel.stem}"
launcher_aliases.append(
f"using BwdDataKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;"
bwd_data_kernel = find_kernel_by_dtype_type(
kernel_headers, "fp16", "_bwd_data_"
)
if not bwd_data_kernel:
bwd_data_kernel = find_kernel_by_dtype_type(
kernel_headers, "fp16", "_bwdd_"
)
if not has_fwd: # If no fwd, use bwd_data as first
if bwd_data_kernel:
bwd_data_ns = f"ns_{bwd_data_kernel.stem}"
launcher_aliases.append(
f"using BwdDataKernelLauncher = {bwd_data_ns}::{bwd_data_kernel.stem}_Launcher;"
)
if not has_fwd:
launcher_aliases.append(
f"using FirstKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;"
f"using FirstKernelLauncher = {bwd_data_ns}::{bwd_data_kernel.stem}_Launcher;"
)
if has_bwd_weight:
bwdw_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdw_")
if bwdw_kernel:
bwdw_ns = f"ns_{bwdw_kernel.stem}"
launcher_aliases.append(
f"using BwdWeightKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;"
bwd_weight_kernel = find_kernel_by_dtype_type(
kernel_headers, "fp16", "_bwd_weight_"
)
if not bwd_weight_kernel:
bwd_weight_kernel = find_kernel_by_dtype_type(
kernel_headers, "fp16", "_bwdw_"
)
if (
not has_fwd and not has_bwd_data
): # If no fwd or bwdd, use bwdw as first
if bwd_weight_kernel:
bwd_weight_ns = f"ns_{bwd_weight_kernel.stem}"
launcher_aliases.append(
f"using BwdWeightKernelLauncher = {bwd_weight_ns}::{bwd_weight_kernel.stem}_Launcher;"
)
if not has_fwd and not has_bwd_data:
launcher_aliases.append(
f"using FirstKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;"
f"using FirstKernelLauncher = {bwd_weight_ns}::{bwd_weight_kernel.stem}_Launcher;"
)
launcher_section = "\n".join(launcher_aliases)
@@ -1382,14 +1536,16 @@ def main():
#include "ck_tile/dispatcher/registry.hpp"
#include "ck_tile/dispatcher/kernel_instance.hpp"
#include "ck_tile/dispatcher/kernel_key.hpp"
#include "ck_tile/dispatcher/grouped_conv_registry.hpp"
#include "ck_tile/dispatcher/backends/generated_conv_backend.hpp"
namespace generated {{
// Kernel launchers for direct use
{launcher_section}
// Registration function
inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{
// Registration function (takes GroupedConvRegistry for conv kernels)
inline void {func_name}(ck_tile::dispatcher::GroupedConvRegistry& registry, const std::string& arch) {{
{register_body}
}}
@@ -1439,7 +1595,7 @@ inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::stri
"""
header_path.write_text(header_content)
print(f"[{target_name}] {len(obj_files)} kernels compiled")
print(f"[{target_name}] OK {len(obj_files)} kernels compiled")
return 0