mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
* WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable.
430 lines
14 KiB
Python
430 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Generate dispatcher registration code for CK Tile kernels
|
|
|
|
This script generates C++ registration code that instantiates TileKernelInstance
|
|
templates for each generated kernel, solving the "cannot instantiate from parsed headers" problem.
|
|
"""
|
|
|
|
import json
|
|
import argparse
|
|
from pathlib import Path
|
|
from typing import List
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class KernelConfig:
|
|
"""Kernel configuration for registration"""
|
|
|
|
name: str
|
|
header_file: str
|
|
tile_m: int
|
|
tile_n: int
|
|
tile_k: int
|
|
warp_m: int
|
|
warp_n: int
|
|
warp_k: int
|
|
warp_tile_m: int
|
|
warp_tile_n: int
|
|
warp_tile_k: int
|
|
block_size: int
|
|
pipeline: str
|
|
epilogue: str
|
|
scheduler: str
|
|
pad_m: bool
|
|
pad_n: bool
|
|
pad_k: bool
|
|
persistent: bool
|
|
double_buffer: bool
|
|
transpose_c: bool
|
|
dtype_a: str = "fp16"
|
|
dtype_b: str = "fp16"
|
|
dtype_c: str = "fp16"
|
|
dtype_acc: str = "fp32"
|
|
layout_a: str = "row"
|
|
layout_b: str = "col"
|
|
layout_c: str = "row"
|
|
|
|
|
|
def generate_registration_header(kernels: List[KernelConfig], output_file: Path):
|
|
"""Generate registration header file"""
|
|
|
|
content = """// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
//
|
|
// AUTO-GENERATED FILE - DO NOT EDIT
|
|
// Generated by generate_dispatcher_registration.py
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/dispatcher/registry.hpp"
|
|
#include "ck_tile/dispatcher/backends/tile_backend.hpp"
|
|
#include "ck_tile/dispatcher/backends/kernel_registration.hpp"
|
|
|
|
// Include all generated kernel headers
|
|
"""
|
|
|
|
# Add includes for all kernel headers
|
|
for kernel in kernels:
|
|
content += f'#include "{kernel.header_file}"\n'
|
|
|
|
content += """
|
|
|
|
namespace ck_tile {
|
|
namespace dispatcher {
|
|
namespace generated {
|
|
|
|
/// Register all generated kernels with the dispatcher
|
|
inline void register_all_kernels(Registry& registry)
|
|
{
|
|
"""
|
|
|
|
# Add registration calls for each kernel
|
|
for kernel in kernels:
|
|
# Extract the SelectedKernel type name from the header file
|
|
# Assuming the header defines a type like: using SelectedKernel = ...
|
|
kernel_type = f"SelectedKernel_{kernel.name}"
|
|
|
|
content += f""" // Register {kernel.name}
|
|
register_tile_kernel<{kernel_type}>(registry, "{kernel.name}");
|
|
"""
|
|
|
|
content += """}
|
|
|
|
/// Register all generated kernels with the global registry
|
|
inline void register_all_kernels()
|
|
{
|
|
auto& registry = Registry::instance();
|
|
register_all_kernels(registry);
|
|
}
|
|
|
|
} // namespace generated
|
|
} // namespace dispatcher
|
|
} // namespace ck_tile
|
|
"""
|
|
|
|
output_file.write_text(content)
|
|
print(f"✓ Generated registration header: {output_file}")
|
|
|
|
|
|
def generate_registration_cpp(kernels: List[KernelConfig], output_file: Path):
|
|
"""Generate registration implementation file"""
|
|
|
|
content = """// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
//
|
|
// AUTO-GENERATED FILE - DO NOT EDIT
|
|
// Generated by generate_dispatcher_registration.py
|
|
|
|
#include "dispatcher_registration.hpp"
|
|
|
|
namespace ck_tile {
|
|
namespace dispatcher {
|
|
namespace generated {
|
|
|
|
// Explicit instantiations to reduce compile time
|
|
// These ensure the templates are instantiated once
|
|
|
|
"""
|
|
|
|
for kernel in kernels:
|
|
kernel_type = f"SelectedKernel_{kernel.name}"
|
|
content += f"template class backends::TileKernelInstance<{kernel_type}>;\n"
|
|
|
|
content += """
|
|
} // namespace generated
|
|
} // namespace dispatcher
|
|
} // namespace ck_tile
|
|
"""
|
|
|
|
output_file.write_text(content)
|
|
print(f"✓ Generated registration implementation: {output_file}")
|
|
|
|
|
|
def generate_kernel_wrapper_header(kernel: KernelConfig, output_dir: Path):
|
|
"""Generate a wrapper header that defines SelectedKernel type"""
|
|
|
|
wrapper_file = output_dir / f"{kernel.name}_wrapper.hpp"
|
|
|
|
content = f"""// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
//
|
|
// AUTO-GENERATED FILE - DO NOT EDIT
|
|
// Generated by generate_dispatcher_registration.py
|
|
|
|
#pragma once
|
|
|
|
#include "{kernel.header_file}"
|
|
|
|
namespace ck_tile {{
|
|
namespace dispatcher {{
|
|
namespace generated {{
|
|
|
|
// Type alias for dispatcher registration
|
|
// This allows the registration code to reference the kernel type
|
|
using SelectedKernel_{kernel.name} = /* Actual kernel type from generated header */;
|
|
|
|
}} // namespace generated
|
|
}} // namespace dispatcher
|
|
}} // namespace ck_tile
|
|
"""
|
|
|
|
wrapper_file.write_text(content)
|
|
|
|
|
|
def load_kernel_manifest(manifest_file: Path) -> List[KernelConfig]:
|
|
"""Load kernel configurations from manifest file"""
|
|
|
|
with open(manifest_file, "r") as f:
|
|
data = json.load(f)
|
|
|
|
kernels = []
|
|
for kernel_data in data.get("kernels", []):
|
|
kernel = KernelConfig(
|
|
name=kernel_data["name"],
|
|
header_file=kernel_data["header_file"],
|
|
tile_m=kernel_data["tile_m"],
|
|
tile_n=kernel_data["tile_n"],
|
|
tile_k=kernel_data["tile_k"],
|
|
warp_m=kernel_data.get("warp_m", 2),
|
|
warp_n=kernel_data.get("warp_n", 2),
|
|
warp_k=kernel_data.get("warp_k", 1),
|
|
warp_tile_m=kernel_data.get("warp_tile_m", 32),
|
|
warp_tile_n=kernel_data.get("warp_tile_n", 32),
|
|
warp_tile_k=kernel_data.get("warp_tile_k", 16),
|
|
block_size=kernel_data.get("block_size", 256),
|
|
pipeline=kernel_data.get("pipeline", "compv4"),
|
|
epilogue=kernel_data.get("epilogue", "cshuffle"),
|
|
scheduler=kernel_data.get("scheduler", "intrawave"),
|
|
pad_m=kernel_data.get("pad_m", False),
|
|
pad_n=kernel_data.get("pad_n", False),
|
|
pad_k=kernel_data.get("pad_k", False),
|
|
persistent=kernel_data.get("persistent", False),
|
|
double_buffer=kernel_data.get("double_buffer", True),
|
|
transpose_c=kernel_data.get("transpose_c", False),
|
|
dtype_a=kernel_data.get("dtype_a", "fp16"),
|
|
dtype_b=kernel_data.get("dtype_b", "fp16"),
|
|
dtype_c=kernel_data.get("dtype_c", "fp16"),
|
|
dtype_acc=kernel_data.get("dtype_acc", "fp32"),
|
|
)
|
|
kernels.append(kernel)
|
|
|
|
return kernels
|
|
|
|
|
|
def scan_generated_headers(generated_dir: Path) -> List[KernelConfig]:
|
|
"""Scan generated headers and extract kernel configurations"""
|
|
|
|
import re
|
|
|
|
kernels = []
|
|
|
|
for header_file in generated_dir.glob("**/*.hpp"):
|
|
try:
|
|
content = header_file.read_text()
|
|
|
|
# Extract kernel name
|
|
name_match = re.search(
|
|
r'constexpr const char\* KERNEL_NAME\s*=\s*"([^"]+)"', content
|
|
)
|
|
if not name_match:
|
|
continue
|
|
|
|
kernel_name = name_match.group(1)
|
|
|
|
# Extract tile configuration (support ck_tile::index_t)
|
|
tile_m_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileM\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
tile_n_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileN\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
tile_k_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileK\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
|
|
tile_m = int(tile_m_match.group(1)) if tile_m_match else 256
|
|
tile_n = int(tile_n_match.group(1)) if tile_n_match else 256
|
|
tile_k = int(tile_k_match.group(1)) if tile_k_match else 32
|
|
|
|
# Extract warp configuration
|
|
warp_m_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_M\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
warp_n_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_N\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
warp_k_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_K\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
|
|
warp_m = int(warp_m_match.group(1)) if warp_m_match else 2
|
|
warp_n = int(warp_n_match.group(1)) if warp_n_match else 2
|
|
warp_k = int(warp_k_match.group(1)) if warp_k_match else 1
|
|
|
|
# Extract warp tile configuration
|
|
warp_tile_m_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileM\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
warp_tile_n_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileN\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
warp_tile_k_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileK\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
|
|
warp_tile_m = int(warp_tile_m_match.group(1)) if warp_tile_m_match else 32
|
|
warp_tile_n = int(warp_tile_n_match.group(1)) if warp_tile_n_match else 32
|
|
warp_tile_k = int(warp_tile_k_match.group(1)) if warp_tile_k_match else 16
|
|
|
|
# Extract other parameters (with defaults)
|
|
block_size_match = re.search(
|
|
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+BlockSize\s*=\s*(\d+)",
|
|
content,
|
|
)
|
|
block_size = int(block_size_match.group(1)) if block_size_match else 256
|
|
|
|
# Extract boolean flags
|
|
pad_m = re.search(r"kPadM\s*=\s*true", content) is not None
|
|
pad_n = re.search(r"kPadN\s*=\s*true", content) is not None
|
|
pad_k = re.search(r"kPadK\s*=\s*true", content) is not None
|
|
persistent = (
|
|
re.search(r"UsePersistentKernel\s*=\s*true", content) is not None
|
|
)
|
|
double_buffer = (
|
|
re.search(r"DoubleSmemBuffer\s*=\s*true", content) is not None
|
|
)
|
|
transpose_c = re.search(r"TransposeC\s*=\s*true", content) is not None
|
|
|
|
kernel = KernelConfig(
|
|
name=kernel_name,
|
|
header_file=str(header_file.relative_to(generated_dir.parent)),
|
|
tile_m=tile_m,
|
|
tile_n=tile_n,
|
|
tile_k=tile_k,
|
|
warp_m=warp_m,
|
|
warp_n=warp_n,
|
|
warp_k=warp_k,
|
|
warp_tile_m=warp_tile_m,
|
|
warp_tile_n=warp_tile_n,
|
|
warp_tile_k=warp_tile_k,
|
|
block_size=block_size,
|
|
pipeline="compv4",
|
|
epilogue="cshuffle",
|
|
scheduler="intrawave",
|
|
pad_m=pad_m,
|
|
pad_n=pad_n,
|
|
pad_k=pad_k,
|
|
persistent=persistent,
|
|
double_buffer=double_buffer,
|
|
transpose_c=transpose_c,
|
|
)
|
|
|
|
kernels.append(kernel)
|
|
|
|
except Exception as e:
|
|
print(f"Warning: Failed to parse {header_file}: {e}")
|
|
continue
|
|
|
|
return kernels
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate dispatcher registration code"
|
|
)
|
|
parser.add_argument(
|
|
"--generated-dir",
|
|
type=str,
|
|
required=True,
|
|
help="Directory containing generated kernel headers",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=str,
|
|
required=True,
|
|
help="Output directory for registration code",
|
|
)
|
|
parser.add_argument(
|
|
"--manifest", type=str, help="Optional manifest file with kernel configurations"
|
|
)
|
|
parser.add_argument(
|
|
"--scan",
|
|
action="store_true",
|
|
help="Scan generated headers instead of using manifest",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
generated_dir = Path(args.generated_dir)
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Load kernel configurations
|
|
if args.manifest:
|
|
print(f"Loading kernels from manifest: {args.manifest}")
|
|
kernels = load_kernel_manifest(Path(args.manifest))
|
|
elif args.scan:
|
|
print(f"Scanning generated headers in: {generated_dir}")
|
|
kernels = scan_generated_headers(generated_dir)
|
|
else:
|
|
print("Error: Must specify either --manifest or --scan")
|
|
return 1
|
|
|
|
print(f"Found {len(kernels)} kernels")
|
|
|
|
# Generate registration code
|
|
registration_header = output_dir / "dispatcher_registration.hpp"
|
|
registration_cpp = output_dir / "dispatcher_registration.cpp"
|
|
|
|
generate_registration_header(kernels, registration_header)
|
|
generate_registration_cpp(kernels, registration_cpp)
|
|
|
|
# Generate manifest for Python
|
|
manifest_output = output_dir / "kernels_manifest.json"
|
|
manifest_data = {
|
|
"kernels": [
|
|
{
|
|
"name": k.name,
|
|
"header_file": k.header_file,
|
|
"tile_m": k.tile_m,
|
|
"tile_n": k.tile_n,
|
|
"tile_k": k.tile_k,
|
|
"block_size": k.block_size,
|
|
"persistent": k.persistent,
|
|
}
|
|
for k in kernels
|
|
]
|
|
}
|
|
|
|
with open(manifest_output, "w") as f:
|
|
json.dump(manifest_data, f, indent=2)
|
|
|
|
print(f"✓ Generated manifest: {manifest_output}")
|
|
print("\n✓ Registration code generation complete!")
|
|
print(f" Total kernels: {len(kernels)}")
|
|
print(" Output files:")
|
|
print(f" - {registration_header}")
|
|
print(f" - {registration_cpp}")
|
|
print(f" - {manifest_output}")
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
exit(main())
|