Files
composable_kernel/dispatcher/scripts/example_kernel_builder.py
Vidyasagar Ananthan 9e049a32a1 Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.
2026-01-22 09:34:33 -08:00

1448 lines
52 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Build example kernels - generates and compiles kernels for a single example.
Detects if example is GEMM or Conv based on macro presence, extracts all
configuration parameters, and generates appropriate kernels.
"""
import argparse
import os
import re
import shutil
import subprocess
import sys
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, List, Tuple
def find_hipcc() -> str:
for path in [os.environ.get("HIPCC"), "/opt/rocm/bin/hipcc", shutil.which("hipcc")]:
if path and os.path.isfile(path):
return path
return "hipcc"
def find_ar() -> str:
for path in [
"/opt/rocm/llvm/bin/llvm-ar",
shutil.which("llvm-ar"),
shutil.which("ar"),
]:
if path and os.path.isfile(path):
return path
return "ar"
def extract_balanced_parens(text: str, start_pos: int) -> str:
"""Extract content between balanced parentheses."""
if start_pos >= len(text) or text[start_pos] != "(":
return ""
depth = 0
for i, c in enumerate(text[start_pos:], start_pos):
if c == "(":
depth += 1
elif c == ")":
depth -= 1
if depth == 0:
return text[start_pos + 1 : i]
return ""
def parse_conv_declarations(content: str) -> List[Dict]:
"""Parse DECL_CONV_KERNEL_SET declarations with all parameters."""
kernels = []
for match in re.finditer(r"DECL_CONV_KERNEL_SET\s*\(", content):
body = extract_balanced_parens(content, match.end() - 1)
if not body:
continue
# Parse each .add() call
for add_match in re.finditer(r"\.add\s*\(", body):
add_body = extract_balanced_parens(body, add_match.end() - 1)
kernel = {}
# ConvSig parameters - handle both single dtype and multi-dtype
# Multi-dtype: .dtype("fp16", "fp16", "fp16", "fp32") or .dtype("fp16", "bf16", "fp16")
if m := re.search(
r'\.dtype\s*\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"(?:\s*,\s*"([^"]+)")?\s*\)',
add_body,
):
kernel["dtype_in"] = m.group(1)
kernel["dtype_wei"] = m.group(2)
kernel["dtype_out"] = m.group(3)
kernel["dtype_acc"] = m.group(4) if m.group(4) else "fp32"
kernel["dtype"] = m.group(1) # Default for codegen
# Single dtype: .dtype("fp16")
elif m := re.search(r'\.dtype\s*\(\s*"([^"]+)"\s*\)', add_body):
kernel["dtype"] = m.group(1)
kernel["dtype_in"] = m.group(1)
kernel["dtype_wei"] = m.group(1)
kernel["dtype_out"] = m.group(1)
kernel["dtype_acc"] = "fp32"
if m := re.search(r'\.layout\s*\(\s*"([^"]+)"', add_body):
kernel["layout"] = m.group(1)
if m := re.search(r'\.conv_type\s*\(\s*"([^"]+)"', add_body):
kernel["conv_type"] = m.group(1)
if m := re.search(r"\.dims\s*\(\s*(\d+)\s*\)", add_body):
kernel["ndim"] = int(m.group(1))
# ConvAlgo parameters - tile(G, M, N) where G=batch, M=output, N=reduction
if m := re.search(
r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body
):
kernel["tile_g"] = int(m.group(1)) # batch tile (usually 1)
kernel["tile_m"] = int(m.group(2)) # output channel tile
kernel["tile_n"] = int(m.group(3)) # input channel tile (reduction)
# wave(M_Warp, N_Warp, K_Warp) - warp distribution
if m := re.search(
r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body
):
kernel["warp_m"] = int(m.group(1))
kernel["warp_n"] = int(m.group(2))
kernel["warp_k"] = int(m.group(3))
# warp(M_Warp_Tile, N_Warp_Tile, K_Warp_Tile) - warp tile sizes
if m := re.search(
r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body
):
kernel["warp_tile_m"] = int(m.group(1))
kernel["warp_tile_n"] = int(m.group(2))
kernel["warp_tile_k"] = int(m.group(3))
# vector_sizes(A, B, C)
if m := re.search(
r"\.vector_sizes\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body
):
kernel["vector_a"] = int(m.group(1))
kernel["vector_b"] = int(m.group(2))
kernel["vector_c"] = int(m.group(3))
# Single-value parameters
if m := re.search(r'\.pipeline\s*\(\s*"([^"]+)"', add_body):
kernel["pipeline"] = m.group(1)
if m := re.search(r'\.scheduler\s*\(\s*"([^"]+)"', add_body):
kernel["scheduler"] = m.group(1)
if m := re.search(r'\.epilogue\s*\(\s*"([^"]+)"', add_body):
kernel["epilogue"] = m.group(1)
if m := re.search(r"\.block_per_cu\s*\(\s*(\d+)\s*\)", add_body):
kernel["block_per_cu"] = int(m.group(1))
if m := re.search(r"\.num_wave_groups\s*\(\s*(\d+)\s*\)", add_body):
kernel["num_wave_groups"] = int(m.group(1))
if m := re.search(r"\.num_groups_to_merge\s*\(\s*(\d+)\s*\)", add_body):
kernel["num_groups_to_merge"] = int(m.group(1))
if m := re.search(
r"\.double_smem_buffer\s*\(\s*(true|false)\s*\)", add_body, re.I
):
kernel["double_smem_buffer"] = m.group(1).lower() == "true"
# Architecture
if m := re.search(r'"(gfx\d+)"', add_body):
kernel["arch"] = m.group(1)
if kernel.get("dtype"):
# Auto-fill missing parameters with defaults (autocorrect)
kernel = auto_fill_conv_defaults(kernel)
kernels.append(kernel)
return kernels
def auto_fill_conv_defaults(kernel: Dict) -> Dict:
"""Auto-fill missing conv parameters with sensible defaults (autofill + autocorrect).
This implements:
1. AUTOFILL: Missing parameters are filled with valid defaults (ConvConfigComputeV3)
2. AUTOCORRECT: Invalid values are corrected to valid ones
"""
# Default tile configuration matching ConvConfigComputeV3
defaults = {
"tile_g": 1,
"tile_m": 16,
"tile_n": 64,
"warp_m": 1,
"warp_n": 4,
"warp_k": 1,
"warp_tile_m": 16,
"warp_tile_n": 16,
"warp_tile_k": 32,
"pipeline": "compv3",
"scheduler": "intrawave",
"epilogue": "cshuffle",
"vector_a": 4,
"vector_b": 8,
"vector_c": 8,
"block_per_cu": 1,
"num_wave_groups": 1,
"num_groups_to_merge": 1,
"ndim": 2,
"layout": "nhwgc",
"conv_type": "forward",
"arch": "gfx942",
}
# AUTOFILL: Fill missing parameters with defaults
autofilled = []
for key, value in defaults.items():
if key not in kernel or kernel[key] is None or kernel[key] == -1:
kernel[key] = value
autofilled.append(f"{key}={value}")
if autofilled:
print(f" [AUTOFILL] {', '.join(autofilled)}")
# AUTOCORRECT: Fix invalid wave configurations for gfx942
valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
current_wave = (
kernel.get("warp_m", 1),
kernel.get("warp_n", 4),
kernel.get("warp_k", 1),
)
if current_wave not in valid_wave_configs:
old = current_wave
kernel["warp_m"] = 1
kernel["warp_n"] = 4
kernel["warp_k"] = 1
print(f" [AUTOCORRECT] wave{old} -> wave(1,4,1) (invalid for gfx942)")
# AUTOCORRECT: Fix invalid pipeline for backward ops
conv_type = kernel.get("conv_type", "forward")
pipeline = kernel.get("pipeline", "compv3")
if conv_type in ["bwd_data", "bwd_weight"] and pipeline in ["compv4", "compv5"]:
old_pipeline = pipeline
kernel["pipeline"] = "compv3"
print(
f" [AUTOCORRECT] pipeline {old_pipeline} -> compv3 (invalid for {conv_type})"
)
return kernel
def expand_conv_wildcards(kernel: Dict, arch: str = "gfx942") -> List[Dict]:
"""Expand wildcard parameters to multiple valid configurations.
When users specify wildcards (-1 or *), this expands them to all
valid configurations for the target architecture.
"""
expanded = []
# Valid wave configurations for gfx942
valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
# Valid warp tile configurations for gfx942 fp16
valid_warp_configs = [(16, 16, 32), (32, 32, 16)]
# Check if expansion is needed
needs_wave = kernel.get("warp_m") is None or kernel.get("warp_m") == -1
needs_warp = kernel.get("warp_tile_m") is None or kernel.get("warp_tile_m") == -1
if not needs_wave and not needs_warp:
return [kernel]
# Expand wave configurations
wave_configs = (
valid_wave_configs
if needs_wave
else [
(kernel.get("warp_m", 2), kernel.get("warp_n", 2), kernel.get("warp_k", 1))
]
)
# Expand warp tile configurations
warp_configs = (
valid_warp_configs
if needs_warp
else [
(
kernel.get("warp_tile_m", 32),
kernel.get("warp_tile_n", 32),
kernel.get("warp_tile_k", 16),
)
]
)
for wm, wn, wk in wave_configs:
for wtm, wtn, wtk in warp_configs:
new_kernel = kernel.copy()
new_kernel["warp_m"] = wm
new_kernel["warp_n"] = wn
new_kernel["warp_k"] = wk
new_kernel["warp_tile_m"] = wtm
new_kernel["warp_tile_n"] = wtn
new_kernel["warp_tile_k"] = wtk
expanded.append(new_kernel)
return expanded
def parse_int_or_wildcard(val: str) -> int:
"""Parse integer or return -1 for wildcards.
Supported wildcard formats:
- ANY_INT: Macro defined as -1
- -1: Direct numeric wildcard
- "*": String wildcard (also maps to -1 for integer params)
"""
val = val.strip()
if val == "ANY_INT" or val == "-1" or val == "*":
return -1
return int(val)
def parse_gemm_declarations(content: str) -> List[Dict]:
"""Parse DECL_KERNEL_SET declarations for GEMM.
Supports wildcards:
- ANY_INT for numeric params (wave, warp) -> expands to all valid combos
- "*" for string params (pipeline, scheduler) -> expands to valid options
Each kernel is tagged with its kernel_set name for separate registration.
"""
kernels = []
for match in re.finditer(r"DECL_KERNEL_SET\s*\(\s*(\w+)\s*,", content):
kernel_set_name = match.group(1)
body = extract_balanced_parens(
content, match.start() + content[match.start() :].find("(")
)
if not body:
continue
for add_match in re.finditer(r"\.add\s*\(", body):
add_body = extract_balanced_parens(body, add_match.end() - 1)
kernel = {}
# Signature parameters
if m := re.search(r'\.dtype\s*\(\s*"([^"]+)"', add_body):
kernel["dtype"] = m.group(1)
if m := re.search(r'\.layout\s*\(\s*"([^"]+)"', add_body):
kernel["layout"] = m.group(1)
if m := re.search(r'\.elementwise\s*\(\s*"([^"]+)"\s*,\s*(\d+)', add_body):
kernel["elementwise_op"] = m.group(1)
kernel["num_d_tensors"] = int(m.group(2))
# Algorithm parameters - support ANY_INT wildcard
if m := re.search(
r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body
):
kernel["tile_m"] = int(m.group(1))
kernel["tile_n"] = int(m.group(2))
kernel["tile_k"] = int(m.group(3))
# Wave: support ANY_INT, -1, and "*" as wildcards
if m := re.search(
r"\.wave\s*\(\s*([\w*-]+)\s*,\s*([\w*-]+)\s*,\s*([\w*-]+)\s*\)",
add_body,
):
kernel["warp_m"] = parse_int_or_wildcard(m.group(1))
kernel["warp_n"] = parse_int_or_wildcard(m.group(2))
kernel["warp_k"] = parse_int_or_wildcard(m.group(3))
# Warp: support ANY_INT, -1, and "*" as wildcards
if m := re.search(
r"\.warp\s*\(\s*([\w*-]+)\s*,\s*([\w*-]+)\s*,\s*([\w*-]+)\s*\)",
add_body,
):
kernel["warp_tile_m"] = parse_int_or_wildcard(m.group(1))
kernel["warp_tile_n"] = parse_int_or_wildcard(m.group(2))
kernel["warp_tile_k"] = parse_int_or_wildcard(m.group(3))
# Pipeline/Scheduler: support "*" wildcard
if m := re.search(r'\.pipeline\s*\(\s*"([^"]+)"', add_body):
kernel["pipeline"] = m.group(1)
if m := re.search(r'\.scheduler\s*\(\s*"([^"]+)"', add_body):
kernel["scheduler"] = m.group(1)
if m := re.search(r'\.epilogue\s*\(\s*"([^"]+)"', add_body):
kernel["epilogue"] = m.group(1)
if m := re.search(
r"\.pad\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)",
add_body,
re.I,
):
kernel["pad_m"] = m.group(1).lower() == "true"
kernel["pad_n"] = m.group(2).lower() == "true"
kernel["pad_k"] = m.group(3).lower() == "true"
# Shorthand format: .add("dtype", "layout", M, N, K)
if not kernel.get("dtype"):
if m := re.match(
r'\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)',
add_body,
):
kernel["dtype"] = m.group(1)
kernel["layout"] = m.group(2)
kernel["tile_m"] = int(m.group(3))
kernel["tile_n"] = int(m.group(4))
kernel["tile_k"] = int(m.group(5))
if kernel.get("dtype"):
kernel["kernel_set"] = kernel_set_name
kernels.append(kernel)
# Expand wildcards to multiple kernels
expanded = []
for kernel in kernels:
expanded.extend(expand_gemm_wildcards(kernel))
# Apply autocorrect to each expanded kernel
return [auto_fill_gemm_defaults(k) for k in expanded]
def expand_gemm_wildcards(kernel: Dict, arch: str = "gfx942") -> List[Dict]:
"""Expand wildcard parameters to multiple valid configurations.
When users specify ANY_INT (-1) or "*", this expands them to all
valid configurations for the target architecture.
Note: Block size constraint filters invalid combos:
- (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * 64 <= 1024
- For 128x128 tile: only (32,32,k) works (16 warps * 64 = 1024)
- For 64x64 tile: both (16,16,k) and (32,32,k) work
"""
# Valid wave configurations for gfx942
valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
# Valid warp tile configurations for gfx942 fp16
valid_warp_configs = [(16, 16, 32), (32, 32, 16)]
# Valid pipelines and schedulers
valid_pipelines = ["compv3"] # compv4 requires special handling
valid_schedulers = ["intrawave"]
# Check what needs expansion
needs_wave = kernel.get("warp_m") == -1
needs_warp = kernel.get("warp_tile_m") == -1
needs_pipeline = kernel.get("pipeline") == "*"
needs_scheduler = kernel.get("scheduler") == "*"
if not any([needs_wave, needs_warp, needs_pipeline, needs_scheduler]):
return [kernel]
# Determine configs to iterate
wave_configs = (
valid_wave_configs
if needs_wave
else [
(kernel.get("warp_m", 2), kernel.get("warp_n", 2), kernel.get("warp_k", 1))
]
)
warp_configs = (
valid_warp_configs
if needs_warp
else [
(
kernel.get("warp_tile_m", 32),
kernel.get("warp_tile_n", 32),
kernel.get("warp_tile_k", 16),
)
]
)
pipelines = (
valid_pipelines if needs_pipeline else [kernel.get("pipeline", "compv3")]
)
schedulers = (
valid_schedulers if needs_scheduler else [kernel.get("scheduler", "intrawave")]
)
expanded = []
for wm, wn, wk in wave_configs:
for wtm, wtn, wtk in warp_configs:
# Check block size constraint: (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * 64 <= 1024
tile_m = kernel.get("tile_m", 128)
tile_n = kernel.get("tile_n", 128)
num_warps = (tile_m // wtm) * (tile_n // wtn)
if num_warps * 64 > 1024:
continue # Skip invalid config
for pipe in pipelines:
for sched in schedulers:
new_kernel = kernel.copy()
new_kernel["warp_m"] = wm
new_kernel["warp_n"] = wn
new_kernel["warp_k"] = wk
new_kernel["warp_tile_m"] = wtm
new_kernel["warp_tile_n"] = wtn
new_kernel["warp_tile_k"] = wtk
new_kernel["pipeline"] = pipe
new_kernel["scheduler"] = sched
expanded.append(new_kernel)
if expanded:
print(f" [WILDCARD] Expanded 1 declaration -> {len(expanded)} kernel(s)")
return expanded if expanded else [kernel]
def auto_fill_gemm_defaults(kernel: Dict) -> Dict:
"""Auto-fill missing GEMM parameters with sensible defaults (autofill + autocorrect).
This implements:
1. AUTOFILL: Missing parameters are filled with valid defaults
2. AUTOCORRECT: Invalid values are corrected to valid ones (e.g., wave(1,1,1) -> wave(2,2,1))
"""
defaults = {
"tile_m": 128,
"tile_n": 128,
"tile_k": 64,
"warp_m": 2,
"warp_n": 2,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 16,
"pipeline": "compv3",
"scheduler": "intrawave",
"epilogue": "cshuffle",
"pad_m": False,
"pad_n": False,
"pad_k": False,
"layout": "rcr",
}
# AUTOFILL: Fill missing parameters with defaults
autofilled = []
for key, value in defaults.items():
if key not in kernel or kernel[key] is None or kernel[key] == -1:
kernel[key] = value
autofilled.append(f"{key}={value}")
if autofilled:
print(f" [AUTOFILL] {', '.join(autofilled)}")
# AUTOCORRECT: Fix invalid wave configurations for gfx942
# Valid wave configs: (1,4,1), (2,2,1), (4,1,1)
valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
current_wave = (
kernel.get("warp_m", 2),
kernel.get("warp_n", 2),
kernel.get("warp_k", 1),
)
if current_wave not in valid_wave_configs:
# Correct to (2,2,1) which is a balanced default
old = current_wave
kernel["warp_m"] = 2
kernel["warp_n"] = 2
kernel["warp_k"] = 1
print(f" [AUTOCORRECT] wave{old} -> wave(2,2,1) (invalid for gfx942)")
# AUTOCORRECT: Fix invalid pipeline/scheduler combinations
invalid_combos = [
("compv3", "interwave"),
("compv4", "interwave"),
]
current_combo = (
kernel.get("pipeline", "compv3"),
kernel.get("scheduler", "intrawave"),
)
if current_combo in invalid_combos:
old = current_combo
kernel["scheduler"] = "intrawave"
print(
f" [AUTOCORRECT] {old[0]}/{old[1]} -> {old[0]}/intrawave (invalid combo)"
)
# AUTOCORRECT: Fix warp tile to avoid exceeding max block size (1024 threads)
# Block size = (tile_m / warp_tile_m) * (tile_n / warp_tile_n) * 64
tile_m = kernel.get("tile_m", 128)
tile_n = kernel.get("tile_n", 128)
warp_tile_m = kernel.get("warp_tile_m", 32)
warp_tile_n = kernel.get("warp_tile_n", 32)
num_warps = (tile_m // warp_tile_m) * (tile_n // warp_tile_n)
block_size = num_warps * 64 # 64 threads per warp
if block_size > 1024:
# Find valid warp tile that fits
old_warp = (warp_tile_m, warp_tile_n, kernel.get("warp_tile_k", 16))
# For large tiles, use larger warp tiles
if tile_m >= 256:
kernel["warp_tile_m"] = 64
if tile_n >= 256:
kernel["warp_tile_n"] = 64
# Recalculate
num_warps = (tile_m // kernel["warp_tile_m"]) * (
tile_n // kernel["warp_tile_n"]
)
block_size = num_warps * 64
if block_size <= 1024:
new_warp = (
kernel["warp_tile_m"],
kernel["warp_tile_n"],
kernel["warp_tile_k"],
)
print(
f" [AUTOCORRECT] warp{old_warp} -> warp{new_warp} (block_size={block_size})"
)
else:
# Still too large, try even larger warp tiles
kernel["warp_tile_m"] = tile_m // 4
kernel["warp_tile_n"] = tile_n // 4
new_warp = (
kernel["warp_tile_m"],
kernel["warp_tile_n"],
kernel["warp_tile_k"],
)
print(
f" [AUTOCORRECT] warp{old_warp} -> warp{new_warp} (block_size adjusted)"
)
return kernel
def strip_cpp_strings_and_comments(content: str) -> str:
"""Strip C++ string literals and comments that could cause false positives.
Only strips:
- Comments (// and /* */) - always stripped
- Raw string literals (R"...") - always stripped (can contain anything)
- Regular strings ONLY if they contain problematic patterns like DECL_KERNEL_SET
Preserves normal string literals like "fp16", "rcr" which are needed for parsing.
"""
result = []
i = 0
n = len(content)
# Patterns that indicate a string is problematic and should be stripped
problematic_patterns = ["DECL_KERNEL_SET", "DECL_CONV_KERNEL_SET", ".add("]
while i < n:
# Check for raw string literal: R"delimiter(...)delimiter"
# Always strip these as they can contain arbitrary content
if i < n - 1 and content[i] == "R" and content[i + 1] == '"':
# Find the delimiter (between R" and ()
j = i + 2
delimiter_start = j
while j < n and content[j] != "(":
j += 1
delimiter = content[delimiter_start:j]
# Find the closing )delimiter"
end_marker = ")" + delimiter + '"'
end_pos = content.find(end_marker, j + 1)
if end_pos != -1:
# Replace with spaces to preserve line numbers
span = content[i : end_pos + len(end_marker)]
result.append("".join("\n" if c == "\n" else " " for c in span))
i = end_pos + len(end_marker)
continue
# Check for regular string literal - only strip if it contains problematic patterns
if content[i] == '"':
j = i + 1
while j < n:
if content[j] == "\\" and j + 1 < n:
j += 2 # Skip escaped character
elif content[j] == '"':
j += 1
break
else:
j += 1
string_content = content[i:j]
# Only strip if this string contains problematic patterns
should_strip = any(pat in string_content for pat in problematic_patterns)
if should_strip:
result.append(" " * len(string_content))
else:
result.append(string_content)
i = j
continue
# Check for single-line comment - always strip
if i < n - 1 and content[i : i + 2] == "//":
j = i
while j < n and content[j] != "\n":
j += 1
result.append(" " * (j - i))
i = j
continue
# Check for multi-line comment - always strip
if i < n - 1 and content[i : i + 2] == "/*":
end_pos = content.find("*/", i + 2)
if end_pos != -1:
span = content[i : end_pos + 2]
# Preserve newlines in multi-line comments
result.append("".join("\n" if c == "\n" else " " for c in span))
i = end_pos + 2
continue
result.append(content[i])
i += 1
return "".join(result)
def detect_and_parse(source_path: Path) -> Tuple[str, List[Dict]]:
"""Detect example type and parse kernel declarations.
Properly strips string literals and comments before parsing to avoid
picking up declarations inside strings or commented-out code.
"""
content = source_path.read_text()
content = strip_cpp_strings_and_comments(content)
if "DECL_CONV_KERNEL_SET" in content:
return "conv", parse_conv_declarations(content)
elif "DECL_KERNEL_SET" in content:
return "gemm", parse_gemm_declarations(content)
return "unknown", []
def generate_gemm_registration(
kernel_headers: List[Path], example_name: str, kernels: List[Dict] = None
) -> str:
"""Generate GEMM kernel registration code for the dispatcher registry.
Uses GeneratedKernelInstance<SelectedKernel> to wrap the generated kernels
and provide the KernelInstance interface for the Dispatcher.
If kernels list is provided with kernel_set info, generates separate
registration functions per kernel set.
"""
if not kernel_headers:
return " // No kernels to register"
# Build mapping from kernel config pattern to kernel set
kernel_to_set = {}
kernel_sets = set()
if kernels:
for k in kernels:
tile_m = k.get("tile_m", 128)
tile_n = k.get("tile_n", 128)
tile_k = k.get("tile_k", 64)
warp_m = k.get("warp_m", 2)
warp_n = k.get("warp_n", 2)
warp_k = k.get("warp_k", 1)
warp_tile_m = k.get("warp_tile_m", 32)
warp_tile_n = k.get("warp_tile_n", 32)
warp_tile_k = k.get("warp_tile_k", 16)
# Pattern that appears in kernel filename
key_pattern = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
kernel_set = k.get("kernel_set", "default")
kernel_to_set[key_pattern] = kernel_set
kernel_sets.add(kernel_set)
def generate_registration_block(h: Path) -> str:
"""Generate registration code for a single kernel."""
kernel_name = h.stem
ns = f"ns_{kernel_name}"
# Parse pipeline, scheduler, and layout from kernel name
# Format: gemm_fp16_rcr_compv3_cshuffle_intrawave_...
parts = kernel_name.split("_")
pipeline = "CompV3"
scheduler = "Intrawave"
epilogue = "CShuffle"
datatype = "FP16"
layout_a = "RowMajor"
layout_b = "ColMajor"
layout_c = "RowMajor"
# Parse datatype (e.g., fp16, bf16, fp32)
dtype_map = {
"fp16": "FP16",
"bf16": "BF16",
"fp32": "FP32",
"fp64": "FP64",
"int8": "INT8",
}
# Parse layout from 3-char codes (e.g., rcr, rrr, rrc, ccc)
# r = RowMajor, c = ColMajor
layout_map = {"r": "RowMajor", "c": "ColMajor"}
# Find pipeline, epilogue, scheduler in the name parts
pipeline_map = {
"mem": "Mem",
"compv1": "CompV1",
"compv2": "CompV2",
"compv3": "CompV3",
"compv4": "CompV4",
"compv5": "CompV5",
"preshufflev1": "PreShuffleV1",
"preshufflev2": "PreShuffleV2",
}
scheduler_map = {
"intrawave": "Intrawave",
"interwave": "Interwave",
"auto": "Auto",
}
epilogue_map = {"default": "Default", "cshuffle": "CShuffle", "none": "None"}
for part in parts:
if part in pipeline_map:
pipeline = pipeline_map[part]
if part in scheduler_map:
scheduler = scheduler_map[part]
if part in epilogue_map:
epilogue = epilogue_map[part]
if part in dtype_map:
datatype = dtype_map[part]
# Parse 3-char layout codes (e.g., rcr, rrr)
if len(part) == 3 and all(c in "rc" for c in part):
layout_a = layout_map[part[0]]
layout_b = layout_map[part[1]]
layout_c = layout_map[part[2]]
block = []
block.append(f" // Register kernel: {kernel_name}")
block.append(" {")
block.append(f" using SelectedKernel = {ns}::SelectedKernel;")
block.append(" ck_tile::dispatcher::KernelKey key;")
block.append(
f" key.signature.dtype_a = ck_tile::dispatcher::DataType::{datatype};"
)
block.append(
f" key.signature.dtype_b = ck_tile::dispatcher::DataType::{datatype};"
)
block.append(
f" key.signature.dtype_c = ck_tile::dispatcher::DataType::{datatype};"
)
block.append(
" key.signature.dtype_acc = ck_tile::dispatcher::DataType::FP32;"
)
block.append(
f" key.signature.layout_a = ck_tile::dispatcher::LayoutTag::{layout_a};"
)
block.append(
f" key.signature.layout_b = ck_tile::dispatcher::LayoutTag::{layout_b};"
)
block.append(
f" key.signature.layout_c = ck_tile::dispatcher::LayoutTag::{layout_c};"
)
block.append(" key.algorithm.tile_shape.m = SelectedKernel::TileM;")
block.append(" key.algorithm.tile_shape.n = SelectedKernel::TileN;")
block.append(" key.algorithm.tile_shape.k = SelectedKernel::TileK;")
block.append(
" key.algorithm.wave_shape.m = SelectedKernel::WarpPerBlock_M;"
)
block.append(
" key.algorithm.wave_shape.n = SelectedKernel::WarpPerBlock_N;"
)
block.append(
" key.algorithm.wave_shape.k = SelectedKernel::WarpPerBlock_K;"
)
block.append(
" key.algorithm.warp_tile_shape.m = SelectedKernel::WarpTileM;"
)
block.append(
" key.algorithm.warp_tile_shape.n = SelectedKernel::WarpTileN;"
)
block.append(
" key.algorithm.warp_tile_shape.k = SelectedKernel::WarpTileK;"
)
block.append(
" key.algorithm.block_size = SelectedKernel::BlockSize;"
)
block.append(
f" key.algorithm.pipeline = ck_tile::dispatcher::Pipeline::{pipeline};"
)
block.append(
f" key.algorithm.scheduler = ck_tile::dispatcher::Scheduler::{scheduler};"
)
block.append(
f" key.algorithm.epilogue = ck_tile::dispatcher::Epilogue::{epilogue};"
)
block.append(" key.gfx_arch = arch;")
block.append(
f' auto instance = std::make_shared<ck_tile::dispatcher::backends::GeneratedKernelInstance<SelectedKernel>>(key, "{kernel_name}");'
)
block.append(" registry.register_kernel(instance);")
block.append(" }")
return "\n".join(block)
def find_kernel_set(header: Path) -> str:
"""Find which kernel set a header belongs to."""
name = header.stem
for pattern, kset in kernel_to_set.items():
if pattern in name:
return kset
return "default"
# Group kernels by set
kernels_by_set = {}
for h in kernel_headers:
kset = find_kernel_set(h)
if kset not in kernels_by_set:
kernels_by_set[kset] = []
kernels_by_set[kset].append(h)
# If only one set or no set info, use simple registration
if len(kernels_by_set) <= 1:
lines = [" (void)arch;", ""]
for h in kernel_headers:
lines.append(generate_registration_block(h))
return "\n".join(lines)
# Multiple sets - generate registration for all, plus store per-set info
lines = [" // Register ALL kernels from all sets", " (void)arch;", ""]
for h in kernel_headers:
lines.append(generate_registration_block(h))
# Store per-set mapping for separate function generation
global _kernels_by_set_cache
_kernels_by_set_cache = (kernels_by_set, generate_registration_block)
return "\n".join(lines)
# Global cache for per-set kernel info
_kernels_by_set_cache = None
def generate_per_set_functions(source_stem: str) -> str:
"""Generate separate registration functions for each kernel set.
Generates:
1. Per-set functions: register_<set_name>(registry, arch)
2. String-based dispatcher: register_kernel_set("set_name", registry, arch)
3. get_kernel_set_names() to list available sets
"""
global _kernels_by_set_cache
if not _kernels_by_set_cache:
return ""
kernels_by_set, gen_block = _kernels_by_set_cache
_kernels_by_set_cache = None # Clear cache
lines = []
set_names = []
# Generate per-set functions
for set_name, headers in kernels_by_set.items():
safe_name = set_name.replace("-", "_")
set_names.append((set_name, safe_name))
lines.append(
f"inline void register_{safe_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{"
)
lines.append(" (void)arch;")
for h in headers:
lines.append(gen_block(h))
lines.append("}")
lines.append("")
# Generate string-based dispatcher (only if multiple sets)
if len(set_names) > 0:
lines.append("// Dynamic registration by kernel set name")
lines.append(
"inline bool register_kernel_set(const std::string& set_name, ck_tile::dispatcher::Registry& registry, const std::string& arch) {"
)
for set_name, safe_name in set_names:
lines.append(
f' if (set_name == "{set_name}") {{ register_{safe_name}(registry, arch); return true; }}'
)
lines.append(" return false; // Unknown set name")
lines.append("}")
lines.append("")
# Generate helper to list available set names
lines.append("// Get list of available kernel set names")
lines.append("inline std::vector<std::string> get_kernel_set_names() {")
names_str = ", ".join(f'"{name}"' for name, _ in set_names)
lines.append(f" return {{{names_str}}};")
lines.append("}")
lines.append("")
return "\n".join(lines)
def generate_conv_registration(
kernel_headers: List[Path], example_name: str, kernels: List[Dict]
) -> str:
"""Generate Conv kernel registration code for the dispatcher registry."""
if not kernel_headers:
return " // No kernels to register"
lines = []
lines.append(
" (void)registry; (void)arch; // Conv uses direct launcher pattern for now"
)
# For conv, we provide direct access to kernel launchers
for i, h in enumerate(kernel_headers):
kernel_name = h.stem
lines.append(f" // Kernel {i + 1}: {kernel_name}")
return "\n".join(lines)
def generate_conv_kernels(
kernels: List[Dict], output_dir: Path, codegen_dir: Path
) -> bool:
"""Generate Conv kernels for ALL declarations using unified codegen."""
if not kernels:
return False
variant_map = {
"forward": "forward",
"bwd_data": "bwd_data",
"backward_data": "bwd_data",
"bwd_weight": "bwd_weight",
"backward_weight": "bwd_weight",
}
success_count = 0
# Generate a kernel for EACH declaration
for idx, k in enumerate(kernels):
variant = variant_map.get(k.get("conv_type", "forward"), "forward")
cmd = [
sys.executable,
str(codegen_dir / "unified_conv_codegen.py"),
"--datatype",
k.get("dtype", "fp16"),
"--variant",
variant,
"--ndim",
str(k.get("ndim", 2)),
"--output",
str(output_dir),
]
# Add optional parameters if specified
if k.get("tile_m"):
cmd.extend(["--tile-m", str(k["tile_m"])])
if k.get("tile_n"):
cmd.extend(["--tile-n", str(k["tile_n"])])
if k.get("warp_m"):
cmd.extend(["--warp-m", str(k["warp_m"])])
if k.get("warp_n"):
cmd.extend(["--warp-n", str(k["warp_n"])])
if k.get("warp_k"):
cmd.extend(["--warp-k", str(k["warp_k"])])
if k.get("warp_tile_m"):
cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])])
if k.get("warp_tile_n"):
cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])])
if k.get("warp_tile_k"):
cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])])
if k.get("pipeline"):
cmd.extend(["--pipeline", k["pipeline"]])
if k.get("scheduler"):
cmd.extend(["--scheduler", k["scheduler"]])
if k.get("epilogue"):
cmd.extend(["--epilogue", k["epilogue"]])
if k.get("vector_a"):
cmd.extend(["--vector-a", str(k["vector_a"])])
if k.get("vector_b"):
cmd.extend(["--vector-b", str(k["vector_b"])])
if k.get("vector_c"):
cmd.extend(["--vector-c", str(k["vector_c"])])
if k.get("block_per_cu"):
cmd.extend(["--block-per-cu", str(k["block_per_cu"])])
if k.get("num_wave_groups"):
cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])])
if k.get("num_groups_to_merge"):
cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])])
if k.get("double_smem_buffer") is not None:
cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()])
if k.get("tile_k"):
cmd.extend(["--tile-k", str(k["tile_k"])])
result = subprocess.run(
cmd, capture_output=True, text=True, cwd=str(codegen_dir)
)
if result.returncode != 0:
print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}")
else:
success_count += 1
return success_count > 0
def generate_gemm_kernels(
kernels: List[Dict], output_dir: Path, codegen_dir: Path
) -> bool:
"""Generate GEMM kernels for ALL declarations using unified codegen."""
import json
if not kernels:
return False
success_count = 0
# Generate a kernel for EACH declaration
for idx, k in enumerate(kernels):
variant = "multi_d" if k.get("elementwise_op") else "standard"
# Build tile config JSON for this specific kernel
tile_config = {
"tile_m": [k.get("tile_m", 128)],
"tile_n": [k.get("tile_n", 128)],
"tile_k": [k.get("tile_k", 32)],
"warp_m": [k.get("warp_m", 2)],
"warp_n": [k.get("warp_n", 2)],
"warp_k": [k.get("warp_k", 1)],
"warp_tile_m": [k.get("warp_tile_m", 32)],
"warp_tile_n": [k.get("warp_tile_n", 32)],
"warp_tile_k": [k.get("warp_tile_k", 16)],
}
trait_config = {
"pipeline": [k.get("pipeline", "compv3")],
"epilogue": [k.get("epilogue", "cshuffle")],
"scheduler": [k.get("scheduler", "intrawave")],
"pad_m": [k.get("pad_m", False)],
"pad_n": [k.get("pad_n", False)],
"pad_k": [k.get("pad_k", False)],
"persistent": [False],
}
config_json = json.dumps(
{"tile_config": tile_config, "trait_config": trait_config}
)
cmd = [
sys.executable,
str(codegen_dir / "unified_gemm_codegen.py"),
"--datatype",
k.get("dtype", "fp16"),
"--layout",
k.get("layout", "rcr"),
"--variants",
variant,
"--output",
str(output_dir),
"--tile-config-json",
config_json,
]
result = subprocess.run(
cmd, capture_output=True, text=True, cwd=str(codegen_dir)
)
if result.returncode != 0:
print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}")
else:
success_count += 1
return success_count > 0
def compile_kernel(args: Tuple) -> Tuple[str, bool, str]:
"""Compile a single kernel to object file."""
kernel_hpp, output_dir, include_dirs, hipcc, gpu_target, idx, total = args
kernel_name = kernel_hpp.stem
wrapper_cpp = output_dir / f"{kernel_name}.cpp"
wrapper_cpp.write_text(
f'#include "{kernel_hpp.name}"\nnamespace {{ volatile bool _k{idx} = true; }}\n'
)
obj_file = output_dir / f"{kernel_name}.o"
cmd = [
hipcc,
"-c",
"-fPIC",
"-std=c++17",
"-O3",
f"--offload-arch={gpu_target}",
"-mllvm",
"-enable-noalias-to-md-conversion=0",
"-Wno-undefined-func-template",
"-Wno-float-equal",
"--offload-compress",
]
for inc_dir in include_dirs:
cmd.extend(["-I", str(inc_dir)])
cmd.extend(["-I", str(kernel_hpp.parent)])
cmd.extend(["-o", str(obj_file), str(wrapper_cpp)])
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
return (kernel_name, False, result.stderr[:500])
return (kernel_name, True, str(obj_file))
def main():
parser = argparse.ArgumentParser(description="Build example kernels")
parser.add_argument("source", type=Path, help="C++ source file")
parser.add_argument("--output-dir", type=Path, required=True)
parser.add_argument("--include-dirs", type=str, required=True)
parser.add_argument("--gpu-target", type=str, default="gfx942")
parser.add_argument("--jobs", type=int, default=os.cpu_count())
parser.add_argument(
"--target-name", type=str, help="CMake target name (for library naming)"
)
args = parser.parse_args()
script_dir = Path(__file__).parent
codegen_dir = script_dir.parent / "codegen"
source_stem = args.source.stem # e.g., "01_basic_gemm"
target_name = args.target_name or source_stem # e.g., "gemm_01_basic" from CMake
args.output_dir.mkdir(parents=True, exist_ok=True)
# Detect and parse
example_type, kernels = detect_and_parse(args.source)
if example_type == "conv":
k = kernels[0] if kernels else {}
variant = k.get("conv_type", "forward")
print(
f"[{target_name}] Conv {k.get('dtype', 'fp16')} {variant} {k.get('ndim', 2)}D ({len(kernels)} declarations)"
)
elif example_type == "gemm":
k = kernels[0] if kernels else {}
print(
f"[{target_name}] GEMM {k.get('dtype', 'fp16')} {k.get('layout', 'rcr')} ({len(kernels)} declarations)"
)
else:
print(f"[{target_name}] No kernel declarations - creating empty library")
lib_path = args.output_dir / f"lib{target_name}_kernels.a"
subprocess.run([find_ar(), "rcs", str(lib_path)], check=True)
header = args.output_dir / f"{source_stem}_kernels.hpp"
header.write_text(f"// No kernels for {target_name}\n#pragma once\n")
return 0
# Generate kernels
print(f"[{target_name}] Generating kernels...")
if example_type == "conv":
success = generate_conv_kernels(kernels, args.output_dir, codegen_dir)
else:
success = generate_gemm_kernels(kernels, args.output_dir, codegen_dir)
if not success:
print(f"[{target_name}] Kernel generation failed!")
return 1
# Find generated headers
if example_type == "gemm":
kernel_headers = list(args.output_dir.glob("gemm_*.hpp"))
else:
k = kernels[0] if kernels else {}
variant = k.get("conv_type", "forward")
prefix_map = {
"forward": "conv_fwd",
"bwd_data": "conv_bwdd",
"bwd_weight": "conv_bwdw",
}
prefix = prefix_map.get(variant, "conv_fwd")
kernel_headers = list(args.output_dir.glob(f"{prefix}_*.hpp"))
if not kernel_headers:
print(f"[{target_name}] No kernel headers generated!")
return 1
print(f"[{target_name}] Compiling {len(kernel_headers)} kernels...")
include_dirs = [Path(p.strip()) for p in args.include_dirs.split(",")]
hipcc = find_hipcc()
work = [
(
h,
args.output_dir,
include_dirs,
hipcc,
args.gpu_target,
i + 1,
len(kernel_headers),
)
for i, h in enumerate(kernel_headers)
]
obj_files = []
failed = []
with ProcessPoolExecutor(max_workers=args.jobs) as executor:
futures = {executor.submit(compile_kernel, w): w[0].name for w in work}
for future in as_completed(futures):
name, ok, result = future.result()
if ok:
obj_files.append(result)
else:
failed.append((name, result))
print(f"[{target_name}] FAILED: {name}")
if failed:
print(f"[{target_name}] {len(failed)} kernels failed")
for name, err in failed[:3]:
print(f" {name}: {err[:200]}")
return 1
# Create static library (use target_name for CMake compatibility)
lib_path = args.output_dir / f"lib{target_name}_kernels.a"
subprocess.run([find_ar(), "rcs", str(lib_path)] + obj_files, check=True)
# Generate registration header (use source_stem for header name to match CMake's EXAMPLE_STEM)
header_path = args.output_dir / f"{source_stem}_kernels.hpp"
# Build includes
includes = "\n".join(f'#include "{h.name}"' for h in kernel_headers)
# Build kernel registration entries
# Function name uses source_stem (e.g., register_01_basic_gemm_kernels)
func_name = f"register_{source_stem}_kernels"
# Generate registration code based on example type
if example_type == "gemm":
register_body = generate_gemm_registration(kernel_headers, target_name, kernels)
else:
register_body = generate_conv_registration(kernel_headers, target_name, kernels)
# Generate appropriate header based on example type
if example_type == "conv" and kernel_headers:
launcher_aliases = []
# Helper to find kernel by dtype and type
def find_kernel_by_dtype_type(headers, dtype, conv_type_marker):
"""Find kernel matching dtype and conv type, prioritize fp16."""
matching = [h for h in headers if conv_type_marker in h.stem]
# Prefer fp16 over bf16 for default launchers
fp16_kernels = [h for h in matching if f"_{dtype}_" in h.stem]
return (
fp16_kernels[0] if fp16_kernels else (matching[0] if matching else None)
)
# Check what conv types are in the declarations
has_fwd = any("forward" in k.get("conv_type", "forward") for k in kernels)
has_bwd_data = any("bwd_data" in k.get("conv_type", "") for k in kernels)
has_bwd_weight = any("bwd_weight" in k.get("conv_type", "") for k in kernels)
# Export dtype-specific launcher aliases for each available dtype
for dtype in ["fp16", "bf16", "fp32"]:
dtype_fwd_kernels = [
h
for h in kernel_headers
if "_fwd_" in h.stem and f"_{dtype}_" in h.stem
]
if dtype_fwd_kernels:
k = dtype_fwd_kernels[0]
ns = f"ns_{k.stem}"
dtype_upper = dtype.upper()
launcher_aliases.append(
f"using {dtype_upper}FwdKernelLauncher = {ns}::{k.stem}_Launcher;"
)
# Export generic launcher aliases (prioritize fp16)
if has_fwd:
fwd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_fwd_")
if fwd_kernel:
fwd_ns = f"ns_{fwd_kernel.stem}"
launcher_aliases.append(
f"using FwdKernelLauncher = {fwd_ns}::{fwd_kernel.stem}_Launcher;"
)
launcher_aliases.append(
f"using FirstKernelLauncher = {fwd_ns}::{fwd_kernel.stem}_Launcher;"
)
if has_bwd_data:
bwdd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdd_")
if bwdd_kernel:
bwdd_ns = f"ns_{bwdd_kernel.stem}"
launcher_aliases.append(
f"using BwdDataKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;"
)
if not has_fwd: # If no fwd, use bwd_data as first
launcher_aliases.append(
f"using FirstKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;"
)
if has_bwd_weight:
bwdw_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdw_")
if bwdw_kernel:
bwdw_ns = f"ns_{bwdw_kernel.stem}"
launcher_aliases.append(
f"using BwdWeightKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;"
)
if (
not has_fwd and not has_bwd_data
): # If no fwd or bwdd, use bwdw as first
launcher_aliases.append(
f"using FirstKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;"
)
launcher_section = "\n".join(launcher_aliases)
header_content = f"""// Auto-generated for {target_name}
#pragma once
{includes}
#include "ck_tile/dispatcher/registry.hpp"
#include "ck_tile/dispatcher/kernel_instance.hpp"
#include "ck_tile/dispatcher/kernel_key.hpp"
namespace generated {{
// Kernel launchers for direct use
{launcher_section}
// Registration function
inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{
{register_body}
}}
}} // namespace generated
// Generic registration - avoids hardcoding the example name in user code
// Safe for single-example executables (typical use case)
#ifndef REGISTER_GENERATED_KERNELS
#define REGISTER_GENERATED_KERNELS(registry, arch) generated::{func_name}(registry, arch)
#endif
"""
else:
# GEMM: Generate per-set functions if multiple kernel sets declared
per_set_funcs = generate_per_set_functions(source_stem)
header_content = f"""// Auto-generated for {target_name}
#pragma once
{includes}
#include "ck_tile/dispatcher/registry.hpp"
#include "ck_tile/dispatcher/kernel_instance.hpp"
#include "ck_tile/dispatcher/kernel_key.hpp"
#include "ck_tile/dispatcher/backends/generated_kernel_backend.hpp"
namespace generated {{
// Register ALL kernels from all declared sets
inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{
{register_body}
}}
{per_set_funcs}
}} // namespace generated
// Generic registration - avoids hardcoding the example name in user code
// Safe for single-example executables (typical use case)
#ifndef REGISTER_GENERATED_KERNELS
#define REGISTER_GENERATED_KERNELS(registry, arch) generated::{func_name}(registry, arch)
#endif
// Register a specific kernel set by name (for multi-registry patterns)
// Usage: REGISTER_KERNEL_SET("compute_bound_set", registry, arch)
#ifndef REGISTER_KERNEL_SET
#define REGISTER_KERNEL_SET(set_name, registry, arch) generated::register_kernel_set(set_name, registry, arch)
#endif
"""
header_path.write_text(header_content)
print(f"[{target_name}] ✓ {len(obj_files)} kernels compiled")
return 0
if __name__ == "__main__":
sys.exit(main())