mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
[CK_TILE] Integrate CK Tile Dispatcher code generation into CK Tile Profiler (#7284) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation CK Tile is going to be delivered to hipDNN via CK Dispatcher. Currently the CK Tile Profiler using CK Builder for generating the profiled instances from the configuration files that identify the instances that old CK exposes. We need to replace this instance generation with the CK Tile Dispatcher codegen. ## Technical Details The old CK Profiler config files are converted to JSON files that the CK Tile Dispatcher can digest. The conversion script for configurations is stored to source control in case we need to update the JSON configurations later. The dispatcher generates instance libraries per conv direction (fwd, bwd data, and bwd weight) that are linked to the CK Profiler executable. I also implemented codegne for the stream-K and depthwise conv instances. The proposed solution replaces the CK Builder codegen with the CK Tile Dispatcher codegen. There are two new methods that are exposed via the dispatcher backend - `is_supported` - required to enabled the profiler workflow where we check the applicability of the kernel instance before running it. - `get_instance_string` - this mainly for verification. This provide the CK Builder instance string for verifying that the old CK Builder based profiler and the new CK Tile Dispatcher based profiler have the same instances. The rules that limit the generated instances are now collected to a single location under the dispacther. The CK Builder codegen uses these, which ensures that the two codegen pipelines are in sync. The next step (different PR) is to remove the CK Builder codegen pipeline altogether. ## Test Plan Verified that the old CK Builder based profiler and the new CK Tile Dispatcher based profiler have the same instances, that is, the Dispatcher based codgen can generate the same instances as the old CK Builder. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
399 lines
15 KiB
Python
399 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
#
|
|
# Shared registration code generation utilities for dispatcher kernel scripts.
|
|
# Generates chunked registration .cpp files for parallel compilation.
|
|
|
|
import re
|
|
from pathlib import Path
|
|
|
|
|
|
# Number of kernels per registration chunk.
|
|
# Each chunk is one .cpp file / compilation unit.
|
|
# TODO: What is the optimal value?
|
|
CHUNK_SIZE = 25
|
|
|
|
|
|
def parse_depthwise_kernel_metadata(kname):
|
|
"""Extract metadata from a depthwise kernel header filename stem.
|
|
|
|
Depthwise kernels have names like:
|
|
grouped_conv_fwd_depthwise_fp16_ngchw_2d_8x8_f3_s1x1_p1x1_nb8_sub2x2_vec2_2
|
|
|
|
Returns a dict with "is_depthwise": True and the actual depthwise parameters.
|
|
"""
|
|
ndim = 3 if "_3d_" in kname else 2
|
|
|
|
dtype = "fp16"
|
|
for dt in ["fp16", "bf16", "fp32"]:
|
|
if f"_{dt}_" in kname:
|
|
dtype = dt
|
|
break
|
|
|
|
# Detect layout from kernel name
|
|
layout = "ngchw"
|
|
for lay in ["ngchw", "ngcdhw", "nhwgc", "ndhwgc"]:
|
|
if f"_{lay}_" in kname:
|
|
layout = lay
|
|
break
|
|
|
|
# Extract vec sizes from _vec{in}_{out} at the end
|
|
in_vec, out_vec = 1, 1
|
|
vec_match = re.search(r"_vec(\d+)_(\d+)$", kname)
|
|
if vec_match:
|
|
in_vec = int(vec_match.group(1))
|
|
out_vec = int(vec_match.group(2))
|
|
|
|
tile_h, tile_w = 0, 0
|
|
tile_match = re.search(r"_(\d+)x(\d+)_f", kname)
|
|
if tile_match:
|
|
tile_h, tile_w = int(tile_match.group(1)), int(tile_match.group(2))
|
|
|
|
filt = 0
|
|
filt_match = re.search(r"_f(\d+)_", kname)
|
|
if filt_match:
|
|
filt = int(filt_match.group(1))
|
|
|
|
str_h, str_w = 0, 0
|
|
stride_match = re.search(r"_s(\d+)x(\d+)_", kname)
|
|
if stride_match:
|
|
str_h, str_w = int(stride_match.group(1)), int(stride_match.group(2))
|
|
|
|
pad_h, pad_w = 0, 0
|
|
pad_match = re.search(r"_p(\d+)x(\d+)_", kname)
|
|
if pad_match:
|
|
pad_h, pad_w = int(pad_match.group(1)), int(pad_match.group(2))
|
|
|
|
nbatch = 0
|
|
nb_match = re.search(r"_nb(\d+)_", kname)
|
|
if nb_match:
|
|
nbatch = int(nb_match.group(1))
|
|
|
|
sub_h, sub_w = 0, 0
|
|
sub_match = re.search(r"_sub(\d+)x(\d+)_", kname)
|
|
if sub_match:
|
|
sub_h, sub_w = int(sub_match.group(1)), int(sub_match.group(2))
|
|
|
|
return {
|
|
"is_depthwise": True,
|
|
"ndim": ndim,
|
|
"dtype": dtype,
|
|
"layout": layout,
|
|
"tile_h": tile_h, "tile_w": tile_w,
|
|
"filt": filt,
|
|
"str_h": str_h, "str_w": str_w,
|
|
"pad_h": pad_h, "pad_w": pad_w,
|
|
"nbatch": nbatch,
|
|
"sub_h": sub_h, "sub_w": sub_w,
|
|
"in_vec": in_vec, "out_vec": out_vec,
|
|
}
|
|
|
|
|
|
def parse_kernel_metadata(kname):
|
|
"""Extract metadata from a kernel header filename stem."""
|
|
# Route depthwise kernels to specialized parser
|
|
if "_depthwise_" in kname:
|
|
return parse_depthwise_kernel_metadata(kname)
|
|
|
|
ndim = 3 if "_3d_" in kname else 2
|
|
|
|
dtype = "fp16"
|
|
for dt in ["fp16", "bf16", "fp32"]:
|
|
if f"_{dt}_" in kname:
|
|
dtype = dt
|
|
break
|
|
|
|
triplets = re.findall(r"_(\d+)x(\d+)x(\d+)", kname)
|
|
tile_m, tile_n, tile_k = 128, 128, 32
|
|
wave_m, wave_n, wave_k = 2, 2, 1
|
|
warp_m, warp_n, warp_k = 32, 32, 16
|
|
if len(triplets) >= 1:
|
|
tile_m, tile_n, tile_k = int(triplets[0][0]), int(triplets[0][1]), int(triplets[0][2])
|
|
if len(triplets) >= 2:
|
|
wave_m, wave_n, wave_k = int(triplets[1][0]), int(triplets[1][1]), int(triplets[1][2])
|
|
if len(triplets) >= 3:
|
|
warp_m, warp_n, warp_k = int(triplets[2][0]), int(triplets[2][1]), int(triplets[2][2])
|
|
|
|
pipeline = "compv3"
|
|
for p in ["compv1", "compv2", "compv3", "compv4", "compv5", "compv6", "mem"]:
|
|
if f"_{p}_" in kname:
|
|
pipeline = p
|
|
break
|
|
scheduler = "interwave" if "interwave" in kname else "intrawave"
|
|
epilogue = "cshuffle"
|
|
|
|
vec_a, vec_b, vec_c = 4, 8, 8
|
|
vec_match = re.search(r"_vec(\d+)_(\d+)_(\d+)", kname)
|
|
if vec_match:
|
|
vec_a = int(vec_match.group(1))
|
|
vec_b = int(vec_match.group(2))
|
|
vec_c = int(vec_match.group(3))
|
|
|
|
block_per_cu = 1
|
|
num_wave_groups = 1
|
|
num_groups_to_merge = 1
|
|
gm_match = re.search(r"_gm(\d+)", kname)
|
|
if gm_match:
|
|
num_groups_to_merge = int(gm_match.group(1))
|
|
|
|
specialization = "default"
|
|
if "filter1x1_stride1_pad0" in kname:
|
|
specialization = "filter1x1_stride1_pad0"
|
|
elif "filter1x1_pad0" in kname:
|
|
specialization = "filter1x1_pad0"
|
|
elif "filter3x3" in kname:
|
|
specialization = "filter3x3"
|
|
|
|
# Large tensor (split image) detection
|
|
large_tensor = "_large_tensor" in kname
|
|
|
|
# Stream-K detection
|
|
streamk_enabled = False
|
|
streamk_reduction = "none"
|
|
streamk_persistent = False
|
|
if "_streamk_" in kname:
|
|
streamk_enabled = True
|
|
if "_streamk_tree" in kname:
|
|
streamk_reduction = "tree"
|
|
elif "_streamk_linear" in kname:
|
|
streamk_reduction = "linear"
|
|
streamk_persistent = kname.endswith("_persistent")
|
|
|
|
return {
|
|
"ndim": ndim,
|
|
"dtype": dtype,
|
|
"layout": "nhwgc",
|
|
"tile_m": tile_m, "tile_n": tile_n, "tile_k": tile_k,
|
|
"wave_m": wave_m, "wave_n": wave_n, "wave_k": wave_k,
|
|
"warp_m": warp_m, "warp_n": warp_n, "warp_k": warp_k,
|
|
"pipeline": pipeline, "scheduler": scheduler, "epilogue": epilogue,
|
|
"vec_a": vec_a, "vec_b": vec_b, "vec_c": vec_c,
|
|
"block_per_cu": block_per_cu, "num_wave_groups": num_wave_groups,
|
|
"num_groups_to_merge": num_groups_to_merge,
|
|
"specialization": specialization,
|
|
"large_tensor": large_tensor,
|
|
"streamk_enabled": streamk_enabled,
|
|
"streamk_reduction": streamk_reduction,
|
|
"streamk_persistent": streamk_persistent,
|
|
}
|
|
|
|
|
|
def _make_depthwise_conv_key(meta):
|
|
"""Generate C++ key assignment lines for a depthwise conv kernel."""
|
|
# Map depthwise parameters into the GEMM key fields to ensure each
|
|
# depthwise kernel gets a unique registry key.
|
|
return [
|
|
f' key.dtype_in = "{meta["dtype"]}";',
|
|
f' key.dtype_wei = "{meta["dtype"]}";',
|
|
f' key.dtype_out = "{meta["dtype"]}";',
|
|
f' key.layout = "{meta["layout"]}";',
|
|
f" key.ndim_spatial = {meta['ndim']};",
|
|
f" // Depthwise params encoded: tile_h/w -> tile_m/n, filt -> tile_k",
|
|
f" key.tile_m = {meta['tile_h']};",
|
|
f" key.tile_n = {meta['tile_w']};",
|
|
f" key.tile_k = {meta['filt']};",
|
|
f" // stride_h/w -> wave_m/n, nbatch -> warp_k",
|
|
f" key.wave_m = {meta['str_h']};",
|
|
f" key.wave_n = {meta['str_w']};",
|
|
f" key.wave_k = 0;",
|
|
f" // pad_h/w -> warp_m/n",
|
|
f" key.warp_m = {meta['pad_h']};",
|
|
f" key.warp_n = {meta['pad_w']};",
|
|
f" key.warp_k = {meta['nbatch']};",
|
|
f' key.pipeline = "depthwise";',
|
|
f' key.scheduler = "none";',
|
|
f' key.epilogue = "none";',
|
|
f" key.vector_size_a = {meta['in_vec']};",
|
|
f" key.vector_size_b = {meta['in_vec']};",
|
|
f" key.vector_size_c = {meta['out_vec']};",
|
|
f" key.block_per_cu = 2;",
|
|
f" key.num_wave_groups = 1;",
|
|
f" key.num_groups_to_merge = 1;",
|
|
f' key.specialization = "sub{meta["sub_h"]}x{meta["sub_w"]}";',
|
|
]
|
|
|
|
|
|
def _make_implicit_gemm_conv_key(meta):
|
|
"""Generate C++ key assignment lines for an implicit GEMM-based conv kernel."""
|
|
return [
|
|
f' key.dtype_in = "{meta["dtype"]}";',
|
|
f' key.dtype_wei = "{meta["dtype"]}";',
|
|
f' key.dtype_out = "{meta["dtype"]}";',
|
|
f' key.layout = "{meta.get("layout", "nhwgc")}";',
|
|
f" key.ndim_spatial = {meta['ndim']};",
|
|
f" key.tile_m = {meta['tile_m']};",
|
|
f" key.tile_n = {meta['tile_n']};",
|
|
f" key.tile_k = {meta['tile_k']};",
|
|
f" key.wave_m = {meta['wave_m']};",
|
|
f" key.wave_n = {meta['wave_n']};",
|
|
f" key.wave_k = {meta['wave_k']};",
|
|
f" key.warp_m = {meta['warp_m']};",
|
|
f" key.warp_n = {meta['warp_n']};",
|
|
f" key.warp_k = {meta['warp_k']};",
|
|
f' key.pipeline = "{meta["pipeline"]}";',
|
|
f' key.scheduler = "{meta["scheduler"]}";',
|
|
f' key.epilogue = "{meta["epilogue"]}";',
|
|
f" key.vector_size_a = {meta['vec_a']};",
|
|
f" key.vector_size_b = {meta['vec_b']};",
|
|
f" key.vector_size_c = {meta['vec_c']};",
|
|
f" key.block_per_cu = {meta['block_per_cu']};",
|
|
f" key.num_wave_groups = {meta['num_wave_groups']};",
|
|
f" key.num_groups_to_merge = {meta['num_groups_to_merge']};",
|
|
f' key.specialization = "{meta["specialization"]}";',
|
|
f' key.large_tensor = {str(meta["large_tensor"]).lower()};',
|
|
f' key.streamk_enabled = {str(meta["streamk_enabled"]).lower()};',
|
|
f' key.streamk_reduction = "{meta["streamk_reduction"]}";',
|
|
f' key.streamk_persistent = {str(meta["streamk_persistent"]).lower()};',
|
|
]
|
|
|
|
|
|
def make_registration_block(kname, global_idx, op_enum, run_fn_maker, is_supported_fn_maker):
|
|
"""Generate C++ registration code for a single kernel."""
|
|
meta = parse_kernel_metadata(kname)
|
|
ns = f"ns_{kname}"
|
|
launcher = f"{ns}::{kname}_Launcher"
|
|
ndim = meta["ndim"]
|
|
is_depthwise = meta.get("is_depthwise", False)
|
|
|
|
lines = []
|
|
lines.append(f" // Kernel {global_idx}: {kname}")
|
|
lines.append(" {")
|
|
lines.append(f" GroupedConvKernelKey key;")
|
|
lines.append(f" key.op = {op_enum};")
|
|
|
|
if is_depthwise:
|
|
lines.extend(_make_depthwise_conv_key(meta))
|
|
else:
|
|
lines.extend(_make_implicit_gemm_conv_key(meta))
|
|
|
|
lines.append(f" key.arch = arch;")
|
|
lines.append(f" auto run_fn = {run_fn_maker}<{launcher}, {ndim}>();")
|
|
lines.append(f" auto is_supported_fn = {is_supported_fn_maker}<{launcher}, {ndim}>();")
|
|
lines.append(f"#ifdef CK_EXPERIMENTAL_BUILDER")
|
|
lines.append(f" auto instance_str = backends::get_instance_string<{launcher}>();")
|
|
lines.append(
|
|
f' auto inst = std::make_shared<GroupedConvKernelInstance>(key, "{kname}", std::move(run_fn), std::move(is_supported_fn), instance_str);'
|
|
)
|
|
lines.append(f"#else")
|
|
lines.append(
|
|
f' auto inst = std::make_shared<GroupedConvKernelInstance>(key, "{kname}", std::move(run_fn), std::move(is_supported_fn));'
|
|
)
|
|
lines.append(f"#endif")
|
|
lines.append(f" registry.register_kernel(key, inst);")
|
|
lines.append(" }")
|
|
return lines
|
|
|
|
|
|
def generate_chunked_registration(headers, output_dir, variant, op_enum,
|
|
run_fn_maker, is_supported_fn_maker,
|
|
register_fn_name, chunk_size=CHUNK_SIZE):
|
|
"""Generate chunked registration .cpp files for parallel compilation.
|
|
|
|
Args:
|
|
headers: list of Path objects for kernel .hpp files
|
|
output_dir: directory to write generated files
|
|
variant: short name like "bwd_weight", "fwd", "bwd_data"
|
|
op_enum: C++ enum value like "GroupedConvOp::BackwardWeight"
|
|
run_fn_maker: C++ template function like "backends::make_conv_bwd_weight_run_fn"
|
|
is_supported_fn_maker: C++ template function like "backends::make_conv_bwd_weight_is_supported_fn"
|
|
register_fn_name: C++ function name like "register_all_grouped_conv_bwd_weight_kernels"
|
|
chunk_size: number of kernels per chunk file
|
|
|
|
Returns:
|
|
list of generated .cpp file paths
|
|
"""
|
|
output_dir = Path(output_dir)
|
|
generated_files = []
|
|
|
|
# Split headers into chunks
|
|
chunks = [headers[i:i + chunk_size] for i in range(0, len(headers), chunk_size)]
|
|
|
|
for chunk_idx, chunk_headers in enumerate(chunks):
|
|
chunk_fn = f"register_{variant}_chunk_{chunk_idx}"
|
|
chunk_cpp = output_dir / f"register_{variant}_chunk_{chunk_idx}.cpp"
|
|
|
|
lines = [
|
|
"// Auto-generated — do not edit",
|
|
f"// Registration chunk {chunk_idx} for {variant} kernels ({len(chunk_headers)} kernels).",
|
|
"",
|
|
"#pragma clang diagnostic push",
|
|
'#pragma clang diagnostic ignored "-Wheader-hygiene"',
|
|
'#pragma clang diagnostic ignored "-Wunused-parameter"',
|
|
]
|
|
# Include only headers for this chunk
|
|
for h in chunk_headers:
|
|
lines.append(f'#include "{h.name}"')
|
|
lines.append("#pragma clang diagnostic pop")
|
|
lines.append("")
|
|
lines.append('#include "ck_tile/dispatcher/grouped_conv_registry.hpp"')
|
|
lines.append('#include "ck_tile/dispatcher/backends/generated_conv_backend.hpp"')
|
|
lines.append("")
|
|
lines.append("namespace ck_tile {")
|
|
lines.append("namespace dispatcher {")
|
|
lines.append("")
|
|
lines.append(f"void {chunk_fn}(GroupedConvRegistry& registry, const std::string& arch)")
|
|
lines.append("{")
|
|
|
|
# Global index offset for this chunk
|
|
global_offset = chunk_idx * chunk_size
|
|
for i, h in enumerate(chunk_headers):
|
|
reg_block = make_registration_block(
|
|
h.stem, global_offset + i, op_enum, run_fn_maker, is_supported_fn_maker
|
|
)
|
|
lines.extend(reg_block)
|
|
|
|
lines.append("}")
|
|
lines.append("")
|
|
lines.append("} // namespace dispatcher")
|
|
lines.append("} // namespace ck_tile")
|
|
lines.append("")
|
|
|
|
chunk_cpp.write_text("\n".join(lines))
|
|
generated_files.append(chunk_cpp)
|
|
|
|
# Generate the thin register_all .cpp that calls all chunks
|
|
all_cpp = output_dir / "register_all_grouped_conv_kernels.cpp"
|
|
lines = [
|
|
"// Auto-generated — do not edit",
|
|
f"// Top-level registration that calls {len(chunks)} chunk functions.",
|
|
"",
|
|
'#include "ck_tile/dispatcher/grouped_conv_registry.hpp"',
|
|
"",
|
|
"namespace ck_tile {",
|
|
"namespace dispatcher {",
|
|
"",
|
|
]
|
|
# Forward-declare chunk functions
|
|
for chunk_idx in range(len(chunks)):
|
|
chunk_fn = f"register_{variant}_chunk_{chunk_idx}"
|
|
lines.append(f"void {chunk_fn}(GroupedConvRegistry& registry, const std::string& arch);")
|
|
lines.append("")
|
|
|
|
# Main registration function (with registry parameter)
|
|
lines.append(f"void {register_fn_name}(")
|
|
lines.append(f" GroupedConvRegistry& registry, const std::string& arch)")
|
|
lines.append("{")
|
|
for chunk_idx in range(len(chunks)):
|
|
chunk_fn = f"register_{variant}_chunk_{chunk_idx}"
|
|
lines.append(f" {chunk_fn}(registry, arch);")
|
|
lines.append("}")
|
|
lines.append("")
|
|
|
|
# Convenience overload
|
|
lines.append(f"void {register_fn_name}(const std::string& arch)")
|
|
lines.append("{")
|
|
lines.append(" auto& registry = GroupedConvRegistry::instance();")
|
|
lines.append(f" {register_fn_name}(registry, arch);")
|
|
lines.append("}")
|
|
lines.append("")
|
|
lines.append("} // namespace dispatcher")
|
|
lines.append("} // namespace ck_tile")
|
|
lines.append("")
|
|
|
|
all_cpp.write_text("\n".join(lines))
|
|
generated_files.append(all_cpp)
|
|
|
|
print(f"Generated {len(chunks)} chunk files + register_all ({len(headers)} total registrations)")
|
|
return generated_files
|