Files
composable_kernel/dispatcher/scripts/example_kernel_builder.py
Vidyasagar Ananthan 920acd2c12 [rocm-libraries] ROCm/rocm-libraries#5168 (commit 8b5afcb)
[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher

## Motivation

This PR adds CK Tile group convolution (forward, backward-data,
backward-weight) support to the kernel dispatcher, matching and unifying
with the existing dispatcher GEMM infrastructure in architecture and
usability. The dispatcher provides a unified kernel dispatch system with
both C++ and Python frontends, and until now only supported GEMM
operations. This PR enables framework integrators to use the same
declarative kernel workflow for convolutions as they do for GEMM:
declare kernels, build a registry JIT, select kernels within the
registry at runtime, and dispatch to GPU. Future PRs will include
runtime kernel selection heuristics for autotuning of kernel parameters
based on (problem, hardware arch).

## Technical Details

Grouped convolution support has been added to the CK Tile Dispatcher
with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out,
problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime
heuristic kernel selection, and GroupedConvKernelKey with full
ConvConfigBase fields. Python side adds parallel JIT via
registry.build(max_workers) and heuristic registry.select(). Includes 7
C++ and 6 Python examples covering all directions with CPU reference
validation, and shared infrastructure improvements (BaseRegistry CRTP,
structured exceptions). As a sanity check, JIT compile times for a
single kernel remains the same and for multiple kernels there is better
parallelism:
Kernels | 1 worker | 8 workers
1 | 7.7 s | 7.7 s
2 | 15.9 s | 8.2 s
4 | 33.4 s | 9.7 s
6 | 52.3 s | 10.2 s

## Test Plan

145 ephemeral unit tests have been added to test basic functionality.
All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7
C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference
validation for forward, backward-data, and backward-weight (2D) in both
C++ and Python examples pass.

## Test Result

30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56),
53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002
for all directions (fp16 vs fp32 reference).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-09 17:39:35 +00:00

1604 lines
58 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Build example kernels - generates and compiles kernels for a single example.
Detects if example is GEMM or Conv based on macro presence, extracts all
configuration parameters, and generates appropriate kernels.
"""
import argparse
import os
import re
import shutil
import subprocess
import sys
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, List, Tuple
def find_hipcc() -> str:
for path in [os.environ.get("HIPCC"), "/opt/rocm/bin/hipcc", shutil.which("hipcc")]:
if path and os.path.isfile(path):
return path
return "hipcc"
def find_ar() -> str:
for path in [
"/opt/rocm/llvm/bin/llvm-ar",
shutil.which("llvm-ar"),
shutil.which("ar"),
]:
if path and os.path.isfile(path):
return path
return "ar"
def extract_balanced_parens(text: str, start_pos: int) -> str:
"""Extract content between balanced parentheses."""
if start_pos >= len(text) or text[start_pos] != "(":
return ""
depth = 0
for i, c in enumerate(text[start_pos:], start_pos):
if c == "(":
depth += 1
elif c == ")":
depth -= 1
if depth == 0:
return text[start_pos + 1 : i]
return ""
def parse_conv_declarations(content: str) -> List[Dict]:
"""Parse DECL_GROUPED_CONV_KERNEL_SET declarations with all parameters."""
kernels = []
for match in re.finditer(r"DECL_GROUPED_CONV_KERNEL_SET\s*\(", content):
body = extract_balanced_parens(content, match.end() - 1)
if not body:
continue
# Parse each .add() call
for add_match in re.finditer(r"\.add\s*\(", body):
add_body = extract_balanced_parens(body, add_match.end() - 1)
kernel = {}
# ConvSig parameters - handle both single dtype and multi-dtype
# Multi-dtype: .dtype("fp16", "fp16", "fp16", "fp32") or .dtype("fp16", "bf16", "fp16")
if m := re.search(
r'\.dtype\s*\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"(?:\s*,\s*"([^"]+)")?\s*\)',
add_body,
):
kernel["dtype_in"] = m.group(1)
kernel["dtype_wei"] = m.group(2)
kernel["dtype_out"] = m.group(3)
kernel["dtype_acc"] = m.group(4) if m.group(4) else "fp32"
kernel["dtype"] = m.group(1) # Default for codegen
# Single dtype: .dtype("fp16")
elif m := re.search(r'\.dtype\s*\(\s*"([^"]+)"\s*\)', add_body):
kernel["dtype"] = m.group(1)
kernel["dtype_in"] = m.group(1)
kernel["dtype_wei"] = m.group(1)
kernel["dtype_out"] = m.group(1)
kernel["dtype_acc"] = "fp32"
if m := re.search(r'\.layout\s*\(\s*"([^"]+)"', add_body):
kernel["layout"] = m.group(1)
if m := re.search(r'\.conv_type\s*\(\s*"([^"]+)"', add_body):
kernel["conv_type"] = m.group(1)
if m := re.search(r"\.dims\s*\(\s*(\d+)\s*\)", add_body):
kernel["ndim"] = int(m.group(1))
# ConvAlgo parameters - tile(G, M, N) where G=batch, M=output, N=reduction
if m := re.search(
r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body
):
kernel["tile_g"] = int(m.group(1)) # batch tile (usually 1)
kernel["tile_m"] = int(m.group(2)) # output channel tile
kernel["tile_n"] = int(m.group(3)) # input channel tile (reduction)
# wave(M_Warp, N_Warp, K_Warp) - warp distribution
if m := re.search(
r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body
):
kernel["warp_m"] = int(m.group(1))
kernel["warp_n"] = int(m.group(2))
kernel["warp_k"] = int(m.group(3))
# warp(M_Warp_Tile, N_Warp_Tile, K_Warp_Tile) - warp tile sizes
if m := re.search(
r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body
):
kernel["warp_tile_m"] = int(m.group(1))
kernel["warp_tile_n"] = int(m.group(2))
kernel["warp_tile_k"] = int(m.group(3))
# vector_sizes(A, B, C)
if m := re.search(
r"\.vector_sizes\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body
):
kernel["vector_a"] = int(m.group(1))
kernel["vector_b"] = int(m.group(2))
kernel["vector_c"] = int(m.group(3))
# Single-value parameters
if m := re.search(r'\.pipeline\s*\(\s*"([^"]+)"', add_body):
kernel["pipeline"] = m.group(1)
if m := re.search(r'\.scheduler\s*\(\s*"([^"]+)"', add_body):
kernel["scheduler"] = m.group(1)
if m := re.search(r'\.epilogue\s*\(\s*"([^"]+)"', add_body):
kernel["epilogue"] = m.group(1)
if m := re.search(r"\.block_per_cu\s*\(\s*(\d+)\s*\)", add_body):
kernel["block_per_cu"] = int(m.group(1))
if m := re.search(r"\.num_wave_groups\s*\(\s*(\d+)\s*\)", add_body):
kernel["num_wave_groups"] = int(m.group(1))
if m := re.search(r"\.num_groups_to_merge\s*\(\s*(\d+)\s*\)", add_body):
kernel["num_groups_to_merge"] = int(m.group(1))
if m := re.search(
r"\.double_smem_buffer\s*\(\s*(true|false)\s*\)", add_body, re.I
):
kernel["double_smem_buffer"] = m.group(1).lower() == "true"
# Architecture
if m := re.search(r'"(gfx\d+)"', add_body):
kernel["arch"] = m.group(1)
if kernel.get("dtype"):
# Auto-fill missing parameters with defaults (autocorrect)
kernel = auto_fill_conv_defaults(kernel)
kernels.append(kernel)
return kernels
def auto_fill_conv_defaults(kernel: Dict) -> Dict:
"""Auto-fill missing conv parameters with sensible defaults (autofill + autocorrect).
This implements:
1. AUTOFILL: Missing parameters are filled with valid defaults (ConvConfigComputeV3)
2. AUTOCORRECT: Invalid values are corrected to valid ones
"""
# Default tile configuration matching ConvConfigComputeV3
defaults = {
"tile_g": 1,
"tile_m": 16,
"tile_n": 64,
"warp_m": 1,
"warp_n": 4,
"warp_k": 1,
"warp_tile_m": 16,
"warp_tile_n": 16,
"warp_tile_k": 32,
"pipeline": "compv3",
"scheduler": "intrawave",
"epilogue": "cshuffle",
"vector_a": 4,
"vector_b": 8,
"vector_c": 8,
"block_per_cu": 1,
"num_wave_groups": 1,
"num_groups_to_merge": 1,
"ndim": 2,
"layout": "nhwgc",
"conv_type": "forward",
"arch": "gfx942",
}
# AUTOFILL: Fill missing parameters with defaults
autofilled = []
for key, value in defaults.items():
if key not in kernel or kernel[key] is None or kernel[key] == -1:
kernel[key] = value
autofilled.append(f"{key}={value}")
if autofilled:
print(f" [AUTOFILL] {', '.join(autofilled)}")
# AUTOCORRECT: Fix invalid wave configurations for gfx942
valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
current_wave = (
kernel.get("warp_m", 1),
kernel.get("warp_n", 4),
kernel.get("warp_k", 1),
)
if current_wave not in valid_wave_configs:
old = current_wave
kernel["warp_m"] = 1
kernel["warp_n"] = 4
kernel["warp_k"] = 1
print(f" [AUTOCORRECT] wave{old} -> wave(1,4,1) (invalid for gfx942)")
# AUTOCORRECT: Fix invalid pipeline for backward ops
conv_type = kernel.get("conv_type", "forward")
pipeline = kernel.get("pipeline", "compv3")
if conv_type in ["bwd_data", "bwd_weight"] and pipeline in ["compv4", "compv5"]:
old_pipeline = pipeline
kernel["pipeline"] = "compv3"
print(
f" [AUTOCORRECT] pipeline {old_pipeline} -> compv3 (invalid for {conv_type})"
)
return kernel
def expand_conv_wildcards(kernel: Dict, arch: str = "gfx942") -> List[Dict]:
"""Expand wildcard parameters to multiple valid configurations.
When users specify wildcards (-1 or *), this expands them to all
valid configurations for the target architecture.
"""
expanded = []
# Valid wave configurations for gfx942
valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
# Valid warp tile configurations for gfx942 fp16
valid_warp_configs = [(16, 16, 32), (32, 32, 16)]
# Check if expansion is needed
needs_wave = kernel.get("warp_m") is None or kernel.get("warp_m") == -1
needs_warp = kernel.get("warp_tile_m") is None or kernel.get("warp_tile_m") == -1
if not needs_wave and not needs_warp:
return [kernel]
# Expand wave configurations
wave_configs = (
valid_wave_configs
if needs_wave
else [
(kernel.get("warp_m", 2), kernel.get("warp_n", 2), kernel.get("warp_k", 1))
]
)
# Expand warp tile configurations
warp_configs = (
valid_warp_configs
if needs_warp
else [
(
kernel.get("warp_tile_m", 32),
kernel.get("warp_tile_n", 32),
kernel.get("warp_tile_k", 16),
)
]
)
for wm, wn, wk in wave_configs:
for wtm, wtn, wtk in warp_configs:
new_kernel = kernel.copy()
new_kernel["warp_m"] = wm
new_kernel["warp_n"] = wn
new_kernel["warp_k"] = wk
new_kernel["warp_tile_m"] = wtm
new_kernel["warp_tile_n"] = wtn
new_kernel["warp_tile_k"] = wtk
expanded.append(new_kernel)
return expanded
def parse_int_or_wildcard(val: str) -> int:
"""Parse integer or return -1 for wildcards.
Supported wildcard formats:
- ANY_INT: Macro defined as -1
- -1: Direct numeric wildcard
- "*": String wildcard (also maps to -1 for integer params)
"""
val = val.strip()
if val == "ANY_INT" or val == "-1" or val == "*":
return -1
return int(val)
def parse_gemm_declarations(content: str) -> List[Dict]:
"""Parse DECL_KERNEL_SET declarations for GEMM.
Supports wildcards:
- ANY_INT for numeric params (wave, warp) -> expands to all valid combos
- "*" for string params (pipeline, scheduler) -> expands to valid options
Each kernel is tagged with its kernel_set name for separate registration.
"""
kernels = []
for match in re.finditer(r"DECL_KERNEL_SET\s*\(\s*(\w+)\s*,", content):
kernel_set_name = match.group(1)
body = extract_balanced_parens(
content, match.start() + content[match.start() :].find("(")
)
if not body:
continue
for add_match in re.finditer(r"\.add\s*\(", body):
add_body = extract_balanced_parens(body, add_match.end() - 1)
kernel = {}
# Signature parameters
if m := re.search(r'\.dtype\s*\(\s*"([^"]+)"', add_body):
kernel["dtype"] = m.group(1)
if m := re.search(r'\.layout\s*\(\s*"([^"]+)"', add_body):
kernel["layout"] = m.group(1)
if m := re.search(r'\.elementwise\s*\(\s*"([^"]+)"\s*,\s*(\d+)', add_body):
kernel["elementwise_op"] = m.group(1)
kernel["num_d_tensors"] = int(m.group(2))
# Algorithm parameters - support ANY_INT wildcard
if m := re.search(
r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body
):
kernel["tile_m"] = int(m.group(1))
kernel["tile_n"] = int(m.group(2))
kernel["tile_k"] = int(m.group(3))
# Wave: support ANY_INT, -1, and "*" as wildcards
if m := re.search(
r"\.wave\s*\(\s*([\w*-]+)\s*,\s*([\w*-]+)\s*,\s*([\w*-]+)\s*\)",
add_body,
):
kernel["warp_m"] = parse_int_or_wildcard(m.group(1))
kernel["warp_n"] = parse_int_or_wildcard(m.group(2))
kernel["warp_k"] = parse_int_or_wildcard(m.group(3))
# Warp: support ANY_INT, -1, and "*" as wildcards
if m := re.search(
r"\.warp\s*\(\s*([\w*-]+)\s*,\s*([\w*-]+)\s*,\s*([\w*-]+)\s*\)",
add_body,
):
kernel["warp_tile_m"] = parse_int_or_wildcard(m.group(1))
kernel["warp_tile_n"] = parse_int_or_wildcard(m.group(2))
kernel["warp_tile_k"] = parse_int_or_wildcard(m.group(3))
# Pipeline/Scheduler: support "*" wildcard
if m := re.search(r'\.pipeline\s*\(\s*"([^"]+)"', add_body):
kernel["pipeline"] = m.group(1)
if m := re.search(r'\.scheduler\s*\(\s*"([^"]+)"', add_body):
kernel["scheduler"] = m.group(1)
if m := re.search(r'\.epilogue\s*\(\s*"([^"]+)"', add_body):
kernel["epilogue"] = m.group(1)
if m := re.search(
r"\.pad\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)",
add_body,
re.I,
):
kernel["pad_m"] = m.group(1).lower() == "true"
kernel["pad_n"] = m.group(2).lower() == "true"
kernel["pad_k"] = m.group(3).lower() == "true"
# Shorthand format: .add("dtype", "layout", M, N, K)
if not kernel.get("dtype"):
if m := re.match(
r'\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)',
add_body,
):
kernel["dtype"] = m.group(1)
kernel["layout"] = m.group(2)
kernel["tile_m"] = int(m.group(3))
kernel["tile_n"] = int(m.group(4))
kernel["tile_k"] = int(m.group(5))
if kernel.get("dtype"):
kernel["kernel_set"] = kernel_set_name
kernels.append(kernel)
# Expand wildcards to multiple kernels
expanded = []
for kernel in kernels:
expanded.extend(expand_gemm_wildcards(kernel))
# Apply autocorrect to each expanded kernel
return [auto_fill_gemm_defaults(k) for k in expanded]
def expand_gemm_wildcards(kernel: Dict, arch: str = "gfx942") -> List[Dict]:
"""Expand wildcard parameters to multiple valid configurations.
When users specify ANY_INT (-1) or "*", this expands them to all
valid configurations for the target architecture.
Note: Block size constraint filters invalid combos:
- (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * 64 <= 1024
- For 128x128 tile: only (32,32,k) works (16 warps * 64 = 1024)
- For 64x64 tile: both (16,16,k) and (32,32,k) work
"""
# Valid wave configurations for gfx942
valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
# Valid warp tile configurations for gfx942 fp16
valid_warp_configs = [(16, 16, 32), (32, 32, 16)]
# Valid pipelines and schedulers
valid_pipelines = ["compv3"] # compv4 requires special handling
valid_schedulers = ["intrawave"]
# Check what needs expansion
needs_wave = kernel.get("warp_m") == -1
needs_warp = kernel.get("warp_tile_m") == -1
needs_pipeline = kernel.get("pipeline") == "*"
needs_scheduler = kernel.get("scheduler") == "*"
if not any([needs_wave, needs_warp, needs_pipeline, needs_scheduler]):
return [kernel]
# Determine configs to iterate
wave_configs = (
valid_wave_configs
if needs_wave
else [
(kernel.get("warp_m", 2), kernel.get("warp_n", 2), kernel.get("warp_k", 1))
]
)
warp_configs = (
valid_warp_configs
if needs_warp
else [
(
kernel.get("warp_tile_m", 32),
kernel.get("warp_tile_n", 32),
kernel.get("warp_tile_k", 16),
)
]
)
pipelines = (
valid_pipelines if needs_pipeline else [kernel.get("pipeline", "compv3")]
)
schedulers = (
valid_schedulers if needs_scheduler else [kernel.get("scheduler", "intrawave")]
)
expanded = []
for wm, wn, wk in wave_configs:
for wtm, wtn, wtk in warp_configs:
# Check block size constraint: (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * 64 <= 1024
tile_m = kernel.get("tile_m", 128)
tile_n = kernel.get("tile_n", 128)
num_warps = (tile_m // wtm) * (tile_n // wtn)
if num_warps * 64 > 1024:
continue # Skip invalid config
for pipe in pipelines:
for sched in schedulers:
new_kernel = kernel.copy()
new_kernel["warp_m"] = wm
new_kernel["warp_n"] = wn
new_kernel["warp_k"] = wk
new_kernel["warp_tile_m"] = wtm
new_kernel["warp_tile_n"] = wtn
new_kernel["warp_tile_k"] = wtk
new_kernel["pipeline"] = pipe
new_kernel["scheduler"] = sched
expanded.append(new_kernel)
if expanded:
print(f" [WILDCARD] Expanded 1 declaration -> {len(expanded)} kernel(s)")
return expanded if expanded else [kernel]
def auto_fill_gemm_defaults(kernel: Dict) -> Dict:
"""Auto-fill missing GEMM parameters with sensible defaults (autofill + autocorrect).
This implements:
1. AUTOFILL: Missing parameters are filled with valid defaults
2. AUTOCORRECT: Invalid values are corrected to valid ones (e.g., wave(1,1,1) -> wave(2,2,1))
"""
defaults = {
"tile_m": 128,
"tile_n": 128,
"tile_k": 64,
"warp_m": 2,
"warp_n": 2,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 16,
"pipeline": "compv3",
"scheduler": "intrawave",
"epilogue": "cshuffle",
"pad_m": False,
"pad_n": False,
"pad_k": False,
"layout": "rcr",
}
# AUTOFILL: Fill missing parameters with defaults
autofilled = []
for key, value in defaults.items():
if key not in kernel or kernel[key] is None or kernel[key] == -1:
kernel[key] = value
autofilled.append(f"{key}={value}")
if autofilled:
print(f" [AUTOFILL] {', '.join(autofilled)}")
# AUTOCORRECT: Fix invalid wave configurations for gfx942
# Valid wave configs: (1,4,1), (2,2,1), (4,1,1)
valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
current_wave = (
kernel.get("warp_m", 2),
kernel.get("warp_n", 2),
kernel.get("warp_k", 1),
)
if current_wave not in valid_wave_configs:
# Correct to (2,2,1) which is a balanced default
old = current_wave
kernel["warp_m"] = 2
kernel["warp_n"] = 2
kernel["warp_k"] = 1
print(f" [AUTOCORRECT] wave{old} -> wave(2,2,1) (invalid for gfx942)")
# AUTOCORRECT: Fix invalid pipeline/scheduler combinations
invalid_combos = [
("compv3", "interwave"),
("compv4", "interwave"),
]
current_combo = (
kernel.get("pipeline", "compv3"),
kernel.get("scheduler", "intrawave"),
)
if current_combo in invalid_combos:
old = current_combo
kernel["scheduler"] = "intrawave"
print(
f" [AUTOCORRECT] {old[0]}/{old[1]} -> {old[0]}/intrawave (invalid combo)"
)
# AUTOCORRECT: Fix warp tile to avoid exceeding max block size (1024 threads)
# Block size = (tile_m / warp_tile_m) * (tile_n / warp_tile_n) * 64
tile_m = kernel.get("tile_m", 128)
tile_n = kernel.get("tile_n", 128)
warp_tile_m = kernel.get("warp_tile_m", 32)
warp_tile_n = kernel.get("warp_tile_n", 32)
num_warps = (tile_m // warp_tile_m) * (tile_n // warp_tile_n)
block_size = num_warps * 64 # 64 threads per warp
if block_size > 1024:
# Find valid warp tile that fits
old_warp = (warp_tile_m, warp_tile_n, kernel.get("warp_tile_k", 16))
# For large tiles, use larger warp tiles
if tile_m >= 256:
kernel["warp_tile_m"] = 64
if tile_n >= 256:
kernel["warp_tile_n"] = 64
# Recalculate
num_warps = (tile_m // kernel["warp_tile_m"]) * (
tile_n // kernel["warp_tile_n"]
)
block_size = num_warps * 64
if block_size <= 1024:
new_warp = (
kernel["warp_tile_m"],
kernel["warp_tile_n"],
kernel["warp_tile_k"],
)
print(
f" [AUTOCORRECT] warp{old_warp} -> warp{new_warp} (block_size={block_size})"
)
else:
# Still too large, try even larger warp tiles
kernel["warp_tile_m"] = tile_m // 4
kernel["warp_tile_n"] = tile_n // 4
new_warp = (
kernel["warp_tile_m"],
kernel["warp_tile_n"],
kernel["warp_tile_k"],
)
print(
f" [AUTOCORRECT] warp{old_warp} -> warp{new_warp} (block_size adjusted)"
)
return kernel
def strip_cpp_strings_and_comments(content: str) -> str:
"""Strip C++ string literals and comments that could cause false positives.
Only strips:
- Comments (// and /* */) - always stripped
- Raw string literals (R"...") - always stripped (can contain anything)
- Regular strings ONLY if they contain problematic patterns like DECL_KERNEL_SET
Preserves normal string literals like "fp16", "rcr" which are needed for parsing.
"""
result = []
i = 0
n = len(content)
# Patterns that indicate a string is problematic and should be stripped
problematic_patterns = ["DECL_KERNEL_SET", "DECL_GROUPED_CONV_KERNEL_SET", ".add("]
while i < n:
# Check for raw string literal: R"delimiter(...)delimiter"
# Always strip these as they can contain arbitrary content
if i < n - 1 and content[i] == "R" and content[i + 1] == '"':
# Find the delimiter (between R" and ()
j = i + 2
delimiter_start = j
while j < n and content[j] != "(":
j += 1
delimiter = content[delimiter_start:j]
# Find the closing )delimiter"
end_marker = ")" + delimiter + '"'
end_pos = content.find(end_marker, j + 1)
if end_pos != -1:
# Replace with spaces to preserve line numbers
span = content[i : end_pos + len(end_marker)]
result.append("".join("\n" if c == "\n" else " " for c in span))
i = end_pos + len(end_marker)
continue
# Check for regular string literal - only strip if it contains problematic patterns
if content[i] == '"':
j = i + 1
while j < n:
if content[j] == "\\" and j + 1 < n:
j += 2 # Skip escaped character
elif content[j] == '"':
j += 1
break
else:
j += 1
string_content = content[i:j]
# Only strip if this string contains problematic patterns
should_strip = any(pat in string_content for pat in problematic_patterns)
if should_strip:
result.append(" " * len(string_content))
else:
result.append(string_content)
i = j
continue
# Check for single-line comment - always strip
if i < n - 1 and content[i : i + 2] == "//":
j = i
while j < n and content[j] != "\n":
j += 1
result.append(" " * (j - i))
i = j
continue
# Check for multi-line comment - always strip
if i < n - 1 and content[i : i + 2] == "/*":
end_pos = content.find("*/", i + 2)
if end_pos != -1:
span = content[i : end_pos + 2]
# Preserve newlines in multi-line comments
result.append("".join("\n" if c == "\n" else " " for c in span))
i = end_pos + 2
continue
result.append(content[i])
i += 1
return "".join(result)
def detect_and_parse(source_path: Path) -> Tuple[str, List[Dict]]:
"""Detect example type and parse kernel declarations.
Properly strips string literals and comments before parsing to avoid
picking up declarations inside strings or commented-out code.
"""
content = source_path.read_text()
content = strip_cpp_strings_and_comments(content)
if "DECL_GROUPED_CONV_KERNEL_SET" in content:
return "conv", parse_conv_declarations(content)
elif "DECL_KERNEL_SET" in content:
return "gemm", parse_gemm_declarations(content)
return "unknown", []
def generate_gemm_registration(
kernel_headers: List[Path], example_name: str, kernels: List[Dict] = None
) -> str:
"""Generate GEMM kernel registration code for the dispatcher registry.
Uses GeneratedKernelInstance<SelectedKernel> to wrap the generated kernels
and provide the KernelInstance interface for the Dispatcher.
If kernels list is provided with kernel_set info, generates separate
registration functions per kernel set.
"""
if not kernel_headers:
return " // No kernels to register"
# Build mapping from kernel config pattern to kernel set
kernel_to_set = {}
kernel_sets = set()
if kernels:
for k in kernels:
tile_m = k.get("tile_m", 128)
tile_n = k.get("tile_n", 128)
tile_k = k.get("tile_k", 64)
warp_m = k.get("warp_m", 2)
warp_n = k.get("warp_n", 2)
warp_k = k.get("warp_k", 1)
warp_tile_m = k.get("warp_tile_m", 32)
warp_tile_n = k.get("warp_tile_n", 32)
warp_tile_k = k.get("warp_tile_k", 16)
# Pattern that appears in kernel filename
key_pattern = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
kernel_set = k.get("kernel_set", "default")
kernel_to_set[key_pattern] = kernel_set
kernel_sets.add(kernel_set)
def generate_registration_block(h: Path) -> str:
"""Generate registration code for a single kernel."""
kernel_name = h.stem
ns = f"ns_{kernel_name}"
# Parse pipeline, scheduler, and layout from kernel name
# Format: gemm_fp16_rcr_compv3_cshuffle_intrawave_...
parts = kernel_name.split("_")
pipeline = "CompV3"
scheduler = "Intrawave"
epilogue = "CShuffle"
datatype = "FP16"
layout_a = "RowMajor"
layout_b = "ColMajor"
layout_c = "RowMajor"
# Parse datatype (e.g., fp16, bf16, fp32)
dtype_map = {
"fp16": "FP16",
"bf16": "BF16",
"fp32": "FP32",
"fp64": "FP64",
"int8": "INT8",
}
# Parse layout from 3-char codes (e.g., rcr, rrr, rrc, ccc)
# r = RowMajor, c = ColMajor
layout_map = {"r": "RowMajor", "c": "ColMajor"}
# Find pipeline, epilogue, scheduler in the name parts
pipeline_map = {
"mem": "Mem",
"compv1": "CompV1",
"compv2": "CompV2",
"compv3": "CompV3",
"compv4": "CompV4",
"compv5": "CompV5",
"preshufflev1": "PreShuffleV1",
"preshufflev2": "PreShuffleV2",
}
scheduler_map = {
"intrawave": "Intrawave",
"interwave": "Interwave",
"auto": "Auto",
}
epilogue_map = {"default": "Default", "cshuffle": "CShuffle", "none": "None"}
for part in parts:
if part in pipeline_map:
pipeline = pipeline_map[part]
if part in scheduler_map:
scheduler = scheduler_map[part]
if part in epilogue_map:
epilogue = epilogue_map[part]
if part in dtype_map:
datatype = dtype_map[part]
# Parse 3-char layout codes (e.g., rcr, rrr)
if len(part) == 3 and all(c in "rc" for c in part):
layout_a = layout_map[part[0]]
layout_b = layout_map[part[1]]
layout_c = layout_map[part[2]]
block = []
block.append(f" // Register kernel: {kernel_name}")
block.append(" {")
block.append(f" using SelectedKernel = {ns}::SelectedKernel;")
block.append(" ck_tile::dispatcher::KernelKey key;")
block.append(
f" key.signature.dtype_a = ck_tile::dispatcher::DataType::{datatype};"
)
block.append(
f" key.signature.dtype_b = ck_tile::dispatcher::DataType::{datatype};"
)
block.append(
f" key.signature.dtype_c = ck_tile::dispatcher::DataType::{datatype};"
)
block.append(
" key.signature.dtype_acc = ck_tile::dispatcher::DataType::FP32;"
)
block.append(
f" key.signature.layout_a = ck_tile::dispatcher::LayoutTag::{layout_a};"
)
block.append(
f" key.signature.layout_b = ck_tile::dispatcher::LayoutTag::{layout_b};"
)
block.append(
f" key.signature.layout_c = ck_tile::dispatcher::LayoutTag::{layout_c};"
)
block.append(" key.algorithm.tile_shape.m = SelectedKernel::TileM;")
block.append(" key.algorithm.tile_shape.n = SelectedKernel::TileN;")
block.append(" key.algorithm.tile_shape.k = SelectedKernel::TileK;")
block.append(
" key.algorithm.wave_shape.m = SelectedKernel::WarpPerBlock_M;"
)
block.append(
" key.algorithm.wave_shape.n = SelectedKernel::WarpPerBlock_N;"
)
block.append(
" key.algorithm.wave_shape.k = SelectedKernel::WarpPerBlock_K;"
)
block.append(
" key.algorithm.warp_tile_shape.m = SelectedKernel::WarpTileM;"
)
block.append(
" key.algorithm.warp_tile_shape.n = SelectedKernel::WarpTileN;"
)
block.append(
" key.algorithm.warp_tile_shape.k = SelectedKernel::WarpTileK;"
)
block.append(
" key.algorithm.block_size = SelectedKernel::BlockSize;"
)
block.append(
f" key.algorithm.pipeline = ck_tile::dispatcher::Pipeline::{pipeline};"
)
block.append(
f" key.algorithm.scheduler = ck_tile::dispatcher::Scheduler::{scheduler};"
)
block.append(
f" key.algorithm.epilogue = ck_tile::dispatcher::Epilogue::{epilogue};"
)
block.append(" key.gfx_arch = arch;")
block.append(
f' auto instance = std::make_shared<ck_tile::dispatcher::backends::GeneratedKernelInstance<SelectedKernel>>(key, "{kernel_name}");'
)
block.append(" registry.register_kernel(instance);")
block.append(" }")
return "\n".join(block)
def find_kernel_set(header: Path) -> str:
"""Find which kernel set a header belongs to."""
name = header.stem
for pattern, kset in kernel_to_set.items():
if pattern in name:
return kset
return "default"
# Group kernels by set
kernels_by_set = {}
for h in kernel_headers:
kset = find_kernel_set(h)
if kset not in kernels_by_set:
kernels_by_set[kset] = []
kernels_by_set[kset].append(h)
# If only one set or no set info, use simple registration
if len(kernels_by_set) <= 1:
lines = [" (void)arch;", ""]
for h in kernel_headers:
lines.append(generate_registration_block(h))
return "\n".join(lines)
# Multiple sets - generate registration for all, plus store per-set info
lines = [" // Register ALL kernels from all sets", " (void)arch;", ""]
for h in kernel_headers:
lines.append(generate_registration_block(h))
# Store per-set mapping for separate function generation
global _kernels_by_set_cache
_kernels_by_set_cache = (kernels_by_set, generate_registration_block)
return "\n".join(lines)
# Global cache for per-set kernel info
_kernels_by_set_cache = None
def generate_per_set_functions(source_stem: str) -> str:
"""Generate separate registration functions for each kernel set.
Generates:
1. Per-set functions: register_<set_name>(registry, arch)
2. String-based dispatcher: register_kernel_set("set_name", registry, arch)
3. get_kernel_set_names() to list available sets
"""
global _kernels_by_set_cache
if not _kernels_by_set_cache:
return ""
kernels_by_set, gen_block = _kernels_by_set_cache
_kernels_by_set_cache = None # Clear cache
lines = []
set_names = []
# Generate per-set functions
for set_name, headers in kernels_by_set.items():
safe_name = set_name.replace("-", "_")
set_names.append((set_name, safe_name))
lines.append(
f"inline void register_{safe_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{"
)
lines.append(" (void)arch;")
for h in headers:
lines.append(gen_block(h))
lines.append("}")
lines.append("")
# Generate string-based dispatcher (only if multiple sets)
if len(set_names) > 0:
lines.append("// Dynamic registration by kernel set name")
lines.append(
"inline bool register_kernel_set(const std::string& set_name, ck_tile::dispatcher::Registry& registry, const std::string& arch) {"
)
for set_name, safe_name in set_names:
lines.append(
f' if (set_name == "{set_name}") {{ register_{safe_name}(registry, arch); return true; }}'
)
lines.append(" return false; // Unknown set name")
lines.append("}")
lines.append("")
# Generate helper to list available set names
lines.append("// Get list of available kernel set names")
lines.append("inline std::vector<std::string> get_kernel_set_names() {")
names_str = ", ".join(f'"{name}"' for name, _ in set_names)
lines.append(f" return {{{names_str}}};")
lines.append("}")
lines.append("")
return "\n".join(lines)
def generate_conv_registration(
kernel_headers: List[Path], example_name: str, kernels: List[Dict]
) -> str:
"""Generate Conv kernel registration code for the dispatcher registry.
Creates real GroupedConvKernelInstance entries backed by the generated
launcher's launch() method via the conv backend RunFn factories.
"""
if not kernel_headers:
return " // No kernels to register"
lines = []
for i, h in enumerate(kernel_headers):
kname = h.stem
ns = f"ns_{kname}"
launcher = f"{ns}::{kname}_Launcher"
# Determine direction and ndim from the kernel header name
if "_fwd_" in kname:
direction = "Forward"
run_fn_factory = "make_conv_fwd_run_fn"
elif "_bwd_data_" in kname or "_bwdd_" in kname:
direction = "BackwardData"
run_fn_factory = "make_conv_bwd_data_run_fn"
elif "_bwd_weight_" in kname or "_bwdw_" in kname:
direction = "BackwardWeight"
run_fn_factory = "make_conv_bwd_weight_run_fn"
else:
direction = "Forward"
run_fn_factory = "make_conv_fwd_run_fn"
ndim = 3 if "_3d_" in kname else 2
# Parse dtype from name (e.g. grouped_conv_fwd_fp16_...)
dtype = "fp16"
for dt in ["fp16", "bf16", "fp32"]:
if f"_{dt}_" in kname:
dtype = dt
break
# Parse tile, wave, warp from name.
# Format: ..._TILExTILExTILE_WAVExWAVExWAVE_WARPxWARPxWARP_...
import re as _re
tile_m, tile_n, tile_k = 1, 128, 128
wave_m, wave_n, wave_k = 2, 2, 1
warp_m, warp_n, warp_k = 32, 32, 16
triplets = _re.findall(r"_(\d+)x(\d+)x(\d+)", kname)
if len(triplets) >= 1:
tile_m, tile_n, tile_k = (
int(triplets[0][0]),
int(triplets[0][1]),
int(triplets[0][2]),
)
if len(triplets) >= 2:
wave_m, wave_n, wave_k = (
int(triplets[1][0]),
int(triplets[1][1]),
int(triplets[1][2]),
)
if len(triplets) >= 3:
warp_m, warp_n, warp_k = (
int(triplets[2][0]),
int(triplets[2][1]),
int(triplets[2][2]),
)
pipeline = "compv4" if "compv4" in kname else "compv3"
scheduler = "interwave" if "interwave" in kname else "intrawave"
epilogue = "cshuffle" if "cshuffle" in kname else "default"
# ConvConfigBase defaults
vec_a, vec_b, vec_c = 4, 8, 8
block_per_cu = 1
num_wave_groups = 1
num_groups_to_merge = 1
lines.append(f" // Kernel {i + 1}: {kname}")
lines.append(" {")
lines.append(f" ck_tile::dispatcher::GroupedConvKernelKey key_{i};")
lines.append(f' key_{i}.dtype_in = "{dtype}";')
lines.append(f' key_{i}.dtype_wei = "{dtype}";')
lines.append(f' key_{i}.dtype_out = "{dtype}";')
lines.append(f' key_{i}.layout = "nhwgc";')
lines.append(f" key_{i}.ndim_spatial = {ndim};")
lines.append(
f" key_{i}.op = ck_tile::dispatcher::GroupedConvOp::{direction};"
)
lines.append(f" key_{i}.tile_m = {tile_m};")
lines.append(f" key_{i}.tile_n = {tile_n};")
lines.append(f" key_{i}.tile_k = {tile_k};")
lines.append(f" key_{i}.wave_m = {wave_m};")
lines.append(f" key_{i}.wave_n = {wave_n};")
lines.append(f" key_{i}.wave_k = {wave_k};")
lines.append(f" key_{i}.warp_m = {warp_m};")
lines.append(f" key_{i}.warp_n = {warp_n};")
lines.append(f" key_{i}.warp_k = {warp_k};")
lines.append(f' key_{i}.pipeline = "{pipeline}";')
lines.append(f' key_{i}.scheduler = "{scheduler}";')
lines.append(f' key_{i}.epilogue = "{epilogue}";')
lines.append(f" key_{i}.vector_size_a = {vec_a};")
lines.append(f" key_{i}.vector_size_b = {vec_b};")
lines.append(f" key_{i}.vector_size_c = {vec_c};")
lines.append(f" key_{i}.block_per_cu = {block_per_cu};")
lines.append(f" key_{i}.num_wave_groups = {num_wave_groups};")
lines.append(f" key_{i}.num_groups_to_merge = {num_groups_to_merge};")
lines.append(f" key_{i}.arch = arch;")
lines.append(
f" auto run_fn_{i} = ck_tile::dispatcher::backends::{run_fn_factory}<{launcher}, {ndim}>();"
)
lines.append(
f' auto inst_{i} = std::make_shared<ck_tile::dispatcher::GroupedConvKernelInstance>(key_{i}, "{kname}", std::move(run_fn_{i}));'
)
lines.append(f" registry.register_kernel(key_{i}, inst_{i});")
lines.append(" }")
return "\n".join(lines)
def _build_conv_codegen_cmd(
idx: int, k: Dict, codegen_dir: Path, output_dir: Path
) -> Tuple[int, List[str], str]:
"""Build the command for a single conv kernel codegen invocation."""
variant_map = {
"forward": "forward",
"bwd_data": "bwd_data",
"backward_data": "bwd_data",
"bwd_weight": "bwd_weight",
"backward_weight": "bwd_weight",
}
variant = variant_map.get(k.get("conv_type", "forward"), "forward")
cmd = [
sys.executable,
str(codegen_dir / "unified_grouped_conv_codegen.py"),
"--datatype",
k.get("dtype", "fp16"),
"--variant",
variant,
"--ndim",
str(k.get("ndim", 2)),
"--output",
str(output_dir),
]
if k.get("tile_m"):
cmd.extend(["--tile-m", str(k["tile_m"])])
if k.get("tile_n"):
cmd.extend(["--tile-n", str(k["tile_n"])])
if k.get("warp_m"):
cmd.extend(["--warp-m", str(k["warp_m"])])
if k.get("warp_n"):
cmd.extend(["--warp-n", str(k["warp_n"])])
if k.get("warp_k"):
cmd.extend(["--warp-k", str(k["warp_k"])])
if k.get("warp_tile_m"):
cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])])
if k.get("warp_tile_n"):
cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])])
if k.get("warp_tile_k"):
cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])])
if k.get("pipeline"):
cmd.extend(["--pipeline", k["pipeline"]])
if k.get("scheduler"):
cmd.extend(["--scheduler", k["scheduler"]])
if k.get("epilogue"):
cmd.extend(["--epilogue", k["epilogue"]])
if k.get("vector_a"):
cmd.extend(["--vector-a", str(k["vector_a"])])
if k.get("vector_b"):
cmd.extend(["--vector-b", str(k["vector_b"])])
if k.get("vector_c"):
cmd.extend(["--vector-c", str(k["vector_c"])])
if k.get("block_per_cu"):
cmd.extend(["--block-per-cu", str(k["block_per_cu"])])
if k.get("num_wave_groups"):
cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])])
if k.get("num_groups_to_merge"):
cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])])
if k.get("double_smem_buffer") is not None:
cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()])
if k.get("tile_k"):
cmd.extend(["--tile-k", str(k["tile_k"])])
return (idx, cmd, str(codegen_dir))
def _run_conv_codegen(args: Tuple) -> Tuple[int, bool, str]:
"""Run unified_grouped_conv_codegen.py for a single kernel config (picklable for ProcessPoolExecutor)."""
idx, cmd, cwd = args
result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd)
if result.returncode != 0:
return (idx, False, result.stderr[:300])
return (idx, True, "")
def generate_conv_kernels(
kernels: List[Dict], output_dir: Path, codegen_dir: Path
) -> bool:
"""Generate Conv kernels for ALL declarations using unified codegen.
Launches all codegen subprocesses in parallel via ProcessPoolExecutor
for significantly faster generation when multiple conv kernels are declared.
"""
if not kernels:
return False
work_items = [
_build_conv_codegen_cmd(idx, k, codegen_dir, output_dir)
for idx, k in enumerate(kernels)
]
success_count = 0
max_workers = min(len(work_items), os.cpu_count() or 4)
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(_run_conv_codegen, w): w[0] for w in work_items}
for future in as_completed(futures):
idx, ok, err = future.result()
if ok:
success_count += 1
else:
print(f" Codegen error for kernel {idx + 1}: {err}")
return success_count > 0
def _run_gemm_codegen(args: Tuple) -> Tuple[int, bool, str]:
"""Run unified_gemm_codegen.py for a single kernel config (picklable for ProcessPoolExecutor)."""
idx, cmd, cwd = args
result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd)
if result.returncode != 0:
return (idx, False, result.stderr[:300])
return (idx, True, "")
def generate_gemm_kernels(
kernels: List[Dict], output_dir: Path, codegen_dir: Path
) -> bool:
"""Generate GEMM kernels for ALL declarations using unified codegen.
Launches all codegen subprocesses in parallel via ProcessPoolExecutor
for significantly faster generation when multiple kernels are declared.
"""
import json
if not kernels:
return False
# Build all commands upfront
work_items = []
for idx, k in enumerate(kernels):
variant = "multi_d" if k.get("elementwise_op") else "standard"
tile_config = {
"tile_m": [k.get("tile_m", 128)],
"tile_n": [k.get("tile_n", 128)],
"tile_k": [k.get("tile_k", 32)],
"warp_m": [k.get("warp_m", 2)],
"warp_n": [k.get("warp_n", 2)],
"warp_k": [k.get("warp_k", 1)],
"warp_tile_m": [k.get("warp_tile_m", 32)],
"warp_tile_n": [k.get("warp_tile_n", 32)],
"warp_tile_k": [k.get("warp_tile_k", 16)],
}
trait_config = {
"pipeline": [k.get("pipeline", "compv3")],
"epilogue": [k.get("epilogue", "cshuffle")],
"scheduler": [k.get("scheduler", "intrawave")],
"pad_m": [k.get("pad_m", False)],
"pad_n": [k.get("pad_n", False)],
"pad_k": [k.get("pad_k", False)],
"persistent": [False],
}
config_json = json.dumps(
{"tile_config": tile_config, "trait_config": trait_config}
)
cmd = [
sys.executable,
str(codegen_dir / "unified_gemm_codegen.py"),
"--datatype",
k.get("dtype", "fp16"),
"--layout",
k.get("layout", "rcr"),
"--variants",
variant,
"--output",
str(output_dir),
"--tile-config-json",
config_json,
]
work_items.append((idx, cmd, str(codegen_dir)))
# Run all codegen subprocesses in parallel
success_count = 0
max_workers = min(len(work_items), os.cpu_count() or 4)
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(_run_gemm_codegen, w): w[0] for w in work_items}
for future in as_completed(futures):
idx, ok, err = future.result()
if ok:
success_count += 1
else:
print(f" Codegen error for kernel {idx + 1}: {err}")
return success_count > 0
def compile_kernel(args: Tuple) -> Tuple[str, bool, str]:
"""Compile a single kernel to object file."""
kernel_hpp, output_dir, include_dirs, hipcc, gpu_target, idx, total = args
kernel_name = kernel_hpp.stem
wrapper_cpp = output_dir / f"{kernel_name}.cpp"
wrapper_cpp.write_text(
f'#include "{kernel_hpp.name}"\nnamespace {{ volatile bool _k{idx} = true; }}\n'
)
obj_file = output_dir / f"{kernel_name}.o"
cmd = [
hipcc,
"-c",
"-fPIC",
"-std=c++17",
"-O3",
f"--offload-arch={gpu_target}",
"-mllvm",
"-enable-noalias-to-md-conversion=0",
"-Wno-undefined-func-template",
"-Wno-float-equal",
"--offload-compress",
]
for inc_dir in include_dirs:
cmd.extend(["-I", str(inc_dir)])
cmd.extend(["-I", str(kernel_hpp.parent)])
cmd.extend(["-o", str(obj_file), str(wrapper_cpp)])
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
return (kernel_name, False, result.stderr[:500])
return (kernel_name, True, str(obj_file))
def main():
parser = argparse.ArgumentParser(description="Build example kernels")
parser.add_argument("source", type=Path, help="C++ source file")
parser.add_argument("--output-dir", type=Path, required=True)
parser.add_argument("--include-dirs", type=str, required=True)
parser.add_argument("--gpu-target", type=str, default="gfx942")
parser.add_argument("--jobs", type=int, default=os.cpu_count())
parser.add_argument(
"--target-name", type=str, help="CMake target name (for library naming)"
)
args = parser.parse_args()
script_dir = Path(__file__).parent
codegen_dir = script_dir.parent / "codegen"
source_stem = args.source.stem # e.g., "01_basic_gemm"
target_name = args.target_name or source_stem # e.g., "gemm_01_basic" from CMake
args.output_dir.mkdir(parents=True, exist_ok=True)
# Detect and parse
example_type, kernels = detect_and_parse(args.source)
if example_type == "conv":
k = kernels[0] if kernels else {}
variant = k.get("conv_type", "forward")
print(
f"[{target_name}] Conv {k.get('dtype', 'fp16')} {variant} {k.get('ndim', 2)}D ({len(kernels)} declarations)"
)
elif example_type == "gemm":
k = kernels[0] if kernels else {}
print(
f"[{target_name}] GEMM {k.get('dtype', 'fp16')} {k.get('layout', 'rcr')} ({len(kernels)} declarations)"
)
else:
print(f"[{target_name}] No kernel declarations - creating empty library")
lib_path = args.output_dir / f"lib{target_name}_kernels.a"
subprocess.run([find_ar(), "rcs", str(lib_path)], check=True)
header = args.output_dir / f"{source_stem}_kernels.hpp"
header.write_text(f"// No kernels for {target_name}\n#pragma once\n")
return 0
# Generate kernels
print(f"[{target_name}] Generating kernels...")
if example_type == "conv":
success = generate_conv_kernels(kernels, args.output_dir, codegen_dir)
else:
success = generate_gemm_kernels(kernels, args.output_dir, codegen_dir)
if not success:
print(f"[{target_name}] Kernel generation failed!")
return 1
# Find generated headers
if example_type == "gemm":
kernel_headers = list(args.output_dir.glob("gemm_*.hpp"))
else:
prefix_map = {
"forward": "grouped_conv_fwd",
"bwd_data": "grouped_conv_bwd_data",
"bwd_weight": "grouped_conv_bwd_weight",
}
# Collect headers from ALL variants present in declarations
variants_used = set(k.get("conv_type", "forward") for k in kernels)
kernel_headers = []
for variant in variants_used:
prefix = prefix_map.get(variant, "grouped_conv_fwd")
kernel_headers.extend(args.output_dir.glob(f"{prefix}_*.hpp"))
if not kernel_headers:
print(f"[{target_name}] No kernel headers generated!")
return 1
print(f"[{target_name}] Compiling {len(kernel_headers)} kernels...")
include_dirs = [Path(p.strip()) for p in args.include_dirs.split(",")]
hipcc = find_hipcc()
work = [
(
h,
args.output_dir,
include_dirs,
hipcc,
args.gpu_target,
i + 1,
len(kernel_headers),
)
for i, h in enumerate(kernel_headers)
]
obj_files = []
failed = []
with ProcessPoolExecutor(max_workers=args.jobs) as executor:
futures = {executor.submit(compile_kernel, w): w[0].name for w in work}
for future in as_completed(futures):
name, ok, result = future.result()
if ok:
obj_files.append(result)
else:
failed.append((name, result))
print(f"[{target_name}] FAILED: {name}")
if failed:
print(f"[{target_name}] {len(failed)} kernels failed")
for name, err in failed[:3]:
print(f" {name}: {err[:200]}")
return 1
# Create static library (use target_name for CMake compatibility)
lib_path = args.output_dir / f"lib{target_name}_kernels.a"
subprocess.run([find_ar(), "rcs", str(lib_path)] + obj_files, check=True)
# Generate registration header (use source_stem for header name to match CMake's EXAMPLE_STEM)
header_path = args.output_dir / f"{source_stem}_kernels.hpp"
# Build includes
includes = "\n".join(f'#include "{h.name}"' for h in kernel_headers)
# Build kernel registration entries
# Function name uses source_stem (e.g., register_01_basic_gemm_kernels)
func_name = f"register_{source_stem}_kernels"
# Generate registration code based on example type
if example_type == "gemm":
register_body = generate_gemm_registration(kernel_headers, target_name, kernels)
else:
register_body = generate_conv_registration(kernel_headers, target_name, kernels)
# Generate appropriate header based on example type
if example_type == "conv" and kernel_headers:
launcher_aliases = []
# Helper to find kernel by dtype and type
def find_kernel_by_dtype_type(headers, dtype, conv_type_marker):
"""Find kernel matching dtype and conv type, prioritize fp16."""
matching = [h for h in headers if conv_type_marker in h.stem]
# Prefer fp16 over bf16 for default launchers
fp16_kernels = [h for h in matching if f"_{dtype}_" in h.stem]
return (
fp16_kernels[0] if fp16_kernels else (matching[0] if matching else None)
)
# Check what conv types are in the declarations
has_fwd = any("forward" in k.get("conv_type", "forward") for k in kernels)
has_bwd_data = any("bwd_data" in k.get("conv_type", "") for k in kernels)
has_bwd_weight = any("bwd_weight" in k.get("conv_type", "") for k in kernels)
# Export dtype-specific launcher aliases for each available dtype
for dtype in ["fp16", "bf16", "fp32"]:
dtype_fwd_kernels = [
h
for h in kernel_headers
if "_fwd_" in h.stem and f"_{dtype}_" in h.stem
]
if dtype_fwd_kernels:
k = dtype_fwd_kernels[0]
ns = f"ns_{k.stem}"
dtype_upper = dtype.upper()
launcher_aliases.append(
f"using {dtype_upper}FwdKernelLauncher = {ns}::{k.stem}_Launcher;"
)
# Export generic launcher aliases (prioritize fp16)
if has_fwd:
fwd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_fwd_")
if fwd_kernel:
fwd_ns = f"ns_{fwd_kernel.stem}"
launcher_aliases.append(
f"using FwdKernelLauncher = {fwd_ns}::{fwd_kernel.stem}_Launcher;"
)
launcher_aliases.append(
f"using FirstKernelLauncher = {fwd_ns}::{fwd_kernel.stem}_Launcher;"
)
if has_bwd_data:
bwd_data_kernel = find_kernel_by_dtype_type(
kernel_headers, "fp16", "_bwd_data_"
)
if not bwd_data_kernel:
bwd_data_kernel = find_kernel_by_dtype_type(
kernel_headers, "fp16", "_bwdd_"
)
if bwd_data_kernel:
bwd_data_ns = f"ns_{bwd_data_kernel.stem}"
launcher_aliases.append(
f"using BwdDataKernelLauncher = {bwd_data_ns}::{bwd_data_kernel.stem}_Launcher;"
)
if not has_fwd:
launcher_aliases.append(
f"using FirstKernelLauncher = {bwd_data_ns}::{bwd_data_kernel.stem}_Launcher;"
)
if has_bwd_weight:
bwd_weight_kernel = find_kernel_by_dtype_type(
kernel_headers, "fp16", "_bwd_weight_"
)
if not bwd_weight_kernel:
bwd_weight_kernel = find_kernel_by_dtype_type(
kernel_headers, "fp16", "_bwdw_"
)
if bwd_weight_kernel:
bwd_weight_ns = f"ns_{bwd_weight_kernel.stem}"
launcher_aliases.append(
f"using BwdWeightKernelLauncher = {bwd_weight_ns}::{bwd_weight_kernel.stem}_Launcher;"
)
if not has_fwd and not has_bwd_data:
launcher_aliases.append(
f"using FirstKernelLauncher = {bwd_weight_ns}::{bwd_weight_kernel.stem}_Launcher;"
)
launcher_section = "\n".join(launcher_aliases)
header_content = f"""// Auto-generated for {target_name}
#pragma once
{includes}
#include "ck_tile/dispatcher/registry.hpp"
#include "ck_tile/dispatcher/kernel_instance.hpp"
#include "ck_tile/dispatcher/kernel_key.hpp"
#include "ck_tile/dispatcher/grouped_conv_registry.hpp"
#include "ck_tile/dispatcher/backends/generated_conv_backend.hpp"
namespace generated {{
// Kernel launchers for direct use
{launcher_section}
// Registration function (takes GroupedConvRegistry for conv kernels)
inline void {func_name}(ck_tile::dispatcher::GroupedConvRegistry& registry, const std::string& arch) {{
{register_body}
}}
}} // namespace generated
// Generic registration - avoids hardcoding the example name in user code
// Safe for single-example executables (typical use case)
#ifndef REGISTER_GENERATED_KERNELS
#define REGISTER_GENERATED_KERNELS(registry, arch) generated::{func_name}(registry, arch)
#endif
"""
else:
# GEMM: Generate per-set functions if multiple kernel sets declared
per_set_funcs = generate_per_set_functions(source_stem)
header_content = f"""// Auto-generated for {target_name}
#pragma once
{includes}
#include "ck_tile/dispatcher/registry.hpp"
#include "ck_tile/dispatcher/kernel_instance.hpp"
#include "ck_tile/dispatcher/kernel_key.hpp"
#include "ck_tile/dispatcher/backends/generated_kernel_backend.hpp"
namespace generated {{
// Register ALL kernels from all declared sets
inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{
{register_body}
}}
{per_set_funcs}
}} // namespace generated
// Generic registration - avoids hardcoding the example name in user code
// Safe for single-example executables (typical use case)
#ifndef REGISTER_GENERATED_KERNELS
#define REGISTER_GENERATED_KERNELS(registry, arch) generated::{func_name}(registry, arch)
#endif
// Register a specific kernel set by name (for multi-registry patterns)
// Usage: REGISTER_KERNEL_SET("compute_bound_set", registry, arch)
#ifndef REGISTER_KERNEL_SET
#define REGISTER_KERNEL_SET(set_name, registry, arch) generated::register_kernel_set(set_name, registry, arch)
#endif
"""
header_path.write_text(header_content)
print(f"[{target_name}] OK {len(obj_files)} kernels compiled")
return 0
if __name__ == "__main__":
sys.exit(main())