mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher ## Motivation This PR adds CK Tile group convolution (forward, backward-data, backward-weight) support to the kernel dispatcher, matching and unifying with the existing dispatcher GEMM infrastructure in architecture and usability. The dispatcher provides a unified kernel dispatch system with both C++ and Python frontends, and until now only supported GEMM operations. This PR enables framework integrators to use the same declarative kernel workflow for convolutions as they do for GEMM: declare kernels, build a registry JIT, select kernels within the registry at runtime, and dispatch to GPU. Future PRs will include runtime kernel selection heuristics for autotuning of kernel parameters based on (problem, hardware arch). ## Technical Details Grouped convolution support has been added to the CK Tile Dispatcher with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out, problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime heuristic kernel selection, and GroupedConvKernelKey with full ConvConfigBase fields. Python side adds parallel JIT via registry.build(max_workers) and heuristic registry.select(). Includes 7 C++ and 6 Python examples covering all directions with CPU reference validation, and shared infrastructure improvements (BaseRegistry CRTP, structured exceptions). As a sanity check, JIT compile times for a single kernel remains the same and for multiple kernels there is better parallelism: Kernels | 1 worker | 8 workers 1 | 7.7 s | 7.7 s 2 | 15.9 s | 8.2 s 4 | 33.4 s | 9.7 s 6 | 52.3 s | 10.2 s ## Test Plan 145 ephemeral unit tests have been added to test basic functionality. All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7 C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference validation for forward, backward-data, and backward-weight (2D) in both C++ and Python examples pass. ## Test Result 30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56), 53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002 for all directions (fp16 vs fp32 reference). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
430 lines
14 KiB
Python
430 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Generate dispatcher registration code for CK Tile kernels
|
|
|
|
This script generates C++ registration code that instantiates TileKernelInstance
|
|
templates for each generated kernel, solving the "cannot instantiate from parsed headers" problem.
|
|
"""
|
|
|
|
import json
|
|
import argparse
|
|
from pathlib import Path
|
|
from typing import List
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class KernelConfig:
|
|
"""Kernel configuration for registration"""
|
|
|
|
name: str
|
|
header_file: str
|
|
tile_m: int
|
|
tile_n: int
|
|
tile_k: int
|
|
warp_m: int
|
|
warp_n: int
|
|
warp_k: int
|
|
warp_tile_m: int
|
|
warp_tile_n: int
|
|
warp_tile_k: int
|
|
block_size: int
|
|
pipeline: str
|
|
epilogue: str
|
|
scheduler: str
|
|
pad_m: bool
|
|
pad_n: bool
|
|
pad_k: bool
|
|
persistent: bool
|
|
double_buffer: bool
|
|
transpose_c: bool
|
|
dtype_a: str = "fp16"
|
|
dtype_b: str = "fp16"
|
|
dtype_c: str = "fp16"
|
|
dtype_acc: str = "fp32"
|
|
layout_a: str = "row"
|
|
layout_b: str = "col"
|
|
layout_c: str = "row"
|
|
|
|
|
|
def generate_registration_header(kernels: List[KernelConfig], output_file: Path):
|
|
"""Generate registration header file"""
|
|
|
|
content = """// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
//
|
|
// AUTO-GENERATED FILE - DO NOT EDIT
|
|
// Generated by generate_dispatcher_registration.py
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/dispatcher/registry.hpp"
|
|
#include "ck_tile/dispatcher/backends/tile_backend.hpp"
|
|
#include "ck_tile/dispatcher/backends/kernel_registration.hpp"
|
|
|
|
// Include all generated kernel headers
|
|
"""
|
|
|
|
# Add includes for all kernel headers
|
|
for kernel in kernels:
|
|
content += f'#include "{kernel.header_file}"\n'
|
|
|
|
content += """
|
|
|
|
namespace ck_tile {
|
|
namespace dispatcher {
|
|
namespace generated {
|
|
|
|
/// Register all generated kernels with the dispatcher
|
|
inline void register_all_kernels(Registry& registry)
|
|
{
|
|
"""
|
|
|
|
# Add registration calls for each kernel
|
|
for kernel in kernels:
|
|
# Extract the SelectedKernel type name from the header file
|
|
# Assuming the header defines a type like: using SelectedKernel = ...
|
|
kernel_type = f"SelectedKernel_{kernel.name}"
|
|
|
|
content += f""" // Register {kernel.name}
|
|
register_tile_kernel<{kernel_type}>(registry, "{kernel.name}");
|
|
"""
|
|
|
|
content += """}
|
|
|
|
/// Register all generated kernels with the global registry
|
|
inline void register_all_kernels()
|
|
{
|
|
auto& registry = Registry::instance();
|
|
register_all_kernels(registry);
|
|
}
|
|
|
|
} // namespace generated
|
|
} // namespace dispatcher
|
|
} // namespace ck_tile
|
|
"""
|
|
|
|
output_file.write_text(content)
|
|
print(f"OK Generated registration header: {output_file}")
|
|
|
|
|
|
def generate_registration_cpp(kernels: List[KernelConfig], output_file: Path):
|
|
"""Generate registration implementation file"""
|
|
|
|
content = """// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
//
|
|
// AUTO-GENERATED FILE - DO NOT EDIT
|
|
// Generated by generate_dispatcher_registration.py
|
|
|
|
#include "dispatcher_registration.hpp"
|
|
|
|
namespace ck_tile {
|
|
namespace dispatcher {
|
|
namespace generated {
|
|
|
|
// Explicit instantiations to reduce compile time
|
|
// These ensure the templates are instantiated once
|
|
|
|
"""
|
|
|
|
for kernel in kernels:
|
|
kernel_type = f"SelectedKernel_{kernel.name}"
|
|
content += f"template class backends::TileKernelInstance<{kernel_type}>;\n"
|
|
|
|
content += """
|
|
} // namespace generated
|
|
} // namespace dispatcher
|
|
} // namespace ck_tile
|
|
"""
|
|
|
|
output_file.write_text(content)
|
|
print(f"OK Generated registration implementation: {output_file}")
|
|
|
|
|
|
def generate_kernel_wrapper_header(kernel: KernelConfig, output_dir: Path):
|
|
"""Generate a wrapper header that defines SelectedKernel type"""
|
|
|
|
wrapper_file = output_dir / f"{kernel.name}_wrapper.hpp"
|
|
|
|
content = f"""// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
//
|
|
// AUTO-GENERATED FILE - DO NOT EDIT
|
|
// Generated by generate_dispatcher_registration.py
|
|
|
|
#pragma once
|
|
|
|
#include "{kernel.header_file}"
|
|
|
|
namespace ck_tile {{
|
|
namespace dispatcher {{
|
|
namespace generated {{
|
|
|
|
// Type alias for dispatcher registration
|
|
// This allows the registration code to reference the kernel type
|
|
using SelectedKernel_{kernel.name} = /* Actual kernel type from generated header */;
|
|
|
|
}} // namespace generated
|
|
}} // namespace dispatcher
|
|
}} // namespace ck_tile
|
|
"""
|
|
|
|
wrapper_file.write_text(content)
|
|
|
|
|
|
def load_kernel_manifest(manifest_file: Path) -> List[KernelConfig]:
|
|
"""Load kernel configurations from manifest file"""
|
|
|
|
with open(manifest_file, "r") as f:
|
|
data = json.load(f)
|
|
|
|
kernels = []
|
|
for kernel_data in data.get("kernels", []):
|
|
kernel = KernelConfig(
|
|
name=kernel_data["name"],
|
|
header_file=kernel_data["header_file"],
|
|
tile_m=kernel_data["tile_m"],
|
|
tile_n=kernel_data["tile_n"],
|
|
tile_k=kernel_data["tile_k"],
|
|
warp_m=kernel_data.get("warp_m", 2),
|
|
warp_n=kernel_data.get("warp_n", 2),
|
|
warp_k=kernel_data.get("warp_k", 1),
|
|
warp_tile_m=kernel_data.get("warp_tile_m", 32),
|
|
warp_tile_n=kernel_data.get("warp_tile_n", 32),
|
|
warp_tile_k=kernel_data.get("warp_tile_k", 16),
|
|
block_size=kernel_data.get("block_size", 256),
|
|
pipeline=kernel_data.get("pipeline", "compv4"),
|
|
epilogue=kernel_data.get("epilogue", "cshuffle"),
|
|
scheduler=kernel_data.get("scheduler", "intrawave"),
|
|
pad_m=kernel_data.get("pad_m", False),
|
|
pad_n=kernel_data.get("pad_n", False),
|
|
pad_k=kernel_data.get("pad_k", False),
|
|
persistent=kernel_data.get("persistent", False),
|
|
double_buffer=kernel_data.get("double_buffer", True),
|
|
transpose_c=kernel_data.get("transpose_c", False),
|
|
dtype_a=kernel_data.get("dtype_a", "fp16"),
|
|
dtype_b=kernel_data.get("dtype_b", "fp16"),
|
|
dtype_c=kernel_data.get("dtype_c", "fp16"),
|
|
dtype_acc=kernel_data.get("dtype_acc", "fp32"),
|
|
)
|
|
kernels.append(kernel)
|
|
|
|
return kernels
|
|
|
|
|
|
def scan_generated_headers(generated_dir: Path) -> List[KernelConfig]:
|
|
"""Scan generated headers and extract kernel configurations"""
|
|
|
|
import re
|
|
|
|
kernels = []
|
|
|
|
for header_file in generated_dir.glob("**/*.hpp"):
|
|
try:
|
|
content = header_file.read_text()
|
|
|
|
# Extract kernel name
|
|
name_match = re.search(
|
|
r'constexpr const char\* KERNEL_NAME\s*=\s*"([^"]+)"', content
|
|
)
|
|
if not name_match:
|
|
continue
|
|
|
|
kernel_name = name_match.group(1)
|
|
|
|
# Extract tile configuration (support ck_tile::index_t)
|
|
tile_m_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileM\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
tile_n_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileN\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
tile_k_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileK\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
|
|
tile_m = int(tile_m_match.group(1)) if tile_m_match else 256
|
|
tile_n = int(tile_n_match.group(1)) if tile_n_match else 256
|
|
tile_k = int(tile_k_match.group(1)) if tile_k_match else 32
|
|
|
|
# Extract warp configuration
|
|
warp_m_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_M\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
warp_n_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_N\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
warp_k_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_K\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
|
|
warp_m = int(warp_m_match.group(1)) if warp_m_match else 2
|
|
warp_n = int(warp_n_match.group(1)) if warp_n_match else 2
|
|
warp_k = int(warp_k_match.group(1)) if warp_k_match else 1
|
|
|
|
# Extract warp tile configuration
|
|
warp_tile_m_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileM\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
warp_tile_n_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileN\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
warp_tile_k_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileK\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
|
|
warp_tile_m = int(warp_tile_m_match.group(1)) if warp_tile_m_match else 32
|
|
warp_tile_n = int(warp_tile_n_match.group(1)) if warp_tile_n_match else 32
|
|
warp_tile_k = int(warp_tile_k_match.group(1)) if warp_tile_k_match else 16
|
|
|
|
# Extract other parameters (with defaults)
|
|
block_size_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+BlockSize\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
block_size = int(block_size_match.group(1)) if block_size_match else 256
|
|
|
|
# Extract boolean flags
|
|
pad_m = re.search(r"kPadM\s*=\s*true", content) is not None
|
|
pad_n = re.search(r"kPadN\s*=\s*true", content) is not None
|
|
pad_k = re.search(r"kPadK\s*=\s*true", content) is not None
|
|
persistent = (
|
|
re.search(r"UsePersistentKernel\s*=\s*true", content) is not None
|
|
)
|
|
double_buffer = (
|
|
re.search(r"DoubleSmemBuffer\s*=\s*true", content) is not None
|
|
)
|
|
transpose_c = re.search(r"TransposeC\s*=\s*true", content) is not None
|
|
|
|
kernel = KernelConfig(
|
|
name=kernel_name,
|
|
header_file=str(header_file.relative_to(generated_dir.parent)),
|
|
tile_m=tile_m,
|
|
tile_n=tile_n,
|
|
tile_k=tile_k,
|
|
warp_m=warp_m,
|
|
warp_n=warp_n,
|
|
warp_k=warp_k,
|
|
warp_tile_m=warp_tile_m,
|
|
warp_tile_n=warp_tile_n,
|
|
warp_tile_k=warp_tile_k,
|
|
block_size=block_size,
|
|
pipeline="compv4",
|
|
epilogue="cshuffle",
|
|
scheduler="intrawave",
|
|
pad_m=pad_m,
|
|
pad_n=pad_n,
|
|
pad_k=pad_k,
|
|
persistent=persistent,
|
|
double_buffer=double_buffer,
|
|
transpose_c=transpose_c,
|
|
)
|
|
|
|
kernels.append(kernel)
|
|
|
|
except Exception as e:
|
|
print(f"Warning: Failed to parse {header_file}: {e}")
|
|
continue
|
|
|
|
return kernels
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate dispatcher registration code"
|
|
)
|
|
parser.add_argument(
|
|
"--generated-dir",
|
|
type=str,
|
|
required=True,
|
|
help="Directory containing generated kernel headers",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=str,
|
|
required=True,
|
|
help="Output directory for registration code",
|
|
)
|
|
parser.add_argument(
|
|
"--manifest", type=str, help="Optional manifest file with kernel configurations"
|
|
)
|
|
parser.add_argument(
|
|
"--scan",
|
|
action="store_true",
|
|
help="Scan generated headers instead of using manifest",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
generated_dir = Path(args.generated_dir)
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Load kernel configurations
|
|
if args.manifest:
|
|
print(f"Loading kernels from manifest: {args.manifest}")
|
|
kernels = load_kernel_manifest(Path(args.manifest))
|
|
elif args.scan:
|
|
print(f"Scanning generated headers in: {generated_dir}")
|
|
kernels = scan_generated_headers(generated_dir)
|
|
else:
|
|
print("Error: Must specify either --manifest or --scan")
|
|
return 1
|
|
|
|
print(f"Found {len(kernels)} kernels")
|
|
|
|
# Generate registration code
|
|
registration_header = output_dir / "dispatcher_registration.hpp"
|
|
registration_cpp = output_dir / "dispatcher_registration.cpp"
|
|
|
|
generate_registration_header(kernels, registration_header)
|
|
generate_registration_cpp(kernels, registration_cpp)
|
|
|
|
# Generate manifest for Python
|
|
manifest_output = output_dir / "kernels_manifest.json"
|
|
manifest_data = {
|
|
"kernels": [
|
|
{
|
|
"name": k.name,
|
|
"header_file": k.header_file,
|
|
"tile_m": k.tile_m,
|
|
"tile_n": k.tile_n,
|
|
"tile_k": k.tile_k,
|
|
"block_size": k.block_size,
|
|
"persistent": k.persistent,
|
|
}
|
|
for k in kernels
|
|
]
|
|
}
|
|
|
|
with open(manifest_output, "w") as f:
|
|
json.dump(manifest_data, f, indent=2)
|
|
|
|
print(f"OK Generated manifest: {manifest_output}")
|
|
print("\nOK Registration code generation complete!")
|
|
print(f" Total kernels: {len(kernels)}")
|
|
print(" Output files:")
|
|
print(f" - {registration_header}")
|
|
print(f" - {registration_cpp}")
|
|
print(f" - {manifest_output}")
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
exit(main())
|