mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-24 00:57:39 +00:00
* WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable.
1714 lines
62 KiB
Python
Executable File
1714 lines
62 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
|
||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||
# SPDX-License-Identifier: MIT
|
||
|
||
"""
|
||
Unified GEMM Code Generator - Single Source of Truth
|
||
|
||
This is THE unified code generator for all GEMM kernel variants:
|
||
- Standard GEMM (C = A × B)
|
||
- Preshuffle GEMM (optimized weight access)
|
||
- Multi-D GEMM (element-wise fusion)
|
||
|
||
Generates both CK Tile kernels AND dispatcher wrappers in one pass.
|
||
Replaces all tile_engine GEMM codegen.
|
||
"""
|
||
|
||
import json
|
||
import argparse
|
||
import itertools
|
||
import logging
|
||
from pathlib import Path
|
||
from typing import Dict, List, Tuple, Optional
|
||
from dataclasses import dataclass, asdict
|
||
from enum import Enum
|
||
import concurrent.futures
|
||
|
||
# Import architecture filter for GPU-specific validation
|
||
try:
|
||
from arch_filter import ArchFilter, KernelConfig as ArchKernelConfig, OperatorType
|
||
|
||
HAS_ARCH_FILTER = True
|
||
except ImportError:
|
||
HAS_ARCH_FILTER = False
|
||
ArchFilter = None
|
||
ArchKernelConfig = None
|
||
OperatorType = None
|
||
|
||
|
||
# =============================================================================
|
||
# Preshuffle Validation (copied from tile_engine/ops/commons/gemm_validation_utils.py)
|
||
# =============================================================================
|
||
|
||
ELEMENT_SIZE_MAP = {
|
||
"fp16": 2,
|
||
"bf16": 2,
|
||
"fp32": 4,
|
||
"fp64": 8,
|
||
"fp8": 1,
|
||
"bf8": 1,
|
||
"int8": 1,
|
||
}
|
||
|
||
|
||
def _validate_preshuffle_vector_load(
|
||
warp_tile_m: int,
|
||
warp_tile_k: int,
|
||
datatype: str,
|
||
m_iter_per_warp: float,
|
||
wave_size: int = 64,
|
||
vector_load_size: int = 16,
|
||
) -> bool:
|
||
"""
|
||
Validate vector load alignment for preshuffle pipeline.
|
||
|
||
Checks: (warp_tile_m * warp_tile_k * elem_size * m_iter_per_warp / wave_size) % vector_load_size == 0
|
||
"""
|
||
elem_size = ELEMENT_SIZE_MAP.get(datatype, 2)
|
||
access_size = (warp_tile_m * warp_tile_k * elem_size * m_iter_per_warp) / wave_size
|
||
return access_size % vector_load_size == 0
|
||
|
||
|
||
def _validate_preshuffle_m0_m1_m2(
|
||
tile_m: int,
|
||
tile_k: int,
|
||
warp_m: int,
|
||
warp_n: int,
|
||
warp_k: int,
|
||
datatype: str,
|
||
vector_load_size: int = 16,
|
||
warp_size: int = 64,
|
||
) -> bool:
|
||
"""
|
||
Validate M0, M1, M2 configuration for preshuffle matrix A row-major layout.
|
||
Ensures proper memory access pattern alignment.
|
||
"""
|
||
try:
|
||
elem_size = ELEMENT_SIZE_MAP.get(datatype, 2)
|
||
MPerBlock = tile_m
|
||
|
||
# Calculate K1
|
||
K1 = vector_load_size / elem_size
|
||
if K1 != int(K1):
|
||
return False
|
||
K1 = int(K1)
|
||
|
||
# Calculate K0
|
||
if tile_k % K1 != 0:
|
||
return False
|
||
K0 = tile_k // K1
|
||
|
||
# Calculate M2
|
||
if warp_size % K0 != 0:
|
||
return False
|
||
M2 = warp_size // K0
|
||
|
||
# Calculate number of warps
|
||
NumWarps = warp_m * warp_n * warp_k
|
||
M0 = NumWarps
|
||
|
||
# Calculate M1
|
||
if (M2 * M0) == 0:
|
||
return False
|
||
if MPerBlock % (M2 * M0) != 0:
|
||
return False
|
||
M1 = MPerBlock // (M2 * M0)
|
||
|
||
# Validate: M0 * M1 * M2 == MPerBlock
|
||
return (M0 * M1 * M2) == MPerBlock
|
||
|
||
except (ZeroDivisionError, ValueError):
|
||
return False
|
||
|
||
|
||
def is_preshuffle_config_valid(
|
||
tile_m: int,
|
||
tile_n: int,
|
||
tile_k: int,
|
||
warp_m: int,
|
||
warp_n: int,
|
||
warp_k: int,
|
||
warp_tile_m: int,
|
||
warp_tile_n: int,
|
||
warp_tile_k: int,
|
||
datatype: str,
|
||
) -> bool:
|
||
"""
|
||
Comprehensive preshuffle configuration validation.
|
||
Copied from tile_engine/ops/commons/gemm_validation_utils.py
|
||
"""
|
||
# Basic divisibility checks
|
||
if tile_m % (warp_m * warp_tile_m) != 0:
|
||
return False
|
||
if tile_n % (warp_n * warp_tile_n) != 0:
|
||
return False
|
||
if tile_k % (warp_k * warp_tile_k) != 0:
|
||
return False
|
||
|
||
# Calculate m_iter_per_warp
|
||
m_iter_per_warp = tile_m / (warp_m * warp_tile_m)
|
||
|
||
# Validate vector load alignment
|
||
if not _validate_preshuffle_vector_load(
|
||
warp_tile_m,
|
||
warp_tile_k,
|
||
datatype,
|
||
m_iter_per_warp,
|
||
wave_size=64,
|
||
vector_load_size=16,
|
||
):
|
||
return False
|
||
|
||
# Validate M0/M1/M2 configuration
|
||
if not _validate_preshuffle_m0_m1_m2(
|
||
tile_m,
|
||
tile_k,
|
||
warp_m,
|
||
warp_n,
|
||
warp_k,
|
||
datatype,
|
||
vector_load_size=16,
|
||
warp_size=64,
|
||
):
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||
|
||
log = logging.getLogger(__name__)
|
||
|
||
|
||
# ============================================================================
|
||
# Configuration and Data Structures
|
||
# ============================================================================
|
||
|
||
|
||
class GemmVariant(Enum):
|
||
"""GEMM kernel variants"""
|
||
|
||
STANDARD = "standard"
|
||
PRESHUFFLE = "preshuffle"
|
||
MULTI_D = "multi_d"
|
||
|
||
|
||
@dataclass
|
||
class TileConfig:
|
||
"""Tile configuration parameters"""
|
||
|
||
tile_m: int
|
||
tile_n: int
|
||
tile_k: int
|
||
warp_m: int
|
||
warp_n: int
|
||
warp_k: int
|
||
warp_tile_m: int
|
||
warp_tile_n: int
|
||
warp_tile_k: int
|
||
|
||
def is_valid(self) -> bool:
|
||
"""Validate tile configuration"""
|
||
return (
|
||
self.tile_m % (self.warp_m * self.warp_tile_m) == 0
|
||
and self.tile_n % (self.warp_n * self.warp_tile_n) == 0
|
||
and self.tile_k % (self.warp_k * self.warp_tile_k) == 0
|
||
and self.tile_m > 0
|
||
and self.tile_n > 0
|
||
and self.tile_k > 0
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class TraitConfig:
|
||
"""Kernel trait configuration"""
|
||
|
||
pipeline: str # mem, compv3, compv4
|
||
epilogue: str # default, cshuffle
|
||
scheduler: str # intrawave, interwave
|
||
pad_m: bool
|
||
pad_n: bool
|
||
pad_k: bool
|
||
persistent: bool
|
||
|
||
def is_valid(self) -> bool:
|
||
"""Check if trait combination is valid"""
|
||
# Unsupported combinations
|
||
# Only 'mem' pipeline supports interwave scheduler.
|
||
# All compute pipelines (compv3/v4/v5/v6/async) only support intrawave.
|
||
unsupported = {
|
||
("compv3", "cshuffle", "interwave"),
|
||
("compv3", "default", "interwave"),
|
||
("compv4", "cshuffle", "interwave"),
|
||
("compv4", "default", "interwave"),
|
||
("compv5", "cshuffle", "interwave"),
|
||
("compv5", "default", "interwave"),
|
||
("compv6", "cshuffle", "interwave"),
|
||
("compv6", "default", "interwave"),
|
||
("comp_async", "cshuffle", "interwave"),
|
||
("comp_async", "default", "interwave"),
|
||
}
|
||
return (self.pipeline, self.epilogue, self.scheduler) not in unsupported
|
||
|
||
|
||
@dataclass
|
||
class KernelConfig:
|
||
"""Complete kernel configuration"""
|
||
|
||
tile: TileConfig
|
||
trait: TraitConfig
|
||
variant: GemmVariant = GemmVariant.STANDARD
|
||
|
||
# Variant-specific
|
||
preshuffle: bool = False
|
||
elementwise_op: str = "PassThrough"
|
||
num_d_tensors: int = 0
|
||
d_layout: str = "r" # Layout for D tensors (r=row, c=col) - same for all D tensors
|
||
|
||
# Fixed parameters
|
||
block_size: int = 256
|
||
k_block_per_cu: int = 1
|
||
num_wave_groups: int = 1
|
||
|
||
def name(self, datatype: str, layout: str) -> str:
|
||
"""C++ alias for template instance"""
|
||
return f"ck_tile_gemm_{self.key_name(datatype, layout)}"
|
||
|
||
def key_name(self, datatype: str, layout: str) -> str:
|
||
"""
|
||
Unique identifier for this kernel configuration.
|
||
|
||
All parameters that affect kernel behavior MUST be included to ensure
|
||
unique names for unique configurations:
|
||
- Data type and layout (signature)
|
||
- Tile, warp, warp_tile dimensions (algorithm)
|
||
- Pipeline, epilogue, scheduler (traits)
|
||
- Padding flags (affects divisibility requirements)
|
||
- Persistent mode
|
||
- Preshuffle variant
|
||
- Multi-D: elementwise op, num D tensors, D layout
|
||
- Occupancy: wave groups, k_block_per_cu (if non-default)
|
||
"""
|
||
parts = []
|
||
# Signature
|
||
parts.append(f"dt_{datatype}")
|
||
parts.append(f"ly_{layout}")
|
||
|
||
# Tile configuration
|
||
parts.append(f"tile_{self.tile.tile_m}x{self.tile.tile_n}x{self.tile.tile_k}")
|
||
parts.append(f"warp_{self.tile.warp_m}x{self.tile.warp_n}x{self.tile.warp_k}")
|
||
parts.append(
|
||
f"wtile_{self.tile.warp_tile_m}x{self.tile.warp_tile_n}x{self.tile.warp_tile_k}"
|
||
)
|
||
|
||
# Traits
|
||
parts.append(f"pipe_{self.trait.pipeline}")
|
||
parts.append(f"epi_{self.trait.epilogue}")
|
||
parts.append(f"sched_{self.trait.scheduler}")
|
||
|
||
# Padding flags (only if not all True - the common case)
|
||
if not (self.trait.pad_m and self.trait.pad_n and self.trait.pad_k):
|
||
parts.append(
|
||
f"pad{int(self.trait.pad_m)}{int(self.trait.pad_n)}{int(self.trait.pad_k)}"
|
||
)
|
||
|
||
# Persistent mode
|
||
if self.trait.persistent:
|
||
parts.append("persist")
|
||
|
||
# Preshuffle variant
|
||
if self.preshuffle:
|
||
parts.append("preshuffle")
|
||
|
||
# Multi-D variant: include elementwise op, num tensors, and D layout
|
||
if self.variant == GemmVariant.MULTI_D:
|
||
parts.append(f"ew_{self.elementwise_op}")
|
||
parts.append(f"nd{self.num_d_tensors}")
|
||
parts.append(f"dly_{self.d_layout}")
|
||
|
||
# Occupancy parameters (only if non-default)
|
||
if self.num_wave_groups != 1:
|
||
parts.append(f"wg{self.num_wave_groups}")
|
||
if self.k_block_per_cu != 1:
|
||
parts.append(f"kbpc{self.k_block_per_cu}")
|
||
|
||
return "_".join(parts)
|
||
|
||
def dict_items(self):
|
||
"""Iterator over (field, value) pairs"""
|
||
return asdict(self).items()
|
||
|
||
|
||
# ============================================================================
|
||
# Type Mappings
|
||
# ============================================================================
|
||
|
||
|
||
class TypeMappings:
|
||
"""Centralized type mappings for code generation"""
|
||
|
||
DTYPE_TO_CK = {
|
||
"fp16": "fp16_t",
|
||
"bf16": "bf16_t",
|
||
"fp32": "float",
|
||
"fp8": "fp8_t",
|
||
"bf8": "bf8_t",
|
||
"int8": "int8_t",
|
||
}
|
||
|
||
# Fully-qualified types for use outside of 'using namespace ck_tile' scope
|
||
DTYPE_TO_CK_QUALIFIED = {
|
||
"fp16": "ck_tile::fp16_t",
|
||
"bf16": "ck_tile::bf16_t",
|
||
"fp32": "float", # Built-in type, no namespace
|
||
"fp8": "ck_tile::fp8_t",
|
||
"bf8": "ck_tile::bf8_t",
|
||
"int8": "int8_t", # Built-in type
|
||
}
|
||
|
||
DTYPE_TO_DISPATCHER = {
|
||
"fp16": "DataType::FP16",
|
||
"bf16": "DataType::BF16",
|
||
"fp32": "DataType::FP32",
|
||
"fp8": "DataType::FP8",
|
||
"bf8": "DataType::BF8",
|
||
"int8": "DataType::INT8",
|
||
}
|
||
|
||
LAYOUT_TO_CK = {
|
||
"r": "tensor_layout::gemm::RowMajor",
|
||
"c": "tensor_layout::gemm::ColumnMajor",
|
||
}
|
||
|
||
LAYOUT_TO_DISPATCHER = {
|
||
"r": "LayoutTag::RowMajor",
|
||
"c": "LayoutTag::ColMajor",
|
||
}
|
||
|
||
PIPELINE_TO_CK = {
|
||
"mem": "GemmPipelineAgBgCrMem",
|
||
"compv3": "GemmPipelineAgBgCrCompV3",
|
||
"compv4": "GemmPipelineAgBgCrCompV4",
|
||
"preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2",
|
||
}
|
||
|
||
PIPELINE_TO_BASE = {
|
||
"mem": "BaseGemmPipelineAgBgCrMem",
|
||
"compv3": "BaseGemmPipelineAgBgCrCompV3",
|
||
"compv4": "BaseGemmPipelineAgBgCrCompV4",
|
||
"preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2",
|
||
}
|
||
|
||
PIPELINE_TO_DISPATCHER = {
|
||
"mem": "Pipeline::Mem",
|
||
"compv3": "Pipeline::CompV3",
|
||
"compv4": "Pipeline::CompV4",
|
||
"preshufflev2": "Pipeline::PreShuffleV2",
|
||
}
|
||
|
||
SCHEDULER_TO_CK = {
|
||
"intrawave": "GemmPipelineScheduler::Intrawave",
|
||
"interwave": "GemmPipelineScheduler::Interwave",
|
||
"default": "GemmPipelineScheduler::Default",
|
||
}
|
||
|
||
SCHEDULER_TO_DISPATCHER = {
|
||
"intrawave": "Scheduler::Intrawave",
|
||
"interwave": "Scheduler::Interwave",
|
||
"default": "Scheduler::Auto",
|
||
}
|
||
|
||
EPILOGUE_TO_DISPATCHER = {
|
||
"cshuffle": "Epilogue::CShuffle",
|
||
"default": "Epilogue::Default",
|
||
}
|
||
|
||
@staticmethod
|
||
def get_output_dtype(dtype: str) -> str:
|
||
"""Get output datatype (fp8/bf8 -> fp16)"""
|
||
return "fp16" if dtype in ["fp8", "bf8"] else dtype
|
||
|
||
|
||
# ============================================================================
|
||
# Kernel Name Generator
|
||
# ============================================================================
|
||
|
||
|
||
class KernelNaming:
|
||
"""Unified kernel naming"""
|
||
|
||
@staticmethod
|
||
def generate(config: KernelConfig, datatype: str, layout: str) -> str:
|
||
"""Generate kernel name following tile_engine convention"""
|
||
t = config.tile
|
||
tr = config.trait
|
||
|
||
# For multi-d, use 4-char layout (abcd), otherwise use 3-char layout (abc)
|
||
if config.variant == GemmVariant.MULTI_D:
|
||
full_layout = layout + config.d_layout # e.g., "rcr" + "r" = "rcrr"
|
||
else:
|
||
full_layout = layout
|
||
|
||
name = (
|
||
f"gemm_{datatype}_{full_layout}_{tr.pipeline}_{tr.epilogue}_{tr.scheduler}"
|
||
)
|
||
name += f"_{str(tr.pad_m).capitalize()}_{str(tr.pad_n).capitalize()}"
|
||
name += f"_{str(tr.pad_k).capitalize()}_{str(tr.persistent).capitalize()}"
|
||
name += f"_{t.tile_m}x{t.tile_n}x{t.tile_k}"
|
||
name += f"_{t.warp_m}x{t.warp_n}x{t.warp_k}"
|
||
name += f"_{t.warp_tile_m}x{t.warp_tile_n}x{t.warp_tile_k}"
|
||
|
||
# Add variant suffix
|
||
if config.variant == GemmVariant.PRESHUFFLE:
|
||
name += "_preshuffle"
|
||
elif config.variant == GemmVariant.MULTI_D:
|
||
name += f"_multid_{config.elementwise_op}_d{config.num_d_tensors}"
|
||
|
||
return name
|
||
|
||
|
||
# ============================================================================
|
||
# CK Tile Kernel Generator
|
||
# ============================================================================
|
||
|
||
|
||
class CKTileKernelGenerator:
|
||
"""Generates CK Tile kernel instance code"""
|
||
|
||
def __init__(self, datatype: str, layout: str):
|
||
self.datatype = datatype
|
||
self.layout = layout
|
||
self.tm = TypeMappings()
|
||
|
||
def generate(self, config: KernelConfig) -> str:
|
||
"""Generate complete CK Tile kernel"""
|
||
kernel_name = KernelNaming.generate(config, self.datatype, self.layout)
|
||
|
||
return f"""{self._header(kernel_name, config)}
|
||
{self._types(config, kernel_name)}
|
||
{self._selected_kernel_struct(config, kernel_name)}
|
||
"""
|
||
|
||
def _header(self, kernel_name: str, config: KernelConfig) -> str:
|
||
"""Generate header includes"""
|
||
includes = """// SPDX-License-Identifier: MIT
|
||
// Auto-generated CK Tile GEMM kernel
|
||
#pragma once
|
||
|
||
#include <cstdint>
|
||
#include "ck_tile/core.hpp"
|
||
#include "ck_tile/host/kernel_launch.hpp"
|
||
#include "ck_tile/ops/gemm.hpp"
|
||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||
|
||
"""
|
||
|
||
if config.variant == GemmVariant.MULTI_D:
|
||
includes += """
|
||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
|
||
"""
|
||
|
||
if config.preshuffle:
|
||
includes += """
|
||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp"
|
||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
|
||
"""
|
||
|
||
return includes
|
||
|
||
def _types(self, config: KernelConfig, kernel_name: str) -> str:
|
||
"""Generate type definitions - just the namespace import, types are in kernel namespace"""
|
||
# Note: Data types and layouts are now defined inside each kernel's unique namespace
|
||
# to avoid type alias redefinition conflicts when mixing layouts (e.g., RCR + RRR)
|
||
types = """
|
||
// Use ck_tile namespace for generated code
|
||
using namespace ck_tile;
|
||
"""
|
||
return types
|
||
|
||
def _kernel_local_types(self, config: KernelConfig) -> str:
|
||
"""Generate data type and layout definitions inside kernel namespace"""
|
||
output_dtype = self.tm.get_output_dtype(self.datatype)
|
||
|
||
return f"""
|
||
// Data types (inside namespace to avoid conflicts across layouts)
|
||
using ADataType = {self.tm.DTYPE_TO_CK[self.datatype]};
|
||
using BDataType = {self.tm.DTYPE_TO_CK[self.datatype]};
|
||
using AccDataType = float;
|
||
using CDataType = {self.tm.DTYPE_TO_CK[output_dtype]};
|
||
|
||
// Layouts (inside namespace to avoid conflicts when mixing layouts)
|
||
using ALayout = {self.tm.LAYOUT_TO_CK[self.layout[0]]};
|
||
using BLayout = {self.tm.LAYOUT_TO_CK[self.layout[1]]};
|
||
using CLayout = {self.tm.LAYOUT_TO_CK[self.layout[2]]};
|
||
"""
|
||
|
||
def _multi_d_types(self, config: KernelConfig) -> str:
|
||
"""Generate multi-d type definitions (inside namespace to avoid conflicts)"""
|
||
if config.variant != GemmVariant.MULTI_D:
|
||
return ""
|
||
|
||
d_types = ", ".join(["CDataType"] * config.num_d_tensors)
|
||
d_layout_ck = self.tm.LAYOUT_TO_CK[config.d_layout]
|
||
d_layouts = ", ".join([d_layout_ck] * config.num_d_tensors)
|
||
|
||
return f"""
|
||
// Multi-D types (defined in namespace to avoid conflicts)
|
||
using DsDataType = tuple<{d_types}>;
|
||
using DLayout = {d_layout_ck}; // D tensor layout (can differ from C)
|
||
using DsLayout = tuple<{d_layouts}>;
|
||
using ElementWiseFn = element_wise::{config.elementwise_op};
|
||
static constexpr index_t NumDTensor = {config.num_d_tensors};
|
||
using GemmMultiDArgs = GemmMultiDHostArgs<NumDTensor>;
|
||
"""
|
||
|
||
def _selected_kernel_struct(self, config: KernelConfig, kernel_name: str) -> str:
|
||
"""Generate SelectedKernel struct with unique name in unique namespace"""
|
||
t = config.tile
|
||
tr = config.trait
|
||
output_dtype = self.tm.get_output_dtype(self.datatype)
|
||
|
||
# Generate unique struct name and namespace from kernel name
|
||
struct_name = f"Kernel_{kernel_name}"
|
||
# Create valid C++ namespace name (replace invalid chars)
|
||
ns_name = "ns_" + kernel_name.replace("-", "_")
|
||
|
||
multi_d_types = self._multi_d_types(config)
|
||
|
||
return f"""
|
||
namespace {ns_name} {{
|
||
constexpr const char* KERNEL_NAME = "{kernel_name}";
|
||
|
||
// Data types (inside namespace to avoid conflicts across different kernels)
|
||
using ADataType = {self.tm.DTYPE_TO_CK[self.datatype]};
|
||
using BDataType = {self.tm.DTYPE_TO_CK[self.datatype]};
|
||
using AccDataType = float;
|
||
using CDataType = {self.tm.DTYPE_TO_CK[output_dtype]};
|
||
|
||
// Layouts (inside namespace to avoid conflicts when mixing layouts like RCR + RRR)
|
||
using ALayout = {self.tm.LAYOUT_TO_CK[self.layout[0]]};
|
||
using BLayout = {self.tm.LAYOUT_TO_CK[self.layout[1]]};
|
||
using CLayout = {self.tm.LAYOUT_TO_CK[self.layout[2]]};
|
||
{multi_d_types}
|
||
struct {struct_name} {{
|
||
// Data types (required by backend as member types)
|
||
using ADataType = {ns_name}::ADataType;
|
||
using BDataType = {ns_name}::BDataType;
|
||
using CDataType = {ns_name}::CDataType;
|
||
using AccDataType = {ns_name}::AccDataType;
|
||
|
||
// Configuration
|
||
static constexpr index_t BlockSize = {config.block_size};
|
||
static constexpr index_t TileM = {t.tile_m};
|
||
static constexpr index_t TileN = {t.tile_n};
|
||
static constexpr index_t TileK = {t.tile_k};
|
||
static constexpr index_t WarpPerBlock_M = {t.warp_m};
|
||
static constexpr index_t WarpPerBlock_N = {t.warp_n};
|
||
static constexpr index_t WarpPerBlock_K = {t.warp_k};
|
||
static constexpr index_t WarpTileM = {t.warp_tile_m};
|
||
static constexpr index_t WarpTileN = {t.warp_tile_n};
|
||
static constexpr index_t WarpTileK = {t.warp_tile_k};
|
||
|
||
// Traits
|
||
static constexpr bool kPadM = {str(tr.pad_m).lower()};
|
||
static constexpr bool kPadN = {str(tr.pad_n).lower()};
|
||
static constexpr bool kPadK = {str(tr.pad_k).lower()};
|
||
static constexpr bool TransposeC = false;
|
||
static constexpr bool UsePersistentKernel = {str(tr.persistent).lower()};
|
||
static constexpr bool DoubleSmemBuffer = {str(tr.pipeline == "compv4" or tr.pipeline == "preshufflev2").lower()};
|
||
static constexpr bool UseStructuredSparsity = false;
|
||
static constexpr bool Preshuffle = {str(config.preshuffle).lower()};
|
||
static constexpr index_t NumWaveGroups = {config.num_wave_groups};
|
||
|
||
{self._tile_types(config, ns_name)}
|
||
{self._launch_function(config)}
|
||
}};
|
||
|
||
// Alias for tile_engine style compatibility (when used with -include)
|
||
using SelectedKernel = {struct_name};
|
||
using SelectedKernelLauncher = {struct_name};
|
||
}} // namespace {ns_name}
|
||
|
||
// Export to global namespace ONLY for single-kernel includes
|
||
// Define CK_TILE_SINGLE_KERNEL_INCLUDE before including this header to enable these aliases
|
||
#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE
|
||
using {struct_name} = {ns_name}::{struct_name};
|
||
using SelectedKernel = {ns_name}::{struct_name};
|
||
constexpr const char* KERNEL_NAME = {ns_name}::KERNEL_NAME;
|
||
using ADataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.datatype]};
|
||
using BDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.datatype]};
|
||
using CDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.tm.get_output_dtype(self.datatype)]};
|
||
using AccDataType = float;
|
||
#endif // CK_TILE_SINGLE_KERNEL_INCLUDE
|
||
"""
|
||
|
||
def _tile_types(self, config: KernelConfig, ns_name: str) -> str:
|
||
"""Generate tile type definitions - uses namespace-qualified types"""
|
||
return (
|
||
f"""// Tile shape
|
||
using TileShape = TileGemmShape<
|
||
sequence<TileM, TileN, TileK>,
|
||
sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
|
||
sequence<WarpTileM, WarpTileN, WarpTileK>,
|
||
false, false>;
|
||
|
||
using TilePartitioner = GemmSpatiallyLocalTilePartitioner<TileShape, 8, 4>;
|
||
using Traits = TileGemmTraits<kPadM, kPadN, kPadK, {ns_name}::ALayout, {ns_name}::BLayout, {ns_name}::CLayout, NumWaveGroups>;
|
||
using GemmPipelineProblem = GemmPipelineProblem<ADataType, BDataType, AccDataType, TileShape, Traits>;
|
||
using BaseGemmPipeline = """
|
||
+ self.tm.PIPELINE_TO_BASE[config.trait.pipeline]
|
||
+ """<GemmPipelineProblem>;"""
|
||
)
|
||
|
||
def _launch_function(self, config: KernelConfig) -> str:
|
||
"""Generate launch function"""
|
||
if config.variant == GemmVariant.MULTI_D:
|
||
return self._launch_function_multi_d(config)
|
||
if config.preshuffle:
|
||
return self._launch_function_preshuffle(config)
|
||
return self._launch_function_standard(config)
|
||
|
||
def _launch_function_standard(self, config: KernelConfig) -> str:
|
||
"""Generate launch function for standard GEMM"""
|
||
return f"""
|
||
static float launch(const GemmHostArgs& args, const stream_config& stream) {{
|
||
const index_t k_grain = args.k_batch * TileK;
|
||
const index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
|
||
const index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||
const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||
|
||
float ave_time{{0}};
|
||
|
||
constexpr auto scheduler = {self.tm.SCHEDULER_TO_CK[config.trait.scheduler]};
|
||
|
||
using UniversalGemmProblem = UniversalGemmPipelineProblem<
|
||
ADataType, BDataType, AccDataType, TileShape,
|
||
TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||
ALayout, BLayout, CLayout, TransposeC,
|
||
UseStructuredSparsity, UsePersistentKernel,
|
||
NumWaveGroups, Preshuffle>,
|
||
scheduler>;
|
||
|
||
using GemmPipeline = {self.tm.PIPELINE_TO_CK[config.trait.pipeline]}<UniversalGemmProblem>;
|
||
{self._epilogue_code(config)}
|
||
|
||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||
|
||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{
|
||
auto kargs = GemmKernel::MakeKernelArgs(args);
|
||
|
||
if (!GemmKernel::IsSupportedArgument(kargs)) {{
|
||
throw std::runtime_error("Arguments not supported!");
|
||
}}
|
||
|
||
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if config.trait.persistent else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
|
||
const dim3 blocks = GemmKernel::BlockSize();
|
||
|
||
constexpr int kBlockPerCu = {config.k_block_per_cu};
|
||
ave_time = launch_kernel(stream,
|
||
make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
||
|
||
return ave_time;
|
||
}};
|
||
|
||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||
return ave_time;
|
||
}}"""
|
||
|
||
def _launch_function_preshuffle(self, config: KernelConfig) -> str:
|
||
"""Generate launch function for preshuffle GEMM (weight preshuffle variant)
|
||
|
||
Preshuffle uses WeightPreshufflePipelineAGmemBGmemCRegV2 which has a different
|
||
API than standard pipelines. It's designed for weight-preshuffled GEMM operations.
|
||
"""
|
||
return f"""
|
||
static float launch(const GemmHostArgs& args, const stream_config& stream) {{
|
||
const index_t k_grain = args.k_batch * TileK;
|
||
const index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
|
||
const index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||
const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||
|
||
float ave_time{{0}};
|
||
|
||
constexpr auto scheduler = GemmPipelineScheduler::Default; // Preshuffle uses Default scheduler
|
||
|
||
// Preshuffle uses TileFlatmmShape instead of TileGemmShape for the problem
|
||
using UniversalGemmProblem = UniversalGemmPipelineProblem<
|
||
ADataType, BDataType, AccDataType, TileShape,
|
||
TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||
ALayout, BLayout, CLayout, TransposeC,
|
||
UseStructuredSparsity, UsePersistentKernel,
|
||
NumWaveGroups, Preshuffle>,
|
||
scheduler>;
|
||
|
||
using GemmPipeline = WeightPreshufflePipelineAGmemBGmemCRegV2<UniversalGemmProblem>;
|
||
{self._epilogue_code(config)}
|
||
|
||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||
|
||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{
|
||
auto kargs = GemmKernel::MakeKernelArgs(args);
|
||
|
||
if (!GemmKernel::IsSupportedArgument(kargs)) {{
|
||
throw std::runtime_error("Arguments not supported for preshuffle kernel!");
|
||
}}
|
||
|
||
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if config.trait.persistent else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
|
||
const dim3 blocks = GemmKernel::BlockSize();
|
||
|
||
constexpr int kBlockPerCu = {config.k_block_per_cu};
|
||
ave_time = launch_kernel(stream,
|
||
make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
||
|
||
return ave_time;
|
||
}};
|
||
|
||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||
return ave_time;
|
||
}}"""
|
||
|
||
def _launch_function_multi_d(self, config: KernelConfig) -> str:
|
||
"""Generate launch function for Multi-D GEMM"""
|
||
return f"""
|
||
// Multi-D launch function - takes GemmMultiDHostArgs with D tensor pointers
|
||
static float launch(const GemmMultiDArgs& args, const stream_config& stream) {{
|
||
const index_t k_grain = args.k_batch * TileK;
|
||
const index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
|
||
const index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||
const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||
|
||
float ave_time{{0}};
|
||
|
||
constexpr auto scheduler = {self.tm.SCHEDULER_TO_CK[config.trait.scheduler]};
|
||
|
||
using UniversalGemmProblem = UniversalGemmPipelineProblem<
|
||
ADataType, BDataType, AccDataType, TileShape,
|
||
TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||
ALayout, BLayout, CLayout, TransposeC,
|
||
UseStructuredSparsity, UsePersistentKernel,
|
||
NumWaveGroups, Preshuffle>,
|
||
scheduler>;
|
||
|
||
using GemmPipeline = {self.tm.PIPELINE_TO_CK[config.trait.pipeline]}<UniversalGemmProblem>;
|
||
{self._epilogue_code(config)}
|
||
|
||
// Use GemmKernelMultiD for Multi-D variant
|
||
using GemmKernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||
|
||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{
|
||
auto kargs = GemmKernel::MakeKernelArgs(args);
|
||
|
||
if (!GemmKernel::IsSupportedArgument(kargs)) {{
|
||
throw std::runtime_error("Arguments not supported! Multi-D currently doesn't support k_batch > 1");
|
||
}}
|
||
|
||
const dim3 grids = GemmKernel::GridSize(args.M, args.N, args.k_batch);
|
||
const dim3 blocks = GemmKernel::BlockSize();
|
||
|
||
constexpr int kBlockPerCu = {config.k_block_per_cu};
|
||
ave_time = launch_kernel(stream,
|
||
make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
||
|
||
return ave_time;
|
||
}};
|
||
|
||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||
return ave_time;
|
||
}}
|
||
|
||
// Overload for standard GemmHostArgs (converts to Multi-D args with empty D tensors)
|
||
static float launch(const GemmHostArgs& args, const stream_config& stream) {{
|
||
std::array<const void*, NumDTensor> empty_ds{{}};
|
||
std::array<index_t, NumDTensor> empty_strides{{}};
|
||
for (index_t i = 0; i < NumDTensor; ++i) {{
|
||
empty_ds[i] = nullptr;
|
||
empty_strides[i] = 0;
|
||
}}
|
||
GemmMultiDArgs multi_d_args{{
|
||
args.a_ptr,
|
||
args.b_ptr,
|
||
empty_ds,
|
||
args.e_ptr,
|
||
args.k_batch,
|
||
args.M,
|
||
args.N,
|
||
args.K,
|
||
args.stride_A,
|
||
args.stride_B,
|
||
empty_strides,
|
||
args.stride_C
|
||
}};
|
||
return launch(multi_d_args, stream);
|
||
}}"""
|
||
|
||
def _epilogue_code(self, config: KernelConfig) -> str:
|
||
"""Generate epilogue code"""
|
||
if config.variant == GemmVariant.MULTI_D:
|
||
return """
|
||
using EpilogueProblem = CShuffleEpilogueProblem<
|
||
ADataType, BDataType, DsDataType, AccDataType, CDataType,
|
||
DsLayout, CLayout, ElementWiseFn,
|
||
TilePartitioner::MPerBlock, TilePartitioner::NPerBlock,
|
||
WarpPerBlock_M, WarpPerBlock_N, WarpTileM, WarpTileN, WarpTileK,
|
||
TransposeC, NumWaveGroups, false, 1, false, 1, DoubleSmemBuffer>;
|
||
using GemmEpilogue = CShuffleEpilogue<EpilogueProblem>;"""
|
||
elif config.trait.epilogue == "cshuffle":
|
||
return """
|
||
using EpilogueProblem = CShuffleEpilogueProblem<
|
||
ADataType, BDataType, tuple<>, AccDataType, CDataType,
|
||
tuple<>, CLayout, element_wise::PassThrough,
|
||
TilePartitioner::MPerBlock, TilePartitioner::NPerBlock,
|
||
WarpPerBlock_M, WarpPerBlock_N, WarpTileM, WarpTileN, WarpTileK,
|
||
TransposeC, NumWaveGroups, false, 1, false, 1, DoubleSmemBuffer>;
|
||
using GemmEpilogue = CShuffleEpilogue<EpilogueProblem>;"""
|
||
else:
|
||
return """
|
||
using EpilogueProblem = DefaultGemm2DEpilogueProblem<
|
||
ADataType, BDataType, tuple<>, AccDataType, CDataType,
|
||
tuple<>, CLayout, element_wise::PassThrough,
|
||
TilePartitioner::MPerBlock, TilePartitioner::NPerBlock,
|
||
kPadM, kPadN, WarpTileM, WarpTileN, WarpTileK, TransposeC>;
|
||
using GemmEpilogue = DefaultGemm2DEpilogue<EpilogueProblem>;"""
|
||
|
||
|
||
# ============================================================================
|
||
# Dispatcher Wrapper Generator
|
||
# ============================================================================
|
||
|
||
|
||
class DispatcherWrapperGenerator:
|
||
"""Generates dispatcher wrapper code"""
|
||
|
||
def __init__(self, datatype: str, layout: str):
|
||
self.datatype = datatype
|
||
self.layout = layout
|
||
self.tm = TypeMappings()
|
||
|
||
def generate(
|
||
self, config: KernelConfig, kernel_path: Path, output_dir: Path
|
||
) -> str:
|
||
"""Generate dispatcher wrapper"""
|
||
kernel_name = KernelNaming.generate(config, self.datatype, self.layout)
|
||
output_dtype = self.tm.get_output_dtype(self.datatype)
|
||
rel_path = kernel_path.relative_to(output_dir)
|
||
|
||
return f"""// SPDX-License-Identifier: MIT
|
||
// Auto-generated dispatcher wrapper
|
||
#pragma once
|
||
|
||
#include "ck_tile/dispatcher.hpp"
|
||
#include "ck_tile/dispatcher/backends/generated_kernel_backend.hpp"
|
||
#include "{rel_path}"
|
||
|
||
namespace ck_tile {{
|
||
namespace dispatcher {{
|
||
namespace generated {{
|
||
|
||
using ::ck_tile::dispatcher::KernelInstancePtr;
|
||
using ::ck_tile::dispatcher::KernelKey;
|
||
using ::ck_tile::dispatcher::DataType;
|
||
using ::ck_tile::dispatcher::LayoutTag;
|
||
using ::ck_tile::dispatcher::Pipeline;
|
||
using ::ck_tile::dispatcher::Scheduler;
|
||
using ::ck_tile::dispatcher::Epilogue;
|
||
using Priority = ::ck_tile::dispatcher::Registry::Priority;
|
||
namespace backends = ::ck_tile::dispatcher::backends;
|
||
|
||
inline KernelInstancePtr make_{kernel_name}(const std::string& gfx_arch = "gfx942") {{
|
||
// Use the unique kernel struct name
|
||
using KernelStruct = Kernel_{kernel_name};
|
||
|
||
KernelKey key;
|
||
|
||
// Signature
|
||
key.signature.dtype_a = {self.tm.DTYPE_TO_DISPATCHER[self.datatype]};
|
||
key.signature.dtype_b = {self.tm.DTYPE_TO_DISPATCHER[self.datatype]};
|
||
key.signature.dtype_c = {self.tm.DTYPE_TO_DISPATCHER[output_dtype]};
|
||
key.signature.dtype_acc = DataType::FP32;
|
||
key.signature.layout_a = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[0]]};
|
||
key.signature.layout_b = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[1]]};
|
||
key.signature.layout_c = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[2]]};
|
||
key.signature.transpose_a = false;
|
||
key.signature.transpose_b = false;
|
||
key.signature.grouped = false;
|
||
key.signature.split_k = 1;
|
||
key.signature.elementwise_op = "{config.elementwise_op}";
|
||
key.signature.num_d_tensors = {config.num_d_tensors};
|
||
key.signature.structured_sparsity = false;
|
||
|
||
// Algorithm
|
||
key.algorithm.tile_shape = {{{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k}}};
|
||
key.algorithm.wave_shape = {{{config.tile.warp_m}, {config.tile.warp_n}, {config.tile.warp_k}}};
|
||
key.algorithm.warp_tile_shape = {{{config.tile.warp_tile_m}, {config.tile.warp_tile_n}, {config.tile.warp_tile_k}}};
|
||
key.algorithm.pipeline = {self.tm.PIPELINE_TO_DISPATCHER[config.trait.pipeline]};
|
||
key.algorithm.scheduler = {self.tm.SCHEDULER_TO_DISPATCHER[config.trait.scheduler]};
|
||
key.algorithm.epilogue = {self.tm.EPILOGUE_TO_DISPATCHER[config.trait.epilogue]};
|
||
key.algorithm.block_size = {config.block_size};
|
||
key.algorithm.double_buffer = {str(config.trait.pipeline == "compv4").lower()};
|
||
key.algorithm.persistent = {str(config.trait.persistent).lower()};
|
||
key.algorithm.preshuffle = {str(config.preshuffle).lower()};
|
||
key.algorithm.transpose_c = false;
|
||
key.algorithm.num_wave_groups = {config.num_wave_groups};
|
||
|
||
key.gfx_arch = gfx_arch;
|
||
|
||
return std::make_shared<backends::GeneratedKernelInstance<KernelStruct>>(key, "{kernel_name}");
|
||
}}
|
||
|
||
}}}}}}
|
||
"""
|
||
|
||
|
||
# ============================================================================
|
||
# Main Unified Generator
|
||
# ============================================================================
|
||
|
||
|
||
class UnifiedGemmCodegen:
|
||
"""Unified GEMM code generator - single entry point"""
|
||
|
||
def __init__(
|
||
self,
|
||
output_dir: Path,
|
||
datatype: str,
|
||
layout: str,
|
||
gpu_target: str = "gfx942",
|
||
config_file: Optional[Path] = None,
|
||
variants: List[GemmVariant] = None,
|
||
use_preselected: Optional[str] = None,
|
||
enable_arch_filter: bool = True,
|
||
kernel_set_name: Optional[str] = None,
|
||
):
|
||
self.output_dir = Path(output_dir)
|
||
self.datatype = datatype
|
||
# Support 3-char (rcr) or 4-char (rcrr) layout codes
|
||
# 4th char specifies D tensor layout for multi-d
|
||
self.layout = layout[:3] # A, B, C layouts
|
||
self.d_layout = (
|
||
layout[3] if len(layout) >= 4 else layout[2]
|
||
) # D layout (default = C layout)
|
||
self.gpu_target = gpu_target
|
||
self.variants = variants or [GemmVariant.STANDARD]
|
||
self.use_preselected = use_preselected
|
||
self.kernel_set_name = kernel_set_name
|
||
|
||
# Create directories - optionally with kernel set subdirectory
|
||
if kernel_set_name:
|
||
self.kernel_dir = self.output_dir / kernel_set_name
|
||
else:
|
||
self.kernel_dir = self.output_dir
|
||
self.kernel_dir.mkdir(parents=True, exist_ok=True)
|
||
self.wrapper_dir = self.kernel_dir / "dispatcher_wrappers"
|
||
self.wrapper_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# Load configuration
|
||
self.config = self._load_config(config_file)
|
||
|
||
# Initialize architecture filter for GPU-specific validation
|
||
self.arch_filter = None
|
||
if enable_arch_filter and HAS_ARCH_FILTER:
|
||
try:
|
||
self.arch_filter = ArchFilter(gpu_target, strict_mode=False)
|
||
log.info(f"Architecture filter enabled for {gpu_target}")
|
||
except ValueError as e:
|
||
log.warning(f"Could not create arch filter: {e}")
|
||
|
||
# Initialize generators (use self.layout which is the 3-char A,B,C layout)
|
||
self.ck_gen = CKTileKernelGenerator(datatype, self.layout)
|
||
self.disp_gen = DispatcherWrapperGenerator(datatype, self.layout)
|
||
|
||
def _load_config(self, config_file: Optional[Path]) -> Dict:
|
||
"""Load or create default configuration"""
|
||
if config_file and config_file.exists():
|
||
with open(config_file) as f:
|
||
return json.load(f)
|
||
|
||
# Match tile_engine default configs for GEMM/Preshuffle/Multi-D
|
||
# See: tile_engine/ops/gemm/configs/default_config.json
|
||
# tile_engine/ops/gemm_preshuffle/configs/default_config.json
|
||
# tile_engine/ops/gemm_multi_d/configs/default_config.json
|
||
return {
|
||
"tile_config": {
|
||
# tile_m/n/k: 64-256 step 64 = [64, 128, 192, 256]
|
||
"tile_m": [64, 128, 192, 256],
|
||
"tile_n": [64, 128, 192, 256],
|
||
"tile_k": [64, 128, 192, 256],
|
||
# warp configs matching tile_engine
|
||
"warp_m": [1, 2, 4],
|
||
"warp_n": [1, 2, 4],
|
||
"warp_k": [1],
|
||
# warp_tile configs matching tile_engine
|
||
"warp_tile_m": [4, 16, 32],
|
||
"warp_tile_n": [16, 32, 64],
|
||
"warp_tile_k": [8, 16, 32, 64, 128],
|
||
},
|
||
"trait_config": {
|
||
"pipeline": ["compv3", "compv4", "mem"],
|
||
"epilogue": ["cshuffle", "default"],
|
||
"scheduler": ["intrawave", "interwave"],
|
||
"pad_m": [False],
|
||
"pad_n": [False],
|
||
"pad_k": [False],
|
||
"persistent": [False, True],
|
||
},
|
||
"multi_d_config": {
|
||
# Note: Only MultiDAdd and MultiDMultiply are compatible with multi-D GEMM.
|
||
# Relu/Gelu are unary ops with signature (y, x), not multi-D signature (e, c, ds...)
|
||
"elementwise_ops": ["MultiDAdd", "MultiDMultiply"],
|
||
"num_d_tensors": [1, 2],
|
||
},
|
||
}
|
||
|
||
def generate_all(self, parallel: bool = True) -> Dict:
|
||
"""Generate all kernels"""
|
||
log.info("Generating GEMM kernels:")
|
||
log.info(f" Datatype: {self.datatype}")
|
||
log.info(f" Layout: {self.layout}")
|
||
log.info(f" Variants: {[v.value for v in self.variants]}")
|
||
if self.use_preselected:
|
||
log.info(f" Using preselected set: {self.use_preselected}")
|
||
|
||
results = {"kernels": [], "wrappers": [], "failed": []}
|
||
|
||
# Get configurations
|
||
if self.use_preselected:
|
||
configs = self._get_preselected_configs()
|
||
log.info(f" Total configurations: {len(configs)}")
|
||
else:
|
||
for variant in self.variants:
|
||
log.info(f"\nGenerating {variant.value} kernels...")
|
||
configs = self._get_configs_for_variant(variant)
|
||
log.info(f" Configurations: {len(configs)}")
|
||
|
||
if parallel:
|
||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||
futures = [
|
||
executor.submit(self._generate_one, cfg) for cfg in configs
|
||
]
|
||
for future in concurrent.futures.as_completed(futures):
|
||
try:
|
||
k, w = future.result()
|
||
results["kernels"].append(k)
|
||
results["wrappers"].append(w)
|
||
except Exception as e:
|
||
results["failed"].append(str(e))
|
||
log.error(f"Failed: {e}")
|
||
else:
|
||
for cfg in configs:
|
||
try:
|
||
k, w = self._generate_one(cfg)
|
||
results["kernels"].append(k)
|
||
results["wrappers"].append(w)
|
||
except Exception as e:
|
||
results["failed"].append(str(e))
|
||
log.error(f"Failed: {e}")
|
||
|
||
# Generate registration header
|
||
if results["wrappers"]:
|
||
self._generate_registration_header(results["wrappers"])
|
||
|
||
return results
|
||
|
||
# Generate from preselected set
|
||
if parallel:
|
||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||
futures = [executor.submit(self._generate_one, cfg) for cfg in configs]
|
||
for future in concurrent.futures.as_completed(futures):
|
||
try:
|
||
k, w = future.result()
|
||
results["kernels"].append(k)
|
||
results["wrappers"].append(w)
|
||
except Exception as e:
|
||
results["failed"].append(str(e))
|
||
log.error(f"Failed: {e}")
|
||
else:
|
||
for cfg in configs:
|
||
try:
|
||
k, w = self._generate_one(cfg)
|
||
results["kernels"].append(k)
|
||
results["wrappers"].append(w)
|
||
except Exception as e:
|
||
results["failed"].append(str(e))
|
||
log.error(f"Failed: {e}")
|
||
|
||
# Generate registration header
|
||
if results["wrappers"]:
|
||
self._generate_registration_header(results["wrappers"])
|
||
|
||
return results
|
||
|
||
def _get_preselected_configs(self) -> List[KernelConfig]:
|
||
"""Get preselected kernel configurations"""
|
||
try:
|
||
from preselected_kernels import get_preselected_set
|
||
|
||
return get_preselected_set(self.use_preselected)
|
||
except ImportError:
|
||
log.warning(
|
||
"preselected_kernels module not found, falling back to config-based generation"
|
||
)
|
||
return []
|
||
except ValueError as e:
|
||
log.error(f"Invalid preselected set: {e}")
|
||
return []
|
||
|
||
def _get_configs_for_variant(self, variant: GemmVariant) -> List[KernelConfig]:
|
||
"""Get all configurations for a variant
|
||
|
||
Args:
|
||
variant: GEMM variant (STANDARD, PRESHUFFLE, MULTI_D)
|
||
|
||
Returns:
|
||
List of valid kernel configurations for the variant
|
||
"""
|
||
configs = []
|
||
|
||
# Get base configs
|
||
tile_configs = self._get_tile_configs()
|
||
trait_configs = self._get_trait_configs()
|
||
|
||
for tile, trait in itertools.product(tile_configs, trait_configs):
|
||
# Perform variant-specific architecture validation
|
||
if self.arch_filter and HAS_ARCH_FILTER:
|
||
if not self._is_tile_arch_valid(tile, variant):
|
||
continue
|
||
|
||
if variant == GemmVariant.STANDARD:
|
||
configs.append(KernelConfig(tile=tile, trait=trait, variant=variant))
|
||
|
||
elif variant == GemmVariant.PRESHUFFLE:
|
||
# Preshuffle needs specific pipeline (preshufflev2) and scheduler (default)
|
||
# Skip configs that don't use preshuffle-compatible traits
|
||
preshuffle_trait = TraitConfig(
|
||
pipeline="preshufflev2",
|
||
epilogue="cshuffle",
|
||
scheduler="default",
|
||
pad_m=trait.pad_m,
|
||
pad_n=trait.pad_n,
|
||
pad_k=trait.pad_k,
|
||
persistent=trait.persistent,
|
||
)
|
||
# Only generate one preshuffle config per tile (not per trait)
|
||
# since preshuffle has fixed pipeline/scheduler
|
||
if trait.pipeline == "compv3" and trait.scheduler == "intrawave":
|
||
configs.append(
|
||
KernelConfig(
|
||
tile=tile,
|
||
trait=preshuffle_trait,
|
||
variant=variant,
|
||
preshuffle=True,
|
||
)
|
||
)
|
||
|
||
elif variant == GemmVariant.MULTI_D:
|
||
multi_d = self.config.get("multi_d_config", {})
|
||
for ew_op, num_d in itertools.product(
|
||
multi_d.get("elementwise_ops", ["MultiDAdd"]),
|
||
multi_d.get("num_d_tensors", [1]),
|
||
):
|
||
configs.append(
|
||
KernelConfig(
|
||
tile=tile,
|
||
trait=trait,
|
||
variant=variant,
|
||
elementwise_op=ew_op,
|
||
num_d_tensors=num_d,
|
||
d_layout=self.d_layout, # Use extracted D layout
|
||
)
|
||
)
|
||
|
||
return configs
|
||
|
||
def _get_tile_configs(self) -> List[TileConfig]:
|
||
"""Get valid tile configurations, filtered by architecture constraints"""
|
||
tc = self.config["tile_config"]
|
||
configs = []
|
||
rejected_count = 0
|
||
|
||
for params in itertools.product(
|
||
tc["tile_m"],
|
||
tc["tile_n"],
|
||
tc["tile_k"],
|
||
tc["warp_m"],
|
||
tc["warp_n"],
|
||
tc["warp_k"],
|
||
tc["warp_tile_m"],
|
||
tc["warp_tile_n"],
|
||
tc["warp_tile_k"],
|
||
):
|
||
tile = TileConfig(*params)
|
||
|
||
# Basic validation
|
||
if not tile.is_valid():
|
||
rejected_count += 1
|
||
continue
|
||
|
||
# Architecture-specific validation
|
||
if self.arch_filter and HAS_ARCH_FILTER:
|
||
if not self._is_tile_arch_valid(tile):
|
||
rejected_count += 1
|
||
continue
|
||
|
||
configs.append(tile)
|
||
|
||
if rejected_count > 0:
|
||
log.debug(f"Rejected {rejected_count} tile configs for {self.gpu_target}")
|
||
|
||
return configs
|
||
|
||
def _is_tile_arch_valid(
|
||
self, tile: TileConfig, variant: GemmVariant = None
|
||
) -> bool:
|
||
"""Check if tile configuration is valid for target architecture
|
||
|
||
Args:
|
||
tile: Tile configuration to validate
|
||
variant: GEMM variant (affects operator-specific constraints)
|
||
"""
|
||
if not self.arch_filter or not HAS_ARCH_FILTER:
|
||
return True
|
||
|
||
# Determine data types based on self.datatype
|
||
# Note: dtype_c is the ACCUMULATOR type, not output type (which may be fp16)
|
||
# WMMA instructions on gfx942 always use fp32 accumulator for fp16 inputs
|
||
dtype_map = {
|
||
"fp16": ("fp16", "fp16", "fp32"), # A=fp16, B=fp16, Acc=fp32
|
||
"bf16": ("bf16", "bf16", "fp32"), # A=bf16, B=bf16, Acc=fp32
|
||
"fp8": ("fp8", "fp8", "fp32"), # A=fp8, B=fp8, Acc=fp32
|
||
"bf8": ("bf8", "bf8", "fp32"), # A=bf8, B=bf8, Acc=fp32
|
||
"int8": ("int8", "int8", "int32"), # A=int8, B=int8, Acc=int32
|
||
}
|
||
dtype_a, dtype_b, dtype_c = dtype_map.get(
|
||
self.datatype, ("fp16", "fp16", "fp32")
|
||
)
|
||
|
||
# Map GEMM variant to operator type for validation
|
||
operator = None
|
||
pipeline = "compv4" # Default
|
||
scheduler = "intrawave" # Default
|
||
|
||
if OperatorType is not None and variant is not None:
|
||
variant_to_operator = {
|
||
GemmVariant.STANDARD: OperatorType.GEMM,
|
||
GemmVariant.PRESHUFFLE: OperatorType.GEMM_PRESHUFFLE,
|
||
GemmVariant.MULTI_D: OperatorType.GEMM_MULTI_D,
|
||
}
|
||
operator = variant_to_operator.get(variant, OperatorType.GEMM)
|
||
|
||
# Preshuffle requires specific pipeline and scheduler
|
||
if variant == GemmVariant.PRESHUFFLE:
|
||
pipeline = "preshufflev2"
|
||
scheduler = "default"
|
||
|
||
# Use preshuffle-specific validation (comprehensive CK-specific checks)
|
||
if variant == GemmVariant.PRESHUFFLE:
|
||
if not is_preshuffle_config_valid(
|
||
tile_m=tile.tile_m,
|
||
tile_n=tile.tile_n,
|
||
tile_k=tile.tile_k,
|
||
warp_m=tile.warp_m,
|
||
warp_n=tile.warp_n,
|
||
warp_k=tile.warp_k,
|
||
warp_tile_m=tile.warp_tile_m,
|
||
warp_tile_n=tile.warp_tile_n,
|
||
warp_tile_k=tile.warp_tile_k,
|
||
datatype=self.datatype,
|
||
):
|
||
return False
|
||
|
||
return self.arch_filter.is_kernel_valid(
|
||
datatype_a=dtype_a,
|
||
datatype_b=dtype_b,
|
||
datatype_c=dtype_c,
|
||
tile_m=tile.tile_m,
|
||
tile_n=tile.tile_n,
|
||
tile_k=tile.tile_k,
|
||
warp_m=tile.warp_m,
|
||
warp_n=tile.warp_n,
|
||
warp_k=tile.warp_k,
|
||
warp_tile_m=tile.warp_tile_m,
|
||
warp_tile_n=tile.warp_tile_n,
|
||
warp_tile_k=tile.warp_tile_k,
|
||
pipeline=pipeline,
|
||
scheduler=scheduler,
|
||
layout=self.layout,
|
||
operator=operator,
|
||
)
|
||
|
||
def _get_trait_configs(self) -> List[TraitConfig]:
|
||
"""Get valid trait configurations, filtered by architecture constraints"""
|
||
tc = self.config["trait_config"]
|
||
configs = []
|
||
rejected_count = 0
|
||
|
||
for params in itertools.product(
|
||
tc["pipeline"],
|
||
tc["epilogue"],
|
||
tc["scheduler"],
|
||
tc["pad_m"],
|
||
tc["pad_n"],
|
||
tc["pad_k"],
|
||
tc["persistent"],
|
||
):
|
||
trait = TraitConfig(*params)
|
||
|
||
# Basic trait validation (unsupported combinations)
|
||
if not trait.is_valid():
|
||
rejected_count += 1
|
||
continue
|
||
|
||
configs.append(trait)
|
||
|
||
if rejected_count > 0:
|
||
log.debug(f"Rejected {rejected_count} trait configs")
|
||
|
||
return configs
|
||
|
||
def _generate_one(self, config: KernelConfig) -> Tuple[str, str]:
|
||
"""Generate one kernel and wrapper"""
|
||
kernel_name = KernelNaming.generate(config, self.datatype, self.layout)
|
||
|
||
# Generate CK Tile kernel
|
||
kernel_code = self.ck_gen.generate(config)
|
||
kernel_path = self.kernel_dir / f"{kernel_name}.hpp"
|
||
kernel_path.write_text(kernel_code)
|
||
|
||
# Generate dispatcher wrapper
|
||
wrapper_code = self.disp_gen.generate(config, kernel_path, self.kernel_dir)
|
||
wrapper_path = self.wrapper_dir / f"dispatcher_wrapper_{kernel_name}.hpp"
|
||
wrapper_path.write_text(wrapper_code)
|
||
|
||
# Generate .cpp compilation unit for per-kernel parallel builds
|
||
cpp_path = self.kernel_dir / f"{kernel_name}.cpp"
|
||
cpp_code = f'''// SPDX-License-Identifier: MIT
|
||
// Auto-generated compilation unit for: {kernel_name}
|
||
// Enables per-kernel parallel compilation with make -j
|
||
|
||
#include "{kernel_name}.hpp"
|
||
|
||
namespace ck_tile {{ namespace generated {{
|
||
volatile bool _{kernel_name.replace("-", "_")}_loaded = true;
|
||
}} }}
|
||
'''
|
||
cpp_path.write_text(cpp_code)
|
||
|
||
return str(kernel_path), str(wrapper_path)
|
||
|
||
def _generate_registration_header(self, wrapper_paths: List[str]):
|
||
"""Generate master registration header"""
|
||
kernel_names = [
|
||
Path(w).stem.replace("dispatcher_wrapper_", "") for w in wrapper_paths
|
||
]
|
||
|
||
includes = "\n".join(
|
||
[f'#include "dispatcher_wrapper_{n}.hpp"' for n in kernel_names]
|
||
)
|
||
registrations = "\n ".join(
|
||
[
|
||
f"registry.register_kernel(generated::make_{n}(gfx_arch), priority);"
|
||
for n in kernel_names
|
||
]
|
||
)
|
||
|
||
content = f"""// SPDX-License-Identifier: MIT
|
||
// Auto-generated master registration
|
||
#pragma once
|
||
|
||
#include "ck_tile/dispatcher.hpp"
|
||
{includes}
|
||
|
||
namespace ck_tile {{
|
||
namespace dispatcher {{
|
||
|
||
using ::ck_tile::dispatcher::Registry;
|
||
using Priority = ::ck_tile::dispatcher::Registry::Priority;
|
||
|
||
inline void register_all_tile_gemm_kernels(
|
||
const std::string& gfx_arch = "gfx942",
|
||
Priority priority = Priority::Normal)
|
||
{{
|
||
auto& registry = Registry::instance();
|
||
{registrations}
|
||
}}
|
||
|
||
inline std::size_t get_tile_gemm_kernel_count() {{ return {len(kernel_names)}; }}
|
||
|
||
}}}}
|
||
"""
|
||
|
||
reg_path = self.wrapper_dir / "register_all_kernels.hpp"
|
||
reg_path.write_text(content)
|
||
logging.info(f"Generated registration header: {reg_path}")
|
||
|
||
|
||
# ============================================================================
|
||
# CLI
|
||
# ============================================================================
|
||
|
||
|
||
def _show_arch_info(gpu_target: str, datatype: str):
|
||
"""Display supported configurations for a GPU architecture"""
|
||
if not HAS_ARCH_FILTER:
|
||
print("Architecture filter module not available")
|
||
return
|
||
|
||
try:
|
||
from arch_filter import (
|
||
get_supported_archs,
|
||
WARP_SUPPORTED_COMBINATIONS,
|
||
WARP_TILE_SUPPORTED_COMBINATIONS,
|
||
LDS_CAPACITY_LIMITS,
|
||
TRAIT_UNSUPPORTED_COMBINATIONS,
|
||
)
|
||
|
||
print(f"\n=== Architecture Info for {gpu_target} ===\n")
|
||
|
||
# Supported architectures
|
||
print(f"Supported GPUs: {get_supported_archs()}")
|
||
|
||
# Warp configurations
|
||
warp_cfgs = WARP_SUPPORTED_COMBINATIONS.get(gpu_target, [])
|
||
print("\nWarp configurations [warp_m, warp_n, warp_k]:")
|
||
for cfg in warp_cfgs:
|
||
print(f" {cfg}")
|
||
|
||
# Warp tile configurations for data type
|
||
dtype_map = {
|
||
"fp16": "fp16_fp16_fp16",
|
||
"bf16": "bf16_bf16_bf16",
|
||
"fp8": "fp8_fp8_fp16",
|
||
"bf8": "bf8_bf8_fp16",
|
||
"int8": "int8_int8_int32",
|
||
}
|
||
dtype_key = dtype_map.get(datatype, "fp16_fp16_fp16")
|
||
|
||
gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_target, {})
|
||
warp_tiles = gpu_combos.get(dtype_key, [])
|
||
print(
|
||
f"\nWarp tile configurations for {dtype_key} [warp_tile_m, warp_tile_n, warp_tile_k]:"
|
||
)
|
||
for cfg in warp_tiles:
|
||
print(f" {cfg}")
|
||
|
||
# All supported data types
|
||
print(f"\nAll supported data types on {gpu_target}:")
|
||
for dtype in gpu_combos.keys():
|
||
print(f" {dtype}")
|
||
|
||
# LDS limits
|
||
print("\nLDS capacity limits:")
|
||
for pipeline, limit in LDS_CAPACITY_LIMITS.items():
|
||
print(f" {pipeline}: {limit // 1024}KB")
|
||
|
||
# Unsupported trait combinations
|
||
print("\nUnsupported trait combinations (pipeline, epilogue, scheduler):")
|
||
for combo in TRAIT_UNSUPPORTED_COMBINATIONS:
|
||
print(f" {combo}")
|
||
|
||
print()
|
||
|
||
except Exception as e:
|
||
print(f"Error showing arch info: {e}")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="Unified GEMM Code Generator - Single Source of Truth"
|
||
)
|
||
parser.add_argument(
|
||
"--output-dir", type=Path, required=True, help="Output directory"
|
||
)
|
||
parser.add_argument(
|
||
"--datatype",
|
||
type=str,
|
||
default="fp16",
|
||
choices=["fp16", "bf16", "fp32", "fp8", "bf8", "int8", "pk_fp4"],
|
||
help="Data type (fp16, bf16, fp32, fp8, bf8, int8, pk_fp4)",
|
||
)
|
||
parser.add_argument(
|
||
"--layout",
|
||
type=str,
|
||
default="rcr",
|
||
help="Layout (e.g., rcr for A=row, B=col, C=row; or rcrr for multi-d with D=row)",
|
||
)
|
||
parser.add_argument(
|
||
"--gpu-target",
|
||
type=str,
|
||
default="gfx942",
|
||
help="Target GPU (gfx90a, gfx942, gfx950, gfx1201)",
|
||
)
|
||
parser.add_argument("--config", type=Path, help="Configuration JSON file")
|
||
parser.add_argument(
|
||
"--variants",
|
||
nargs="+",
|
||
choices=["standard", "preshuffle", "multi_d"],
|
||
default=["standard"],
|
||
help="Variants to generate",
|
||
)
|
||
parser.add_argument(
|
||
"--preselected",
|
||
type=str,
|
||
help="Use preselected kernel set (e.g., fp16_rcr_essential)",
|
||
)
|
||
parser.add_argument(
|
||
"--no-parallel", action="store_true", help="Disable parallel generation"
|
||
)
|
||
parser.add_argument(
|
||
"--register", action="store_true", help="Generate dispatcher registration code"
|
||
)
|
||
parser.add_argument(
|
||
"--no-arch-filter",
|
||
action="store_true",
|
||
help="Disable architecture-specific kernel filtering",
|
||
)
|
||
parser.add_argument(
|
||
"--show-arch-info",
|
||
action="store_true",
|
||
help="Show supported configurations for target GPU and exit",
|
||
)
|
||
parser.add_argument(
|
||
"--kernel-set",
|
||
type=str,
|
||
help="Kernel set name (creates subdirectory for organization)",
|
||
)
|
||
parser.add_argument(
|
||
"--tile-config-json",
|
||
type=str,
|
||
help="JSON string specifying exact tile configuration (for minimal builds)",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
# Handle inline tile config JSON for minimal/single-kernel builds
|
||
if args.tile_config_json:
|
||
try:
|
||
cfg = json.loads(args.tile_config_json)
|
||
|
||
# Build proper config structure
|
||
full_config = {}
|
||
|
||
# Extract tile config
|
||
tile_keys = [
|
||
"tile_m",
|
||
"tile_n",
|
||
"tile_k",
|
||
"warp_m",
|
||
"warp_n",
|
||
"warp_k",
|
||
"warp_tile_m",
|
||
"warp_tile_n",
|
||
"warp_tile_k",
|
||
"block_size",
|
||
]
|
||
tile_config = {k: cfg[k] for k in tile_keys if k in cfg}
|
||
if tile_config:
|
||
full_config["tile_config"] = tile_config
|
||
|
||
# Extract trait config
|
||
trait_keys = ["pipeline", "epilogue", "scheduler"]
|
||
trait_config = {k: cfg[k] for k in trait_keys if k in cfg}
|
||
# Add default pad/persistent values
|
||
trait_config.setdefault("pad_m", [False])
|
||
trait_config.setdefault("pad_n", [False])
|
||
trait_config.setdefault("pad_k", [False])
|
||
trait_config.setdefault("persistent", [False])
|
||
if trait_config:
|
||
full_config["trait_config"] = trait_config
|
||
|
||
# Extract multi_d config (for multi_d variant)
|
||
if "elementwise_ops" in cfg or "num_d_tensors" in cfg:
|
||
multi_d_config = {}
|
||
if "elementwise_ops" in cfg:
|
||
multi_d_config["elementwise_ops"] = cfg["elementwise_ops"]
|
||
if "num_d_tensors" in cfg:
|
||
multi_d_config["num_d_tensors"] = cfg["num_d_tensors"]
|
||
full_config["multi_d_config"] = multi_d_config
|
||
|
||
# Use already structured config if provided
|
||
if "tile_config" in cfg:
|
||
full_config = cfg
|
||
|
||
# Write to temp file and use as config
|
||
import tempfile
|
||
|
||
with tempfile.NamedTemporaryFile(
|
||
mode="w", suffix=".json", delete=False
|
||
) as f:
|
||
json.dump(full_config, f)
|
||
args.config = Path(f.name)
|
||
except json.JSONDecodeError as e:
|
||
logging.error(f"Invalid tile-config-json: {e}")
|
||
return 1
|
||
except KeyError as e:
|
||
logging.error(f"Missing required key in tile-config-json: {e}")
|
||
return 1
|
||
|
||
# Show architecture info if requested
|
||
if args.show_arch_info:
|
||
_show_arch_info(args.gpu_target, args.datatype)
|
||
return 0
|
||
|
||
variants = [GemmVariant(v) for v in args.variants] if not args.preselected else None
|
||
|
||
codegen = UnifiedGemmCodegen(
|
||
output_dir=args.output_dir,
|
||
datatype=args.datatype,
|
||
layout=args.layout,
|
||
gpu_target=args.gpu_target,
|
||
config_file=args.config,
|
||
variants=variants,
|
||
use_preselected=args.preselected,
|
||
enable_arch_filter=not args.no_arch_filter,
|
||
kernel_set_name=args.kernel_set,
|
||
)
|
||
|
||
results = codegen.generate_all(parallel=not args.no_parallel)
|
||
|
||
logging.info("\n✅ Generation complete!")
|
||
logging.info(f" Kernels: {len(results['kernels'])}")
|
||
logging.info(f" Wrappers: {len(results['wrappers'])}")
|
||
logging.info(f" Failed: {len(results['failed'])}")
|
||
|
||
if results["failed"]:
|
||
logging.error(f"\nFailed kernels: {len(results['failed'])}")
|
||
for err in results["failed"][:5]:
|
||
logging.error(f" {err}")
|
||
|
||
# Generate dispatcher registration if requested
|
||
if args.register:
|
||
logging.info("\n📝 Generating dispatcher registration code...")
|
||
try:
|
||
from generate_dispatcher_registration import (
|
||
scan_generated_headers,
|
||
generate_registration_header,
|
||
generate_registration_cpp,
|
||
)
|
||
|
||
kernels = scan_generated_headers(args.output_dir)
|
||
reg_dir = args.output_dir / "registration"
|
||
reg_dir.mkdir(exist_ok=True)
|
||
|
||
generate_registration_header(
|
||
kernels, reg_dir / "dispatcher_registration.hpp"
|
||
)
|
||
generate_registration_cpp(kernels, reg_dir / "dispatcher_registration.cpp")
|
||
|
||
logging.info(f"✓ Generated registration code for {len(kernels)} kernels")
|
||
except Exception as e:
|
||
logging.error(f"Failed to generate registration code: {e}")
|
||
return 1
|
||
|
||
return 0 if not results["failed"] else 1
|
||
|
||
|
||
if __name__ == "__main__":
|
||
exit(main())
|