mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
[rocm-libraries] ROCm/rocm-libraries#5168 (commit 8b5afcb)
[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher ## Motivation This PR adds CK Tile group convolution (forward, backward-data, backward-weight) support to the kernel dispatcher, matching and unifying with the existing dispatcher GEMM infrastructure in architecture and usability. The dispatcher provides a unified kernel dispatch system with both C++ and Python frontends, and until now only supported GEMM operations. This PR enables framework integrators to use the same declarative kernel workflow for convolutions as they do for GEMM: declare kernels, build a registry JIT, select kernels within the registry at runtime, and dispatch to GPU. Future PRs will include runtime kernel selection heuristics for autotuning of kernel parameters based on (problem, hardware arch). ## Technical Details Grouped convolution support has been added to the CK Tile Dispatcher with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out, problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime heuristic kernel selection, and GroupedConvKernelKey with full ConvConfigBase fields. Python side adds parallel JIT via registry.build(max_workers) and heuristic registry.select(). Includes 7 C++ and 6 Python examples covering all directions with CPU reference validation, and shared infrastructure improvements (BaseRegistry CRTP, structured exceptions). As a sanity check, JIT compile times for a single kernel remains the same and for multiple kernels there is better parallelism: Kernels | 1 worker | 8 workers 1 | 7.7 s | 7.7 s 2 | 15.9 s | 8.2 s 4 | 33.4 s | 9.7 s 6 | 52.3 s | 10.2 s ## Test Plan 145 ephemeral unit tests have been added to test basic functionality. All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7 C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference validation for forward, backward-data, and backward-weight (2D) in both C++ and Python examples pass. ## Test Result 30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56), 53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002 for all directions (fp16 vs fp32 reference). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
4c0e73ab12
commit
920acd2c12
@@ -55,10 +55,10 @@ def extract_balanced_parens(text: str, start_pos: int) -> str:
|
||||
|
||||
|
||||
def parse_conv_declarations(content: str) -> List[Dict]:
|
||||
"""Parse DECL_CONV_KERNEL_SET declarations with all parameters."""
|
||||
"""Parse DECL_GROUPED_CONV_KERNEL_SET declarations with all parameters."""
|
||||
kernels = []
|
||||
|
||||
for match in re.finditer(r"DECL_CONV_KERNEL_SET\s*\(", content):
|
||||
for match in re.finditer(r"DECL_GROUPED_CONV_KERNEL_SET\s*\(", content):
|
||||
body = extract_balanced_parens(content, match.end() - 1)
|
||||
if not body:
|
||||
continue
|
||||
@@ -619,7 +619,7 @@ def strip_cpp_strings_and_comments(content: str) -> str:
|
||||
n = len(content)
|
||||
|
||||
# Patterns that indicate a string is problematic and should be stripped
|
||||
problematic_patterns = ["DECL_KERNEL_SET", "DECL_CONV_KERNEL_SET", ".add("]
|
||||
problematic_patterns = ["DECL_KERNEL_SET", "DECL_GROUPED_CONV_KERNEL_SET", ".add("]
|
||||
|
||||
while i < n:
|
||||
# Check for raw string literal: R"delimiter(...)delimiter"
|
||||
@@ -697,7 +697,7 @@ def detect_and_parse(source_path: Path) -> Tuple[str, List[Dict]]:
|
||||
content = source_path.read_text()
|
||||
content = strip_cpp_strings_and_comments(content)
|
||||
|
||||
if "DECL_CONV_KERNEL_SET" in content:
|
||||
if "DECL_GROUPED_CONV_KERNEL_SET" in content:
|
||||
return "conv", parse_conv_declarations(content)
|
||||
elif "DECL_KERNEL_SET" in content:
|
||||
return "gemm", parse_gemm_declarations(content)
|
||||
@@ -966,30 +966,128 @@ def generate_per_set_functions(source_stem: str) -> str:
|
||||
def generate_conv_registration(
|
||||
kernel_headers: List[Path], example_name: str, kernels: List[Dict]
|
||||
) -> str:
|
||||
"""Generate Conv kernel registration code for the dispatcher registry."""
|
||||
"""Generate Conv kernel registration code for the dispatcher registry.
|
||||
|
||||
Creates real GroupedConvKernelInstance entries backed by the generated
|
||||
launcher's launch() method via the conv backend RunFn factories.
|
||||
"""
|
||||
if not kernel_headers:
|
||||
return " // No kernels to register"
|
||||
|
||||
lines = []
|
||||
lines.append(
|
||||
" (void)registry; (void)arch; // Conv uses direct launcher pattern for now"
|
||||
)
|
||||
|
||||
# For conv, we provide direct access to kernel launchers
|
||||
for i, h in enumerate(kernel_headers):
|
||||
kernel_name = h.stem
|
||||
lines.append(f" // Kernel {i + 1}: {kernel_name}")
|
||||
kname = h.stem
|
||||
ns = f"ns_{kname}"
|
||||
launcher = f"{ns}::{kname}_Launcher"
|
||||
|
||||
# Determine direction and ndim from the kernel header name
|
||||
if "_fwd_" in kname:
|
||||
direction = "Forward"
|
||||
run_fn_factory = "make_conv_fwd_run_fn"
|
||||
elif "_bwd_data_" in kname or "_bwdd_" in kname:
|
||||
direction = "BackwardData"
|
||||
run_fn_factory = "make_conv_bwd_data_run_fn"
|
||||
elif "_bwd_weight_" in kname or "_bwdw_" in kname:
|
||||
direction = "BackwardWeight"
|
||||
run_fn_factory = "make_conv_bwd_weight_run_fn"
|
||||
else:
|
||||
direction = "Forward"
|
||||
run_fn_factory = "make_conv_fwd_run_fn"
|
||||
|
||||
ndim = 3 if "_3d_" in kname else 2
|
||||
|
||||
# Parse dtype from name (e.g. grouped_conv_fwd_fp16_...)
|
||||
dtype = "fp16"
|
||||
for dt in ["fp16", "bf16", "fp32"]:
|
||||
if f"_{dt}_" in kname:
|
||||
dtype = dt
|
||||
break
|
||||
|
||||
# Parse tile, wave, warp from name.
|
||||
# Format: ..._TILExTILExTILE_WAVExWAVExWAVE_WARPxWARPxWARP_...
|
||||
import re as _re
|
||||
|
||||
tile_m, tile_n, tile_k = 1, 128, 128
|
||||
wave_m, wave_n, wave_k = 2, 2, 1
|
||||
warp_m, warp_n, warp_k = 32, 32, 16
|
||||
|
||||
triplets = _re.findall(r"_(\d+)x(\d+)x(\d+)", kname)
|
||||
if len(triplets) >= 1:
|
||||
tile_m, tile_n, tile_k = (
|
||||
int(triplets[0][0]),
|
||||
int(triplets[0][1]),
|
||||
int(triplets[0][2]),
|
||||
)
|
||||
if len(triplets) >= 2:
|
||||
wave_m, wave_n, wave_k = (
|
||||
int(triplets[1][0]),
|
||||
int(triplets[1][1]),
|
||||
int(triplets[1][2]),
|
||||
)
|
||||
if len(triplets) >= 3:
|
||||
warp_m, warp_n, warp_k = (
|
||||
int(triplets[2][0]),
|
||||
int(triplets[2][1]),
|
||||
int(triplets[2][2]),
|
||||
)
|
||||
|
||||
pipeline = "compv4" if "compv4" in kname else "compv3"
|
||||
scheduler = "interwave" if "interwave" in kname else "intrawave"
|
||||
epilogue = "cshuffle" if "cshuffle" in kname else "default"
|
||||
|
||||
# ConvConfigBase defaults
|
||||
vec_a, vec_b, vec_c = 4, 8, 8
|
||||
block_per_cu = 1
|
||||
num_wave_groups = 1
|
||||
num_groups_to_merge = 1
|
||||
|
||||
lines.append(f" // Kernel {i + 1}: {kname}")
|
||||
lines.append(" {")
|
||||
lines.append(f" ck_tile::dispatcher::GroupedConvKernelKey key_{i};")
|
||||
lines.append(f' key_{i}.dtype_in = "{dtype}";')
|
||||
lines.append(f' key_{i}.dtype_wei = "{dtype}";')
|
||||
lines.append(f' key_{i}.dtype_out = "{dtype}";')
|
||||
lines.append(f' key_{i}.layout = "nhwgc";')
|
||||
lines.append(f" key_{i}.ndim_spatial = {ndim};")
|
||||
lines.append(
|
||||
f" key_{i}.op = ck_tile::dispatcher::GroupedConvOp::{direction};"
|
||||
)
|
||||
lines.append(f" key_{i}.tile_m = {tile_m};")
|
||||
lines.append(f" key_{i}.tile_n = {tile_n};")
|
||||
lines.append(f" key_{i}.tile_k = {tile_k};")
|
||||
lines.append(f" key_{i}.wave_m = {wave_m};")
|
||||
lines.append(f" key_{i}.wave_n = {wave_n};")
|
||||
lines.append(f" key_{i}.wave_k = {wave_k};")
|
||||
lines.append(f" key_{i}.warp_m = {warp_m};")
|
||||
lines.append(f" key_{i}.warp_n = {warp_n};")
|
||||
lines.append(f" key_{i}.warp_k = {warp_k};")
|
||||
lines.append(f' key_{i}.pipeline = "{pipeline}";')
|
||||
lines.append(f' key_{i}.scheduler = "{scheduler}";')
|
||||
lines.append(f' key_{i}.epilogue = "{epilogue}";')
|
||||
lines.append(f" key_{i}.vector_size_a = {vec_a};")
|
||||
lines.append(f" key_{i}.vector_size_b = {vec_b};")
|
||||
lines.append(f" key_{i}.vector_size_c = {vec_c};")
|
||||
lines.append(f" key_{i}.block_per_cu = {block_per_cu};")
|
||||
lines.append(f" key_{i}.num_wave_groups = {num_wave_groups};")
|
||||
lines.append(f" key_{i}.num_groups_to_merge = {num_groups_to_merge};")
|
||||
lines.append(f" key_{i}.arch = arch;")
|
||||
lines.append(
|
||||
f" auto run_fn_{i} = ck_tile::dispatcher::backends::{run_fn_factory}<{launcher}, {ndim}>();"
|
||||
)
|
||||
lines.append(
|
||||
f' auto inst_{i} = std::make_shared<ck_tile::dispatcher::GroupedConvKernelInstance>(key_{i}, "{kname}", std::move(run_fn_{i}));'
|
||||
)
|
||||
lines.append(f" registry.register_kernel(key_{i}, inst_{i});")
|
||||
lines.append(" }")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def generate_conv_kernels(
|
||||
kernels: List[Dict], output_dir: Path, codegen_dir: Path
|
||||
) -> bool:
|
||||
"""Generate Conv kernels for ALL declarations using unified codegen."""
|
||||
if not kernels:
|
||||
return False
|
||||
|
||||
def _build_conv_codegen_cmd(
|
||||
idx: int, k: Dict, codegen_dir: Path, output_dir: Path
|
||||
) -> Tuple[int, List[str], str]:
|
||||
"""Build the command for a single conv kernel codegen invocation."""
|
||||
variant_map = {
|
||||
"forward": "forward",
|
||||
"bwd_data": "bwd_data",
|
||||
@@ -997,93 +1095,130 @@ def generate_conv_kernels(
|
||||
"bwd_weight": "bwd_weight",
|
||||
"backward_weight": "bwd_weight",
|
||||
}
|
||||
variant = variant_map.get(k.get("conv_type", "forward"), "forward")
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(codegen_dir / "unified_grouped_conv_codegen.py"),
|
||||
"--datatype",
|
||||
k.get("dtype", "fp16"),
|
||||
"--variant",
|
||||
variant,
|
||||
"--ndim",
|
||||
str(k.get("ndim", 2)),
|
||||
"--output",
|
||||
str(output_dir),
|
||||
]
|
||||
|
||||
if k.get("tile_m"):
|
||||
cmd.extend(["--tile-m", str(k["tile_m"])])
|
||||
if k.get("tile_n"):
|
||||
cmd.extend(["--tile-n", str(k["tile_n"])])
|
||||
if k.get("warp_m"):
|
||||
cmd.extend(["--warp-m", str(k["warp_m"])])
|
||||
if k.get("warp_n"):
|
||||
cmd.extend(["--warp-n", str(k["warp_n"])])
|
||||
if k.get("warp_k"):
|
||||
cmd.extend(["--warp-k", str(k["warp_k"])])
|
||||
if k.get("warp_tile_m"):
|
||||
cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])])
|
||||
if k.get("warp_tile_n"):
|
||||
cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])])
|
||||
if k.get("warp_tile_k"):
|
||||
cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])])
|
||||
if k.get("pipeline"):
|
||||
cmd.extend(["--pipeline", k["pipeline"]])
|
||||
if k.get("scheduler"):
|
||||
cmd.extend(["--scheduler", k["scheduler"]])
|
||||
if k.get("epilogue"):
|
||||
cmd.extend(["--epilogue", k["epilogue"]])
|
||||
if k.get("vector_a"):
|
||||
cmd.extend(["--vector-a", str(k["vector_a"])])
|
||||
if k.get("vector_b"):
|
||||
cmd.extend(["--vector-b", str(k["vector_b"])])
|
||||
if k.get("vector_c"):
|
||||
cmd.extend(["--vector-c", str(k["vector_c"])])
|
||||
if k.get("block_per_cu"):
|
||||
cmd.extend(["--block-per-cu", str(k["block_per_cu"])])
|
||||
if k.get("num_wave_groups"):
|
||||
cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])])
|
||||
if k.get("num_groups_to_merge"):
|
||||
cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])])
|
||||
if k.get("double_smem_buffer") is not None:
|
||||
cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()])
|
||||
if k.get("tile_k"):
|
||||
cmd.extend(["--tile-k", str(k["tile_k"])])
|
||||
|
||||
return (idx, cmd, str(codegen_dir))
|
||||
|
||||
|
||||
def _run_conv_codegen(args: Tuple) -> Tuple[int, bool, str]:
|
||||
"""Run unified_grouped_conv_codegen.py for a single kernel config (picklable for ProcessPoolExecutor)."""
|
||||
idx, cmd, cwd = args
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd)
|
||||
if result.returncode != 0:
|
||||
return (idx, False, result.stderr[:300])
|
||||
return (idx, True, "")
|
||||
|
||||
|
||||
def generate_conv_kernels(
|
||||
kernels: List[Dict], output_dir: Path, codegen_dir: Path
|
||||
) -> bool:
|
||||
"""Generate Conv kernels for ALL declarations using unified codegen.
|
||||
|
||||
Launches all codegen subprocesses in parallel via ProcessPoolExecutor
|
||||
for significantly faster generation when multiple conv kernels are declared.
|
||||
"""
|
||||
if not kernels:
|
||||
return False
|
||||
|
||||
work_items = [
|
||||
_build_conv_codegen_cmd(idx, k, codegen_dir, output_dir)
|
||||
for idx, k in enumerate(kernels)
|
||||
]
|
||||
|
||||
success_count = 0
|
||||
max_workers = min(len(work_items), os.cpu_count() or 4)
|
||||
|
||||
# Generate a kernel for EACH declaration
|
||||
for idx, k in enumerate(kernels):
|
||||
variant = variant_map.get(k.get("conv_type", "forward"), "forward")
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(codegen_dir / "unified_conv_codegen.py"),
|
||||
"--datatype",
|
||||
k.get("dtype", "fp16"),
|
||||
"--variant",
|
||||
variant,
|
||||
"--ndim",
|
||||
str(k.get("ndim", 2)),
|
||||
"--output",
|
||||
str(output_dir),
|
||||
]
|
||||
|
||||
# Add optional parameters if specified
|
||||
if k.get("tile_m"):
|
||||
cmd.extend(["--tile-m", str(k["tile_m"])])
|
||||
if k.get("tile_n"):
|
||||
cmd.extend(["--tile-n", str(k["tile_n"])])
|
||||
if k.get("warp_m"):
|
||||
cmd.extend(["--warp-m", str(k["warp_m"])])
|
||||
if k.get("warp_n"):
|
||||
cmd.extend(["--warp-n", str(k["warp_n"])])
|
||||
if k.get("warp_k"):
|
||||
cmd.extend(["--warp-k", str(k["warp_k"])])
|
||||
if k.get("warp_tile_m"):
|
||||
cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])])
|
||||
if k.get("warp_tile_n"):
|
||||
cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])])
|
||||
if k.get("warp_tile_k"):
|
||||
cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])])
|
||||
if k.get("pipeline"):
|
||||
cmd.extend(["--pipeline", k["pipeline"]])
|
||||
if k.get("scheduler"):
|
||||
cmd.extend(["--scheduler", k["scheduler"]])
|
||||
if k.get("epilogue"):
|
||||
cmd.extend(["--epilogue", k["epilogue"]])
|
||||
if k.get("vector_a"):
|
||||
cmd.extend(["--vector-a", str(k["vector_a"])])
|
||||
if k.get("vector_b"):
|
||||
cmd.extend(["--vector-b", str(k["vector_b"])])
|
||||
if k.get("vector_c"):
|
||||
cmd.extend(["--vector-c", str(k["vector_c"])])
|
||||
if k.get("block_per_cu"):
|
||||
cmd.extend(["--block-per-cu", str(k["block_per_cu"])])
|
||||
if k.get("num_wave_groups"):
|
||||
cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])])
|
||||
if k.get("num_groups_to_merge"):
|
||||
cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])])
|
||||
if k.get("double_smem_buffer") is not None:
|
||||
cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()])
|
||||
if k.get("tile_k"):
|
||||
cmd.extend(["--tile-k", str(k["tile_k"])])
|
||||
|
||||
result = subprocess.run(
|
||||
cmd, capture_output=True, text=True, cwd=str(codegen_dir)
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}")
|
||||
else:
|
||||
success_count += 1
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {executor.submit(_run_conv_codegen, w): w[0] for w in work_items}
|
||||
for future in as_completed(futures):
|
||||
idx, ok, err = future.result()
|
||||
if ok:
|
||||
success_count += 1
|
||||
else:
|
||||
print(f" Codegen error for kernel {idx + 1}: {err}")
|
||||
|
||||
return success_count > 0
|
||||
|
||||
|
||||
def _run_gemm_codegen(args: Tuple) -> Tuple[int, bool, str]:
|
||||
"""Run unified_gemm_codegen.py for a single kernel config (picklable for ProcessPoolExecutor)."""
|
||||
idx, cmd, cwd = args
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd)
|
||||
if result.returncode != 0:
|
||||
return (idx, False, result.stderr[:300])
|
||||
return (idx, True, "")
|
||||
|
||||
|
||||
def generate_gemm_kernels(
|
||||
kernels: List[Dict], output_dir: Path, codegen_dir: Path
|
||||
) -> bool:
|
||||
"""Generate GEMM kernels for ALL declarations using unified codegen."""
|
||||
"""Generate GEMM kernels for ALL declarations using unified codegen.
|
||||
|
||||
Launches all codegen subprocesses in parallel via ProcessPoolExecutor
|
||||
for significantly faster generation when multiple kernels are declared.
|
||||
"""
|
||||
import json
|
||||
|
||||
if not kernels:
|
||||
return False
|
||||
|
||||
success_count = 0
|
||||
|
||||
# Generate a kernel for EACH declaration
|
||||
# Build all commands upfront
|
||||
work_items = []
|
||||
for idx, k in enumerate(kernels):
|
||||
variant = "multi_d" if k.get("elementwise_op") else "standard"
|
||||
|
||||
# Build tile config JSON for this specific kernel
|
||||
tile_config = {
|
||||
"tile_m": [k.get("tile_m", 128)],
|
||||
"tile_n": [k.get("tile_n", 128)],
|
||||
@@ -1125,13 +1260,20 @@ def generate_gemm_kernels(
|
||||
config_json,
|
||||
]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd, capture_output=True, text=True, cwd=str(codegen_dir)
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}")
|
||||
else:
|
||||
success_count += 1
|
||||
work_items.append((idx, cmd, str(codegen_dir)))
|
||||
|
||||
# Run all codegen subprocesses in parallel
|
||||
success_count = 0
|
||||
max_workers = min(len(work_items), os.cpu_count() or 4)
|
||||
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {executor.submit(_run_gemm_codegen, w): w[0] for w in work_items}
|
||||
for future in as_completed(futures):
|
||||
idx, ok, err = future.result()
|
||||
if ok:
|
||||
success_count += 1
|
||||
else:
|
||||
print(f" Codegen error for kernel {idx + 1}: {err}")
|
||||
|
||||
return success_count > 0
|
||||
|
||||
@@ -1229,15 +1371,17 @@ def main():
|
||||
if example_type == "gemm":
|
||||
kernel_headers = list(args.output_dir.glob("gemm_*.hpp"))
|
||||
else:
|
||||
k = kernels[0] if kernels else {}
|
||||
variant = k.get("conv_type", "forward")
|
||||
prefix_map = {
|
||||
"forward": "conv_fwd",
|
||||
"bwd_data": "conv_bwdd",
|
||||
"bwd_weight": "conv_bwdw",
|
||||
"forward": "grouped_conv_fwd",
|
||||
"bwd_data": "grouped_conv_bwd_data",
|
||||
"bwd_weight": "grouped_conv_bwd_weight",
|
||||
}
|
||||
prefix = prefix_map.get(variant, "conv_fwd")
|
||||
kernel_headers = list(args.output_dir.glob(f"{prefix}_*.hpp"))
|
||||
# Collect headers from ALL variants present in declarations
|
||||
variants_used = set(k.get("conv_type", "forward") for k in kernels)
|
||||
kernel_headers = []
|
||||
for variant in variants_used:
|
||||
prefix = prefix_map.get(variant, "grouped_conv_fwd")
|
||||
kernel_headers.extend(args.output_dir.glob(f"{prefix}_*.hpp"))
|
||||
|
||||
if not kernel_headers:
|
||||
print(f"[{target_name}] No kernel headers generated!")
|
||||
@@ -1347,29 +1491,39 @@ def main():
|
||||
)
|
||||
|
||||
if has_bwd_data:
|
||||
bwdd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdd_")
|
||||
if bwdd_kernel:
|
||||
bwdd_ns = f"ns_{bwdd_kernel.stem}"
|
||||
launcher_aliases.append(
|
||||
f"using BwdDataKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;"
|
||||
bwd_data_kernel = find_kernel_by_dtype_type(
|
||||
kernel_headers, "fp16", "_bwd_data_"
|
||||
)
|
||||
if not bwd_data_kernel:
|
||||
bwd_data_kernel = find_kernel_by_dtype_type(
|
||||
kernel_headers, "fp16", "_bwdd_"
|
||||
)
|
||||
if not has_fwd: # If no fwd, use bwd_data as first
|
||||
if bwd_data_kernel:
|
||||
bwd_data_ns = f"ns_{bwd_data_kernel.stem}"
|
||||
launcher_aliases.append(
|
||||
f"using BwdDataKernelLauncher = {bwd_data_ns}::{bwd_data_kernel.stem}_Launcher;"
|
||||
)
|
||||
if not has_fwd:
|
||||
launcher_aliases.append(
|
||||
f"using FirstKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;"
|
||||
f"using FirstKernelLauncher = {bwd_data_ns}::{bwd_data_kernel.stem}_Launcher;"
|
||||
)
|
||||
|
||||
if has_bwd_weight:
|
||||
bwdw_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdw_")
|
||||
if bwdw_kernel:
|
||||
bwdw_ns = f"ns_{bwdw_kernel.stem}"
|
||||
launcher_aliases.append(
|
||||
f"using BwdWeightKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;"
|
||||
bwd_weight_kernel = find_kernel_by_dtype_type(
|
||||
kernel_headers, "fp16", "_bwd_weight_"
|
||||
)
|
||||
if not bwd_weight_kernel:
|
||||
bwd_weight_kernel = find_kernel_by_dtype_type(
|
||||
kernel_headers, "fp16", "_bwdw_"
|
||||
)
|
||||
if (
|
||||
not has_fwd and not has_bwd_data
|
||||
): # If no fwd or bwdd, use bwdw as first
|
||||
if bwd_weight_kernel:
|
||||
bwd_weight_ns = f"ns_{bwd_weight_kernel.stem}"
|
||||
launcher_aliases.append(
|
||||
f"using BwdWeightKernelLauncher = {bwd_weight_ns}::{bwd_weight_kernel.stem}_Launcher;"
|
||||
)
|
||||
if not has_fwd and not has_bwd_data:
|
||||
launcher_aliases.append(
|
||||
f"using FirstKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;"
|
||||
f"using FirstKernelLauncher = {bwd_weight_ns}::{bwd_weight_kernel.stem}_Launcher;"
|
||||
)
|
||||
|
||||
launcher_section = "\n".join(launcher_aliases)
|
||||
@@ -1382,14 +1536,16 @@ def main():
|
||||
#include "ck_tile/dispatcher/registry.hpp"
|
||||
#include "ck_tile/dispatcher/kernel_instance.hpp"
|
||||
#include "ck_tile/dispatcher/kernel_key.hpp"
|
||||
#include "ck_tile/dispatcher/grouped_conv_registry.hpp"
|
||||
#include "ck_tile/dispatcher/backends/generated_conv_backend.hpp"
|
||||
|
||||
namespace generated {{
|
||||
|
||||
// Kernel launchers for direct use
|
||||
{launcher_section}
|
||||
|
||||
// Registration function
|
||||
inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{
|
||||
// Registration function (takes GroupedConvRegistry for conv kernels)
|
||||
inline void {func_name}(ck_tile::dispatcher::GroupedConvRegistry& registry, const std::string& arch) {{
|
||||
{register_body}
|
||||
}}
|
||||
|
||||
@@ -1439,7 +1595,7 @@ inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::stri
|
||||
"""
|
||||
header_path.write_text(header_content)
|
||||
|
||||
print(f"[{target_name}] ✓ {len(obj_files)} kernels compiled")
|
||||
print(f"[{target_name}] OK {len(obj_files)} kernels compiled")
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user