mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-23 00:27:38 +00:00
* WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable.
1448 lines
52 KiB
Python
Executable File
1448 lines
52 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Build example kernels - generates and compiles kernels for a single example.
|
|
|
|
Detects if example is GEMM or Conv based on macro presence, extracts all
|
|
configuration parameters, and generates appropriate kernels.
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import re
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
|
def find_hipcc() -> str:
|
|
for path in [os.environ.get("HIPCC"), "/opt/rocm/bin/hipcc", shutil.which("hipcc")]:
|
|
if path and os.path.isfile(path):
|
|
return path
|
|
return "hipcc"
|
|
|
|
|
|
def find_ar() -> str:
|
|
for path in [
|
|
"/opt/rocm/llvm/bin/llvm-ar",
|
|
shutil.which("llvm-ar"),
|
|
shutil.which("ar"),
|
|
]:
|
|
if path and os.path.isfile(path):
|
|
return path
|
|
return "ar"
|
|
|
|
|
|
def extract_balanced_parens(text: str, start_pos: int) -> str:
|
|
"""Extract content between balanced parentheses."""
|
|
if start_pos >= len(text) or text[start_pos] != "(":
|
|
return ""
|
|
depth = 0
|
|
for i, c in enumerate(text[start_pos:], start_pos):
|
|
if c == "(":
|
|
depth += 1
|
|
elif c == ")":
|
|
depth -= 1
|
|
if depth == 0:
|
|
return text[start_pos + 1 : i]
|
|
return ""
|
|
|
|
|
|
def parse_conv_declarations(content: str) -> List[Dict]:
|
|
"""Parse DECL_CONV_KERNEL_SET declarations with all parameters."""
|
|
kernels = []
|
|
|
|
for match in re.finditer(r"DECL_CONV_KERNEL_SET\s*\(", content):
|
|
body = extract_balanced_parens(content, match.end() - 1)
|
|
if not body:
|
|
continue
|
|
|
|
# Parse each .add() call
|
|
for add_match in re.finditer(r"\.add\s*\(", body):
|
|
add_body = extract_balanced_parens(body, add_match.end() - 1)
|
|
|
|
kernel = {}
|
|
|
|
# ConvSig parameters - handle both single dtype and multi-dtype
|
|
# Multi-dtype: .dtype("fp16", "fp16", "fp16", "fp32") or .dtype("fp16", "bf16", "fp16")
|
|
if m := re.search(
|
|
r'\.dtype\s*\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"(?:\s*,\s*"([^"]+)")?\s*\)',
|
|
add_body,
|
|
):
|
|
kernel["dtype_in"] = m.group(1)
|
|
kernel["dtype_wei"] = m.group(2)
|
|
kernel["dtype_out"] = m.group(3)
|
|
kernel["dtype_acc"] = m.group(4) if m.group(4) else "fp32"
|
|
kernel["dtype"] = m.group(1) # Default for codegen
|
|
# Single dtype: .dtype("fp16")
|
|
elif m := re.search(r'\.dtype\s*\(\s*"([^"]+)"\s*\)', add_body):
|
|
kernel["dtype"] = m.group(1)
|
|
kernel["dtype_in"] = m.group(1)
|
|
kernel["dtype_wei"] = m.group(1)
|
|
kernel["dtype_out"] = m.group(1)
|
|
kernel["dtype_acc"] = "fp32"
|
|
if m := re.search(r'\.layout\s*\(\s*"([^"]+)"', add_body):
|
|
kernel["layout"] = m.group(1)
|
|
if m := re.search(r'\.conv_type\s*\(\s*"([^"]+)"', add_body):
|
|
kernel["conv_type"] = m.group(1)
|
|
if m := re.search(r"\.dims\s*\(\s*(\d+)\s*\)", add_body):
|
|
kernel["ndim"] = int(m.group(1))
|
|
|
|
# ConvAlgo parameters - tile(G, M, N) where G=batch, M=output, N=reduction
|
|
if m := re.search(
|
|
r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body
|
|
):
|
|
kernel["tile_g"] = int(m.group(1)) # batch tile (usually 1)
|
|
kernel["tile_m"] = int(m.group(2)) # output channel tile
|
|
kernel["tile_n"] = int(m.group(3)) # input channel tile (reduction)
|
|
|
|
# wave(M_Warp, N_Warp, K_Warp) - warp distribution
|
|
if m := re.search(
|
|
r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body
|
|
):
|
|
kernel["warp_m"] = int(m.group(1))
|
|
kernel["warp_n"] = int(m.group(2))
|
|
kernel["warp_k"] = int(m.group(3))
|
|
|
|
# warp(M_Warp_Tile, N_Warp_Tile, K_Warp_Tile) - warp tile sizes
|
|
if m := re.search(
|
|
r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body
|
|
):
|
|
kernel["warp_tile_m"] = int(m.group(1))
|
|
kernel["warp_tile_n"] = int(m.group(2))
|
|
kernel["warp_tile_k"] = int(m.group(3))
|
|
|
|
# vector_sizes(A, B, C)
|
|
if m := re.search(
|
|
r"\.vector_sizes\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body
|
|
):
|
|
kernel["vector_a"] = int(m.group(1))
|
|
kernel["vector_b"] = int(m.group(2))
|
|
kernel["vector_c"] = int(m.group(3))
|
|
|
|
# Single-value parameters
|
|
if m := re.search(r'\.pipeline\s*\(\s*"([^"]+)"', add_body):
|
|
kernel["pipeline"] = m.group(1)
|
|
if m := re.search(r'\.scheduler\s*\(\s*"([^"]+)"', add_body):
|
|
kernel["scheduler"] = m.group(1)
|
|
if m := re.search(r'\.epilogue\s*\(\s*"([^"]+)"', add_body):
|
|
kernel["epilogue"] = m.group(1)
|
|
if m := re.search(r"\.block_per_cu\s*\(\s*(\d+)\s*\)", add_body):
|
|
kernel["block_per_cu"] = int(m.group(1))
|
|
if m := re.search(r"\.num_wave_groups\s*\(\s*(\d+)\s*\)", add_body):
|
|
kernel["num_wave_groups"] = int(m.group(1))
|
|
if m := re.search(r"\.num_groups_to_merge\s*\(\s*(\d+)\s*\)", add_body):
|
|
kernel["num_groups_to_merge"] = int(m.group(1))
|
|
if m := re.search(
|
|
r"\.double_smem_buffer\s*\(\s*(true|false)\s*\)", add_body, re.I
|
|
):
|
|
kernel["double_smem_buffer"] = m.group(1).lower() == "true"
|
|
|
|
# Architecture
|
|
if m := re.search(r'"(gfx\d+)"', add_body):
|
|
kernel["arch"] = m.group(1)
|
|
|
|
if kernel.get("dtype"):
|
|
# Auto-fill missing parameters with defaults (autocorrect)
|
|
kernel = auto_fill_conv_defaults(kernel)
|
|
kernels.append(kernel)
|
|
|
|
return kernels
|
|
|
|
|
|
def auto_fill_conv_defaults(kernel: Dict) -> Dict:
|
|
"""Auto-fill missing conv parameters with sensible defaults (autofill + autocorrect).
|
|
|
|
This implements:
|
|
1. AUTOFILL: Missing parameters are filled with valid defaults (ConvConfigComputeV3)
|
|
2. AUTOCORRECT: Invalid values are corrected to valid ones
|
|
"""
|
|
# Default tile configuration matching ConvConfigComputeV3
|
|
defaults = {
|
|
"tile_g": 1,
|
|
"tile_m": 16,
|
|
"tile_n": 64,
|
|
"warp_m": 1,
|
|
"warp_n": 4,
|
|
"warp_k": 1,
|
|
"warp_tile_m": 16,
|
|
"warp_tile_n": 16,
|
|
"warp_tile_k": 32,
|
|
"pipeline": "compv3",
|
|
"scheduler": "intrawave",
|
|
"epilogue": "cshuffle",
|
|
"vector_a": 4,
|
|
"vector_b": 8,
|
|
"vector_c": 8,
|
|
"block_per_cu": 1,
|
|
"num_wave_groups": 1,
|
|
"num_groups_to_merge": 1,
|
|
"ndim": 2,
|
|
"layout": "nhwgc",
|
|
"conv_type": "forward",
|
|
"arch": "gfx942",
|
|
}
|
|
|
|
# AUTOFILL: Fill missing parameters with defaults
|
|
autofilled = []
|
|
for key, value in defaults.items():
|
|
if key not in kernel or kernel[key] is None or kernel[key] == -1:
|
|
kernel[key] = value
|
|
autofilled.append(f"{key}={value}")
|
|
|
|
if autofilled:
|
|
print(f" [AUTOFILL] {', '.join(autofilled)}")
|
|
|
|
# AUTOCORRECT: Fix invalid wave configurations for gfx942
|
|
valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
|
|
current_wave = (
|
|
kernel.get("warp_m", 1),
|
|
kernel.get("warp_n", 4),
|
|
kernel.get("warp_k", 1),
|
|
)
|
|
|
|
if current_wave not in valid_wave_configs:
|
|
old = current_wave
|
|
kernel["warp_m"] = 1
|
|
kernel["warp_n"] = 4
|
|
kernel["warp_k"] = 1
|
|
print(f" [AUTOCORRECT] wave{old} -> wave(1,4,1) (invalid for gfx942)")
|
|
|
|
# AUTOCORRECT: Fix invalid pipeline for backward ops
|
|
conv_type = kernel.get("conv_type", "forward")
|
|
pipeline = kernel.get("pipeline", "compv3")
|
|
|
|
if conv_type in ["bwd_data", "bwd_weight"] and pipeline in ["compv4", "compv5"]:
|
|
old_pipeline = pipeline
|
|
kernel["pipeline"] = "compv3"
|
|
print(
|
|
f" [AUTOCORRECT] pipeline {old_pipeline} -> compv3 (invalid for {conv_type})"
|
|
)
|
|
|
|
return kernel
|
|
|
|
|
|
def expand_conv_wildcards(kernel: Dict, arch: str = "gfx942") -> List[Dict]:
|
|
"""Expand wildcard parameters to multiple valid configurations.
|
|
|
|
When users specify wildcards (-1 or *), this expands them to all
|
|
valid configurations for the target architecture.
|
|
"""
|
|
expanded = []
|
|
|
|
# Valid wave configurations for gfx942
|
|
valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
|
|
|
|
# Valid warp tile configurations for gfx942 fp16
|
|
valid_warp_configs = [(16, 16, 32), (32, 32, 16)]
|
|
|
|
# Check if expansion is needed
|
|
needs_wave = kernel.get("warp_m") is None or kernel.get("warp_m") == -1
|
|
needs_warp = kernel.get("warp_tile_m") is None or kernel.get("warp_tile_m") == -1
|
|
|
|
if not needs_wave and not needs_warp:
|
|
return [kernel]
|
|
|
|
# Expand wave configurations
|
|
wave_configs = (
|
|
valid_wave_configs
|
|
if needs_wave
|
|
else [
|
|
(kernel.get("warp_m", 2), kernel.get("warp_n", 2), kernel.get("warp_k", 1))
|
|
]
|
|
)
|
|
|
|
# Expand warp tile configurations
|
|
warp_configs = (
|
|
valid_warp_configs
|
|
if needs_warp
|
|
else [
|
|
(
|
|
kernel.get("warp_tile_m", 32),
|
|
kernel.get("warp_tile_n", 32),
|
|
kernel.get("warp_tile_k", 16),
|
|
)
|
|
]
|
|
)
|
|
|
|
for wm, wn, wk in wave_configs:
|
|
for wtm, wtn, wtk in warp_configs:
|
|
new_kernel = kernel.copy()
|
|
new_kernel["warp_m"] = wm
|
|
new_kernel["warp_n"] = wn
|
|
new_kernel["warp_k"] = wk
|
|
new_kernel["warp_tile_m"] = wtm
|
|
new_kernel["warp_tile_n"] = wtn
|
|
new_kernel["warp_tile_k"] = wtk
|
|
expanded.append(new_kernel)
|
|
|
|
return expanded
|
|
|
|
|
|
def parse_int_or_wildcard(val: str) -> int:
|
|
"""Parse integer or return -1 for wildcards.
|
|
|
|
Supported wildcard formats:
|
|
- ANY_INT: Macro defined as -1
|
|
- -1: Direct numeric wildcard
|
|
- "*": String wildcard (also maps to -1 for integer params)
|
|
"""
|
|
val = val.strip()
|
|
if val == "ANY_INT" or val == "-1" or val == "*":
|
|
return -1
|
|
return int(val)
|
|
|
|
|
|
def parse_gemm_declarations(content: str) -> List[Dict]:
|
|
"""Parse DECL_KERNEL_SET declarations for GEMM.
|
|
|
|
Supports wildcards:
|
|
- ANY_INT for numeric params (wave, warp) -> expands to all valid combos
|
|
- "*" for string params (pipeline, scheduler) -> expands to valid options
|
|
|
|
Each kernel is tagged with its kernel_set name for separate registration.
|
|
"""
|
|
kernels = []
|
|
|
|
for match in re.finditer(r"DECL_KERNEL_SET\s*\(\s*(\w+)\s*,", content):
|
|
kernel_set_name = match.group(1)
|
|
body = extract_balanced_parens(
|
|
content, match.start() + content[match.start() :].find("(")
|
|
)
|
|
if not body:
|
|
continue
|
|
|
|
for add_match in re.finditer(r"\.add\s*\(", body):
|
|
add_body = extract_balanced_parens(body, add_match.end() - 1)
|
|
|
|
kernel = {}
|
|
|
|
# Signature parameters
|
|
if m := re.search(r'\.dtype\s*\(\s*"([^"]+)"', add_body):
|
|
kernel["dtype"] = m.group(1)
|
|
if m := re.search(r'\.layout\s*\(\s*"([^"]+)"', add_body):
|
|
kernel["layout"] = m.group(1)
|
|
if m := re.search(r'\.elementwise\s*\(\s*"([^"]+)"\s*,\s*(\d+)', add_body):
|
|
kernel["elementwise_op"] = m.group(1)
|
|
kernel["num_d_tensors"] = int(m.group(2))
|
|
|
|
# Algorithm parameters - support ANY_INT wildcard
|
|
if m := re.search(
|
|
r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body
|
|
):
|
|
kernel["tile_m"] = int(m.group(1))
|
|
kernel["tile_n"] = int(m.group(2))
|
|
kernel["tile_k"] = int(m.group(3))
|
|
|
|
# Wave: support ANY_INT, -1, and "*" as wildcards
|
|
if m := re.search(
|
|
r"\.wave\s*\(\s*([\w*-]+)\s*,\s*([\w*-]+)\s*,\s*([\w*-]+)\s*\)",
|
|
add_body,
|
|
):
|
|
kernel["warp_m"] = parse_int_or_wildcard(m.group(1))
|
|
kernel["warp_n"] = parse_int_or_wildcard(m.group(2))
|
|
kernel["warp_k"] = parse_int_or_wildcard(m.group(3))
|
|
|
|
# Warp: support ANY_INT, -1, and "*" as wildcards
|
|
if m := re.search(
|
|
r"\.warp\s*\(\s*([\w*-]+)\s*,\s*([\w*-]+)\s*,\s*([\w*-]+)\s*\)",
|
|
add_body,
|
|
):
|
|
kernel["warp_tile_m"] = parse_int_or_wildcard(m.group(1))
|
|
kernel["warp_tile_n"] = parse_int_or_wildcard(m.group(2))
|
|
kernel["warp_tile_k"] = parse_int_or_wildcard(m.group(3))
|
|
|
|
# Pipeline/Scheduler: support "*" wildcard
|
|
if m := re.search(r'\.pipeline\s*\(\s*"([^"]+)"', add_body):
|
|
kernel["pipeline"] = m.group(1)
|
|
if m := re.search(r'\.scheduler\s*\(\s*"([^"]+)"', add_body):
|
|
kernel["scheduler"] = m.group(1)
|
|
if m := re.search(r'\.epilogue\s*\(\s*"([^"]+)"', add_body):
|
|
kernel["epilogue"] = m.group(1)
|
|
if m := re.search(
|
|
r"\.pad\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)",
|
|
add_body,
|
|
re.I,
|
|
):
|
|
kernel["pad_m"] = m.group(1).lower() == "true"
|
|
kernel["pad_n"] = m.group(2).lower() == "true"
|
|
kernel["pad_k"] = m.group(3).lower() == "true"
|
|
|
|
# Shorthand format: .add("dtype", "layout", M, N, K)
|
|
if not kernel.get("dtype"):
|
|
if m := re.match(
|
|
r'\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)',
|
|
add_body,
|
|
):
|
|
kernel["dtype"] = m.group(1)
|
|
kernel["layout"] = m.group(2)
|
|
kernel["tile_m"] = int(m.group(3))
|
|
kernel["tile_n"] = int(m.group(4))
|
|
kernel["tile_k"] = int(m.group(5))
|
|
|
|
if kernel.get("dtype"):
|
|
kernel["kernel_set"] = kernel_set_name
|
|
kernels.append(kernel)
|
|
|
|
# Expand wildcards to multiple kernels
|
|
expanded = []
|
|
for kernel in kernels:
|
|
expanded.extend(expand_gemm_wildcards(kernel))
|
|
|
|
# Apply autocorrect to each expanded kernel
|
|
return [auto_fill_gemm_defaults(k) for k in expanded]
|
|
|
|
|
|
def expand_gemm_wildcards(kernel: Dict, arch: str = "gfx942") -> List[Dict]:
|
|
"""Expand wildcard parameters to multiple valid configurations.
|
|
|
|
When users specify ANY_INT (-1) or "*", this expands them to all
|
|
valid configurations for the target architecture.
|
|
|
|
Note: Block size constraint filters invalid combos:
|
|
- (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * 64 <= 1024
|
|
- For 128x128 tile: only (32,32,k) works (16 warps * 64 = 1024)
|
|
- For 64x64 tile: both (16,16,k) and (32,32,k) work
|
|
"""
|
|
# Valid wave configurations for gfx942
|
|
valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
|
|
|
|
# Valid warp tile configurations for gfx942 fp16
|
|
valid_warp_configs = [(16, 16, 32), (32, 32, 16)]
|
|
|
|
# Valid pipelines and schedulers
|
|
valid_pipelines = ["compv3"] # compv4 requires special handling
|
|
valid_schedulers = ["intrawave"]
|
|
|
|
# Check what needs expansion
|
|
needs_wave = kernel.get("warp_m") == -1
|
|
needs_warp = kernel.get("warp_tile_m") == -1
|
|
needs_pipeline = kernel.get("pipeline") == "*"
|
|
needs_scheduler = kernel.get("scheduler") == "*"
|
|
|
|
if not any([needs_wave, needs_warp, needs_pipeline, needs_scheduler]):
|
|
return [kernel]
|
|
|
|
# Determine configs to iterate
|
|
wave_configs = (
|
|
valid_wave_configs
|
|
if needs_wave
|
|
else [
|
|
(kernel.get("warp_m", 2), kernel.get("warp_n", 2), kernel.get("warp_k", 1))
|
|
]
|
|
)
|
|
warp_configs = (
|
|
valid_warp_configs
|
|
if needs_warp
|
|
else [
|
|
(
|
|
kernel.get("warp_tile_m", 32),
|
|
kernel.get("warp_tile_n", 32),
|
|
kernel.get("warp_tile_k", 16),
|
|
)
|
|
]
|
|
)
|
|
pipelines = (
|
|
valid_pipelines if needs_pipeline else [kernel.get("pipeline", "compv3")]
|
|
)
|
|
schedulers = (
|
|
valid_schedulers if needs_scheduler else [kernel.get("scheduler", "intrawave")]
|
|
)
|
|
|
|
expanded = []
|
|
for wm, wn, wk in wave_configs:
|
|
for wtm, wtn, wtk in warp_configs:
|
|
# Check block size constraint: (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * 64 <= 1024
|
|
tile_m = kernel.get("tile_m", 128)
|
|
tile_n = kernel.get("tile_n", 128)
|
|
num_warps = (tile_m // wtm) * (tile_n // wtn)
|
|
if num_warps * 64 > 1024:
|
|
continue # Skip invalid config
|
|
|
|
for pipe in pipelines:
|
|
for sched in schedulers:
|
|
new_kernel = kernel.copy()
|
|
new_kernel["warp_m"] = wm
|
|
new_kernel["warp_n"] = wn
|
|
new_kernel["warp_k"] = wk
|
|
new_kernel["warp_tile_m"] = wtm
|
|
new_kernel["warp_tile_n"] = wtn
|
|
new_kernel["warp_tile_k"] = wtk
|
|
new_kernel["pipeline"] = pipe
|
|
new_kernel["scheduler"] = sched
|
|
expanded.append(new_kernel)
|
|
|
|
if expanded:
|
|
print(f" [WILDCARD] Expanded 1 declaration -> {len(expanded)} kernel(s)")
|
|
|
|
return expanded if expanded else [kernel]
|
|
|
|
|
|
def auto_fill_gemm_defaults(kernel: Dict) -> Dict:
|
|
"""Auto-fill missing GEMM parameters with sensible defaults (autofill + autocorrect).
|
|
|
|
This implements:
|
|
1. AUTOFILL: Missing parameters are filled with valid defaults
|
|
2. AUTOCORRECT: Invalid values are corrected to valid ones (e.g., wave(1,1,1) -> wave(2,2,1))
|
|
"""
|
|
defaults = {
|
|
"tile_m": 128,
|
|
"tile_n": 128,
|
|
"tile_k": 64,
|
|
"warp_m": 2,
|
|
"warp_n": 2,
|
|
"warp_k": 1,
|
|
"warp_tile_m": 32,
|
|
"warp_tile_n": 32,
|
|
"warp_tile_k": 16,
|
|
"pipeline": "compv3",
|
|
"scheduler": "intrawave",
|
|
"epilogue": "cshuffle",
|
|
"pad_m": False,
|
|
"pad_n": False,
|
|
"pad_k": False,
|
|
"layout": "rcr",
|
|
}
|
|
|
|
# AUTOFILL: Fill missing parameters with defaults
|
|
autofilled = []
|
|
for key, value in defaults.items():
|
|
if key not in kernel or kernel[key] is None or kernel[key] == -1:
|
|
kernel[key] = value
|
|
autofilled.append(f"{key}={value}")
|
|
|
|
if autofilled:
|
|
print(f" [AUTOFILL] {', '.join(autofilled)}")
|
|
|
|
# AUTOCORRECT: Fix invalid wave configurations for gfx942
|
|
# Valid wave configs: (1,4,1), (2,2,1), (4,1,1)
|
|
valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
|
|
current_wave = (
|
|
kernel.get("warp_m", 2),
|
|
kernel.get("warp_n", 2),
|
|
kernel.get("warp_k", 1),
|
|
)
|
|
|
|
if current_wave not in valid_wave_configs:
|
|
# Correct to (2,2,1) which is a balanced default
|
|
old = current_wave
|
|
kernel["warp_m"] = 2
|
|
kernel["warp_n"] = 2
|
|
kernel["warp_k"] = 1
|
|
print(f" [AUTOCORRECT] wave{old} -> wave(2,2,1) (invalid for gfx942)")
|
|
|
|
# AUTOCORRECT: Fix invalid pipeline/scheduler combinations
|
|
invalid_combos = [
|
|
("compv3", "interwave"),
|
|
("compv4", "interwave"),
|
|
]
|
|
current_combo = (
|
|
kernel.get("pipeline", "compv3"),
|
|
kernel.get("scheduler", "intrawave"),
|
|
)
|
|
if current_combo in invalid_combos:
|
|
old = current_combo
|
|
kernel["scheduler"] = "intrawave"
|
|
print(
|
|
f" [AUTOCORRECT] {old[0]}/{old[1]} -> {old[0]}/intrawave (invalid combo)"
|
|
)
|
|
|
|
# AUTOCORRECT: Fix warp tile to avoid exceeding max block size (1024 threads)
|
|
# Block size = (tile_m / warp_tile_m) * (tile_n / warp_tile_n) * 64
|
|
tile_m = kernel.get("tile_m", 128)
|
|
tile_n = kernel.get("tile_n", 128)
|
|
warp_tile_m = kernel.get("warp_tile_m", 32)
|
|
warp_tile_n = kernel.get("warp_tile_n", 32)
|
|
|
|
num_warps = (tile_m // warp_tile_m) * (tile_n // warp_tile_n)
|
|
block_size = num_warps * 64 # 64 threads per warp
|
|
|
|
if block_size > 1024:
|
|
# Find valid warp tile that fits
|
|
old_warp = (warp_tile_m, warp_tile_n, kernel.get("warp_tile_k", 16))
|
|
|
|
# For large tiles, use larger warp tiles
|
|
if tile_m >= 256:
|
|
kernel["warp_tile_m"] = 64
|
|
if tile_n >= 256:
|
|
kernel["warp_tile_n"] = 64
|
|
|
|
# Recalculate
|
|
num_warps = (tile_m // kernel["warp_tile_m"]) * (
|
|
tile_n // kernel["warp_tile_n"]
|
|
)
|
|
block_size = num_warps * 64
|
|
|
|
if block_size <= 1024:
|
|
new_warp = (
|
|
kernel["warp_tile_m"],
|
|
kernel["warp_tile_n"],
|
|
kernel["warp_tile_k"],
|
|
)
|
|
print(
|
|
f" [AUTOCORRECT] warp{old_warp} -> warp{new_warp} (block_size={block_size})"
|
|
)
|
|
else:
|
|
# Still too large, try even larger warp tiles
|
|
kernel["warp_tile_m"] = tile_m // 4
|
|
kernel["warp_tile_n"] = tile_n // 4
|
|
new_warp = (
|
|
kernel["warp_tile_m"],
|
|
kernel["warp_tile_n"],
|
|
kernel["warp_tile_k"],
|
|
)
|
|
print(
|
|
f" [AUTOCORRECT] warp{old_warp} -> warp{new_warp} (block_size adjusted)"
|
|
)
|
|
|
|
return kernel
|
|
|
|
|
|
def strip_cpp_strings_and_comments(content: str) -> str:
|
|
"""Strip C++ string literals and comments that could cause false positives.
|
|
|
|
Only strips:
|
|
- Comments (// and /* */) - always stripped
|
|
- Raw string literals (R"...") - always stripped (can contain anything)
|
|
- Regular strings ONLY if they contain problematic patterns like DECL_KERNEL_SET
|
|
|
|
Preserves normal string literals like "fp16", "rcr" which are needed for parsing.
|
|
"""
|
|
result = []
|
|
i = 0
|
|
n = len(content)
|
|
|
|
# Patterns that indicate a string is problematic and should be stripped
|
|
problematic_patterns = ["DECL_KERNEL_SET", "DECL_CONV_KERNEL_SET", ".add("]
|
|
|
|
while i < n:
|
|
# Check for raw string literal: R"delimiter(...)delimiter"
|
|
# Always strip these as they can contain arbitrary content
|
|
if i < n - 1 and content[i] == "R" and content[i + 1] == '"':
|
|
# Find the delimiter (between R" and ()
|
|
j = i + 2
|
|
delimiter_start = j
|
|
while j < n and content[j] != "(":
|
|
j += 1
|
|
delimiter = content[delimiter_start:j]
|
|
# Find the closing )delimiter"
|
|
end_marker = ")" + delimiter + '"'
|
|
end_pos = content.find(end_marker, j + 1)
|
|
if end_pos != -1:
|
|
# Replace with spaces to preserve line numbers
|
|
span = content[i : end_pos + len(end_marker)]
|
|
result.append("".join("\n" if c == "\n" else " " for c in span))
|
|
i = end_pos + len(end_marker)
|
|
continue
|
|
|
|
# Check for regular string literal - only strip if it contains problematic patterns
|
|
if content[i] == '"':
|
|
j = i + 1
|
|
while j < n:
|
|
if content[j] == "\\" and j + 1 < n:
|
|
j += 2 # Skip escaped character
|
|
elif content[j] == '"':
|
|
j += 1
|
|
break
|
|
else:
|
|
j += 1
|
|
string_content = content[i:j]
|
|
|
|
# Only strip if this string contains problematic patterns
|
|
should_strip = any(pat in string_content for pat in problematic_patterns)
|
|
if should_strip:
|
|
result.append(" " * len(string_content))
|
|
else:
|
|
result.append(string_content)
|
|
i = j
|
|
continue
|
|
|
|
# Check for single-line comment - always strip
|
|
if i < n - 1 and content[i : i + 2] == "//":
|
|
j = i
|
|
while j < n and content[j] != "\n":
|
|
j += 1
|
|
result.append(" " * (j - i))
|
|
i = j
|
|
continue
|
|
|
|
# Check for multi-line comment - always strip
|
|
if i < n - 1 and content[i : i + 2] == "/*":
|
|
end_pos = content.find("*/", i + 2)
|
|
if end_pos != -1:
|
|
span = content[i : end_pos + 2]
|
|
# Preserve newlines in multi-line comments
|
|
result.append("".join("\n" if c == "\n" else " " for c in span))
|
|
i = end_pos + 2
|
|
continue
|
|
|
|
result.append(content[i])
|
|
i += 1
|
|
|
|
return "".join(result)
|
|
|
|
|
|
def detect_and_parse(source_path: Path) -> Tuple[str, List[Dict]]:
|
|
"""Detect example type and parse kernel declarations.
|
|
|
|
Properly strips string literals and comments before parsing to avoid
|
|
picking up declarations inside strings or commented-out code.
|
|
"""
|
|
content = source_path.read_text()
|
|
content = strip_cpp_strings_and_comments(content)
|
|
|
|
if "DECL_CONV_KERNEL_SET" in content:
|
|
return "conv", parse_conv_declarations(content)
|
|
elif "DECL_KERNEL_SET" in content:
|
|
return "gemm", parse_gemm_declarations(content)
|
|
return "unknown", []
|
|
|
|
|
|
def generate_gemm_registration(
|
|
kernel_headers: List[Path], example_name: str, kernels: List[Dict] = None
|
|
) -> str:
|
|
"""Generate GEMM kernel registration code for the dispatcher registry.
|
|
|
|
Uses GeneratedKernelInstance<SelectedKernel> to wrap the generated kernels
|
|
and provide the KernelInstance interface for the Dispatcher.
|
|
|
|
If kernels list is provided with kernel_set info, generates separate
|
|
registration functions per kernel set.
|
|
"""
|
|
if not kernel_headers:
|
|
return " // No kernels to register"
|
|
|
|
# Build mapping from kernel config pattern to kernel set
|
|
kernel_to_set = {}
|
|
kernel_sets = set()
|
|
if kernels:
|
|
for k in kernels:
|
|
tile_m = k.get("tile_m", 128)
|
|
tile_n = k.get("tile_n", 128)
|
|
tile_k = k.get("tile_k", 64)
|
|
warp_m = k.get("warp_m", 2)
|
|
warp_n = k.get("warp_n", 2)
|
|
warp_k = k.get("warp_k", 1)
|
|
warp_tile_m = k.get("warp_tile_m", 32)
|
|
warp_tile_n = k.get("warp_tile_n", 32)
|
|
warp_tile_k = k.get("warp_tile_k", 16)
|
|
|
|
# Pattern that appears in kernel filename
|
|
key_pattern = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
|
kernel_set = k.get("kernel_set", "default")
|
|
kernel_to_set[key_pattern] = kernel_set
|
|
kernel_sets.add(kernel_set)
|
|
|
|
def generate_registration_block(h: Path) -> str:
|
|
"""Generate registration code for a single kernel."""
|
|
kernel_name = h.stem
|
|
ns = f"ns_{kernel_name}"
|
|
|
|
# Parse pipeline, scheduler, and layout from kernel name
|
|
# Format: gemm_fp16_rcr_compv3_cshuffle_intrawave_...
|
|
parts = kernel_name.split("_")
|
|
pipeline = "CompV3"
|
|
scheduler = "Intrawave"
|
|
epilogue = "CShuffle"
|
|
datatype = "FP16"
|
|
layout_a = "RowMajor"
|
|
layout_b = "ColMajor"
|
|
layout_c = "RowMajor"
|
|
|
|
# Parse datatype (e.g., fp16, bf16, fp32)
|
|
dtype_map = {
|
|
"fp16": "FP16",
|
|
"bf16": "BF16",
|
|
"fp32": "FP32",
|
|
"fp64": "FP64",
|
|
"int8": "INT8",
|
|
}
|
|
|
|
# Parse layout from 3-char codes (e.g., rcr, rrr, rrc, ccc)
|
|
# r = RowMajor, c = ColMajor
|
|
layout_map = {"r": "RowMajor", "c": "ColMajor"}
|
|
|
|
# Find pipeline, epilogue, scheduler in the name parts
|
|
pipeline_map = {
|
|
"mem": "Mem",
|
|
"compv1": "CompV1",
|
|
"compv2": "CompV2",
|
|
"compv3": "CompV3",
|
|
"compv4": "CompV4",
|
|
"compv5": "CompV5",
|
|
"preshufflev1": "PreShuffleV1",
|
|
"preshufflev2": "PreShuffleV2",
|
|
}
|
|
scheduler_map = {
|
|
"intrawave": "Intrawave",
|
|
"interwave": "Interwave",
|
|
"auto": "Auto",
|
|
}
|
|
epilogue_map = {"default": "Default", "cshuffle": "CShuffle", "none": "None"}
|
|
|
|
for part in parts:
|
|
if part in pipeline_map:
|
|
pipeline = pipeline_map[part]
|
|
if part in scheduler_map:
|
|
scheduler = scheduler_map[part]
|
|
if part in epilogue_map:
|
|
epilogue = epilogue_map[part]
|
|
if part in dtype_map:
|
|
datatype = dtype_map[part]
|
|
# Parse 3-char layout codes (e.g., rcr, rrr)
|
|
if len(part) == 3 and all(c in "rc" for c in part):
|
|
layout_a = layout_map[part[0]]
|
|
layout_b = layout_map[part[1]]
|
|
layout_c = layout_map[part[2]]
|
|
|
|
block = []
|
|
block.append(f" // Register kernel: {kernel_name}")
|
|
block.append(" {")
|
|
block.append(f" using SelectedKernel = {ns}::SelectedKernel;")
|
|
block.append(" ck_tile::dispatcher::KernelKey key;")
|
|
block.append(
|
|
f" key.signature.dtype_a = ck_tile::dispatcher::DataType::{datatype};"
|
|
)
|
|
block.append(
|
|
f" key.signature.dtype_b = ck_tile::dispatcher::DataType::{datatype};"
|
|
)
|
|
block.append(
|
|
f" key.signature.dtype_c = ck_tile::dispatcher::DataType::{datatype};"
|
|
)
|
|
block.append(
|
|
" key.signature.dtype_acc = ck_tile::dispatcher::DataType::FP32;"
|
|
)
|
|
block.append(
|
|
f" key.signature.layout_a = ck_tile::dispatcher::LayoutTag::{layout_a};"
|
|
)
|
|
block.append(
|
|
f" key.signature.layout_b = ck_tile::dispatcher::LayoutTag::{layout_b};"
|
|
)
|
|
block.append(
|
|
f" key.signature.layout_c = ck_tile::dispatcher::LayoutTag::{layout_c};"
|
|
)
|
|
block.append(" key.algorithm.tile_shape.m = SelectedKernel::TileM;")
|
|
block.append(" key.algorithm.tile_shape.n = SelectedKernel::TileN;")
|
|
block.append(" key.algorithm.tile_shape.k = SelectedKernel::TileK;")
|
|
block.append(
|
|
" key.algorithm.wave_shape.m = SelectedKernel::WarpPerBlock_M;"
|
|
)
|
|
block.append(
|
|
" key.algorithm.wave_shape.n = SelectedKernel::WarpPerBlock_N;"
|
|
)
|
|
block.append(
|
|
" key.algorithm.wave_shape.k = SelectedKernel::WarpPerBlock_K;"
|
|
)
|
|
block.append(
|
|
" key.algorithm.warp_tile_shape.m = SelectedKernel::WarpTileM;"
|
|
)
|
|
block.append(
|
|
" key.algorithm.warp_tile_shape.n = SelectedKernel::WarpTileN;"
|
|
)
|
|
block.append(
|
|
" key.algorithm.warp_tile_shape.k = SelectedKernel::WarpTileK;"
|
|
)
|
|
block.append(
|
|
" key.algorithm.block_size = SelectedKernel::BlockSize;"
|
|
)
|
|
block.append(
|
|
f" key.algorithm.pipeline = ck_tile::dispatcher::Pipeline::{pipeline};"
|
|
)
|
|
block.append(
|
|
f" key.algorithm.scheduler = ck_tile::dispatcher::Scheduler::{scheduler};"
|
|
)
|
|
block.append(
|
|
f" key.algorithm.epilogue = ck_tile::dispatcher::Epilogue::{epilogue};"
|
|
)
|
|
block.append(" key.gfx_arch = arch;")
|
|
block.append(
|
|
f' auto instance = std::make_shared<ck_tile::dispatcher::backends::GeneratedKernelInstance<SelectedKernel>>(key, "{kernel_name}");'
|
|
)
|
|
block.append(" registry.register_kernel(instance);")
|
|
block.append(" }")
|
|
return "\n".join(block)
|
|
|
|
def find_kernel_set(header: Path) -> str:
|
|
"""Find which kernel set a header belongs to."""
|
|
name = header.stem
|
|
for pattern, kset in kernel_to_set.items():
|
|
if pattern in name:
|
|
return kset
|
|
return "default"
|
|
|
|
# Group kernels by set
|
|
kernels_by_set = {}
|
|
for h in kernel_headers:
|
|
kset = find_kernel_set(h)
|
|
if kset not in kernels_by_set:
|
|
kernels_by_set[kset] = []
|
|
kernels_by_set[kset].append(h)
|
|
|
|
# If only one set or no set info, use simple registration
|
|
if len(kernels_by_set) <= 1:
|
|
lines = [" (void)arch;", ""]
|
|
for h in kernel_headers:
|
|
lines.append(generate_registration_block(h))
|
|
return "\n".join(lines)
|
|
|
|
# Multiple sets - generate registration for all, plus store per-set info
|
|
lines = [" // Register ALL kernels from all sets", " (void)arch;", ""]
|
|
for h in kernel_headers:
|
|
lines.append(generate_registration_block(h))
|
|
|
|
# Store per-set mapping for separate function generation
|
|
global _kernels_by_set_cache
|
|
_kernels_by_set_cache = (kernels_by_set, generate_registration_block)
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
# Global cache for per-set kernel info
|
|
_kernels_by_set_cache = None
|
|
|
|
|
|
def generate_per_set_functions(source_stem: str) -> str:
|
|
"""Generate separate registration functions for each kernel set.
|
|
|
|
Generates:
|
|
1. Per-set functions: register_<set_name>(registry, arch)
|
|
2. String-based dispatcher: register_kernel_set("set_name", registry, arch)
|
|
3. get_kernel_set_names() to list available sets
|
|
"""
|
|
global _kernels_by_set_cache
|
|
if not _kernels_by_set_cache:
|
|
return ""
|
|
|
|
kernels_by_set, gen_block = _kernels_by_set_cache
|
|
_kernels_by_set_cache = None # Clear cache
|
|
|
|
lines = []
|
|
set_names = []
|
|
|
|
# Generate per-set functions
|
|
for set_name, headers in kernels_by_set.items():
|
|
safe_name = set_name.replace("-", "_")
|
|
set_names.append((set_name, safe_name))
|
|
lines.append(
|
|
f"inline void register_{safe_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{"
|
|
)
|
|
lines.append(" (void)arch;")
|
|
for h in headers:
|
|
lines.append(gen_block(h))
|
|
lines.append("}")
|
|
lines.append("")
|
|
|
|
# Generate string-based dispatcher (only if multiple sets)
|
|
if len(set_names) > 0:
|
|
lines.append("// Dynamic registration by kernel set name")
|
|
lines.append(
|
|
"inline bool register_kernel_set(const std::string& set_name, ck_tile::dispatcher::Registry& registry, const std::string& arch) {"
|
|
)
|
|
for set_name, safe_name in set_names:
|
|
lines.append(
|
|
f' if (set_name == "{set_name}") {{ register_{safe_name}(registry, arch); return true; }}'
|
|
)
|
|
lines.append(" return false; // Unknown set name")
|
|
lines.append("}")
|
|
lines.append("")
|
|
|
|
# Generate helper to list available set names
|
|
lines.append("// Get list of available kernel set names")
|
|
lines.append("inline std::vector<std::string> get_kernel_set_names() {")
|
|
names_str = ", ".join(f'"{name}"' for name, _ in set_names)
|
|
lines.append(f" return {{{names_str}}};")
|
|
lines.append("}")
|
|
lines.append("")
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
def generate_conv_registration(
|
|
kernel_headers: List[Path], example_name: str, kernels: List[Dict]
|
|
) -> str:
|
|
"""Generate Conv kernel registration code for the dispatcher registry."""
|
|
if not kernel_headers:
|
|
return " // No kernels to register"
|
|
|
|
lines = []
|
|
lines.append(
|
|
" (void)registry; (void)arch; // Conv uses direct launcher pattern for now"
|
|
)
|
|
|
|
# For conv, we provide direct access to kernel launchers
|
|
for i, h in enumerate(kernel_headers):
|
|
kernel_name = h.stem
|
|
lines.append(f" // Kernel {i + 1}: {kernel_name}")
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
def generate_conv_kernels(
|
|
kernels: List[Dict], output_dir: Path, codegen_dir: Path
|
|
) -> bool:
|
|
"""Generate Conv kernels for ALL declarations using unified codegen."""
|
|
if not kernels:
|
|
return False
|
|
|
|
variant_map = {
|
|
"forward": "forward",
|
|
"bwd_data": "bwd_data",
|
|
"backward_data": "bwd_data",
|
|
"bwd_weight": "bwd_weight",
|
|
"backward_weight": "bwd_weight",
|
|
}
|
|
|
|
success_count = 0
|
|
|
|
# Generate a kernel for EACH declaration
|
|
for idx, k in enumerate(kernels):
|
|
variant = variant_map.get(k.get("conv_type", "forward"), "forward")
|
|
|
|
cmd = [
|
|
sys.executable,
|
|
str(codegen_dir / "unified_conv_codegen.py"),
|
|
"--datatype",
|
|
k.get("dtype", "fp16"),
|
|
"--variant",
|
|
variant,
|
|
"--ndim",
|
|
str(k.get("ndim", 2)),
|
|
"--output",
|
|
str(output_dir),
|
|
]
|
|
|
|
# Add optional parameters if specified
|
|
if k.get("tile_m"):
|
|
cmd.extend(["--tile-m", str(k["tile_m"])])
|
|
if k.get("tile_n"):
|
|
cmd.extend(["--tile-n", str(k["tile_n"])])
|
|
if k.get("warp_m"):
|
|
cmd.extend(["--warp-m", str(k["warp_m"])])
|
|
if k.get("warp_n"):
|
|
cmd.extend(["--warp-n", str(k["warp_n"])])
|
|
if k.get("warp_k"):
|
|
cmd.extend(["--warp-k", str(k["warp_k"])])
|
|
if k.get("warp_tile_m"):
|
|
cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])])
|
|
if k.get("warp_tile_n"):
|
|
cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])])
|
|
if k.get("warp_tile_k"):
|
|
cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])])
|
|
if k.get("pipeline"):
|
|
cmd.extend(["--pipeline", k["pipeline"]])
|
|
if k.get("scheduler"):
|
|
cmd.extend(["--scheduler", k["scheduler"]])
|
|
if k.get("epilogue"):
|
|
cmd.extend(["--epilogue", k["epilogue"]])
|
|
if k.get("vector_a"):
|
|
cmd.extend(["--vector-a", str(k["vector_a"])])
|
|
if k.get("vector_b"):
|
|
cmd.extend(["--vector-b", str(k["vector_b"])])
|
|
if k.get("vector_c"):
|
|
cmd.extend(["--vector-c", str(k["vector_c"])])
|
|
if k.get("block_per_cu"):
|
|
cmd.extend(["--block-per-cu", str(k["block_per_cu"])])
|
|
if k.get("num_wave_groups"):
|
|
cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])])
|
|
if k.get("num_groups_to_merge"):
|
|
cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])])
|
|
if k.get("double_smem_buffer") is not None:
|
|
cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()])
|
|
if k.get("tile_k"):
|
|
cmd.extend(["--tile-k", str(k["tile_k"])])
|
|
|
|
result = subprocess.run(
|
|
cmd, capture_output=True, text=True, cwd=str(codegen_dir)
|
|
)
|
|
if result.returncode != 0:
|
|
print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}")
|
|
else:
|
|
success_count += 1
|
|
|
|
return success_count > 0
|
|
|
|
|
|
def generate_gemm_kernels(
|
|
kernels: List[Dict], output_dir: Path, codegen_dir: Path
|
|
) -> bool:
|
|
"""Generate GEMM kernels for ALL declarations using unified codegen."""
|
|
import json
|
|
|
|
if not kernels:
|
|
return False
|
|
|
|
success_count = 0
|
|
|
|
# Generate a kernel for EACH declaration
|
|
for idx, k in enumerate(kernels):
|
|
variant = "multi_d" if k.get("elementwise_op") else "standard"
|
|
|
|
# Build tile config JSON for this specific kernel
|
|
tile_config = {
|
|
"tile_m": [k.get("tile_m", 128)],
|
|
"tile_n": [k.get("tile_n", 128)],
|
|
"tile_k": [k.get("tile_k", 32)],
|
|
"warp_m": [k.get("warp_m", 2)],
|
|
"warp_n": [k.get("warp_n", 2)],
|
|
"warp_k": [k.get("warp_k", 1)],
|
|
"warp_tile_m": [k.get("warp_tile_m", 32)],
|
|
"warp_tile_n": [k.get("warp_tile_n", 32)],
|
|
"warp_tile_k": [k.get("warp_tile_k", 16)],
|
|
}
|
|
|
|
trait_config = {
|
|
"pipeline": [k.get("pipeline", "compv3")],
|
|
"epilogue": [k.get("epilogue", "cshuffle")],
|
|
"scheduler": [k.get("scheduler", "intrawave")],
|
|
"pad_m": [k.get("pad_m", False)],
|
|
"pad_n": [k.get("pad_n", False)],
|
|
"pad_k": [k.get("pad_k", False)],
|
|
"persistent": [False],
|
|
}
|
|
|
|
config_json = json.dumps(
|
|
{"tile_config": tile_config, "trait_config": trait_config}
|
|
)
|
|
|
|
cmd = [
|
|
sys.executable,
|
|
str(codegen_dir / "unified_gemm_codegen.py"),
|
|
"--datatype",
|
|
k.get("dtype", "fp16"),
|
|
"--layout",
|
|
k.get("layout", "rcr"),
|
|
"--variants",
|
|
variant,
|
|
"--output",
|
|
str(output_dir),
|
|
"--tile-config-json",
|
|
config_json,
|
|
]
|
|
|
|
result = subprocess.run(
|
|
cmd, capture_output=True, text=True, cwd=str(codegen_dir)
|
|
)
|
|
if result.returncode != 0:
|
|
print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}")
|
|
else:
|
|
success_count += 1
|
|
|
|
return success_count > 0
|
|
|
|
|
|
def compile_kernel(args: Tuple) -> Tuple[str, bool, str]:
|
|
"""Compile a single kernel to object file."""
|
|
kernel_hpp, output_dir, include_dirs, hipcc, gpu_target, idx, total = args
|
|
kernel_name = kernel_hpp.stem
|
|
|
|
wrapper_cpp = output_dir / f"{kernel_name}.cpp"
|
|
wrapper_cpp.write_text(
|
|
f'#include "{kernel_hpp.name}"\nnamespace {{ volatile bool _k{idx} = true; }}\n'
|
|
)
|
|
|
|
obj_file = output_dir / f"{kernel_name}.o"
|
|
|
|
cmd = [
|
|
hipcc,
|
|
"-c",
|
|
"-fPIC",
|
|
"-std=c++17",
|
|
"-O3",
|
|
f"--offload-arch={gpu_target}",
|
|
"-mllvm",
|
|
"-enable-noalias-to-md-conversion=0",
|
|
"-Wno-undefined-func-template",
|
|
"-Wno-float-equal",
|
|
"--offload-compress",
|
|
]
|
|
|
|
for inc_dir in include_dirs:
|
|
cmd.extend(["-I", str(inc_dir)])
|
|
cmd.extend(["-I", str(kernel_hpp.parent)])
|
|
cmd.extend(["-o", str(obj_file), str(wrapper_cpp)])
|
|
|
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
if result.returncode != 0:
|
|
return (kernel_name, False, result.stderr[:500])
|
|
return (kernel_name, True, str(obj_file))
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Build example kernels")
|
|
parser.add_argument("source", type=Path, help="C++ source file")
|
|
parser.add_argument("--output-dir", type=Path, required=True)
|
|
parser.add_argument("--include-dirs", type=str, required=True)
|
|
parser.add_argument("--gpu-target", type=str, default="gfx942")
|
|
parser.add_argument("--jobs", type=int, default=os.cpu_count())
|
|
parser.add_argument(
|
|
"--target-name", type=str, help="CMake target name (for library naming)"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
script_dir = Path(__file__).parent
|
|
codegen_dir = script_dir.parent / "codegen"
|
|
source_stem = args.source.stem # e.g., "01_basic_gemm"
|
|
target_name = args.target_name or source_stem # e.g., "gemm_01_basic" from CMake
|
|
|
|
args.output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Detect and parse
|
|
example_type, kernels = detect_and_parse(args.source)
|
|
|
|
if example_type == "conv":
|
|
k = kernels[0] if kernels else {}
|
|
variant = k.get("conv_type", "forward")
|
|
print(
|
|
f"[{target_name}] Conv {k.get('dtype', 'fp16')} {variant} {k.get('ndim', 2)}D ({len(kernels)} declarations)"
|
|
)
|
|
elif example_type == "gemm":
|
|
k = kernels[0] if kernels else {}
|
|
print(
|
|
f"[{target_name}] GEMM {k.get('dtype', 'fp16')} {k.get('layout', 'rcr')} ({len(kernels)} declarations)"
|
|
)
|
|
else:
|
|
print(f"[{target_name}] No kernel declarations - creating empty library")
|
|
lib_path = args.output_dir / f"lib{target_name}_kernels.a"
|
|
subprocess.run([find_ar(), "rcs", str(lib_path)], check=True)
|
|
header = args.output_dir / f"{source_stem}_kernels.hpp"
|
|
header.write_text(f"// No kernels for {target_name}\n#pragma once\n")
|
|
return 0
|
|
|
|
# Generate kernels
|
|
print(f"[{target_name}] Generating kernels...")
|
|
if example_type == "conv":
|
|
success = generate_conv_kernels(kernels, args.output_dir, codegen_dir)
|
|
else:
|
|
success = generate_gemm_kernels(kernels, args.output_dir, codegen_dir)
|
|
|
|
if not success:
|
|
print(f"[{target_name}] Kernel generation failed!")
|
|
return 1
|
|
|
|
# Find generated headers
|
|
if example_type == "gemm":
|
|
kernel_headers = list(args.output_dir.glob("gemm_*.hpp"))
|
|
else:
|
|
k = kernels[0] if kernels else {}
|
|
variant = k.get("conv_type", "forward")
|
|
prefix_map = {
|
|
"forward": "conv_fwd",
|
|
"bwd_data": "conv_bwdd",
|
|
"bwd_weight": "conv_bwdw",
|
|
}
|
|
prefix = prefix_map.get(variant, "conv_fwd")
|
|
kernel_headers = list(args.output_dir.glob(f"{prefix}_*.hpp"))
|
|
|
|
if not kernel_headers:
|
|
print(f"[{target_name}] No kernel headers generated!")
|
|
return 1
|
|
|
|
print(f"[{target_name}] Compiling {len(kernel_headers)} kernels...")
|
|
|
|
include_dirs = [Path(p.strip()) for p in args.include_dirs.split(",")]
|
|
hipcc = find_hipcc()
|
|
|
|
work = [
|
|
(
|
|
h,
|
|
args.output_dir,
|
|
include_dirs,
|
|
hipcc,
|
|
args.gpu_target,
|
|
i + 1,
|
|
len(kernel_headers),
|
|
)
|
|
for i, h in enumerate(kernel_headers)
|
|
]
|
|
|
|
obj_files = []
|
|
failed = []
|
|
|
|
with ProcessPoolExecutor(max_workers=args.jobs) as executor:
|
|
futures = {executor.submit(compile_kernel, w): w[0].name for w in work}
|
|
for future in as_completed(futures):
|
|
name, ok, result = future.result()
|
|
if ok:
|
|
obj_files.append(result)
|
|
else:
|
|
failed.append((name, result))
|
|
print(f"[{target_name}] FAILED: {name}")
|
|
|
|
if failed:
|
|
print(f"[{target_name}] {len(failed)} kernels failed")
|
|
for name, err in failed[:3]:
|
|
print(f" {name}: {err[:200]}")
|
|
return 1
|
|
|
|
# Create static library (use target_name for CMake compatibility)
|
|
lib_path = args.output_dir / f"lib{target_name}_kernels.a"
|
|
subprocess.run([find_ar(), "rcs", str(lib_path)] + obj_files, check=True)
|
|
|
|
# Generate registration header (use source_stem for header name to match CMake's EXAMPLE_STEM)
|
|
header_path = args.output_dir / f"{source_stem}_kernels.hpp"
|
|
|
|
# Build includes
|
|
includes = "\n".join(f'#include "{h.name}"' for h in kernel_headers)
|
|
|
|
# Build kernel registration entries
|
|
# Function name uses source_stem (e.g., register_01_basic_gemm_kernels)
|
|
func_name = f"register_{source_stem}_kernels"
|
|
|
|
# Generate registration code based on example type
|
|
if example_type == "gemm":
|
|
register_body = generate_gemm_registration(kernel_headers, target_name, kernels)
|
|
else:
|
|
register_body = generate_conv_registration(kernel_headers, target_name, kernels)
|
|
|
|
# Generate appropriate header based on example type
|
|
if example_type == "conv" and kernel_headers:
|
|
launcher_aliases = []
|
|
|
|
# Helper to find kernel by dtype and type
|
|
def find_kernel_by_dtype_type(headers, dtype, conv_type_marker):
|
|
"""Find kernel matching dtype and conv type, prioritize fp16."""
|
|
matching = [h for h in headers if conv_type_marker in h.stem]
|
|
# Prefer fp16 over bf16 for default launchers
|
|
fp16_kernels = [h for h in matching if f"_{dtype}_" in h.stem]
|
|
return (
|
|
fp16_kernels[0] if fp16_kernels else (matching[0] if matching else None)
|
|
)
|
|
|
|
# Check what conv types are in the declarations
|
|
has_fwd = any("forward" in k.get("conv_type", "forward") for k in kernels)
|
|
has_bwd_data = any("bwd_data" in k.get("conv_type", "") for k in kernels)
|
|
has_bwd_weight = any("bwd_weight" in k.get("conv_type", "") for k in kernels)
|
|
|
|
# Export dtype-specific launcher aliases for each available dtype
|
|
for dtype in ["fp16", "bf16", "fp32"]:
|
|
dtype_fwd_kernels = [
|
|
h
|
|
for h in kernel_headers
|
|
if "_fwd_" in h.stem and f"_{dtype}_" in h.stem
|
|
]
|
|
if dtype_fwd_kernels:
|
|
k = dtype_fwd_kernels[0]
|
|
ns = f"ns_{k.stem}"
|
|
dtype_upper = dtype.upper()
|
|
launcher_aliases.append(
|
|
f"using {dtype_upper}FwdKernelLauncher = {ns}::{k.stem}_Launcher;"
|
|
)
|
|
|
|
# Export generic launcher aliases (prioritize fp16)
|
|
if has_fwd:
|
|
fwd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_fwd_")
|
|
if fwd_kernel:
|
|
fwd_ns = f"ns_{fwd_kernel.stem}"
|
|
launcher_aliases.append(
|
|
f"using FwdKernelLauncher = {fwd_ns}::{fwd_kernel.stem}_Launcher;"
|
|
)
|
|
launcher_aliases.append(
|
|
f"using FirstKernelLauncher = {fwd_ns}::{fwd_kernel.stem}_Launcher;"
|
|
)
|
|
|
|
if has_bwd_data:
|
|
bwdd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdd_")
|
|
if bwdd_kernel:
|
|
bwdd_ns = f"ns_{bwdd_kernel.stem}"
|
|
launcher_aliases.append(
|
|
f"using BwdDataKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;"
|
|
)
|
|
if not has_fwd: # If no fwd, use bwd_data as first
|
|
launcher_aliases.append(
|
|
f"using FirstKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;"
|
|
)
|
|
|
|
if has_bwd_weight:
|
|
bwdw_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdw_")
|
|
if bwdw_kernel:
|
|
bwdw_ns = f"ns_{bwdw_kernel.stem}"
|
|
launcher_aliases.append(
|
|
f"using BwdWeightKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;"
|
|
)
|
|
if (
|
|
not has_fwd and not has_bwd_data
|
|
): # If no fwd or bwdd, use bwdw as first
|
|
launcher_aliases.append(
|
|
f"using FirstKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;"
|
|
)
|
|
|
|
launcher_section = "\n".join(launcher_aliases)
|
|
|
|
header_content = f"""// Auto-generated for {target_name}
|
|
#pragma once
|
|
|
|
{includes}
|
|
|
|
#include "ck_tile/dispatcher/registry.hpp"
|
|
#include "ck_tile/dispatcher/kernel_instance.hpp"
|
|
#include "ck_tile/dispatcher/kernel_key.hpp"
|
|
|
|
namespace generated {{
|
|
|
|
// Kernel launchers for direct use
|
|
{launcher_section}
|
|
|
|
// Registration function
|
|
inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{
|
|
{register_body}
|
|
}}
|
|
|
|
}} // namespace generated
|
|
|
|
// Generic registration - avoids hardcoding the example name in user code
|
|
// Safe for single-example executables (typical use case)
|
|
#ifndef REGISTER_GENERATED_KERNELS
|
|
#define REGISTER_GENERATED_KERNELS(registry, arch) generated::{func_name}(registry, arch)
|
|
#endif
|
|
"""
|
|
else:
|
|
# GEMM: Generate per-set functions if multiple kernel sets declared
|
|
per_set_funcs = generate_per_set_functions(source_stem)
|
|
|
|
header_content = f"""// Auto-generated for {target_name}
|
|
#pragma once
|
|
|
|
{includes}
|
|
|
|
#include "ck_tile/dispatcher/registry.hpp"
|
|
#include "ck_tile/dispatcher/kernel_instance.hpp"
|
|
#include "ck_tile/dispatcher/kernel_key.hpp"
|
|
#include "ck_tile/dispatcher/backends/generated_kernel_backend.hpp"
|
|
|
|
namespace generated {{
|
|
|
|
// Register ALL kernels from all declared sets
|
|
inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{
|
|
{register_body}
|
|
}}
|
|
|
|
{per_set_funcs}
|
|
}} // namespace generated
|
|
|
|
// Generic registration - avoids hardcoding the example name in user code
|
|
// Safe for single-example executables (typical use case)
|
|
#ifndef REGISTER_GENERATED_KERNELS
|
|
#define REGISTER_GENERATED_KERNELS(registry, arch) generated::{func_name}(registry, arch)
|
|
#endif
|
|
|
|
// Register a specific kernel set by name (for multi-registry patterns)
|
|
// Usage: REGISTER_KERNEL_SET("compute_bound_set", registry, arch)
|
|
#ifndef REGISTER_KERNEL_SET
|
|
#define REGISTER_KERNEL_SET(set_name, registry, arch) generated::register_kernel_set(set_name, registry, arch)
|
|
#endif
|
|
"""
|
|
header_path.write_text(header_content)
|
|
|
|
print(f"[{target_name}] ✓ {len(obj_files)} kernels compiled")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|