Files
composable_kernel/dispatcher/codegen/generate_dispatcher_registration.py
Vidyasagar Ananthan 9e049a32a1 Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.
2026-01-22 09:34:33 -08:00

430 lines
14 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Generate dispatcher registration code for CK Tile kernels
This script generates C++ registration code that instantiates TileKernelInstance
templates for each generated kernel, solving the "cannot instantiate from parsed headers" problem.
"""
import json
import argparse
from pathlib import Path
from typing import List
from dataclasses import dataclass
@dataclass
class KernelConfig:
"""Kernel configuration for registration"""
name: str
header_file: str
tile_m: int
tile_n: int
tile_k: int
warp_m: int
warp_n: int
warp_k: int
warp_tile_m: int
warp_tile_n: int
warp_tile_k: int
block_size: int
pipeline: str
epilogue: str
scheduler: str
pad_m: bool
pad_n: bool
pad_k: bool
persistent: bool
double_buffer: bool
transpose_c: bool
dtype_a: str = "fp16"
dtype_b: str = "fp16"
dtype_c: str = "fp16"
dtype_acc: str = "fp32"
layout_a: str = "row"
layout_b: str = "col"
layout_c: str = "row"
def generate_registration_header(kernels: List[KernelConfig], output_file: Path):
"""Generate registration header file"""
content = """// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
//
// AUTO-GENERATED FILE - DO NOT EDIT
// Generated by generate_dispatcher_registration.py
#pragma once
#include "ck_tile/dispatcher/registry.hpp"
#include "ck_tile/dispatcher/backends/tile_backend.hpp"
#include "ck_tile/dispatcher/backends/kernel_registration.hpp"
// Include all generated kernel headers
"""
# Add includes for all kernel headers
for kernel in kernels:
content += f'#include "{kernel.header_file}"\n'
content += """
namespace ck_tile {
namespace dispatcher {
namespace generated {
/// Register all generated kernels with the dispatcher
inline void register_all_kernels(Registry& registry)
{
"""
# Add registration calls for each kernel
for kernel in kernels:
# Extract the SelectedKernel type name from the header file
# Assuming the header defines a type like: using SelectedKernel = ...
kernel_type = f"SelectedKernel_{kernel.name}"
content += f""" // Register {kernel.name}
register_tile_kernel<{kernel_type}>(registry, "{kernel.name}");
"""
content += """}
/// Register all generated kernels with the global registry
inline void register_all_kernels()
{
auto& registry = Registry::instance();
register_all_kernels(registry);
}
} // namespace generated
} // namespace dispatcher
} // namespace ck_tile
"""
output_file.write_text(content)
print(f"✓ Generated registration header: {output_file}")
def generate_registration_cpp(kernels: List[KernelConfig], output_file: Path):
"""Generate registration implementation file"""
content = """// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
//
// AUTO-GENERATED FILE - DO NOT EDIT
// Generated by generate_dispatcher_registration.py
#include "dispatcher_registration.hpp"
namespace ck_tile {
namespace dispatcher {
namespace generated {
// Explicit instantiations to reduce compile time
// These ensure the templates are instantiated once
"""
for kernel in kernels:
kernel_type = f"SelectedKernel_{kernel.name}"
content += f"template class backends::TileKernelInstance<{kernel_type}>;\n"
content += """
} // namespace generated
} // namespace dispatcher
} // namespace ck_tile
"""
output_file.write_text(content)
print(f"✓ Generated registration implementation: {output_file}")
def generate_kernel_wrapper_header(kernel: KernelConfig, output_dir: Path):
"""Generate a wrapper header that defines SelectedKernel type"""
wrapper_file = output_dir / f"{kernel.name}_wrapper.hpp"
content = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
//
// AUTO-GENERATED FILE - DO NOT EDIT
// Generated by generate_dispatcher_registration.py
#pragma once
#include "{kernel.header_file}"
namespace ck_tile {{
namespace dispatcher {{
namespace generated {{
// Type alias for dispatcher registration
// This allows the registration code to reference the kernel type
using SelectedKernel_{kernel.name} = /* Actual kernel type from generated header */;
}} // namespace generated
}} // namespace dispatcher
}} // namespace ck_tile
"""
wrapper_file.write_text(content)
def load_kernel_manifest(manifest_file: Path) -> List[KernelConfig]:
"""Load kernel configurations from manifest file"""
with open(manifest_file, "r") as f:
data = json.load(f)
kernels = []
for kernel_data in data.get("kernels", []):
kernel = KernelConfig(
name=kernel_data["name"],
header_file=kernel_data["header_file"],
tile_m=kernel_data["tile_m"],
tile_n=kernel_data["tile_n"],
tile_k=kernel_data["tile_k"],
warp_m=kernel_data.get("warp_m", 2),
warp_n=kernel_data.get("warp_n", 2),
warp_k=kernel_data.get("warp_k", 1),
warp_tile_m=kernel_data.get("warp_tile_m", 32),
warp_tile_n=kernel_data.get("warp_tile_n", 32),
warp_tile_k=kernel_data.get("warp_tile_k", 16),
block_size=kernel_data.get("block_size", 256),
pipeline=kernel_data.get("pipeline", "compv4"),
epilogue=kernel_data.get("epilogue", "cshuffle"),
scheduler=kernel_data.get("scheduler", "intrawave"),
pad_m=kernel_data.get("pad_m", False),
pad_n=kernel_data.get("pad_n", False),
pad_k=kernel_data.get("pad_k", False),
persistent=kernel_data.get("persistent", False),
double_buffer=kernel_data.get("double_buffer", True),
transpose_c=kernel_data.get("transpose_c", False),
dtype_a=kernel_data.get("dtype_a", "fp16"),
dtype_b=kernel_data.get("dtype_b", "fp16"),
dtype_c=kernel_data.get("dtype_c", "fp16"),
dtype_acc=kernel_data.get("dtype_acc", "fp32"),
)
kernels.append(kernel)
return kernels
def scan_generated_headers(generated_dir: Path) -> List[KernelConfig]:
"""Scan generated headers and extract kernel configurations"""
import re
kernels = []
for header_file in generated_dir.glob("**/*.hpp"):
try:
content = header_file.read_text()
# Extract kernel name
name_match = re.search(
r'constexpr const char\* KERNEL_NAME\s*=\s*"([^"]+)"', content
)
if not name_match:
continue
kernel_name = name_match.group(1)
# Extract tile configuration (support ck_tile::index_t)
tile_m_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileM\s*=\s*(\d+)",
content,
)
tile_n_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileN\s*=\s*(\d+)",
content,
)
tile_k_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileK\s*=\s*(\d+)",
content,
)
tile_m = int(tile_m_match.group(1)) if tile_m_match else 256
tile_n = int(tile_n_match.group(1)) if tile_n_match else 256
tile_k = int(tile_k_match.group(1)) if tile_k_match else 32
# Extract warp configuration
warp_m_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_M\s*=\s*(\d+)",
content,
)
warp_n_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_N\s*=\s*(\d+)",
content,
)
warp_k_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_K\s*=\s*(\d+)",
content,
)
warp_m = int(warp_m_match.group(1)) if warp_m_match else 2
warp_n = int(warp_n_match.group(1)) if warp_n_match else 2
warp_k = int(warp_k_match.group(1)) if warp_k_match else 1
# Extract warp tile configuration
warp_tile_m_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileM\s*=\s*(\d+)",
content,
)
warp_tile_n_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileN\s*=\s*(\d+)",
content,
)
warp_tile_k_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileK\s*=\s*(\d+)",
content,
)
warp_tile_m = int(warp_tile_m_match.group(1)) if warp_tile_m_match else 32
warp_tile_n = int(warp_tile_n_match.group(1)) if warp_tile_n_match else 32
warp_tile_k = int(warp_tile_k_match.group(1)) if warp_tile_k_match else 16
# Extract other parameters (with defaults)
block_size_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+BlockSize\s*=\s*(\d+)",
content,
)
block_size = int(block_size_match.group(1)) if block_size_match else 256
# Extract boolean flags
pad_m = re.search(r"kPadM\s*=\s*true", content) is not None
pad_n = re.search(r"kPadN\s*=\s*true", content) is not None
pad_k = re.search(r"kPadK\s*=\s*true", content) is not None
persistent = (
re.search(r"UsePersistentKernel\s*=\s*true", content) is not None
)
double_buffer = (
re.search(r"DoubleSmemBuffer\s*=\s*true", content) is not None
)
transpose_c = re.search(r"TransposeC\s*=\s*true", content) is not None
kernel = KernelConfig(
name=kernel_name,
header_file=str(header_file.relative_to(generated_dir.parent)),
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
warp_m=warp_m,
warp_n=warp_n,
warp_k=warp_k,
warp_tile_m=warp_tile_m,
warp_tile_n=warp_tile_n,
warp_tile_k=warp_tile_k,
block_size=block_size,
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=pad_m,
pad_n=pad_n,
pad_k=pad_k,
persistent=persistent,
double_buffer=double_buffer,
transpose_c=transpose_c,
)
kernels.append(kernel)
except Exception as e:
print(f"Warning: Failed to parse {header_file}: {e}")
continue
return kernels
def main():
parser = argparse.ArgumentParser(
description="Generate dispatcher registration code"
)
parser.add_argument(
"--generated-dir",
type=str,
required=True,
help="Directory containing generated kernel headers",
)
parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Output directory for registration code",
)
parser.add_argument(
"--manifest", type=str, help="Optional manifest file with kernel configurations"
)
parser.add_argument(
"--scan",
action="store_true",
help="Scan generated headers instead of using manifest",
)
args = parser.parse_args()
generated_dir = Path(args.generated_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Load kernel configurations
if args.manifest:
print(f"Loading kernels from manifest: {args.manifest}")
kernels = load_kernel_manifest(Path(args.manifest))
elif args.scan:
print(f"Scanning generated headers in: {generated_dir}")
kernels = scan_generated_headers(generated_dir)
else:
print("Error: Must specify either --manifest or --scan")
return 1
print(f"Found {len(kernels)} kernels")
# Generate registration code
registration_header = output_dir / "dispatcher_registration.hpp"
registration_cpp = output_dir / "dispatcher_registration.cpp"
generate_registration_header(kernels, registration_header)
generate_registration_cpp(kernels, registration_cpp)
# Generate manifest for Python
manifest_output = output_dir / "kernels_manifest.json"
manifest_data = {
"kernels": [
{
"name": k.name,
"header_file": k.header_file,
"tile_m": k.tile_m,
"tile_n": k.tile_n,
"tile_k": k.tile_k,
"block_size": k.block_size,
"persistent": k.persistent,
}
for k in kernels
]
}
with open(manifest_output, "w") as f:
json.dump(manifest_data, f, indent=2)
print(f"✓ Generated manifest: {manifest_output}")
print("\n✓ Registration code generation complete!")
print(f" Total kernels: {len(kernels)}")
print(" Output files:")
print(f" - {registration_header}")
print(f" - {registration_cpp}")
print(f" - {manifest_output}")
return 0
if __name__ == "__main__":
exit(main())