mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher ## Motivation This PR adds CK Tile group convolution (forward, backward-data, backward-weight) support to the kernel dispatcher, matching and unifying with the existing dispatcher GEMM infrastructure in architecture and usability. The dispatcher provides a unified kernel dispatch system with both C++ and Python frontends, and until now only supported GEMM operations. This PR enables framework integrators to use the same declarative kernel workflow for convolutions as they do for GEMM: declare kernels, build a registry JIT, select kernels within the registry at runtime, and dispatch to GPU. Future PRs will include runtime kernel selection heuristics for autotuning of kernel parameters based on (problem, hardware arch). ## Technical Details Grouped convolution support has been added to the CK Tile Dispatcher with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out, problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime heuristic kernel selection, and GroupedConvKernelKey with full ConvConfigBase fields. Python side adds parallel JIT via registry.build(max_workers) and heuristic registry.select(). Includes 7 C++ and 6 Python examples covering all directions with CPU reference validation, and shared infrastructure improvements (BaseRegistry CRTP, structured exceptions). As a sanity check, JIT compile times for a single kernel remains the same and for multiple kernels there is better parallelism: Kernels | 1 worker | 8 workers 1 | 7.7 s | 7.7 s 2 | 15.9 s | 8.2 s 4 | 33.4 s | 9.7 s 6 | 52.3 s | 10.2 s ## Test Plan 145 ephemeral unit tests have been added to test basic functionality. All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7 C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference validation for forward, backward-data, and backward-weight (2D) in both C++ and Python examples pass. ## Test Result 30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56), 53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002 for all directions (fp16 vs fp32 reference). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2255 lines
81 KiB
Python
2255 lines
81 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Cross-platform build script for declarative kernel workflow.
|
|
|
|
Uses existing ctypes_utils.py for path management and codegen.
|
|
|
|
Usage:
|
|
python3 compile_gemm_examples.py <source_file.cpp> [output_name]
|
|
|
|
Example:
|
|
python3 compile_gemm_examples.py examples/cpp/01_basic_gemm.cpp my_app
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
import shutil
|
|
|
|
# Add dispatcher/python to path to reuse existing utilities
|
|
SCRIPT_DIR = Path(__file__).parent.resolve()
|
|
DISPATCHER_DIR = SCRIPT_DIR.parent
|
|
sys.path.insert(0, str(DISPATCHER_DIR / "python"))
|
|
|
|
# Import existing utilities (after sys.path modification)
|
|
from ctypes_utils import ( # noqa: E402
|
|
get_dispatcher_root,
|
|
get_ck_root,
|
|
get_build_dir,
|
|
get_generated_kernels_dir,
|
|
CodegenRunner,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Terminal Colors (cross-platform)
|
|
# =============================================================================
|
|
|
|
|
|
class Colors:
|
|
if sys.platform != "win32" and sys.stdout.isatty():
|
|
GREEN = "\033[0;32m"
|
|
YELLOW = "\033[1;33m"
|
|
RED = "\033[0;31m"
|
|
NC = "\033[0m"
|
|
else:
|
|
GREEN = YELLOW = RED = NC = ""
|
|
|
|
|
|
def print_phase(msg: str):
|
|
print(f"{Colors.YELLOW}{msg}{Colors.NC}")
|
|
|
|
|
|
def print_success(msg: str):
|
|
print(f"{Colors.GREEN}{msg}{Colors.NC}")
|
|
|
|
|
|
def print_error(msg: str):
|
|
print(f"{Colors.RED}{msg}{Colors.NC}", file=sys.stderr)
|
|
|
|
|
|
# =============================================================================
|
|
# Compiler Detection
|
|
# =============================================================================
|
|
|
|
|
|
def find_hipcc() -> str:
|
|
"""Find hipcc compiler."""
|
|
candidates = [
|
|
os.environ.get("HIPCC"),
|
|
"/opt/rocm/bin/hipcc",
|
|
"/opt/rocm/hip/bin/hipcc",
|
|
shutil.which("hipcc"),
|
|
]
|
|
|
|
for path in candidates:
|
|
if path and os.path.isfile(path):
|
|
return path
|
|
|
|
raise RuntimeError(
|
|
"hipcc not found. Please install ROCm or set HIPCC environment variable."
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Declaration Extraction
|
|
# =============================================================================
|
|
|
|
|
|
def extract_conv_kernel_declarations(source_file: Path) -> list:
|
|
"""Extract GROUPED CONVOLUTION kernel declarations from C++ source file.
|
|
|
|
Supports DECL_GROUPED_CONV_KERNEL_SET macro with ConvSig/ConvAlgo pattern.
|
|
Extracts all parameters: dtype, layout, conv_type, dims, tile, wave, warp, pipeline, scheduler.
|
|
"""
|
|
content = source_file.read_text()
|
|
declarations = []
|
|
seen = set()
|
|
|
|
# Pattern: DECL_GROUPED_CONV_KERNEL_SET(name, .add(...).add(...))
|
|
set_pattern = r"DECL_GROUPED_CONV_KERNEL_SET\s*\(\s*(\w+)\s*,([^;]+)\)"
|
|
|
|
for match in re.finditer(set_pattern, content, re.DOTALL):
|
|
set_name = match.group(1)
|
|
set_body = match.group(2)
|
|
|
|
# Pattern 1: Simple add("dtype", "layout", "conv_type", tile_k, tile_c)
|
|
simple_add = (
|
|
r'\.add\s*\(\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)'
|
|
)
|
|
for add_match in re.finditer(simple_add, set_body):
|
|
dtype = add_match.group(1)
|
|
layout = add_match.group(2)
|
|
conv_type = add_match.group(3)
|
|
tile_k = int(add_match.group(4))
|
|
tile_c = int(add_match.group(5))
|
|
|
|
name = f"{set_name}:{dtype}_{layout}_{conv_type}_{tile_k}x{tile_c}"
|
|
if name not in seen:
|
|
seen.add(name)
|
|
declarations.append(
|
|
{
|
|
"type": "conv",
|
|
"dtype": dtype,
|
|
"layout": layout,
|
|
"conv_type": conv_type,
|
|
"num_dims": 2,
|
|
"groups": 1,
|
|
"tile_n": 1,
|
|
"tile_k": tile_k,
|
|
"tile_c": tile_c,
|
|
"wave_m": -1, # Wildcard - will expand
|
|
"wave_n": -1,
|
|
"wave_k": 1,
|
|
"warp_m": -1,
|
|
"warp_n": -1,
|
|
"warp_k": 16,
|
|
"pipeline": "compv3",
|
|
"scheduler": "intrawave",
|
|
"epilogue": "cshuffle",
|
|
"name": name,
|
|
"set": set_name,
|
|
"arch": "gfx942",
|
|
}
|
|
)
|
|
|
|
# Pattern 2: Full specification with ConvSig() and ConvAlgo()
|
|
# Match .add( ConvSig()..., ConvAlgo()..., "arch" )
|
|
# Use robust parsing that handles multi-line and comments
|
|
|
|
# Find all .add( blocks containing ConvSig
|
|
add_blocks = re.findall(
|
|
r"\.add\s*\(\s*ConvSig\(\)([\s\S]*?)(?=\.add\s*\(|$)", set_body
|
|
)
|
|
|
|
for add_block in add_blocks:
|
|
# Find ConvAlgo and arch in this block
|
|
algo_match = re.search(r'ConvAlgo\(\)([\s\S]*?),\s*"(\w+)"\s*\)', add_block)
|
|
if not algo_match:
|
|
continue
|
|
|
|
sig_str = add_block[: add_block.find("ConvAlgo()")]
|
|
algo_str = algo_match.group(1)
|
|
arch = algo_match.group(2)
|
|
|
|
# Parse ConvSig
|
|
dtype = "fp16"
|
|
dtype_match = re.search(r'\.dtype\s*\(\s*"([^"]+)"', sig_str)
|
|
if dtype_match:
|
|
dtype = dtype_match.group(1)
|
|
|
|
layout = "nhwgc"
|
|
layout_match = re.search(r'\.layout\s*\(\s*"([^"]+)"', sig_str)
|
|
if layout_match:
|
|
layout = layout_match.group(1)
|
|
|
|
conv_type = "forward"
|
|
conv_type_match = re.search(r'\.conv_type\s*\(\s*"([^"]+)"', sig_str)
|
|
if conv_type_match:
|
|
conv_type = conv_type_match.group(1)
|
|
|
|
num_dims = 2
|
|
dims_match = re.search(r"\.dims\s*\(\s*(\d+)", sig_str)
|
|
if dims_match:
|
|
num_dims = int(dims_match.group(1))
|
|
|
|
groups = 1
|
|
groups_match = re.search(r"\.groups\s*\(\s*(\d+)", sig_str)
|
|
if groups_match:
|
|
groups = int(groups_match.group(1))
|
|
|
|
# Parse ConvAlgo
|
|
tile_n, tile_k, tile_c = 1, 128, 128
|
|
tile_match = re.search(
|
|
r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", algo_str
|
|
)
|
|
if tile_match:
|
|
tile_n = int(tile_match.group(1))
|
|
tile_k = int(tile_match.group(2))
|
|
tile_c = int(tile_match.group(3))
|
|
|
|
wave_m, wave_n, wave_k = 2, 2, 1
|
|
wave_match = re.search(
|
|
r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str
|
|
)
|
|
if wave_match:
|
|
wave_m = int(wave_match.group(1))
|
|
wave_n = int(wave_match.group(2))
|
|
wave_k = int(wave_match.group(3) or 1)
|
|
|
|
warp_m, warp_n, warp_k = 32, 32, 16
|
|
warp_match = re.search(
|
|
r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str
|
|
)
|
|
if warp_match:
|
|
warp_m = int(warp_match.group(1))
|
|
warp_n = int(warp_match.group(2))
|
|
warp_k = int(warp_match.group(3) or 16)
|
|
|
|
pipeline = "compv3"
|
|
pipeline_match = re.search(r'\.pipeline\s*\(\s*"([^"]+)"', algo_str)
|
|
if pipeline_match:
|
|
pipeline = pipeline_match.group(1)
|
|
|
|
scheduler = "intrawave"
|
|
scheduler_match = re.search(r'\.scheduler\s*\(\s*"([^"]+)"', algo_str)
|
|
if scheduler_match:
|
|
scheduler = scheduler_match.group(1)
|
|
|
|
epilogue = "cshuffle"
|
|
epilogue_match = re.search(r'\.epilogue\s*\(\s*"([^"]+)"', algo_str)
|
|
if epilogue_match:
|
|
epilogue = epilogue_match.group(1)
|
|
|
|
# Build unique name with full config
|
|
name = f"{set_name}:{dtype}_{conv_type}_{num_dims}d_{pipeline}_{scheduler}_{tile_k}x{tile_c}_{wave_m}x{wave_n}x{wave_k}"
|
|
if name not in seen:
|
|
seen.add(name)
|
|
declarations.append(
|
|
{
|
|
"type": "conv",
|
|
"dtype": dtype,
|
|
"layout": layout,
|
|
"conv_type": conv_type,
|
|
"num_dims": num_dims,
|
|
"groups": groups,
|
|
"tile_n": tile_n,
|
|
"tile_k": tile_k,
|
|
"tile_c": tile_c,
|
|
"wave_m": wave_m,
|
|
"wave_n": wave_n,
|
|
"wave_k": wave_k,
|
|
"warp_m": warp_m,
|
|
"warp_n": warp_n,
|
|
"warp_k": warp_k,
|
|
"pipeline": pipeline,
|
|
"scheduler": scheduler,
|
|
"epilogue": epilogue,
|
|
"name": name,
|
|
"set": set_name,
|
|
"arch": arch,
|
|
}
|
|
)
|
|
|
|
return declarations
|
|
|
|
|
|
def expand_conv_declaration_with_arch_filter(decl: dict, arch: str = "gfx942") -> list:
|
|
"""Expand a convolution declaration to all valid combinations.
|
|
|
|
Like GEMM, convolution supports wildcard expansion for:
|
|
- wave/warp: If -1, generates all valid combinations
|
|
- pipeline/scheduler: If "*", generates all valid trait combinations
|
|
"""
|
|
# Import arch filter
|
|
codegen_dir = get_dispatcher_root() / "codegen"
|
|
sys.path.insert(0, str(codegen_dir))
|
|
|
|
try:
|
|
from arch_specs_generated import (
|
|
WARP_SUPPORTED_COMBINATIONS,
|
|
WARP_TILE_SUPPORTED_COMBINATIONS,
|
|
TRAIT_UNSUPPORTED_COMBINATIONS,
|
|
)
|
|
except ImportError:
|
|
# Fallback
|
|
WARP_SUPPORTED_COMBINATIONS = {
|
|
"gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
|
|
}
|
|
WARP_TILE_SUPPORTED_COMBINATIONS = {
|
|
"gfx942": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]},
|
|
}
|
|
TRAIT_UNSUPPORTED_COMBINATIONS = set()
|
|
|
|
d = decl.copy()
|
|
tile_k = d.get("tile_k", 128)
|
|
tile_c = d.get("tile_c", 128)
|
|
dtype = d.get("dtype", "fp16")
|
|
|
|
# Check what needs expansion
|
|
needs_wave_expansion = d.get("wave_m", -1) < 0 or d.get("wave_n", -1) < 0
|
|
needs_warp_expansion = d.get("warp_m", -1) < 0 or d.get("warp_n", -1) < 0
|
|
needs_pipeline_expansion = d.get("pipeline", "compv4") == "*"
|
|
needs_scheduler_expansion = d.get("scheduler", "intrawave") == "*"
|
|
|
|
if (
|
|
not needs_wave_expansion
|
|
and not needs_warp_expansion
|
|
and not needs_pipeline_expansion
|
|
and not needs_scheduler_expansion
|
|
):
|
|
return [d]
|
|
|
|
# Build valid combinations
|
|
if needs_wave_expansion or needs_warp_expansion:
|
|
wave_configs = WARP_SUPPORTED_COMBINATIONS.get(arch, [[2, 2, 1]])
|
|
dtype_key = f"{dtype}_{dtype}_{dtype}"
|
|
warp_tile_configs = WARP_TILE_SUPPORTED_COMBINATIONS.get(arch, {}).get(
|
|
dtype_key, [[32, 32, 16], [16, 16, 16]]
|
|
)
|
|
else:
|
|
wave_configs = [[d.get("wave_m", 2), d.get("wave_n", 2), d.get("wave_k", 1)]]
|
|
warp_tile_configs = [
|
|
[d.get("warp_m", 32), d.get("warp_n", 32), d.get("warp_k", 16)]
|
|
]
|
|
|
|
# Pipeline/scheduler combinations
|
|
ALL_PIPELINES = ["compv3", "compv4"]
|
|
ALL_SCHEDULERS = ["intrawave", "interwave"]
|
|
|
|
pipelines = (
|
|
ALL_PIPELINES if needs_pipeline_expansion else [d.get("pipeline", "compv4")]
|
|
)
|
|
schedulers = (
|
|
ALL_SCHEDULERS
|
|
if needs_scheduler_expansion
|
|
else [d.get("scheduler", "intrawave")]
|
|
)
|
|
|
|
expanded = []
|
|
|
|
for wm, wn, wk in wave_configs:
|
|
for wtm, wtn, wtk in warp_tile_configs:
|
|
# Check divisibility for conv (M=output spatial, N=K channels, K=C channels)
|
|
# Simplified check for now
|
|
if tile_k % (wn * wtn) != 0:
|
|
continue
|
|
if tile_c % (wk * wtk) != 0:
|
|
continue
|
|
|
|
for pipeline in pipelines:
|
|
for scheduler in schedulers:
|
|
# Check trait combination
|
|
if (
|
|
pipeline,
|
|
"cshuffle",
|
|
scheduler,
|
|
) in TRAIT_UNSUPPORTED_COMBINATIONS:
|
|
continue
|
|
|
|
expanded_d = d.copy()
|
|
expanded_d["wave_m"] = wm
|
|
expanded_d["wave_n"] = wn
|
|
expanded_d["wave_k"] = wk
|
|
expanded_d["warp_m"] = wtm
|
|
expanded_d["warp_n"] = wtn
|
|
expanded_d["warp_k"] = wtk
|
|
expanded_d["pipeline"] = pipeline
|
|
expanded_d["scheduler"] = scheduler
|
|
|
|
expanded_d["name"] = (
|
|
f"conv_{d['conv_type']}_{dtype}_{d['num_dims']}d_{pipeline}_"
|
|
f"{scheduler}_{tile_k}x{tile_c}_{wm}x{wn}x{wk}"
|
|
)
|
|
expanded.append(expanded_d)
|
|
|
|
if not expanded:
|
|
# Fallback to defaults
|
|
d["wave_m"] = 2
|
|
d["wave_n"] = 2
|
|
d["wave_k"] = 1
|
|
d["warp_m"] = 32
|
|
d["warp_n"] = 32
|
|
d["warp_k"] = 16
|
|
d["pipeline"] = "compv4"
|
|
d["scheduler"] = "intrawave"
|
|
return [d]
|
|
|
|
return expanded
|
|
|
|
|
|
def generate_conv_kernels(declarations: list, gpu_target: str = "gfx942") -> int:
|
|
"""Generate grouped convolution kernels using unified_grouped_conv_codegen."""
|
|
kernel_dir = get_generated_kernels_dir()
|
|
kernel_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
codegen_dir = get_dispatcher_root() / "codegen"
|
|
sys.path.insert(0, str(codegen_dir))
|
|
|
|
try:
|
|
from unified_grouped_conv_codegen import (
|
|
UnifiedGroupedConvCodegen as UnifiedConvCodegen,
|
|
GroupedConvKernelConfig as ConvKernelConfig,
|
|
GroupedConvVariant as ConvVariant,
|
|
TileConfig,
|
|
GroupedConvTraitConfig as TraitConfig,
|
|
)
|
|
except ImportError as e:
|
|
print_error(f" Failed to import grouped conv codegen: {e}")
|
|
return 0
|
|
|
|
codegen = UnifiedConvCodegen(kernel_dir)
|
|
total_generated = 0
|
|
|
|
# Group by dtype and variant for efficient generation
|
|
groups = {}
|
|
for decl in declarations:
|
|
dtype = decl.get("dtype", "fp16")
|
|
conv_type = decl.get("conv_type", "forward")
|
|
num_dims = decl.get("num_dims", 2)
|
|
key = (dtype, conv_type, num_dims)
|
|
if key not in groups:
|
|
groups[key] = []
|
|
groups[key].append(decl)
|
|
|
|
for (dtype, conv_type, num_dims), decls in groups.items():
|
|
print(f" Generating {dtype} {conv_type} {num_dims}D kernels...")
|
|
|
|
# Map to ConvVariant
|
|
variant = ConvVariant.FORWARD
|
|
if conv_type == "bwd_data":
|
|
variant = ConvVariant.BACKWARD_DATA
|
|
elif conv_type == "bwd_weight":
|
|
variant = ConvVariant.BACKWARD_WEIGHT
|
|
|
|
for decl in decls:
|
|
pipeline = decl.get("pipeline", "compv3")
|
|
scheduler = decl.get("scheduler", "intrawave")
|
|
epilogue = decl.get("epilogue", "cshuffle")
|
|
|
|
tile_k = decl.get("tile_k", 128)
|
|
tile_c = decl.get("tile_c", 128)
|
|
wave_m = decl.get("wave_m", 2)
|
|
wave_n = decl.get("wave_n", 2)
|
|
warp_m = decl.get("warp_m", 32)
|
|
warp_n = decl.get("warp_n", 32)
|
|
warp_k = decl.get("warp_k", 16)
|
|
|
|
# Adjust tile_k for compv4
|
|
adj_tile_k = 64 * 2 if pipeline == "compv4" else 64
|
|
|
|
# Create TileConfig
|
|
tile_config = TileConfig(
|
|
tile_m=tile_k, # K is M in conv GEMM view
|
|
tile_n=tile_c, # C is N in conv GEMM view
|
|
tile_k=adj_tile_k,
|
|
warp_m=wave_m,
|
|
warp_n=wave_n,
|
|
warp_k=1,
|
|
warp_tile_m=warp_m,
|
|
warp_tile_n=warp_n,
|
|
warp_tile_k=warp_k,
|
|
)
|
|
|
|
# Create TraitConfig
|
|
trait_config = TraitConfig(
|
|
pipeline=pipeline,
|
|
scheduler=scheduler,
|
|
epilogue=epilogue,
|
|
double_smem_buffer=(pipeline == "compv4"),
|
|
pad_m=True,
|
|
pad_n=True,
|
|
pad_k=True,
|
|
)
|
|
|
|
# Create ConvKernelConfig
|
|
config = ConvKernelConfig(
|
|
tile=tile_config,
|
|
trait=trait_config,
|
|
variant=variant,
|
|
ndim_spatial=num_dims,
|
|
arch=gpu_target,
|
|
)
|
|
|
|
try:
|
|
filepath = codegen.generate_kernel(config, dtype)
|
|
total_generated += 1
|
|
print(f" Generated: {filepath.name}")
|
|
except Exception as e:
|
|
print_error(f" Failed to generate {decl['name']}: {e}")
|
|
|
|
return total_generated
|
|
|
|
|
|
# Original GEMM extraction continues here
|
|
def extract_kernel_declarations(source_file: Path) -> list:
|
|
"""Extract GEMM kernel declarations from C++ source file."""
|
|
content = source_file.read_text()
|
|
declarations = []
|
|
seen = set()
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Pattern 1: Simple DECL_KERNEL_SIMPLE(dtype, layout, tile_m, tile_n, tile_k)
|
|
# -------------------------------------------------------------------------
|
|
legacy_pattern = r"DECL_KERNEL_SIMPLE\s*\(\s*(\w+)\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)"
|
|
for match in re.findall(legacy_pattern, content):
|
|
dtype, layout, tm, tn, tk = match
|
|
name = f"{dtype}_{layout}_{tm}x{tn}x{tk}"
|
|
if name not in seen:
|
|
seen.add(name)
|
|
declarations.append(
|
|
{
|
|
"dtype_a": dtype,
|
|
"dtype_b": dtype,
|
|
"dtype_c": dtype,
|
|
"layout": layout,
|
|
"tile_m": int(tm),
|
|
"tile_n": int(tn),
|
|
"tile_k": int(tk),
|
|
"wave_m": -1,
|
|
"wave_n": -1,
|
|
"wave_k": 1,
|
|
"warp_m": -1,
|
|
"warp_n": -1,
|
|
"warp_k": 16,
|
|
"pipeline": "compv4",
|
|
"scheduler": "intrawave",
|
|
"epilogue": "cshuffle",
|
|
"name": name,
|
|
"wildcard": False,
|
|
}
|
|
)
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Pattern 2: Fluent API: DECL_KERNEL(Signature()..., Algorithm()..., arch)
|
|
# -------------------------------------------------------------------------
|
|
# Match DECL_KERNEL( ... ); blocks
|
|
fluent_pattern = r'DECL_KERNEL\s*\(\s*(Signature\(\)[^,]+),\s*(Algorithm\(\)[^,]+)(?:,\s*"([^"]+)")?\s*\)'
|
|
|
|
for match in re.finditer(fluent_pattern, content, re.DOTALL):
|
|
sig_str = match.group(1)
|
|
algo_str = match.group(2)
|
|
arch = match.group(3) or "gfx942"
|
|
|
|
# Parse Signature
|
|
sig = {"dtype_a": "fp16", "dtype_b": "fp16", "dtype_c": "fp16", "layout": "rcr"}
|
|
|
|
# .dtype("fp16", "fp16", "fp16", "fp32") or .dtype("fp16")
|
|
dtype_match = re.search(
|
|
r'\.dtype\("([^"]+)"(?:,\s*"([^"]+)")?(?:,\s*"([^"]+)")?', sig_str
|
|
)
|
|
if dtype_match:
|
|
sig["dtype_a"] = dtype_match.group(1)
|
|
sig["dtype_b"] = dtype_match.group(2) or dtype_match.group(1)
|
|
sig["dtype_c"] = dtype_match.group(3) or dtype_match.group(1)
|
|
|
|
# .layout("rcr") or .layout("row", "col", "row")
|
|
layout_match = re.search(
|
|
r'\.layout\("([^"]+)"(?:,\s*"([^"]+)")?(?:,\s*"([^"]+)")?', sig_str
|
|
)
|
|
if layout_match:
|
|
if layout_match.group(2): # Three-arg form
|
|
la = layout_match.group(1)
|
|
lb = layout_match.group(2)
|
|
lc = layout_match.group(3) or "row"
|
|
sig["layout"] = (
|
|
("r" if la == "row" else "c")
|
|
+ ("r" if lb == "row" else "c")
|
|
+ ("r" if lc == "row" else "c")
|
|
)
|
|
else: # Single arg "rcr"
|
|
sig["layout"] = layout_match.group(1)
|
|
|
|
# Parse Algorithm
|
|
algo = {}
|
|
|
|
# .tile(128, 128, 32)
|
|
tile_match = re.search(r"\.tile\((\d+),\s*(\d+),\s*(\d+)\)", algo_str)
|
|
if tile_match:
|
|
algo["tile_m"] = int(tile_match.group(1))
|
|
algo["tile_n"] = int(tile_match.group(2))
|
|
algo["tile_k"] = int(tile_match.group(3))
|
|
|
|
# .wave(2, 2, 1)
|
|
wave_match = re.search(r"\.wave\((\d+),\s*(\d+)(?:,\s*(\d+))?\)", algo_str)
|
|
if wave_match:
|
|
algo["wave_m"] = int(wave_match.group(1))
|
|
algo["wave_n"] = int(wave_match.group(2))
|
|
algo["wave_k"] = int(wave_match.group(3) or 1)
|
|
|
|
# .warp(32, 32, 16)
|
|
warp_match = re.search(r"\.warp\((\d+),\s*(\d+)(?:,\s*(\d+))?\)", algo_str)
|
|
if warp_match:
|
|
algo["warp_m"] = int(warp_match.group(1))
|
|
algo["warp_n"] = int(warp_match.group(2))
|
|
algo["warp_k"] = int(warp_match.group(3) or 16)
|
|
|
|
# .pipeline("compv4"), .scheduler("intrawave"), .epilogue("cshuffle")
|
|
for field in ["pipeline", "scheduler", "epilogue"]:
|
|
fmatch = re.search(rf'\.{field}\("([^"]+)"\)', algo_str)
|
|
if fmatch:
|
|
algo[field] = fmatch.group(1)
|
|
|
|
# Build declaration
|
|
tm = algo.get("tile_m", 128)
|
|
tn = algo.get("tile_n", 128)
|
|
tk = algo.get("tile_k", 32)
|
|
|
|
name = f"{sig['dtype_a']}_{sig['layout']}_{tm}x{tn}x{tk}"
|
|
|
|
if name not in seen:
|
|
seen.add(name)
|
|
declarations.append(
|
|
{
|
|
"dtype_a": sig["dtype_a"],
|
|
"dtype_b": sig["dtype_b"],
|
|
"dtype_c": sig["dtype_c"],
|
|
"layout": sig["layout"],
|
|
"tile_m": tm,
|
|
"tile_n": tn,
|
|
"tile_k": tk,
|
|
"wave_m": algo.get("wave_m", -1),
|
|
"wave_n": algo.get("wave_n", -1),
|
|
"wave_k": algo.get("wave_k", 1),
|
|
"warp_m": algo.get("warp_m", -1),
|
|
"warp_n": algo.get("warp_n", -1),
|
|
"warp_k": algo.get("warp_k", 16),
|
|
"pipeline": algo.get("pipeline", "compv4"),
|
|
"scheduler": algo.get("scheduler", "intrawave"),
|
|
"epilogue": algo.get("epilogue", "cshuffle"),
|
|
"arch": arch,
|
|
"name": name,
|
|
"wildcard": False,
|
|
}
|
|
)
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Pattern 3: DECL_KERNEL_ALL(dtype, layout) - wildcard
|
|
# -------------------------------------------------------------------------
|
|
all_pattern = r"DECL_KERNEL(?:S)?_ALL\s*\(\s*(\w+)\s*,\s*(\w+)\s*\)"
|
|
for match in re.findall(all_pattern, content):
|
|
dtype, layout = match
|
|
name = f"wildcard_{dtype}_{layout}"
|
|
if name not in seen:
|
|
seen.add(name)
|
|
declarations.append(
|
|
{
|
|
"dtype_a": dtype,
|
|
"dtype_b": dtype,
|
|
"dtype_c": dtype,
|
|
"layout": layout,
|
|
"tile_m": -1,
|
|
"tile_n": -1,
|
|
"tile_k": -1,
|
|
"wave_m": -1,
|
|
"wave_n": -1,
|
|
"wave_k": 1,
|
|
"warp_m": -1,
|
|
"warp_n": -1,
|
|
"warp_k": 16,
|
|
"pipeline": "compv4",
|
|
"scheduler": "intrawave",
|
|
"epilogue": "cshuffle",
|
|
"name": name,
|
|
"wildcard": True,
|
|
}
|
|
)
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Pattern 4: DECL_KERNEL_SIMPLE(dtype, layout, tm, tn, tk)
|
|
# -------------------------------------------------------------------------
|
|
simple_pattern = r"DECL_KERNEL_SIMPLE\s*\(\s*(\w+)\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)"
|
|
for match in re.findall(simple_pattern, content):
|
|
dtype, layout, tm, tn, tk = match
|
|
name = f"{dtype}_{layout}_{tm}x{tn}x{tk}"
|
|
if name not in seen:
|
|
seen.add(name)
|
|
declarations.append(
|
|
{
|
|
"dtype_a": dtype,
|
|
"dtype_b": dtype,
|
|
"dtype_c": dtype,
|
|
"layout": layout,
|
|
"tile_m": int(tm),
|
|
"tile_n": int(tn),
|
|
"tile_k": int(tk),
|
|
"wave_m": -1,
|
|
"wave_n": -1,
|
|
"wave_k": 1,
|
|
"warp_m": -1,
|
|
"warp_n": -1,
|
|
"warp_k": 16,
|
|
"pipeline": "compv4",
|
|
"scheduler": "intrawave",
|
|
"epilogue": "cshuffle",
|
|
"name": name,
|
|
"wildcard": False,
|
|
"set": None,
|
|
}
|
|
)
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Pattern 5: DECL_KERNEL_SET(name, .add(...).add(...))
|
|
# Named kernel sets for multiple registries
|
|
# Match only DECL_KERNEL_SET at start of line (not in comments)
|
|
# -------------------------------------------------------------------------
|
|
set_pattern = r"^DECL_KERNEL_SET\s*\(\s*(\w+)\s*,([\s\S]*?)\)\s*;"
|
|
for match in re.finditer(set_pattern, content, re.MULTILINE):
|
|
set_name = match.group(1)
|
|
set_body = match.group(2)
|
|
|
|
# Parse .add("dtype", "layout", tm, tn, tk) calls - simple form
|
|
add_simple = r'\.add\s*\(\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)'
|
|
for add_match in re.findall(add_simple, set_body):
|
|
dtype, layout, tm, tn, tk = add_match
|
|
name = f"{set_name}:{dtype}_{layout}_{tm}x{tn}x{tk}"
|
|
if name not in seen:
|
|
seen.add(name)
|
|
declarations.append(
|
|
{
|
|
"dtype_a": dtype,
|
|
"dtype_b": dtype,
|
|
"dtype_c": dtype,
|
|
"layout": layout,
|
|
"tile_m": int(tm),
|
|
"tile_n": int(tn),
|
|
"tile_k": int(tk),
|
|
"wave_m": -1,
|
|
"wave_n": -1,
|
|
"wave_k": 1,
|
|
"warp_m": -1,
|
|
"warp_n": -1,
|
|
"warp_k": 16,
|
|
"pipeline": "compv4",
|
|
"scheduler": "intrawave",
|
|
"epilogue": "cshuffle",
|
|
"name": name,
|
|
"wildcard": False,
|
|
"set": set_name,
|
|
}
|
|
)
|
|
|
|
# Parse .add(Signature()..., Algorithm()..., "arch") fluent calls
|
|
# Robust approach: find each .add( block and parse methods individually
|
|
# This handles any method order and optional methods
|
|
|
|
# Split set_body into .add() blocks
|
|
add_blocks = []
|
|
add_starts = [m.start() for m in re.finditer(r"\.add\s*\(", set_body)]
|
|
|
|
for i, start in enumerate(add_starts):
|
|
# Find the matching closing paren by counting parens
|
|
depth = 0
|
|
end = start
|
|
in_string = False
|
|
escape_next = False
|
|
|
|
for j, ch in enumerate(set_body[start:], start):
|
|
if escape_next:
|
|
escape_next = False
|
|
continue
|
|
if ch == "\\":
|
|
escape_next = True
|
|
continue
|
|
if ch == '"' and not escape_next:
|
|
in_string = not in_string
|
|
continue
|
|
if in_string:
|
|
continue
|
|
if ch == "(":
|
|
depth += 1
|
|
elif ch == ")":
|
|
depth -= 1
|
|
if depth == 0:
|
|
end = j + 1
|
|
break
|
|
|
|
if end > start:
|
|
add_blocks.append(set_body[start:end])
|
|
|
|
for add_block in add_blocks:
|
|
# Skip if doesn't have both Signature() and Algorithm()
|
|
if "Signature()" not in add_block or "Algorithm()" not in add_block:
|
|
continue
|
|
|
|
# Split on Algorithm() to separate Signature and Algorithm parts
|
|
algo_idx = add_block.find("Algorithm()")
|
|
if algo_idx == -1:
|
|
continue
|
|
|
|
sig_str = add_block[:algo_idx]
|
|
algo_str = add_block[algo_idx:] # Include Algorithm() and everything after
|
|
|
|
# Parse dtype from Signature - handles .dtype("fp16", "fp16", "fp16", "fp32")
|
|
dtype = "fp16"
|
|
dtype_m = re.search(r'\.dtype\s*\(\s*"([^"]+)"', sig_str)
|
|
if dtype_m:
|
|
dtype = dtype_m.group(1)
|
|
|
|
# Parse layout from Signature - handles .layout("row", "col", "row")
|
|
layout = "rcr"
|
|
layout_m = re.search(
|
|
r'\.layout\s*\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"', sig_str
|
|
)
|
|
if layout_m:
|
|
la, lb, lc = layout_m.group(1), layout_m.group(2), layout_m.group(3)
|
|
layout = (
|
|
("r" if la == "row" else "c")
|
|
+ ("r" if lb == "row" else "c")
|
|
+ ("r" if lc == "row" else "c")
|
|
)
|
|
else:
|
|
# Single arg form: .layout("rcr")
|
|
layout_m = re.search(r'\.layout\s*\(\s*"([^"]+)"', sig_str)
|
|
if layout_m:
|
|
layout = layout_m.group(1)
|
|
|
|
# Parse tile from Algorithm
|
|
tm, tn, tk = 128, 128, 32
|
|
tile_m = re.search(
|
|
r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", algo_str
|
|
)
|
|
if tile_m:
|
|
tm, tn, tk = (
|
|
int(tile_m.group(1)),
|
|
int(tile_m.group(2)),
|
|
int(tile_m.group(3)),
|
|
)
|
|
|
|
# Parse wave
|
|
wave_m, wave_n, wave_k = 2, 2, 1
|
|
wave_match = re.search(
|
|
r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str
|
|
)
|
|
if wave_match:
|
|
wave_m, wave_n = int(wave_match.group(1)), int(wave_match.group(2))
|
|
wave_k = int(wave_match.group(3) or 1)
|
|
|
|
# Parse warp
|
|
warp_m, warp_n, warp_k = 32, 32, 16
|
|
warp_match = re.search(
|
|
r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str
|
|
)
|
|
if warp_match:
|
|
warp_m, warp_n = int(warp_match.group(1)), int(warp_match.group(2))
|
|
warp_k = int(warp_match.group(3) or 16)
|
|
|
|
# Parse pipeline - NEW: extract from declaration
|
|
pipeline = "compv4"
|
|
pipeline_m = re.search(r'\.pipeline\s*\(\s*"([^"]+)"', algo_str)
|
|
if pipeline_m:
|
|
pipeline = pipeline_m.group(1)
|
|
|
|
# Parse scheduler - NEW: extract from declaration
|
|
scheduler = "intrawave"
|
|
scheduler_m = re.search(r'\.scheduler\s*\(\s*"([^"]+)"', algo_str)
|
|
if scheduler_m:
|
|
scheduler = scheduler_m.group(1)
|
|
|
|
# Parse epilogue - NEW: extract from declaration
|
|
epilogue = "cshuffle"
|
|
epilogue_m = re.search(r'\.epilogue\s*\(\s*"([^"]+)"', algo_str)
|
|
if epilogue_m:
|
|
epilogue = epilogue_m.group(1)
|
|
|
|
# Parse padding - NEW: extract from declaration
|
|
pad_m, pad_n, pad_k = False, False, False
|
|
pad_match = re.search(
|
|
r"\.pad\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)\s*\)",
|
|
algo_str,
|
|
re.IGNORECASE,
|
|
)
|
|
if pad_match:
|
|
pad_m = pad_match.group(1).lower() == "true"
|
|
pad_n = pad_match.group(2).lower() == "true"
|
|
pad_k = pad_match.group(3).lower() == "true"
|
|
|
|
# Parse elementwise from Signature - for Multi-D kernels
|
|
elementwise_op = "PassThrough"
|
|
num_d_tensors = 0
|
|
elem_match = re.search(
|
|
r'\.elementwise\s*\(\s*"([^"]+)"\s*,\s*(\d+)\s*\)',
|
|
sig_str,
|
|
)
|
|
if elem_match:
|
|
elementwise_op = elem_match.group(1)
|
|
num_d_tensors = int(elem_match.group(2))
|
|
|
|
name = f"{set_name}:{dtype}_{layout}_{pipeline}_{scheduler}_{tm}x{tn}x{tk}_{wave_m}x{wave_n}x{wave_k}"
|
|
if elementwise_op != "PassThrough":
|
|
name += f"_{elementwise_op}_d{num_d_tensors}"
|
|
if name not in seen:
|
|
seen.add(name)
|
|
declarations.append(
|
|
{
|
|
"dtype_a": dtype,
|
|
"dtype_b": dtype,
|
|
"dtype_c": dtype,
|
|
"layout": layout,
|
|
"tile_m": tm,
|
|
"tile_n": tn,
|
|
"tile_k": tk,
|
|
"wave_m": wave_m,
|
|
"wave_n": wave_n,
|
|
"wave_k": wave_k,
|
|
"warp_m": warp_m,
|
|
"warp_n": warp_n,
|
|
"warp_k": warp_k,
|
|
"pipeline": pipeline,
|
|
"scheduler": scheduler,
|
|
"epilogue": epilogue,
|
|
"pad_m": pad_m,
|
|
"pad_n": pad_n,
|
|
"pad_k": pad_k,
|
|
"elementwise_op": elementwise_op,
|
|
"num_d_tensors": num_d_tensors,
|
|
"name": name,
|
|
"wildcard": False,
|
|
"set": set_name,
|
|
}
|
|
)
|
|
|
|
return declarations
|
|
|
|
|
|
def expand_declaration_with_arch_filter(decl: dict, arch: str = "gfx942") -> list:
|
|
"""Expand a declaration to all valid combinations using arch filter.
|
|
|
|
Expands wildcards for:
|
|
- wave/warp: If -1, generates all valid wave/warp_tile combinations
|
|
- pipeline/scheduler/epilogue: If "*", generates all valid trait combinations
|
|
|
|
Uses the arch_filter module for architecture-specific validation.
|
|
"""
|
|
# Import arch filter
|
|
codegen_dir = get_dispatcher_root() / "codegen"
|
|
sys.path.insert(0, str(codegen_dir))
|
|
|
|
try:
|
|
from arch_specs_generated import (
|
|
WARP_SUPPORTED_COMBINATIONS,
|
|
WARP_TILE_SUPPORTED_COMBINATIONS,
|
|
TRAIT_UNSUPPORTED_COMBINATIONS,
|
|
)
|
|
except ImportError:
|
|
# Fallback to hardcoded valid combinations
|
|
WARP_SUPPORTED_COMBINATIONS = {
|
|
"gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
|
|
"gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
|
|
"gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
|
|
}
|
|
WARP_TILE_SUPPORTED_COMBINATIONS = {
|
|
"gfx942": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]},
|
|
}
|
|
TRAIT_UNSUPPORTED_COMBINATIONS = {
|
|
("compv3", "cshuffle", "interwave"),
|
|
("compv3", "default", "interwave"),
|
|
("compv4", "cshuffle", "interwave"),
|
|
("compv4", "default", "interwave"),
|
|
}
|
|
|
|
d = decl.copy()
|
|
tm = d.get("tile_m", 128)
|
|
tn = d.get("tile_n", 128)
|
|
tk = d.get("tile_k", 32)
|
|
dtype = d.get("dtype_a", "fp16")
|
|
|
|
# Check what needs expansion
|
|
needs_wave_expansion = d.get("wave_m", -1) < 0 or d.get("wave_n", -1) < 0
|
|
needs_warp_expansion = d.get("warp_m", -1) < 0 or d.get("warp_n", -1) < 0
|
|
needs_pipeline_expansion = d.get("pipeline", "compv4") == "*"
|
|
needs_scheduler_expansion = d.get("scheduler", "intrawave") == "*"
|
|
needs_epilogue_expansion = d.get("epilogue", "cshuffle") == "*"
|
|
needs_pad_m_expansion = d.get("pad_m", 1) == -1
|
|
needs_pad_n_expansion = d.get("pad_n", 1) == -1
|
|
needs_pad_k_expansion = d.get("pad_k", 1) == -1
|
|
needs_trait_expansion = (
|
|
needs_pipeline_expansion
|
|
or needs_scheduler_expansion
|
|
or needs_epilogue_expansion
|
|
)
|
|
needs_pad_expansion = (
|
|
needs_pad_m_expansion or needs_pad_n_expansion or needs_pad_k_expansion
|
|
)
|
|
|
|
if (
|
|
not needs_wave_expansion
|
|
and not needs_warp_expansion
|
|
and not needs_trait_expansion
|
|
and not needs_pad_expansion
|
|
):
|
|
# Already fully specified
|
|
return [d]
|
|
|
|
# === Build valid combinations ===
|
|
|
|
# Wave configurations
|
|
if needs_wave_expansion:
|
|
wave_configs = WARP_SUPPORTED_COMBINATIONS.get(arch, [[2, 2, 1]])
|
|
else:
|
|
wave_configs = [[d.get("wave_m", 2), d.get("wave_n", 2), d.get("wave_k", 1)]]
|
|
|
|
# Warp tile configurations
|
|
if needs_warp_expansion:
|
|
arch_warp_tiles = WARP_TILE_SUPPORTED_COMBINATIONS.get(arch, {})
|
|
|
|
# Try to find warp tile configs for this dtype
|
|
# Keys are like: fp16_fp16_fp32, int8_int8_int32, etc.
|
|
warp_tile_configs = None
|
|
dtype_key_variants = [
|
|
f"{dtype}_{dtype}_{dtype}", # e.g., fp32_fp32_fp32
|
|
f"{dtype}_{dtype}_fp32", # e.g., fp16_fp16_fp32
|
|
f"{dtype}_{dtype}_int32", # e.g., int8_int8_int32
|
|
]
|
|
for dtype_key in dtype_key_variants:
|
|
warp_tile_configs = arch_warp_tiles.get(dtype_key, None)
|
|
if warp_tile_configs is not None:
|
|
break
|
|
|
|
# If dtype is not supported on this arch, return empty list
|
|
if warp_tile_configs is None:
|
|
return []
|
|
else:
|
|
warp_tile_configs = [
|
|
[d.get("warp_m", 32), d.get("warp_n", 32), d.get("warp_k", 16)]
|
|
]
|
|
|
|
# Pipeline/scheduler/epilogue combinations
|
|
# Valid options per category
|
|
ALL_PIPELINES = ["compv3", "compv4"] # Most common; add more if needed
|
|
ALL_SCHEDULERS = ["intrawave", "interwave"]
|
|
ALL_EPILOGUES = ["cshuffle", "default"]
|
|
ALL_PAD_OPTIONS = [False, True] # 0 and 1
|
|
|
|
pipelines = (
|
|
ALL_PIPELINES if needs_pipeline_expansion else [d.get("pipeline", "compv4")]
|
|
)
|
|
schedulers = (
|
|
ALL_SCHEDULERS
|
|
if needs_scheduler_expansion
|
|
else [d.get("scheduler", "intrawave")]
|
|
)
|
|
epilogues = (
|
|
ALL_EPILOGUES if needs_epilogue_expansion else [d.get("epilogue", "cshuffle")]
|
|
)
|
|
pad_m_opts = ALL_PAD_OPTIONS if needs_pad_m_expansion else [bool(d.get("pad_m", 1))]
|
|
pad_n_opts = ALL_PAD_OPTIONS if needs_pad_n_expansion else [bool(d.get("pad_n", 1))]
|
|
pad_k_opts = ALL_PAD_OPTIONS if needs_pad_k_expansion else [bool(d.get("pad_k", 1))]
|
|
|
|
expanded = []
|
|
|
|
# Generate all valid combinations
|
|
for wm, wn, wk in wave_configs:
|
|
for wtm, wtn, wtk in warp_tile_configs:
|
|
# Check divisibility constraints
|
|
if tm % (wm * wtm) != 0:
|
|
continue
|
|
if tn % (wn * wtn) != 0:
|
|
continue
|
|
if tk % (wk * wtk) != 0:
|
|
continue
|
|
|
|
for pipeline in pipelines:
|
|
for scheduler in schedulers:
|
|
for epilogue in epilogues:
|
|
# Check trait combination is valid
|
|
if (
|
|
pipeline,
|
|
epilogue,
|
|
scheduler,
|
|
) in TRAIT_UNSUPPORTED_COMBINATIONS:
|
|
continue
|
|
|
|
for pad_m in pad_m_opts:
|
|
for pad_n in pad_n_opts:
|
|
for pad_k in pad_k_opts:
|
|
# Create expanded declaration
|
|
expanded_d = d.copy()
|
|
expanded_d["wave_m"] = wm
|
|
expanded_d["wave_n"] = wn
|
|
expanded_d["wave_k"] = wk
|
|
expanded_d["warp_m"] = wtm
|
|
expanded_d["warp_n"] = wtn
|
|
expanded_d["warp_k"] = wtk
|
|
expanded_d["pipeline"] = pipeline
|
|
expanded_d["scheduler"] = scheduler
|
|
expanded_d["epilogue"] = epilogue
|
|
expanded_d["pad_m"] = int(pad_m)
|
|
expanded_d["pad_n"] = int(pad_n)
|
|
expanded_d["pad_k"] = int(pad_k)
|
|
|
|
pad_str = f"{'T' if pad_m else 'F'}{'T' if pad_n else 'F'}{'T' if pad_k else 'F'}"
|
|
expanded_d["name"] = (
|
|
f"{dtype}_{d.get('layout', 'rcr')}_{pipeline}_{scheduler}_"
|
|
f"pad{pad_str}_{tm}x{tn}x{tk}_{wm}x{wn}x{wk}"
|
|
)
|
|
expanded_d["wildcard"] = False
|
|
expanded.append(expanded_d)
|
|
|
|
if not expanded:
|
|
# No valid combinations found, return single default
|
|
print(f" Warning: No valid combinations for {tm}x{tn}x{tk} on {arch}")
|
|
d["wave_m"] = 2
|
|
d["wave_n"] = 2
|
|
d["wave_k"] = 1
|
|
d["warp_m"] = 32
|
|
d["warp_n"] = 32
|
|
d["warp_k"] = 16
|
|
d["pipeline"] = "compv4"
|
|
d["scheduler"] = "intrawave"
|
|
d["epilogue"] = "cshuffle"
|
|
return [d]
|
|
|
|
return expanded
|
|
|
|
|
|
def auto_fill_declaration(decl: dict) -> dict:
|
|
"""Auto-fill with single default (for backward compat)."""
|
|
expanded = expand_declaration_with_arch_filter(decl, decl.get("arch", "gfx942"))
|
|
return expanded[0] if expanded else decl
|
|
|
|
|
|
# =============================================================================
|
|
# Build Functions
|
|
# =============================================================================
|
|
|
|
|
|
def generate_kernels(declarations: list, gpu_target: str = "gfx942") -> int:
|
|
"""Generate kernels using CodegenRunner from ctypes_utils."""
|
|
kernel_dir = get_generated_kernels_dir()
|
|
kernel_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Group by dtype+layout for efficient generation
|
|
groups = {}
|
|
for decl in declarations:
|
|
dtype = decl.get("dtype_a", decl.get("dtype", "fp16"))
|
|
layout = decl.get("layout", "rcr")
|
|
key = (dtype, layout)
|
|
if key not in groups:
|
|
groups[key] = []
|
|
groups[key].append(auto_fill_declaration(decl))
|
|
|
|
total_generated = 0
|
|
|
|
for (dtype, layout), decls in groups.items():
|
|
print(f" Generating {dtype} {layout} kernels...")
|
|
|
|
# Check for wildcards - if any decl is wildcard, generate all
|
|
has_wildcard = any(d.get("wildcard", False) for d in decls)
|
|
|
|
# Use CodegenRunner from ctypes_utils
|
|
runner = CodegenRunner(
|
|
datatype=dtype,
|
|
layout=layout,
|
|
gpu_target=gpu_target,
|
|
)
|
|
|
|
result = runner.generate("standard")
|
|
|
|
if result.success:
|
|
total_generated += result.kernel_count
|
|
if has_wildcard:
|
|
print(f" [wildcard] Generated all {result.kernel_count} variants")
|
|
else:
|
|
print_error(f" Failed: {result.stderr[:200]}")
|
|
|
|
return total_generated
|
|
|
|
|
|
def get_arch_filter_data():
|
|
"""Load arch filter data from arch_specs_generated if available."""
|
|
codegen_dir = get_dispatcher_root() / "codegen"
|
|
sys.path.insert(0, str(codegen_dir))
|
|
|
|
try:
|
|
from arch_specs_generated import (
|
|
TRAIT_UNSUPPORTED_COMBINATIONS,
|
|
WARP_SUPPORTED_COMBINATIONS,
|
|
WARP_TILE_SUPPORTED_COMBINATIONS,
|
|
get_supported_archs,
|
|
)
|
|
|
|
return {
|
|
"trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS,
|
|
"warp_combos": WARP_SUPPORTED_COMBINATIONS,
|
|
"warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS,
|
|
"supported_archs": get_supported_archs(),
|
|
}
|
|
except ImportError:
|
|
# Fallback defaults
|
|
return {
|
|
"trait_unsupported": {
|
|
("compv3", "cshuffle", "interwave"),
|
|
("compv3", "default", "interwave"),
|
|
("compv4", "cshuffle", "interwave"),
|
|
("compv4", "default", "interwave"),
|
|
},
|
|
"warp_combos": {
|
|
"gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
|
|
},
|
|
"warp_tile_combos": {
|
|
"gfx942": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]},
|
|
},
|
|
"supported_archs": ["gfx90a", "gfx942", "gfx950"],
|
|
}
|
|
|
|
|
|
def is_wildcard_declaration(decl: dict) -> bool:
|
|
"""Check if declaration has wildcards that need expansion."""
|
|
# Wave/warp wildcards
|
|
if decl.get("wave_m", 2) < 0 or decl.get("wave_n", 2) < 0:
|
|
return True
|
|
if decl.get("warp_m", 32) < 0 or decl.get("warp_n", 32) < 0:
|
|
return True
|
|
# Pipeline/scheduler wildcards
|
|
if decl.get("pipeline", "compv4") == "*":
|
|
return True
|
|
if decl.get("scheduler", "intrawave") == "*":
|
|
return True
|
|
if decl.get("epilogue", "cshuffle") == "*":
|
|
return True
|
|
return False
|
|
|
|
|
|
def validate_kernel_config(decl: dict, arch: str = "gfx942") -> tuple:
|
|
"""Validate a kernel configuration against known supported combinations.
|
|
|
|
Uses arch_specs_generated for architecture-specific validation.
|
|
|
|
For wildcard declarations (-1 values or "*" strings), validation is skipped
|
|
because the expansion phase will generate only valid combinations.
|
|
|
|
Returns: (is_valid, error_message)
|
|
"""
|
|
# Skip validation for wildcards - expansion will filter invalid combos
|
|
if is_wildcard_declaration(decl):
|
|
return (True, None)
|
|
|
|
arch_data = get_arch_filter_data()
|
|
|
|
pipeline = decl.get("pipeline", "compv4")
|
|
epilogue = decl.get("epilogue", "cshuffle")
|
|
scheduler = decl.get("scheduler", "intrawave")
|
|
dtype = decl.get("dtype_a", "fp16")
|
|
|
|
wave_m = decl.get("wave_m", 2)
|
|
wave_n = decl.get("wave_n", 2)
|
|
wave_k = decl.get("wave_k", 1)
|
|
|
|
warp_m = decl.get("warp_m", 32)
|
|
warp_n = decl.get("warp_n", 32)
|
|
warp_k = decl.get("warp_k", 16)
|
|
|
|
errors = []
|
|
|
|
# Check trait combination (pipeline, epilogue, scheduler)
|
|
combo = (pipeline, epilogue, scheduler)
|
|
if combo in arch_data["trait_unsupported"]:
|
|
errors.append(
|
|
f"Unsupported trait combination: pipeline={pipeline}, epilogue={epilogue}, scheduler={scheduler}\n"
|
|
f" Valid schedulers for {pipeline}+{epilogue}: intrawave"
|
|
)
|
|
|
|
# Check wave configuration for this arch
|
|
warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]])
|
|
wave_cfg = [wave_m, wave_n, wave_k]
|
|
if wave_cfg not in warp_combos:
|
|
valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos)
|
|
errors.append(
|
|
f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}\n"
|
|
f" Valid wave configs: {valid_str}"
|
|
)
|
|
|
|
# Check warp tile configuration for this arch and dtype
|
|
dtype_key = f"{dtype}_{dtype}_{dtype}"
|
|
warp_tile_combos = (
|
|
arch_data["warp_tile_combos"]
|
|
.get(arch, {})
|
|
.get(dtype_key, [[32, 32, 16], [16, 16, 16]])
|
|
)
|
|
warp_cfg = [warp_m, warp_n, warp_k]
|
|
if warp_cfg not in warp_tile_combos:
|
|
valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5])
|
|
errors.append(
|
|
f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}\n"
|
|
f" Valid warp tiles: {valid_str}"
|
|
)
|
|
|
|
# Check arch is supported
|
|
if arch not in arch_data["supported_archs"]:
|
|
errors.append(
|
|
f"Unsupported architecture: {arch}\n"
|
|
f" Supported: {', '.join(arch_data['supported_archs'])}"
|
|
)
|
|
|
|
if errors:
|
|
return (False, "\n".join(errors))
|
|
|
|
return (True, None)
|
|
|
|
|
|
def build_exact_kernel_filename(decl: dict) -> str:
|
|
"""Build the exact kernel filename from a fully-specified declaration.
|
|
|
|
Standard format:
|
|
gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{pad_m}_{pad_n}_{pad_k}_{preshuffle}_{tile}_{wave}_{warp}.hpp
|
|
|
|
Multi-D format:
|
|
gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{pad_m}_{pad_n}_{pad_k}_{preshuffle}_{tile}_{wave}_{warp}_multid_{op}_d{num}.hpp
|
|
"""
|
|
dtype = decl.get("dtype_a", decl.get("dtype", "fp16"))
|
|
layout = decl.get("layout", "rcr")
|
|
pipeline = decl.get("pipeline", "compv4")
|
|
epilogue = decl.get("epilogue", "cshuffle")
|
|
scheduler = decl.get("scheduler", "intrawave")
|
|
|
|
pad_m = "True" if decl.get("pad_m", False) else "False"
|
|
pad_n = "True" if decl.get("pad_n", False) else "False"
|
|
pad_k = "True" if decl.get("pad_k", False) else "False"
|
|
preshuffle = "True" if decl.get("preshuffle", False) else "False"
|
|
|
|
tile_m = decl.get("tile_m", 128)
|
|
tile_n = decl.get("tile_n", 128)
|
|
tile_k = decl.get("tile_k", 32)
|
|
|
|
wave_m = decl.get("wave_m", 2)
|
|
wave_n = decl.get("wave_n", 2)
|
|
wave_k = decl.get("wave_k", 1)
|
|
|
|
warp_m = decl.get("warp_m", 32)
|
|
warp_n = decl.get("warp_n", 32)
|
|
warp_k = decl.get("warp_k", 16)
|
|
|
|
tile_str = f"{tile_m}x{tile_n}x{tile_k}"
|
|
wave_str = f"{wave_m}x{wave_n}x{wave_k}"
|
|
warp_str = f"{warp_m}x{warp_n}x{warp_k}"
|
|
|
|
base = f"gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{pad_m}_{pad_n}_{pad_k}_{preshuffle}_{tile_str}_{wave_str}_{warp_str}"
|
|
|
|
# Handle Multi-D kernels
|
|
elementwise_op = decl.get("elementwise_op", "PassThrough")
|
|
num_d_tensors = decl.get("num_d_tensors", 0)
|
|
if elementwise_op != "PassThrough" and num_d_tensors > 0:
|
|
base += f"_multid_{elementwise_op}_d{num_d_tensors}"
|
|
|
|
return f"{base}.hpp"
|
|
|
|
|
|
def generate_specific_kernel(decl: dict, gpu_target: str = "gfx942") -> bool:
|
|
"""Generate a specific kernel based on declaration."""
|
|
dtype = decl.get("dtype_a", decl.get("dtype", "fp16"))
|
|
layout = decl.get("layout", "rcr")
|
|
|
|
print(f" Generating kernel for {dtype}/{layout}...")
|
|
|
|
# Use CodegenRunner to generate
|
|
runner = CodegenRunner(
|
|
datatype=dtype,
|
|
layout=layout,
|
|
gpu_target=gpu_target,
|
|
)
|
|
|
|
result = runner.generate("standard")
|
|
return result.success
|
|
|
|
|
|
def find_kernel_header(decl: dict, gpu_target: str = "gfx942") -> Path:
|
|
"""Find a matching kernel header file for a declaration.
|
|
|
|
Tries multiple matching strategies:
|
|
1. Exact filename match
|
|
2. Match with key parameters (dtype, layout, pipeline, scheduler, tile)
|
|
3. Match with just dtype, layout, and tile (more flexible)
|
|
4. Any kernel with matching dtype and layout
|
|
|
|
If no kernel exists, attempts to generate it.
|
|
Returns None only if all strategies fail.
|
|
"""
|
|
kernel_dir = get_generated_kernels_dir()
|
|
|
|
dtype = decl.get("dtype_a", decl.get("dtype", "fp16"))
|
|
layout = decl.get("layout", "rcr")
|
|
pipeline = decl.get("pipeline", "compv4")
|
|
scheduler = decl.get("scheduler", "intrawave")
|
|
tile_m = decl.get("tile_m", 128)
|
|
tile_n = decl.get("tile_n", 128)
|
|
tile_k = decl.get("tile_k", 32)
|
|
wave_m = decl.get("wave_m", 2)
|
|
wave_n = decl.get("wave_n", 2)
|
|
wave_k = decl.get("wave_k", 1)
|
|
|
|
tile_str = f"{tile_m}x{tile_n}x{tile_k}"
|
|
wave_str = f"{wave_m}x{wave_n}x{wave_k}"
|
|
|
|
# Build exact filename
|
|
exact_filename = build_exact_kernel_filename(decl)
|
|
exact_path = kernel_dir / exact_filename
|
|
|
|
# Strategy 1: Exact filename match
|
|
if exact_path.exists():
|
|
print(f" Found exact kernel: {exact_filename}")
|
|
return exact_path
|
|
|
|
# Strategy 2: Match with key parameters
|
|
pattern = (
|
|
f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_{wave_str}_*.hpp"
|
|
)
|
|
matches = list(kernel_dir.glob(pattern))
|
|
if matches:
|
|
print(f" Found matching kernel: {matches[0].name}")
|
|
return matches[0]
|
|
|
|
# Strategy 3: Match with just dtype, layout, tile (ignore wave/warp)
|
|
pattern = f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_*.hpp"
|
|
matches = list(kernel_dir.glob(pattern))
|
|
if matches:
|
|
print(f" Found kernel with matching tile: {matches[0].name}")
|
|
return matches[0]
|
|
|
|
# Strategy 4: Match with just dtype, layout (most flexible, for wildcards)
|
|
# Prefer kernels with intrawave scheduler (known to work)
|
|
pattern = f"gemm_{dtype}_{layout}_*_intrawave_*_{tile_str}_*.hpp"
|
|
matches = list(kernel_dir.glob(pattern))
|
|
if matches:
|
|
print(f" Found kernel with intrawave: {matches[0].name}")
|
|
return matches[0]
|
|
|
|
# Strategy 5: Any kernel with matching dtype and layout
|
|
pattern = f"gemm_{dtype}_{layout}_*_{tile_str}_*.hpp"
|
|
matches = list(kernel_dir.glob(pattern))
|
|
if matches:
|
|
print(f" Found kernel with matching dtype/layout/tile: {matches[0].name}")
|
|
return matches[0]
|
|
|
|
# Strategy 6: Try to generate the kernel
|
|
print(" No matching kernel found, attempting to generate...")
|
|
if generate_specific_kernel(decl, gpu_target):
|
|
# Check strategies again after generation
|
|
for pattern in [
|
|
f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_*.hpp",
|
|
f"gemm_{dtype}_{layout}_*_intrawave_*_{tile_str}_*.hpp",
|
|
f"gemm_{dtype}_{layout}_*_{tile_str}_*.hpp",
|
|
]:
|
|
matches = list(kernel_dir.glob(pattern))
|
|
if matches:
|
|
print(f" Generated: {matches[0].name}")
|
|
return matches[0]
|
|
|
|
# All strategies failed - return None (caller will try next expanded decl)
|
|
return None
|
|
|
|
|
|
def is_conv_wildcard_declaration(decl: dict) -> bool:
|
|
"""Check if conv declaration has wildcards that need expansion."""
|
|
if decl.get("wave_m", 2) < 0 or decl.get("wave_n", 2) < 0:
|
|
return True
|
|
if decl.get("warp_m", 32) < 0 or decl.get("warp_n", 32) < 0:
|
|
return True
|
|
if decl.get("pipeline", "compv3") == "*":
|
|
return True
|
|
if decl.get("scheduler", "intrawave") == "*":
|
|
return True
|
|
return False
|
|
|
|
|
|
def validate_conv_kernel_config(decl: dict, arch: str = "gfx942") -> tuple:
|
|
"""Validate a conv kernel configuration against arch filter.
|
|
|
|
For wildcard declarations, validation is skipped (expansion handles it).
|
|
|
|
Returns: (is_valid, error_message)
|
|
"""
|
|
# Skip validation for wildcards
|
|
if is_conv_wildcard_declaration(decl):
|
|
return (True, None)
|
|
|
|
arch_data = get_arch_filter_data()
|
|
|
|
pipeline = decl.get("pipeline", "compv3")
|
|
epilogue = decl.get("epilogue", "cshuffle")
|
|
scheduler = decl.get("scheduler", "intrawave")
|
|
dtype = decl.get("dtype", "fp16")
|
|
|
|
wave_m = decl.get("wave_m", 2)
|
|
wave_n = decl.get("wave_n", 2)
|
|
wave_k = decl.get("wave_k", 1)
|
|
|
|
warp_m = decl.get("warp_m", 32)
|
|
warp_n = decl.get("warp_n", 32)
|
|
warp_k = decl.get("warp_k", 16)
|
|
|
|
errors = []
|
|
|
|
# Check trait combination
|
|
combo = (pipeline, epilogue, scheduler)
|
|
if combo in arch_data["trait_unsupported"]:
|
|
errors.append(
|
|
f"Unsupported trait combination: pipeline={pipeline}, epilogue={epilogue}, scheduler={scheduler}\n"
|
|
f" Valid schedulers for {pipeline}+{epilogue}: intrawave"
|
|
)
|
|
|
|
# Check wave configuration
|
|
warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]])
|
|
wave_cfg = [wave_m, wave_n, wave_k]
|
|
if wave_cfg not in warp_combos:
|
|
valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos)
|
|
errors.append(
|
|
f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}\n"
|
|
f" Valid wave configs: {valid_str}"
|
|
)
|
|
|
|
# Check warp tile configuration
|
|
dtype_key = f"{dtype}_{dtype}_{dtype}"
|
|
warp_tile_combos = (
|
|
arch_data["warp_tile_combos"]
|
|
.get(arch, {})
|
|
.get(dtype_key, [[32, 32, 16], [16, 16, 16]])
|
|
)
|
|
warp_cfg = [warp_m, warp_n, warp_k]
|
|
if warp_cfg not in warp_tile_combos:
|
|
valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5])
|
|
errors.append(
|
|
f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}\n"
|
|
f" Valid warp tiles: {valid_str}"
|
|
)
|
|
|
|
# Check arch is supported
|
|
if arch not in arch_data["supported_archs"]:
|
|
errors.append(
|
|
f"Unsupported architecture: {arch}\n"
|
|
f" Supported: {', '.join(arch_data['supported_archs'])}"
|
|
)
|
|
|
|
if errors:
|
|
return (False, "\n".join(errors))
|
|
|
|
return (True, None)
|
|
|
|
|
|
def build_exact_conv_kernel_filename(decl: dict) -> str:
|
|
"""Build the exact conv kernel filename from a fully-specified declaration.
|
|
|
|
Conv filename format:
|
|
conv_{type}_{dtype}_{ndim}d_{pipeline}_{epilogue}_{scheduler}_{tile}_{wave}.hpp
|
|
|
|
Example:
|
|
conv_fwd_fp16_2d_compv3_cshuffle_intrawave_128x128x32_2x2x1.hpp
|
|
"""
|
|
dtype = decl.get("dtype", "fp16")
|
|
conv_type = decl.get("conv_type", "forward")
|
|
num_dims = decl.get("num_dims", 2)
|
|
pipeline = decl.get("pipeline", "compv3")
|
|
epilogue = decl.get("epilogue", "cshuffle")
|
|
scheduler = decl.get("scheduler", "intrawave")
|
|
|
|
# Map conv_type to filename prefix
|
|
if conv_type == "forward":
|
|
type_prefix = "fwd"
|
|
elif conv_type == "bwd_data":
|
|
type_prefix = "bwd_data"
|
|
elif conv_type == "bwd_weight":
|
|
type_prefix = "bwd_weight"
|
|
else:
|
|
type_prefix = conv_type
|
|
|
|
tile_k = decl.get("tile_k", 128)
|
|
tile_c = decl.get("tile_c", 128)
|
|
|
|
wave_m = decl.get("wave_m", 2)
|
|
wave_n = decl.get("wave_n", 2)
|
|
wave_k = decl.get("wave_k", 1)
|
|
|
|
tile_str = f"{tile_k}x{tile_c}x32" # Conv uses tile_k x tile_c x 32 format
|
|
wave_str = f"{wave_m}x{wave_n}x{wave_k}"
|
|
|
|
return f"conv_{type_prefix}_{dtype}_{num_dims}d_{pipeline}_{epilogue}_{scheduler}_{tile_str}_{wave_str}.hpp"
|
|
|
|
|
|
def generate_specific_conv_kernel(decl: dict, gpu_target: str = "gfx942") -> bool:
|
|
"""Generate a specific conv kernel based on declaration."""
|
|
dtype = decl.get("dtype", "fp16")
|
|
conv_type = decl.get("conv_type", "forward")
|
|
num_dims = decl.get("num_dims", 2)
|
|
|
|
print(f" Generating conv kernel for {dtype}/{conv_type}/{num_dims}d...")
|
|
|
|
# Map to variant name
|
|
if conv_type == "forward":
|
|
variant = "forward"
|
|
elif conv_type == "bwd_data":
|
|
variant = "bwd_data"
|
|
elif conv_type == "bwd_weight":
|
|
variant = "bwd_weight"
|
|
else:
|
|
variant = "forward"
|
|
|
|
# Use unified_grouped_conv_codegen
|
|
codegen_dir = get_dispatcher_root() / "codegen"
|
|
codegen_script = codegen_dir / "unified_grouped_conv_codegen.py"
|
|
output_dir = get_generated_kernels_dir()
|
|
|
|
cmd = [
|
|
"python3",
|
|
str(codegen_script),
|
|
"--datatype",
|
|
dtype,
|
|
"--variant",
|
|
variant,
|
|
"--ndim",
|
|
str(num_dims),
|
|
"--arch",
|
|
gpu_target,
|
|
"--output",
|
|
str(output_dir),
|
|
]
|
|
|
|
try:
|
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
|
return result.returncode == 0
|
|
except subprocess.TimeoutExpired:
|
|
return False
|
|
|
|
|
|
def find_conv_kernel_header(decl: dict, gpu_target: str = "gfx942") -> Path:
|
|
"""Find the EXACT matching conv kernel header file for a declaration.
|
|
|
|
If the kernel doesn't exist, attempts to generate it.
|
|
Returns None only if generation also fails.
|
|
"""
|
|
kernel_dir = get_generated_kernels_dir()
|
|
|
|
# Build exact filename
|
|
exact_filename = build_exact_conv_kernel_filename(decl)
|
|
exact_path = kernel_dir / exact_filename
|
|
|
|
# Check if exact kernel exists
|
|
if exact_path.exists():
|
|
print(f" Found exact conv kernel: {exact_filename}")
|
|
return exact_path
|
|
|
|
# Try to find with glob (in case of minor variations)
|
|
dtype = decl.get("dtype", "fp16")
|
|
conv_type = decl.get("conv_type", "forward")
|
|
num_dims = decl.get("num_dims", 2)
|
|
pipeline = decl.get("pipeline", "compv3")
|
|
scheduler = decl.get("scheduler", "intrawave")
|
|
tile_k = decl.get("tile_k", 128)
|
|
tile_c = decl.get("tile_c", 128)
|
|
wave_m = decl.get("wave_m", 2)
|
|
wave_n = decl.get("wave_n", 2)
|
|
wave_k = decl.get("wave_k", 1)
|
|
|
|
# Map conv_type to prefix
|
|
if conv_type == "forward":
|
|
type_prefix = "fwd"
|
|
elif conv_type == "bwd_data":
|
|
type_prefix = "bwd_data"
|
|
elif conv_type == "bwd_weight":
|
|
type_prefix = "bwd_weight"
|
|
else:
|
|
type_prefix = conv_type
|
|
|
|
tile_str = f"{tile_k}x{tile_c}"
|
|
wave_str = f"{wave_m}x{wave_n}x{wave_k}"
|
|
|
|
# Search pattern with key parameters
|
|
pattern = f"conv_{type_prefix}_{dtype}_{num_dims}d_{pipeline}_*_{scheduler}_*{tile_str}*_{wave_str}.hpp"
|
|
matches = list(kernel_dir.glob(pattern))
|
|
|
|
if matches:
|
|
print(f" Found matching conv kernel: {matches[0].name}")
|
|
return matches[0]
|
|
|
|
# Kernel doesn't exist - try to generate it
|
|
print(f" Conv kernel not found: {exact_filename}")
|
|
print(" Attempting to generate...")
|
|
|
|
if generate_specific_conv_kernel(decl, gpu_target):
|
|
# Check again after generation
|
|
matches = list(kernel_dir.glob(pattern))
|
|
if matches:
|
|
print(f" Generated: {matches[0].name}")
|
|
return matches[0]
|
|
|
|
# Check for exact match
|
|
if exact_path.exists():
|
|
print(f" Generated: {exact_filename}")
|
|
return exact_path
|
|
|
|
# Still not found - print helpful error
|
|
print_error(
|
|
" ERROR: Could not find or generate conv kernel matching declaration:"
|
|
)
|
|
print_error(f" dtype={dtype}, conv_type={conv_type}, num_dims={num_dims}")
|
|
print_error(f" pipeline={pipeline}, scheduler={scheduler}")
|
|
print_error(f" tile={tile_k}x{tile_c}, wave={wave_str}")
|
|
print_error(f" Expected: {exact_filename}")
|
|
print_error(f" Available conv kernels in {kernel_dir}:")
|
|
|
|
available = list(kernel_dir.glob(f"conv_{type_prefix}_{dtype}_{num_dims}d_*.hpp"))[
|
|
:5
|
|
]
|
|
for k in available:
|
|
print_error(f" - {k.name}")
|
|
if len(list(kernel_dir.glob(f"conv_{type_prefix}_{dtype}_{num_dims}d_*.hpp"))) > 5:
|
|
print_error(" ... and more")
|
|
|
|
return None
|
|
|
|
|
|
def build_dispatcher_library(hipcc: str) -> bool:
|
|
"""Build the dispatcher library if needed."""
|
|
build_dir = get_build_dir()
|
|
lib_path = build_dir / "libck_tile_dispatcher.a"
|
|
|
|
if lib_path.exists():
|
|
return True
|
|
|
|
print(" Building dispatcher library...")
|
|
build_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
dispatcher_dir = get_dispatcher_root()
|
|
|
|
# Run cmake
|
|
cmake_cmd = ["cmake", str(dispatcher_dir), f"-DCMAKE_CXX_COMPILER={hipcc}"]
|
|
result = subprocess.run(
|
|
cmake_cmd, cwd=str(build_dir), capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
print_error(f"CMake failed: {result.stderr}")
|
|
return False
|
|
|
|
# Run make
|
|
make_cmd = ["make", "ck_tile_dispatcher", f"-j{os.cpu_count() or 4}"]
|
|
result = subprocess.run(
|
|
make_cmd, cwd=str(build_dir), capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
print_error(f"Make failed: {result.stderr}")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def compile_application(
|
|
source_file: Path,
|
|
output_bin: Path,
|
|
kernel_header: Path,
|
|
hipcc: str,
|
|
gpu_target: str = "gfx942",
|
|
) -> bool:
|
|
"""Compile the application with hipcc."""
|
|
ck_root = get_ck_root()
|
|
dispatcher_dir = get_dispatcher_root()
|
|
build_dir = get_build_dir()
|
|
kernel_dir = get_generated_kernels_dir()
|
|
|
|
includes = [
|
|
f"-I{ck_root / 'include'}",
|
|
f"-I{dispatcher_dir / 'include'}",
|
|
f"-I{kernel_dir}",
|
|
]
|
|
|
|
cmd = [
|
|
hipcc,
|
|
"-std=c++17",
|
|
"-O3",
|
|
f"--offload-arch={gpu_target}",
|
|
*includes,
|
|
"-include",
|
|
str(kernel_header),
|
|
f"-L{build_dir}",
|
|
"-lck_tile_dispatcher",
|
|
"-o",
|
|
str(output_bin),
|
|
str(source_file),
|
|
]
|
|
|
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
# Filter out nodiscard warnings
|
|
if result.stderr:
|
|
lines = result.stderr.split("\n")
|
|
errors = [line for line in lines if "error:" in line.lower()]
|
|
if errors:
|
|
for err_line in errors[:5]:
|
|
print_error(f" {err_line}")
|
|
|
|
return result.returncode == 0
|
|
|
|
|
|
# =============================================================================
|
|
# Main
|
|
# =============================================================================
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Build CK Tile application with declarative kernels",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Example:
|
|
python3 compile_gemm_examples.py examples/cpp/01_basic_gemm_declarative.cpp my_app
|
|
|
|
In your C++ code, declare kernels like:
|
|
DECL_KERNEL_SET(my_kernels,
|
|
.add(Signature().dtype("fp16").layout("rcr"),
|
|
Algorithm().tile(128, 128, 32).wave(2, 2, 1).warp(32, 32, 16)
|
|
.pipeline("compv4").scheduler("intrawave"))
|
|
);
|
|
""",
|
|
)
|
|
parser.add_argument("source", help="Source file (.cpp)")
|
|
parser.add_argument(
|
|
"output", nargs="?", help="Output name (default: source basename)"
|
|
)
|
|
parser.add_argument(
|
|
"--gpu-target", default="gfx942", help="GPU target architecture"
|
|
)
|
|
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
|
|
args = parser.parse_args()
|
|
|
|
# Resolve paths using utilities from ctypes_utils
|
|
dispatcher_dir = get_dispatcher_root()
|
|
build_dir = get_build_dir()
|
|
|
|
source_file = Path(args.source)
|
|
if not source_file.is_absolute():
|
|
# Try relative to dispatcher dir first, then CWD
|
|
candidates = [
|
|
dispatcher_dir / args.source,
|
|
dispatcher_dir / "examples" / args.source, # examples/gemm/cpp/...
|
|
Path.cwd() / args.source,
|
|
]
|
|
for candidate in candidates:
|
|
if candidate.exists():
|
|
source_file = candidate
|
|
break
|
|
|
|
if not source_file.exists():
|
|
print_error(f"Source file not found: {source_file}")
|
|
return 1
|
|
|
|
output_name = args.output or source_file.stem
|
|
output_bin = build_dir / output_name
|
|
|
|
# Ensure build directory exists
|
|
build_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
print_success("=== CK Tile Declarative Kernel Build ===")
|
|
print()
|
|
|
|
# Phase 1: Extract declarations (both GEMM and Conv)
|
|
print_phase("Phase 1: Scanning for kernel declarations...")
|
|
|
|
gemm_declarations = extract_kernel_declarations(source_file)
|
|
conv_declarations = extract_conv_kernel_declarations(source_file)
|
|
|
|
if not gemm_declarations and not conv_declarations:
|
|
print_error(" No kernel declarations found!")
|
|
print(
|
|
" Add DECL_KERNEL_SET for GEMM or DECL_GROUPED_CONV_KERNEL_SET for Grouped Conv"
|
|
)
|
|
return 1
|
|
|
|
# Handle GEMM declarations
|
|
if gemm_declarations:
|
|
print(f"\n GEMM: Found {len(gemm_declarations)} declaration(s)")
|
|
|
|
# Group by kernel set
|
|
sets = {}
|
|
for decl in gemm_declarations:
|
|
set_name = decl.get("set") or "(global)"
|
|
if set_name not in sets:
|
|
sets[set_name] = []
|
|
sets[set_name].append(decl)
|
|
|
|
for set_name, set_decls in sets.items():
|
|
print(f" [{set_name}] ({len(set_decls)} kernels):")
|
|
for decl in set_decls[:5]:
|
|
needs_expansion = (
|
|
decl.get("wave_m", -1) < 0 or decl.get("warp_m", -1) < 0
|
|
)
|
|
suffix = " [expands]" if needs_expansion else ""
|
|
display_name = (
|
|
decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"]
|
|
)
|
|
print(f" - {display_name}{suffix}")
|
|
if len(set_decls) > 5:
|
|
print(f" ... and {len(set_decls) - 5} more")
|
|
|
|
# Validate declarations against arch filter
|
|
print(f"\n Validating against {args.gpu_target} arch filter...")
|
|
wildcard_count = 0
|
|
invalid_count = 0
|
|
auto_corrections = []
|
|
|
|
for decl in gemm_declarations:
|
|
arch = decl.get("arch", args.gpu_target)
|
|
decl_name = (
|
|
decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"]
|
|
)
|
|
|
|
# Check for wildcards
|
|
if is_wildcard_declaration(decl):
|
|
wildcard_count += 1
|
|
continue # Wildcards validated during expansion
|
|
|
|
is_valid, error_msg = validate_kernel_config(decl, arch)
|
|
if not is_valid:
|
|
print(f"\n WARNING Invalid configuration: {decl_name}")
|
|
|
|
# Parse the error and show specific auto-corrections
|
|
corrections = []
|
|
original_values = {}
|
|
|
|
if "wave configuration" in error_msg.lower():
|
|
original_values["wave"] = (
|
|
f"[{decl.get('wave_m', 2)}, {decl.get('wave_n', 2)}, {decl.get('wave_k', 1)}]"
|
|
)
|
|
decl["wave_m"] = -1
|
|
decl["wave_n"] = -1
|
|
corrections.append(
|
|
f"wave: {original_values['wave']} -> [wildcard expansion]"
|
|
)
|
|
|
|
if "warp tile" in error_msg.lower():
|
|
original_values["warp"] = (
|
|
f"[{decl.get('warp_m', 32)}, {decl.get('warp_n', 32)}, {decl.get('warp_k', 16)}]"
|
|
)
|
|
decl["warp_m"] = -1
|
|
decl["warp_n"] = -1
|
|
corrections.append(
|
|
f"warp_tile: {original_values['warp']} -> [wildcard expansion]"
|
|
)
|
|
|
|
if "trait combination" in error_msg.lower():
|
|
original_values["pipeline"] = decl.get("pipeline", "compv4")
|
|
original_values["scheduler"] = decl.get("scheduler", "intrawave")
|
|
decl["pipeline"] = "*"
|
|
decl["scheduler"] = "*"
|
|
corrections.append(
|
|
f"pipeline: {original_values['pipeline']} -> [wildcard expansion]"
|
|
)
|
|
corrections.append(
|
|
f"scheduler: {original_values['scheduler']} -> [wildcard expansion]"
|
|
)
|
|
|
|
# Print the auto-corrections
|
|
print(" AUTO-CORRECTION:")
|
|
for corr in corrections:
|
|
print(f" - {corr}")
|
|
auto_corrections.append((decl_name, corrections))
|
|
|
|
invalid_count += 1
|
|
wildcard_count += 1
|
|
|
|
if invalid_count > 0:
|
|
print(
|
|
f"\n WARNING {invalid_count} invalid config(s) auto-corrected via wildcard expansion"
|
|
)
|
|
|
|
if wildcard_count > 0:
|
|
print(
|
|
f" OK {len(gemm_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)"
|
|
)
|
|
else:
|
|
print(f" OK All {len(gemm_declarations)} configurations valid")
|
|
|
|
# Expand GEMM declarations (for wildcards)
|
|
print("\n Expanding wildcards to valid configurations...")
|
|
expanded_gemm = []
|
|
for decl in gemm_declarations:
|
|
arch = decl.get("arch", args.gpu_target)
|
|
decl_name = (
|
|
decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"]
|
|
)
|
|
|
|
expanded = expand_declaration_with_arch_filter(decl, arch)
|
|
expanded_gemm.extend(expanded)
|
|
|
|
# Show what the wildcard expanded to
|
|
if len(expanded) > 1:
|
|
print(
|
|
f" {decl_name}: expanded to {len(expanded)} valid configurations"
|
|
)
|
|
# Show first few expanded configs
|
|
for exp in expanded[:3]:
|
|
wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]"
|
|
warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]"
|
|
print(
|
|
f" -> wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}"
|
|
)
|
|
if len(expanded) > 3:
|
|
print(f" ... and {len(expanded) - 3} more")
|
|
elif len(expanded) == 1 and is_wildcard_declaration(decl):
|
|
exp = expanded[0]
|
|
wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]"
|
|
warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]"
|
|
print(f" {decl_name}: -> wave={wave_str}, warp={warp_str}")
|
|
|
|
if len(expanded_gemm) > len(gemm_declarations):
|
|
print(
|
|
f"\n Total: {len(gemm_declarations)} declarations -> {len(expanded_gemm)} configurations"
|
|
)
|
|
|
|
gemm_declarations = expanded_gemm
|
|
|
|
# Handle Conv declarations
|
|
if conv_declarations:
|
|
print(f"\n CONV: Found {len(conv_declarations)} declaration(s)")
|
|
|
|
# Group by kernel set
|
|
sets = {}
|
|
for decl in conv_declarations:
|
|
set_name = decl.get("set") or "(global)"
|
|
if set_name not in sets:
|
|
sets[set_name] = []
|
|
sets[set_name].append(decl)
|
|
|
|
for set_name, set_decls in sets.items():
|
|
print(f" [{set_name}] ({len(set_decls)} kernels):")
|
|
for decl in set_decls[:5]:
|
|
needs_expansion = is_conv_wildcard_declaration(decl)
|
|
suffix = " [expands]" if needs_expansion else ""
|
|
display_name = (
|
|
decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"]
|
|
)
|
|
print(f" - {display_name}{suffix}")
|
|
if len(set_decls) > 5:
|
|
print(f" ... and {len(set_decls) - 5} more")
|
|
|
|
# Validate Conv declarations against arch filter
|
|
print(f"\n Validating against {args.gpu_target} arch filter...")
|
|
wildcard_count = 0
|
|
invalid_count = 0
|
|
auto_corrections = []
|
|
|
|
for decl in conv_declarations:
|
|
arch = decl.get("arch", args.gpu_target)
|
|
decl_name = (
|
|
decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"]
|
|
)
|
|
|
|
# Check for wildcards
|
|
if is_conv_wildcard_declaration(decl):
|
|
wildcard_count += 1
|
|
continue # Wildcards validated during expansion
|
|
|
|
is_valid, error_msg = validate_conv_kernel_config(decl, arch)
|
|
if not is_valid:
|
|
print(f"\n WARNING Invalid conv configuration: {decl_name}")
|
|
|
|
# Parse the error and show specific auto-corrections
|
|
corrections = []
|
|
original_values = {}
|
|
|
|
if "wave configuration" in error_msg.lower():
|
|
original_values["wave"] = (
|
|
f"[{decl.get('wave_m', 2)}, {decl.get('wave_n', 2)}, {decl.get('wave_k', 1)}]"
|
|
)
|
|
decl["wave_m"] = -1
|
|
decl["wave_n"] = -1
|
|
corrections.append(
|
|
f"wave: {original_values['wave']} -> [wildcard expansion]"
|
|
)
|
|
|
|
if "warp tile" in error_msg.lower():
|
|
original_values["warp"] = (
|
|
f"[{decl.get('warp_m', 32)}, {decl.get('warp_n', 32)}, {decl.get('warp_k', 16)}]"
|
|
)
|
|
decl["warp_m"] = -1
|
|
decl["warp_n"] = -1
|
|
corrections.append(
|
|
f"warp_tile: {original_values['warp']} -> [wildcard expansion]"
|
|
)
|
|
|
|
if "trait combination" in error_msg.lower():
|
|
original_values["pipeline"] = decl.get("pipeline", "compv3")
|
|
original_values["scheduler"] = decl.get("scheduler", "intrawave")
|
|
decl["pipeline"] = "*"
|
|
decl["scheduler"] = "*"
|
|
corrections.append(
|
|
f"pipeline: {original_values['pipeline']} -> [wildcard expansion]"
|
|
)
|
|
corrections.append(
|
|
f"scheduler: {original_values['scheduler']} -> [wildcard expansion]"
|
|
)
|
|
|
|
# Print the auto-corrections
|
|
print(" AUTO-CORRECTION:")
|
|
for corr in corrections:
|
|
print(f" - {corr}")
|
|
auto_corrections.append((decl_name, corrections))
|
|
|
|
invalid_count += 1
|
|
wildcard_count += 1
|
|
|
|
if invalid_count > 0:
|
|
print(
|
|
f"\n WARNING {invalid_count} invalid config(s) auto-corrected via wildcard expansion"
|
|
)
|
|
|
|
if wildcard_count > 0:
|
|
print(
|
|
f" OK {len(conv_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)"
|
|
)
|
|
else:
|
|
print(f" OK All {len(conv_declarations)} configurations valid")
|
|
|
|
# Expand Conv declarations (for wildcards)
|
|
print("\n Expanding wildcards to valid configurations...")
|
|
expanded_conv = []
|
|
for decl in conv_declarations:
|
|
arch = decl.get("arch", args.gpu_target)
|
|
decl_name = (
|
|
decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"]
|
|
)
|
|
|
|
expanded = expand_conv_declaration_with_arch_filter(decl, arch)
|
|
expanded_conv.extend(expanded)
|
|
|
|
# Show what the wildcard expanded to
|
|
if len(expanded) > 1:
|
|
print(
|
|
f" {decl_name}: expanded to {len(expanded)} valid configurations"
|
|
)
|
|
for exp in expanded[:3]:
|
|
wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]"
|
|
warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]"
|
|
print(
|
|
f" -> wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}"
|
|
)
|
|
if len(expanded) > 3:
|
|
print(f" ... and {len(expanded) - 3} more")
|
|
elif len(expanded) == 1 and is_conv_wildcard_declaration(decl):
|
|
exp = expanded[0]
|
|
wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]"
|
|
warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]"
|
|
print(f" {decl_name}: -> wave={wave_str}, warp={warp_str}")
|
|
|
|
if len(expanded_conv) > len(conv_declarations):
|
|
print(
|
|
f"\n Total: {len(conv_declarations)} declarations -> {len(expanded_conv)} configurations"
|
|
)
|
|
|
|
conv_declarations = expanded_conv
|
|
|
|
print()
|
|
|
|
# Phase 2: Generate kernels
|
|
print_phase("Phase 2: Generating kernels...")
|
|
|
|
total_generated = 0
|
|
|
|
# Generate GEMM kernels
|
|
if gemm_declarations:
|
|
print(" GEMM kernels:")
|
|
num_gemm = generate_kernels(gemm_declarations, args.gpu_target)
|
|
total_generated += num_gemm
|
|
print(f" Generated: {num_gemm}")
|
|
|
|
# Generate Conv kernels
|
|
if conv_declarations:
|
|
print(" CONV kernels:")
|
|
num_conv = generate_conv_kernels(conv_declarations, args.gpu_target)
|
|
total_generated += num_conv
|
|
print(f" Generated: {num_conv}")
|
|
|
|
print(f" Total kernel files: {total_generated}")
|
|
print()
|
|
|
|
# Phase 3: Find kernel header
|
|
print_phase("Phase 3: Selecting kernel for compilation...")
|
|
|
|
kernel_headers = []
|
|
|
|
# Find GEMM kernel header (try each expanded declaration until one matches)
|
|
if gemm_declarations:
|
|
gemm_header = None
|
|
for decl in gemm_declarations:
|
|
header = find_kernel_header(decl, args.gpu_target)
|
|
if header:
|
|
gemm_header = header
|
|
break
|
|
|
|
if gemm_header:
|
|
kernel_headers.append(gemm_header)
|
|
print(f" GEMM: {gemm_header.name}")
|
|
else:
|
|
print_error(" GEMM: No kernel found matching any declaration!")
|
|
print_error(
|
|
" The kernels declared in DECL_KERNEL_SET must exist or be generatable."
|
|
)
|
|
return 1
|
|
|
|
# Find Conv kernel header
|
|
if conv_declarations:
|
|
first_conv = conv_declarations[0]
|
|
conv_header = find_conv_kernel_header(first_conv)
|
|
if conv_header:
|
|
kernel_headers.append(conv_header)
|
|
print(f" CONV: {conv_header.name}")
|
|
|
|
if not kernel_headers:
|
|
print_error(" No kernel headers found!")
|
|
return 1
|
|
|
|
# Use first available header (can be extended to use multiple)
|
|
kernel_header = kernel_headers[0]
|
|
print()
|
|
|
|
# Phase 4: Build dispatcher library
|
|
print_phase("Phase 4: Building dispatcher library...")
|
|
hipcc = find_hipcc()
|
|
|
|
if not build_dispatcher_library(hipcc):
|
|
print_error(" Failed to build dispatcher library!")
|
|
return 1
|
|
print(" Done")
|
|
print()
|
|
|
|
# Phase 5: Compile application
|
|
print_phase("Phase 5: Compiling application...")
|
|
|
|
if not compile_application(
|
|
source_file, output_bin, kernel_header, hipcc, args.gpu_target
|
|
):
|
|
print_error(" Compilation failed!")
|
|
return 1
|
|
|
|
print(f" Output: {output_bin}")
|
|
print()
|
|
|
|
# Done
|
|
print_success("=== Build Complete ===")
|
|
print()
|
|
print("Run with:")
|
|
print(f" {output_bin}")
|
|
print()
|
|
print("List declared kernels:")
|
|
print(f" {output_bin} --list-kernels")
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|