Files
composable_kernel/dispatcher/scripts/registration_codegen.py
Bartłomiej Kocot 7c2b979de2 [rocm-libraries] ROCm/rocm-libraries#8573 (commit 04c9f1d)
[CK][CK Tile] Drop profiler for experimental builder codegen
 (#8573)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

Switch to dispatcher profiler for ck tile conv.

## Technical Details

- Switch to dispatcher profiler for ck tile conv.
- Drop profiler for experimental codegen
- Minor fixes for bwd data printing
- Minor fixes for 3d conv in dispatcher codegen

## Test Plan

test_grouped_conv*tile

## Test Result

Passed

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-06-19 09:38:44 +00:00

399 lines
16 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
#
# Shared registration code generation utilities for dispatcher kernel scripts.
# Generates chunked registration .cpp files for parallel compilation.
import re
from pathlib import Path
# Number of kernels per registration chunk.
# Each chunk is one .cpp file / compilation unit.
# TODO: What is the optimal value?
CHUNK_SIZE = 25
def parse_depthwise_kernel_metadata(kname):
"""Extract metadata from a depthwise kernel header filename stem.
Depthwise kernels have names like:
grouped_conv_fwd_depthwise_fp16_ngchw_2d_8x8_f3_s1x1_p1x1_nb8_sub2x2_vec2_2
Returns a dict with "is_depthwise": True and the actual depthwise parameters.
"""
ndim = 3 if "_3d_" in kname else 2
dtype = "fp16"
for dt in ["fp16", "bf16", "fp32"]:
if f"_{dt}_" in kname:
dtype = dt
break
# Detect layout from kernel name
layout = "ngchw"
for lay in ["ngchw", "ngcdhw", "nhwgc", "ndhwgc"]:
if f"_{lay}_" in kname:
layout = lay
break
# Extract vec sizes from _vec{in}_{out} at the end
in_vec, out_vec = 1, 1
vec_match = re.search(r"_vec(\d+)_(\d+)$", kname)
if vec_match:
in_vec = int(vec_match.group(1))
out_vec = int(vec_match.group(2))
tile_h, tile_w = 0, 0
tile_match = re.search(r"_(\d+)x(\d+)_f", kname)
if tile_match:
tile_h, tile_w = int(tile_match.group(1)), int(tile_match.group(2))
filt = 0
filt_match = re.search(r"_f(\d+)_", kname)
if filt_match:
filt = int(filt_match.group(1))
str_h, str_w = 0, 0
stride_match = re.search(r"_s(\d+)x(\d+)_", kname)
if stride_match:
str_h, str_w = int(stride_match.group(1)), int(stride_match.group(2))
pad_h, pad_w = 0, 0
pad_match = re.search(r"_p(\d+)x(\d+)_", kname)
if pad_match:
pad_h, pad_w = int(pad_match.group(1)), int(pad_match.group(2))
nbatch = 0
nb_match = re.search(r"_nb(\d+)_", kname)
if nb_match:
nbatch = int(nb_match.group(1))
sub_h, sub_w = 0, 0
sub_match = re.search(r"_sub(\d+)x(\d+)_", kname)
if sub_match:
sub_h, sub_w = int(sub_match.group(1)), int(sub_match.group(2))
return {
"is_depthwise": True,
"ndim": ndim,
"dtype": dtype,
"layout": layout,
"tile_h": tile_h, "tile_w": tile_w,
"filt": filt,
"str_h": str_h, "str_w": str_w,
"pad_h": pad_h, "pad_w": pad_w,
"nbatch": nbatch,
"sub_h": sub_h, "sub_w": sub_w,
"in_vec": in_vec, "out_vec": out_vec,
}
def parse_kernel_metadata(kname):
"""Extract metadata from a kernel header filename stem."""
# Route depthwise kernels to specialized parser
if "_depthwise_" in kname:
return parse_depthwise_kernel_metadata(kname)
ndim = 3 if "_3d_" in kname else 2
dtype = "fp16"
for dt in ["fp16", "bf16", "fp32"]:
if f"_{dt}_" in kname:
dtype = dt
break
triplets = re.findall(r"_(\d+)x(\d+)x(\d+)", kname)
tile_m, tile_n, tile_k = 128, 128, 32
wave_m, wave_n, wave_k = 2, 2, 1
warp_m, warp_n, warp_k = 32, 32, 16
if len(triplets) >= 1:
tile_m, tile_n, tile_k = int(triplets[0][0]), int(triplets[0][1]), int(triplets[0][2])
if len(triplets) >= 2:
wave_m, wave_n, wave_k = int(triplets[1][0]), int(triplets[1][1]), int(triplets[1][2])
if len(triplets) >= 3:
warp_m, warp_n, warp_k = int(triplets[2][0]), int(triplets[2][1]), int(triplets[2][2])
pipeline = "compv3"
for p in ["compv1", "compv2", "compv3", "compv4", "compv5", "compv6", "mem"]:
if f"_{p}_" in kname:
pipeline = p
break
scheduler = "interwave" if "interwave" in kname else "intrawave"
epilogue = "cshuffle"
vec_a, vec_b, vec_c = 4, 8, 8
vec_match = re.search(r"_vec(\d+)_(\d+)_(\d+)", kname)
if vec_match:
vec_a = int(vec_match.group(1))
vec_b = int(vec_match.group(2))
vec_c = int(vec_match.group(3))
block_per_cu = 1
num_wave_groups = 1
num_groups_to_merge = 1
gm_match = re.search(r"_gm(\d+)", kname)
if gm_match:
num_groups_to_merge = int(gm_match.group(1))
specialization = "default"
if "filter1x1_stride1_pad0" in kname:
specialization = "filter1x1_stride1_pad0"
elif "filter1x1_pad0" in kname:
specialization = "filter1x1_pad0"
elif "filter3x3" in kname:
specialization = "filter3x3"
# Large tensor (split image) detection
large_tensor = "_large_tensor" in kname
# Stream-K detection
streamk_enabled = False
streamk_reduction = "none"
streamk_persistent = False
if "_streamk_" in kname:
streamk_enabled = True
if "_streamk_tree" in kname:
streamk_reduction = "tree"
elif "_streamk_linear" in kname:
streamk_reduction = "linear"
streamk_persistent = kname.endswith("_persistent")
return {
"ndim": ndim,
"dtype": dtype,
"layout": "ndhwgc" if "_ndhwgc_" in kname else "nhwgc",
"tile_m": tile_m, "tile_n": tile_n, "tile_k": tile_k,
"wave_m": wave_m, "wave_n": wave_n, "wave_k": wave_k,
"warp_m": warp_m, "warp_n": warp_n, "warp_k": warp_k,
"pipeline": pipeline, "scheduler": scheduler, "epilogue": epilogue,
"vec_a": vec_a, "vec_b": vec_b, "vec_c": vec_c,
"block_per_cu": block_per_cu, "num_wave_groups": num_wave_groups,
"num_groups_to_merge": num_groups_to_merge,
"specialization": specialization,
"large_tensor": large_tensor,
"streamk_enabled": streamk_enabled,
"streamk_reduction": streamk_reduction,
"streamk_persistent": streamk_persistent,
}
def _make_depthwise_conv_key(meta):
"""Generate C++ key assignment lines for a depthwise conv kernel."""
# Map depthwise parameters into the GEMM key fields to ensure each
# depthwise kernel gets a unique registry key.
return [
f' key.dtype_in = "{meta["dtype"]}";',
f' key.dtype_wei = "{meta["dtype"]}";',
f' key.dtype_out = "{meta["dtype"]}";',
f' key.layout = "{meta["layout"]}";',
f" key.ndim_spatial = {meta['ndim']};",
f" // Depthwise params encoded: tile_h/w -> tile_m/n, filt -> tile_k",
f" key.tile_m = {meta['tile_h']};",
f" key.tile_n = {meta['tile_w']};",
f" key.tile_k = {meta['filt']};",
f" // stride_h/w -> wave_m/n, nbatch -> warp_k",
f" key.wave_m = {meta['str_h']};",
f" key.wave_n = {meta['str_w']};",
f" key.wave_k = 0;",
f" // pad_h/w -> warp_m/n",
f" key.warp_m = {meta['pad_h']};",
f" key.warp_n = {meta['pad_w']};",
f" key.warp_k = {meta['nbatch']};",
f' key.pipeline = "depthwise";',
f' key.scheduler = "none";',
f' key.epilogue = "none";',
f" key.vector_size_a = {meta['in_vec']};",
f" key.vector_size_b = {meta['in_vec']};",
f" key.vector_size_c = {meta['out_vec']};",
f" key.block_per_cu = 2;",
f" key.num_wave_groups = 1;",
f" key.num_groups_to_merge = 1;",
f' key.specialization = "sub{meta["sub_h"]}x{meta["sub_w"]}";',
]
def _make_implicit_gemm_conv_key(meta):
"""Generate C++ key assignment lines for an implicit GEMM-based conv kernel."""
return [
f' key.dtype_in = "{meta["dtype"]}";',
f' key.dtype_wei = "{meta["dtype"]}";',
f' key.dtype_out = "{meta["dtype"]}";',
f' key.layout = "{meta.get("layout", "nhwgc")}";',
f" key.ndim_spatial = {meta['ndim']};",
f" key.tile_m = {meta['tile_m']};",
f" key.tile_n = {meta['tile_n']};",
f" key.tile_k = {meta['tile_k']};",
f" key.wave_m = {meta['wave_m']};",
f" key.wave_n = {meta['wave_n']};",
f" key.wave_k = {meta['wave_k']};",
f" key.warp_m = {meta['warp_m']};",
f" key.warp_n = {meta['warp_n']};",
f" key.warp_k = {meta['warp_k']};",
f' key.pipeline = "{meta["pipeline"]}";',
f' key.scheduler = "{meta["scheduler"]}";',
f' key.epilogue = "{meta["epilogue"]}";',
f" key.vector_size_a = {meta['vec_a']};",
f" key.vector_size_b = {meta['vec_b']};",
f" key.vector_size_c = {meta['vec_c']};",
f" key.block_per_cu = {meta['block_per_cu']};",
f" key.num_wave_groups = {meta['num_wave_groups']};",
f" key.num_groups_to_merge = {meta['num_groups_to_merge']};",
f' key.specialization = "{meta["specialization"]}";',
f' key.large_tensor = {str(meta["large_tensor"]).lower()};',
f' key.streamk_enabled = {str(meta["streamk_enabled"]).lower()};',
f' key.streamk_reduction = "{meta["streamk_reduction"]}";',
f' key.streamk_persistent = {str(meta["streamk_persistent"]).lower()};',
]
def make_registration_block(kname, global_idx, op_enum, run_fn_maker, is_supported_fn_maker):
"""Generate C++ registration code for a single kernel."""
meta = parse_kernel_metadata(kname)
ns = f"ns_{kname}"
launcher = f"{ns}::{kname}_Launcher"
ndim = meta["ndim"]
is_depthwise = meta.get("is_depthwise", False)
lines = []
lines.append(f" // Kernel {global_idx}: {kname}")
lines.append(" {")
lines.append(f" GroupedConvKernelKey key;")
lines.append(f" key.op = {op_enum};")
if is_depthwise:
lines.extend(_make_depthwise_conv_key(meta))
else:
lines.extend(_make_implicit_gemm_conv_key(meta))
lines.append(f" key.arch = arch;")
lines.append(f" auto run_fn = {run_fn_maker}<{launcher}, {ndim}>();")
lines.append(f" auto is_supported_fn = {is_supported_fn_maker}<{launcher}, {ndim}>();")
lines.append(f"#ifdef CK_EXPERIMENTAL_BUILDER")
lines.append(f" auto instance_str = backends::get_instance_string<{launcher}>();")
lines.append(
f' auto inst = std::make_shared<GroupedConvKernelInstance>(key, "{kname}", std::move(run_fn), std::move(is_supported_fn), instance_str);'
)
lines.append(f"#else")
lines.append(
f' auto inst = std::make_shared<GroupedConvKernelInstance>(key, "{kname}", std::move(run_fn), std::move(is_supported_fn));'
)
lines.append(f"#endif")
lines.append(f" registry.register_kernel(key, inst);")
lines.append(" }")
return lines
def generate_chunked_registration(headers, output_dir, variant, op_enum,
run_fn_maker, is_supported_fn_maker,
register_fn_name, chunk_size=CHUNK_SIZE):
"""Generate chunked registration .cpp files for parallel compilation.
Args:
headers: list of Path objects for kernel .hpp files
output_dir: directory to write generated files
variant: short name like "bwd_weight", "fwd", "bwd_data"
op_enum: C++ enum value like "GroupedConvOp::BackwardWeight"
run_fn_maker: C++ template function like "backends::make_conv_bwd_weight_run_fn"
is_supported_fn_maker: C++ template function like "backends::make_conv_bwd_weight_is_supported_fn"
register_fn_name: C++ function name like "register_all_grouped_conv_bwd_weight_kernels"
chunk_size: number of kernels per chunk file
Returns:
list of generated .cpp file paths
"""
output_dir = Path(output_dir)
generated_files = []
# Split headers into chunks
chunks = [headers[i:i + chunk_size] for i in range(0, len(headers), chunk_size)]
for chunk_idx, chunk_headers in enumerate(chunks):
chunk_fn = f"register_{variant}_chunk_{chunk_idx}"
chunk_cpp = output_dir / f"register_{variant}_chunk_{chunk_idx}.cpp"
lines = [
"// Auto-generated — do not edit",
f"// Registration chunk {chunk_idx} for {variant} kernels ({len(chunk_headers)} kernels).",
"",
"#pragma clang diagnostic push",
'#pragma clang diagnostic ignored "-Wheader-hygiene"',
'#pragma clang diagnostic ignored "-Wunused-parameter"',
]
# Include only headers for this chunk
for h in chunk_headers:
lines.append(f'#include "{h.name}"')
lines.append("#pragma clang diagnostic pop")
lines.append("")
lines.append('#include "ck_tile/dispatcher/grouped_conv_registry.hpp"')
lines.append('#include "ck_tile/dispatcher/backends/generated_conv_backend.hpp"')
lines.append("")
lines.append("namespace ck_tile {")
lines.append("namespace dispatcher {")
lines.append("")
lines.append(f"void {chunk_fn}(GroupedConvRegistry& registry, const std::string& arch)")
lines.append("{")
# Global index offset for this chunk
global_offset = chunk_idx * chunk_size
for i, h in enumerate(chunk_headers):
reg_block = make_registration_block(
h.stem, global_offset + i, op_enum, run_fn_maker, is_supported_fn_maker
)
lines.extend(reg_block)
lines.append("}")
lines.append("")
lines.append("} // namespace dispatcher")
lines.append("} // namespace ck_tile")
lines.append("")
chunk_cpp.write_text("\n".join(lines))
generated_files.append(chunk_cpp)
# Generate the thin register_all .cpp that calls all chunks
all_cpp = output_dir / "register_all_grouped_conv_kernels.cpp"
lines = [
"// Auto-generated — do not edit",
f"// Top-level registration that calls {len(chunks)} chunk functions.",
"",
'#include "ck_tile/dispatcher/grouped_conv_registry.hpp"',
"",
"namespace ck_tile {",
"namespace dispatcher {",
"",
]
# Forward-declare chunk functions
for chunk_idx in range(len(chunks)):
chunk_fn = f"register_{variant}_chunk_{chunk_idx}"
lines.append(f"void {chunk_fn}(GroupedConvRegistry& registry, const std::string& arch);")
lines.append("")
# Main registration function (with registry parameter)
lines.append(f"void {register_fn_name}(")
lines.append(f" GroupedConvRegistry& registry, const std::string& arch)")
lines.append("{")
for chunk_idx in range(len(chunks)):
chunk_fn = f"register_{variant}_chunk_{chunk_idx}"
lines.append(f" {chunk_fn}(registry, arch);")
lines.append("}")
lines.append("")
# Convenience overload
lines.append(f"void {register_fn_name}(const std::string& arch)")
lines.append("{")
lines.append(" auto& registry = GroupedConvRegistry::instance();")
lines.append(f" {register_fn_name}(registry, arch);")
lines.append("}")
lines.append("")
lines.append("} // namespace dispatcher")
lines.append("} // namespace ck_tile")
lines.append("")
all_cpp.write_text("\n".join(lines))
generated_files.append(all_cpp)
print(f"Generated {len(chunks)} chunk files + register_all ({len(headers)} total registrations)")
return generated_files