Files
composable_kernel/dispatcher/codegen/generate_dispatcher_registration.py
Vidyasagar Ananthan 920acd2c12 [rocm-libraries] ROCm/rocm-libraries#5168 (commit 8b5afcb)
[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher

## Motivation

This PR adds CK Tile group convolution (forward, backward-data,
backward-weight) support to the kernel dispatcher, matching and unifying
with the existing dispatcher GEMM infrastructure in architecture and
usability. The dispatcher provides a unified kernel dispatch system with
both C++ and Python frontends, and until now only supported GEMM
operations. This PR enables framework integrators to use the same
declarative kernel workflow for convolutions as they do for GEMM:
declare kernels, build a registry JIT, select kernels within the
registry at runtime, and dispatch to GPU. Future PRs will include
runtime kernel selection heuristics for autotuning of kernel parameters
based on (problem, hardware arch).

## Technical Details

Grouped convolution support has been added to the CK Tile Dispatcher
with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out,
problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime
heuristic kernel selection, and GroupedConvKernelKey with full
ConvConfigBase fields. Python side adds parallel JIT via
registry.build(max_workers) and heuristic registry.select(). Includes 7
C++ and 6 Python examples covering all directions with CPU reference
validation, and shared infrastructure improvements (BaseRegistry CRTP,
structured exceptions). As a sanity check, JIT compile times for a
single kernel remains the same and for multiple kernels there is better
parallelism:
Kernels | 1 worker | 8 workers
1 | 7.7 s | 7.7 s
2 | 15.9 s | 8.2 s
4 | 33.4 s | 9.7 s
6 | 52.3 s | 10.2 s

## Test Plan

145 ephemeral unit tests have been added to test basic functionality.
All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7
C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference
validation for forward, backward-data, and backward-weight (2D) in both
C++ and Python examples pass.

## Test Result

30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56),
53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002
for all directions (fp16 vs fp32 reference).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-09 17:39:35 +00:00

430 lines
14 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Generate dispatcher registration code for CK Tile kernels
This script generates C++ registration code that instantiates TileKernelInstance
templates for each generated kernel, solving the "cannot instantiate from parsed headers" problem.
"""
import json
import argparse
from pathlib import Path
from typing import List
from dataclasses import dataclass
@dataclass
class KernelConfig:
"""Kernel configuration for registration"""
name: str
header_file: str
tile_m: int
tile_n: int
tile_k: int
warp_m: int
warp_n: int
warp_k: int
warp_tile_m: int
warp_tile_n: int
warp_tile_k: int
block_size: int
pipeline: str
epilogue: str
scheduler: str
pad_m: bool
pad_n: bool
pad_k: bool
persistent: bool
double_buffer: bool
transpose_c: bool
dtype_a: str = "fp16"
dtype_b: str = "fp16"
dtype_c: str = "fp16"
dtype_acc: str = "fp32"
layout_a: str = "row"
layout_b: str = "col"
layout_c: str = "row"
def generate_registration_header(kernels: List[KernelConfig], output_file: Path):
"""Generate registration header file"""
content = """// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
//
// AUTO-GENERATED FILE - DO NOT EDIT
// Generated by generate_dispatcher_registration.py
#pragma once
#include "ck_tile/dispatcher/registry.hpp"
#include "ck_tile/dispatcher/backends/tile_backend.hpp"
#include "ck_tile/dispatcher/backends/kernel_registration.hpp"
// Include all generated kernel headers
"""
# Add includes for all kernel headers
for kernel in kernels:
content += f'#include "{kernel.header_file}"\n'
content += """
namespace ck_tile {
namespace dispatcher {
namespace generated {
/// Register all generated kernels with the dispatcher
inline void register_all_kernels(Registry& registry)
{
"""
# Add registration calls for each kernel
for kernel in kernels:
# Extract the SelectedKernel type name from the header file
# Assuming the header defines a type like: using SelectedKernel = ...
kernel_type = f"SelectedKernel_{kernel.name}"
content += f""" // Register {kernel.name}
register_tile_kernel<{kernel_type}>(registry, "{kernel.name}");
"""
content += """}
/// Register all generated kernels with the global registry
inline void register_all_kernels()
{
auto& registry = Registry::instance();
register_all_kernels(registry);
}
} // namespace generated
} // namespace dispatcher
} // namespace ck_tile
"""
output_file.write_text(content)
print(f"OK Generated registration header: {output_file}")
def generate_registration_cpp(kernels: List[KernelConfig], output_file: Path):
"""Generate registration implementation file"""
content = """// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
//
// AUTO-GENERATED FILE - DO NOT EDIT
// Generated by generate_dispatcher_registration.py
#include "dispatcher_registration.hpp"
namespace ck_tile {
namespace dispatcher {
namespace generated {
// Explicit instantiations to reduce compile time
// These ensure the templates are instantiated once
"""
for kernel in kernels:
kernel_type = f"SelectedKernel_{kernel.name}"
content += f"template class backends::TileKernelInstance<{kernel_type}>;\n"
content += """
} // namespace generated
} // namespace dispatcher
} // namespace ck_tile
"""
output_file.write_text(content)
print(f"OK Generated registration implementation: {output_file}")
def generate_kernel_wrapper_header(kernel: KernelConfig, output_dir: Path):
"""Generate a wrapper header that defines SelectedKernel type"""
wrapper_file = output_dir / f"{kernel.name}_wrapper.hpp"
content = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
//
// AUTO-GENERATED FILE - DO NOT EDIT
// Generated by generate_dispatcher_registration.py
#pragma once
#include "{kernel.header_file}"
namespace ck_tile {{
namespace dispatcher {{
namespace generated {{
// Type alias for dispatcher registration
// This allows the registration code to reference the kernel type
using SelectedKernel_{kernel.name} = /* Actual kernel type from generated header */;
}} // namespace generated
}} // namespace dispatcher
}} // namespace ck_tile
"""
wrapper_file.write_text(content)
def load_kernel_manifest(manifest_file: Path) -> List[KernelConfig]:
"""Load kernel configurations from manifest file"""
with open(manifest_file, "r") as f:
data = json.load(f)
kernels = []
for kernel_data in data.get("kernels", []):
kernel = KernelConfig(
name=kernel_data["name"],
header_file=kernel_data["header_file"],
tile_m=kernel_data["tile_m"],
tile_n=kernel_data["tile_n"],
tile_k=kernel_data["tile_k"],
warp_m=kernel_data.get("warp_m", 2),
warp_n=kernel_data.get("warp_n", 2),
warp_k=kernel_data.get("warp_k", 1),
warp_tile_m=kernel_data.get("warp_tile_m", 32),
warp_tile_n=kernel_data.get("warp_tile_n", 32),
warp_tile_k=kernel_data.get("warp_tile_k", 16),
block_size=kernel_data.get("block_size", 256),
pipeline=kernel_data.get("pipeline", "compv4"),
epilogue=kernel_data.get("epilogue", "cshuffle"),
scheduler=kernel_data.get("scheduler", "intrawave"),
pad_m=kernel_data.get("pad_m", False),
pad_n=kernel_data.get("pad_n", False),
pad_k=kernel_data.get("pad_k", False),
persistent=kernel_data.get("persistent", False),
double_buffer=kernel_data.get("double_buffer", True),
transpose_c=kernel_data.get("transpose_c", False),
dtype_a=kernel_data.get("dtype_a", "fp16"),
dtype_b=kernel_data.get("dtype_b", "fp16"),
dtype_c=kernel_data.get("dtype_c", "fp16"),
dtype_acc=kernel_data.get("dtype_acc", "fp32"),
)
kernels.append(kernel)
return kernels
def scan_generated_headers(generated_dir: Path) -> List[KernelConfig]:
"""Scan generated headers and extract kernel configurations"""
import re
kernels = []
for header_file in generated_dir.glob("**/*.hpp"):
try:
content = header_file.read_text()
# Extract kernel name
name_match = re.search(
r'constexpr const char\* KERNEL_NAME\s*=\s*"([^"]+)"', content
)
if not name_match:
continue
kernel_name = name_match.group(1)
# Extract tile configuration (support ck_tile::index_t)
tile_m_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileM\s*=\s*(\d+)",
content,
)
tile_n_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileN\s*=\s*(\d+)",
content,
)
tile_k_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileK\s*=\s*(\d+)",
content,
)
tile_m = int(tile_m_match.group(1)) if tile_m_match else 256
tile_n = int(tile_n_match.group(1)) if tile_n_match else 256
tile_k = int(tile_k_match.group(1)) if tile_k_match else 32
# Extract warp configuration
warp_m_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_M\s*=\s*(\d+)",
content,
)
warp_n_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_N\s*=\s*(\d+)",
content,
)
warp_k_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_K\s*=\s*(\d+)",
content,
)
warp_m = int(warp_m_match.group(1)) if warp_m_match else 2
warp_n = int(warp_n_match.group(1)) if warp_n_match else 2
warp_k = int(warp_k_match.group(1)) if warp_k_match else 1
# Extract warp tile configuration
warp_tile_m_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileM\s*=\s*(\d+)",
content,
)
warp_tile_n_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileN\s*=\s*(\d+)",
content,
)
warp_tile_k_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileK\s*=\s*(\d+)",
content,
)
warp_tile_m = int(warp_tile_m_match.group(1)) if warp_tile_m_match else 32
warp_tile_n = int(warp_tile_n_match.group(1)) if warp_tile_n_match else 32
warp_tile_k = int(warp_tile_k_match.group(1)) if warp_tile_k_match else 16
# Extract other parameters (with defaults)
block_size_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+BlockSize\s*=\s*(\d+)",
content,
)
block_size = int(block_size_match.group(1)) if block_size_match else 256
# Extract boolean flags
pad_m = re.search(r"kPadM\s*=\s*true", content) is not None
pad_n = re.search(r"kPadN\s*=\s*true", content) is not None
pad_k = re.search(r"kPadK\s*=\s*true", content) is not None
persistent = (
re.search(r"UsePersistentKernel\s*=\s*true", content) is not None
)
double_buffer = (
re.search(r"DoubleSmemBuffer\s*=\s*true", content) is not None
)
transpose_c = re.search(r"TransposeC\s*=\s*true", content) is not None
kernel = KernelConfig(
name=kernel_name,
header_file=str(header_file.relative_to(generated_dir.parent)),
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
warp_m=warp_m,
warp_n=warp_n,
warp_k=warp_k,
warp_tile_m=warp_tile_m,
warp_tile_n=warp_tile_n,
warp_tile_k=warp_tile_k,
block_size=block_size,
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=pad_m,
pad_n=pad_n,
pad_k=pad_k,
persistent=persistent,
double_buffer=double_buffer,
transpose_c=transpose_c,
)
kernels.append(kernel)
except Exception as e:
print(f"Warning: Failed to parse {header_file}: {e}")
continue
return kernels
def main():
parser = argparse.ArgumentParser(
description="Generate dispatcher registration code"
)
parser.add_argument(
"--generated-dir",
type=str,
required=True,
help="Directory containing generated kernel headers",
)
parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Output directory for registration code",
)
parser.add_argument(
"--manifest", type=str, help="Optional manifest file with kernel configurations"
)
parser.add_argument(
"--scan",
action="store_true",
help="Scan generated headers instead of using manifest",
)
args = parser.parse_args()
generated_dir = Path(args.generated_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Load kernel configurations
if args.manifest:
print(f"Loading kernels from manifest: {args.manifest}")
kernels = load_kernel_manifest(Path(args.manifest))
elif args.scan:
print(f"Scanning generated headers in: {generated_dir}")
kernels = scan_generated_headers(generated_dir)
else:
print("Error: Must specify either --manifest or --scan")
return 1
print(f"Found {len(kernels)} kernels")
# Generate registration code
registration_header = output_dir / "dispatcher_registration.hpp"
registration_cpp = output_dir / "dispatcher_registration.cpp"
generate_registration_header(kernels, registration_header)
generate_registration_cpp(kernels, registration_cpp)
# Generate manifest for Python
manifest_output = output_dir / "kernels_manifest.json"
manifest_data = {
"kernels": [
{
"name": k.name,
"header_file": k.header_file,
"tile_m": k.tile_m,
"tile_n": k.tile_n,
"tile_k": k.tile_k,
"block_size": k.block_size,
"persistent": k.persistent,
}
for k in kernels
]
}
with open(manifest_output, "w") as f:
json.dump(manifest_data, f, indent=2)
print(f"OK Generated manifest: {manifest_output}")
print("\nOK Registration code generation complete!")
print(f" Total kernels: {len(kernels)}")
print(" Output files:")
print(f" - {registration_header}")
print(f" - {registration_cpp}")
print(f" - {manifest_output}")
return 0
if __name__ == "__main__":
exit(main())