mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher (#5168)
## Motivation This PR adds CK Tile group convolution (forward, backward-data, backward-weight) support to the kernel dispatcher, matching and unifying with the existing dispatcher GEMM infrastructure in architecture and usability. The dispatcher provides a unified kernel dispatch system with both C++ and Python frontends, and until now only supported GEMM operations. This PR enables framework integrators to use the same declarative kernel workflow for convolutions as they do for GEMM: declare kernels, build a registry JIT, select kernels within the registry at runtime, and dispatch to GPU. Future PRs will include runtime kernel selection heuristics for autotuning of kernel parameters based on (problem, hardware arch). ## Technical Details Grouped convolution support has been added to the CK Tile Dispatcher with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out, problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime heuristic kernel selection, and GroupedConvKernelKey with full ConvConfigBase fields. Python side adds parallel JIT via registry.build(max_workers) and heuristic registry.select(). Includes 7 C++ and 6 Python examples covering all directions with CPU reference validation, and shared infrastructure improvements (BaseRegistry CRTP, structured exceptions). As a sanity check, JIT compile times for a single kernel remains the same and for multiple kernels there is better parallelism: Kernels | 1 worker | 8 workers 1 | 7.7 s | 7.7 s 2 | 15.9 s | 8.2 s 4 | 33.4 s | 9.7 s 6 | 52.3 s | 10.2 s ## Test Plan 145 ephemeral unit tests have been added to test basic functionality. All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7 C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference validation for forward, backward-data, and backward-weight (2D) in both C++ and Python examples pass. ## Test Result 30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56), 53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002 for all directions (fp16 vs fp32 reference). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Yaswanth Raparti <113389104+yraparti@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
fb22cd0c69
commit
a2b844d335
@@ -9,8 +9,8 @@ Guide for adding support for a new AMD GPU architecture to the CK Tile Dispatche
|
||||
The dispatcher uses `arch_specs.json` as the **single source of truth** for GPU specifications:
|
||||
|
||||
```
|
||||
arch_specs.json → generate_arch_specs.py → arch_specs_generated.py (Python)
|
||||
→ arch_specs_generated.hpp (C++)
|
||||
arch_specs.json -> generate_arch_specs.py -> arch_specs_generated.py (Python)
|
||||
-> arch_specs_generated.hpp (C++)
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
@@ -175,14 +175,14 @@ for error in result.errors:
|
||||
|
||||
```
|
||||
codegen/
|
||||
├── arch_specs.json # Single source of truth (EDIT THIS)
|
||||
├── generate_arch_specs.py # Generator script
|
||||
├── arch_specs_generated.py # Generated Python module
|
||||
└── ADDING_NEW_GPU.md # This file
|
||||
|---- arch_specs.json # Single source of truth (EDIT THIS)
|
||||
|---- generate_arch_specs.py # Generator script
|
||||
|---- arch_specs_generated.py # Generated Python module
|
||||
+---- ADDING_NEW_GPU.md # This file
|
||||
|
||||
include/ck_tile/dispatcher/
|
||||
├── arch_specs_generated.hpp # Generated C++ header
|
||||
└── arch_filter.hpp # C++ filter
|
||||
|---- arch_specs_generated.hpp # Generated C++ header
|
||||
+---- arch_filter.hpp # C++ filter
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
# CK Tile GEMM Unified Code Generator
|
||||
# CK Tile Unified Code Generators
|
||||
|
||||
Single source of truth for all GEMM kernel generation.
|
||||
Single source of truth for GEMM and Grouped Convolution kernel generation.
|
||||
|
||||
> **See also:** [Main Dispatcher README](../README.md) for installation and core concepts.
|
||||
|
||||
## Shared Infrastructure
|
||||
|
||||
Both GEMM and Grouped Conv generators share common code via `codegen_common.py`:
|
||||
- `TileConfig` - Dataclass for tile dimensions
|
||||
- `TraitConfigBase` - Base for kernel trait configurations with arch-aware validation
|
||||
- `CommonTypeMappings` - Dtype-to-C++ type mappings
|
||||
- `parallel_generate()` - Parallel kernel generation with per-kernel progress logging
|
||||
- Arch-aware expansion helpers (`valid_wave_configs`, `valid_warp_configs`, etc.)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### GEMM
|
||||
|
||||
```bash
|
||||
cd dispatcher/codegen
|
||||
|
||||
@@ -22,6 +33,25 @@ python3 unified_gemm_codegen.py \
|
||||
--variants standard preshuffle multi_d
|
||||
```
|
||||
|
||||
### Grouped Convolution
|
||||
|
||||
```bash
|
||||
cd dispatcher/codegen
|
||||
|
||||
# Generate forward FP16 grouped conv kernels
|
||||
python3 unified_grouped_conv_codegen.py \
|
||||
--output-dir ../build/generated_kernels \
|
||||
--datatype fp16 \
|
||||
--variant forward \
|
||||
--ndim-spatial 2
|
||||
|
||||
# Generate backward data kernels
|
||||
python3 unified_grouped_conv_codegen.py \
|
||||
--output-dir ../build/generated_kernels \
|
||||
--variant backward_data \
|
||||
--ndim-spatial 2
|
||||
```
|
||||
|
||||
## Using from Python
|
||||
|
||||
```python
|
||||
@@ -58,13 +88,13 @@ results = codegen.generate_all()
|
||||
## Variants
|
||||
|
||||
### Standard
|
||||
Basic GEMM: `C = A × B`
|
||||
Basic GEMM: `C = A x B`
|
||||
|
||||
### PreShuffle
|
||||
Optimized weight access with LDS pre-shuffling. Best for large matrices.
|
||||
|
||||
### Multi-D
|
||||
Element-wise fusion: `C = op(A × B + D0 + D1 + ...)`
|
||||
Element-wise fusion: `C = op(A x B + D0 + D1 + ...)`
|
||||
|
||||
Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh`
|
||||
|
||||
@@ -72,10 +102,11 @@ Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh`
|
||||
|
||||
```
|
||||
generated_kernels/
|
||||
├── gemm_fp16_rcr_compv4_..._128x128x32_....hpp
|
||||
├── gemm_fp16_rcr_compv4_..._preshuffle.hpp
|
||||
├── gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp
|
||||
└── ...
|
||||
|---- gemm_fp16_rcr_compv4_..._128x128x32_....hpp # GEMM kernels
|
||||
|---- gemm_fp16_rcr_compv4_..._preshuffle.hpp
|
||||
|---- gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp
|
||||
|---- grouped_conv_fwd_fp16_nhwgc_..._128x128x32_....hpp # Grouped conv kernels
|
||||
+---- ...
|
||||
```
|
||||
|
||||
## Configuration Files
|
||||
|
||||
350
dispatcher/codegen/codegen_common.py
Normal file
350
dispatcher/codegen/codegen_common.py
Normal file
@@ -0,0 +1,350 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Shared codegen infrastructure for GEMM and grouped convolution code generators.
|
||||
|
||||
Extracted from unified_gemm_codegen.py + arch-aware expansion helpers from conv.
|
||||
Both unified_gemm_codegen.py and unified_grouped_conv_codegen.py import from here
|
||||
to eliminate duplication.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import concurrent.futures
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
|
||||
ANY_INT = -1
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tile and Trait Configuration (shared between GEMM and Conv)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class TileConfig:
|
||||
"""Tile configuration parameters shared by GEMM and grouped conv."""
|
||||
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
warp_m: int
|
||||
warp_n: int
|
||||
warp_k: int
|
||||
warp_tile_m: int
|
||||
warp_tile_n: int
|
||||
warp_tile_k: int
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
if self.tile_m <= 0 or self.tile_n <= 0 or self.tile_k <= 0:
|
||||
return False
|
||||
return (
|
||||
self.tile_m % (self.warp_m * self.warp_tile_m) == 0
|
||||
and self.tile_n % (self.warp_n * self.warp_tile_n) == 0
|
||||
and self.tile_k % (self.warp_k * self.warp_tile_k) == 0
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraitConfigBase:
|
||||
"""
|
||||
Base kernel trait configuration shared by GEMM and grouped conv.
|
||||
|
||||
GEMM extends this with ``persistent``; grouped conv extends with
|
||||
``double_smem_buffer`` and ``num_groups_to_merge``.
|
||||
"""
|
||||
|
||||
pipeline: str # mem, compv3, compv4, compv5, ...
|
||||
epilogue: str # cshuffle, default
|
||||
scheduler: str # intrawave, interwave
|
||||
pad_m: bool
|
||||
pad_n: bool
|
||||
pad_k: bool
|
||||
|
||||
# Unsupported (pipeline, epilogue, scheduler) combinations.
|
||||
# Only 'mem' and 'basic_v1' pipelines support interwave; all compute
|
||||
# pipelines (compv3/v4/v5/v6/async) only support intrawave.
|
||||
_UNSUPPORTED: ClassVar[FrozenSet] = frozenset(
|
||||
{
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave"),
|
||||
("compv5", "cshuffle", "interwave"),
|
||||
("compv5", "default", "interwave"),
|
||||
("compv6", "cshuffle", "interwave"),
|
||||
("compv6", "default", "interwave"),
|
||||
("comp_async", "cshuffle", "interwave"),
|
||||
("comp_async", "default", "interwave"),
|
||||
("basic_async_v1", "cshuffle", "interwave"),
|
||||
("basic_async_v1", "default", "interwave"),
|
||||
}
|
||||
)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
return (self.pipeline, self.epilogue, self.scheduler) not in self._UNSUPPORTED
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Type Mappings (centralized for both GEMM and conv codegen)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CommonTypeMappings:
|
||||
"""Centralized type mappings shared by GEMM and grouped conv codegen."""
|
||||
|
||||
DTYPE_TO_CK = {
|
||||
"fp16": "fp16_t",
|
||||
"bf16": "bf16_t",
|
||||
"fp32": "float",
|
||||
"fp8": "fp8_t",
|
||||
"bf8": "bf8_t",
|
||||
"int8": "int8_t",
|
||||
}
|
||||
|
||||
DTYPE_TO_CK_QUALIFIED = {
|
||||
"fp16": "ck_tile::fp16_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp32": "float",
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf8": "ck_tile::bf8_t",
|
||||
"int8": "int8_t",
|
||||
}
|
||||
|
||||
DTYPE_TO_DISPATCHER = {
|
||||
"fp16": "DataType::FP16",
|
||||
"bf16": "DataType::BF16",
|
||||
"fp32": "DataType::FP32",
|
||||
"fp8": "DataType::FP8",
|
||||
"bf8": "DataType::BF8",
|
||||
"int8": "DataType::INT8",
|
||||
}
|
||||
|
||||
# GEMM-specific layout mappings ("r"/"c" for row/column major).
|
||||
# Convolution layouts (NHWGC, GKYXC, etc.) are handled by
|
||||
# unified_grouped_conv_codegen.py via GroupedConvLayout / GroupedConvTypeMappings.
|
||||
GEMM_LAYOUT_TO_CK = {
|
||||
"r": "tensor_layout::gemm::RowMajor",
|
||||
"c": "tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
LAYOUT_TO_CK = GEMM_LAYOUT_TO_CK # backward compat alias
|
||||
|
||||
GEMM_LAYOUT_TO_DISPATCHER = {
|
||||
"r": "LayoutTag::RowMajor",
|
||||
"c": "LayoutTag::ColMajor",
|
||||
}
|
||||
LAYOUT_TO_DISPATCHER = GEMM_LAYOUT_TO_DISPATCHER # backward compat alias
|
||||
|
||||
# GEMM-only pipeline mappings (used by unified_gemm_codegen.py).
|
||||
# Convolution pipelines are in GroupedConvTypeMappings
|
||||
# (unified_grouped_conv_codegen.py). CK Tile conv supports:
|
||||
# BASIC_V1, Mem, CompV3, CompV4, CompV5, CompV6, ASYNC_V1, ASYNC_V4.
|
||||
# The dispatcher currently generates: mem, compv3, compv4.
|
||||
# preshufflev2 is GEMM-only (weight pre-shuffle for GEMM, not conv).
|
||||
PIPELINE_TO_CK = {
|
||||
"mem": "GemmPipelineAgBgCrMem",
|
||||
"compv3": "GemmPipelineAgBgCrCompV3",
|
||||
"compv4": "GemmPipelineAgBgCrCompV4",
|
||||
"compv5": "GemmPipelineAgBgCrCompV5",
|
||||
"preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2",
|
||||
}
|
||||
|
||||
PIPELINE_TO_BASE = {
|
||||
"mem": "BaseGemmPipelineAgBgCrMem",
|
||||
"compv3": "BaseGemmPipelineAgBgCrCompV3",
|
||||
"compv4": "BaseGemmPipelineAgBgCrCompV4",
|
||||
"compv5": "BaseGemmPipelineAgBgCrCompV5",
|
||||
"preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2",
|
||||
}
|
||||
|
||||
PIPELINE_TO_DISPATCHER = {
|
||||
"mem": "Pipeline::Mem",
|
||||
"compv3": "Pipeline::CompV3",
|
||||
"compv4": "Pipeline::CompV4",
|
||||
"compv5": "Pipeline::CompV5",
|
||||
"preshufflev2": "Pipeline::PreShuffleV2",
|
||||
}
|
||||
|
||||
SCHEDULER_TO_CK = {
|
||||
"intrawave": "GemmPipelineScheduler::Intrawave",
|
||||
"interwave": "GemmPipelineScheduler::Interwave",
|
||||
"default": "GemmPipelineScheduler::Default",
|
||||
}
|
||||
|
||||
SCHEDULER_TO_DISPATCHER = {
|
||||
"intrawave": "Scheduler::Intrawave",
|
||||
"interwave": "Scheduler::Interwave",
|
||||
"default": "Scheduler::Auto",
|
||||
}
|
||||
|
||||
EPILOGUE_TO_DISPATCHER = {
|
||||
"cshuffle": "Epilogue::CShuffle",
|
||||
"default": "Epilogue::Default",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_output_dtype(dtype: str) -> str:
|
||||
"""Get output datatype (fp8/bf8 -> fp16)."""
|
||||
return "fp16" if dtype in ("fp8", "bf8") else dtype
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Code Generation Helpers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def generate_cpp_compilation_unit(kernel_name: str) -> str:
|
||||
"""Generate a .cpp compilation unit that includes a kernel header.
|
||||
|
||||
This is the standard pattern: one .cpp per kernel that just includes
|
||||
the generated .hpp header, causing template instantiation.
|
||||
"""
|
||||
return (
|
||||
f"// Auto-generated compilation unit for {kernel_name}\n"
|
||||
f'#include "{kernel_name}.hpp"\n'
|
||||
)
|
||||
|
||||
|
||||
def parallel_generate(
|
||||
generate_fn: Callable[[T], R],
|
||||
items: Sequence[T],
|
||||
parallel: bool = True,
|
||||
) -> List[R]:
|
||||
"""Run ``generate_fn`` over ``items``, optionally in parallel.
|
||||
|
||||
Logs per-item progress (best-of-conv pattern).
|
||||
Returns a flat list of results in completion order.
|
||||
"""
|
||||
results: List[R] = []
|
||||
if not items:
|
||||
return results
|
||||
|
||||
if parallel and len(items) > 1:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = {executor.submit(generate_fn, item): item for item in items}
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
log.info("Generated: %s", futures[future])
|
||||
else:
|
||||
for item in items:
|
||||
result = generate_fn(item)
|
||||
results.append(result)
|
||||
log.info("Generated: %s", item)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Arch-Aware Expansion Helpers (adopted from conv kernel_decl.hpp)
|
||||
# ============================================================================
|
||||
|
||||
# These load from arch_specs_generated when available, falling back to
|
||||
# hardcoded defaults that match the most common arch (gfx942).
|
||||
|
||||
_arch_data_cache: Optional[Dict] = None
|
||||
|
||||
|
||||
def _get_arch_data() -> Dict:
|
||||
"""Load arch filter data, with caching."""
|
||||
global _arch_data_cache
|
||||
if _arch_data_cache is not None:
|
||||
return _arch_data_cache
|
||||
|
||||
try:
|
||||
from arch_specs_generated import (
|
||||
WARP_SUPPORTED_COMBINATIONS,
|
||||
WARP_TILE_SUPPORTED_COMBINATIONS,
|
||||
TRAIT_UNSUPPORTED_COMBINATIONS,
|
||||
get_supported_archs,
|
||||
)
|
||||
|
||||
_arch_data_cache = {
|
||||
"warp_combos": WARP_SUPPORTED_COMBINATIONS,
|
||||
"warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS,
|
||||
"trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS,
|
||||
"supported_archs": get_supported_archs(),
|
||||
}
|
||||
except ImportError:
|
||||
_arch_data_cache = {
|
||||
"warp_combos": {
|
||||
"gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
|
||||
"gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
|
||||
},
|
||||
"warp_tile_combos": {
|
||||
"gfx942": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]},
|
||||
"gfx90a": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]},
|
||||
},
|
||||
"trait_unsupported": {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
},
|
||||
"supported_archs": ["gfx90a", "gfx942", "gfx950"],
|
||||
}
|
||||
return _arch_data_cache
|
||||
|
||||
|
||||
def valid_wave_configs(arch: str) -> List[List[int]]:
|
||||
"""Return valid [wave_m, wave_n, wave_k] combos for *arch*."""
|
||||
data = _get_arch_data()
|
||||
return data["warp_combos"].get(arch, [[2, 2, 1]])
|
||||
|
||||
|
||||
def valid_warp_configs(arch: str, dtype: str) -> List[List[int]]:
|
||||
"""Return valid [warp_tile_m, warp_tile_n, warp_tile_k] combos for *arch*/*dtype*.
|
||||
|
||||
The dtype key is constructed as ``{dtype}_{dtype}_{acc}`` where acc is
|
||||
fp32 for float types and int32 for int8.
|
||||
"""
|
||||
data = _get_arch_data()
|
||||
acc = "int32" if dtype == "int8" else "fp32"
|
||||
dtype_key = f"{dtype}_{dtype}_{acc}"
|
||||
arch_tiles = data["warp_tile_combos"].get(arch, {})
|
||||
return arch_tiles.get(dtype_key, [[32, 32, 16]])
|
||||
|
||||
|
||||
def valid_trait_configs() -> List[Tuple[str, str]]:
|
||||
"""Return valid (pipeline, scheduler) pairs.
|
||||
|
||||
Compute pipelines only support intrawave; mem supports both.
|
||||
"""
|
||||
return [
|
||||
("compv3", "intrawave"),
|
||||
("compv4", "intrawave"),
|
||||
("compv5", "intrawave"),
|
||||
("mem", "intrawave"),
|
||||
("mem", "interwave"),
|
||||
]
|
||||
|
||||
|
||||
def needs_wave_expansion(config: dict) -> bool:
|
||||
"""True if wave_m or wave_n is a wildcard (ANY_INT = -1)."""
|
||||
return config.get("wave_m", 2) == ANY_INT or config.get("wave_n", 2) == ANY_INT
|
||||
|
||||
|
||||
def needs_warp_expansion(config: dict) -> bool:
|
||||
"""True if warp_m or warp_n is a wildcard (ANY_INT = -1)."""
|
||||
return config.get("warp_m", 32) == ANY_INT or config.get("warp_n", 32) == ANY_INT
|
||||
|
||||
|
||||
def needs_pipeline_expansion(config: dict) -> bool:
|
||||
"""True if pipeline is a wildcard (\"*\")."""
|
||||
return config.get("pipeline", "compv4") == "*"
|
||||
@@ -109,7 +109,7 @@ inline void register_all_kernels()
|
||||
"""
|
||||
|
||||
output_file.write_text(content)
|
||||
print(f"✓ Generated registration header: {output_file}")
|
||||
print(f"OK Generated registration header: {output_file}")
|
||||
|
||||
|
||||
def generate_registration_cpp(kernels: List[KernelConfig], output_file: Path):
|
||||
@@ -143,7 +143,7 @@ namespace generated {
|
||||
"""
|
||||
|
||||
output_file.write_text(content)
|
||||
print(f"✓ Generated registration implementation: {output_file}")
|
||||
print(f"OK Generated registration implementation: {output_file}")
|
||||
|
||||
|
||||
def generate_kernel_wrapper_header(kernel: KernelConfig, output_dir: Path):
|
||||
@@ -414,8 +414,8 @@ def main():
|
||||
with open(manifest_output, "w") as f:
|
||||
json.dump(manifest_data, f, indent=2)
|
||||
|
||||
print(f"✓ Generated manifest: {manifest_output}")
|
||||
print("\n✓ Registration code generation complete!")
|
||||
print(f"OK Generated manifest: {manifest_output}")
|
||||
print("\nOK Registration code generation complete!")
|
||||
print(f" Total kernels: {len(kernels)}")
|
||||
print(" Output files:")
|
||||
print(f" - {registration_header}")
|
||||
|
||||
@@ -17,10 +17,10 @@ Usage:
|
||||
|
||||
Output structure:
|
||||
build/kernel_wrappers/
|
||||
├── gemm_fp16_rcr_128x128x32.cpp
|
||||
├── gemm_fp16_rcr_256x256x64.cpp
|
||||
├── conv_fwd_fp16_2d_128x128.cpp
|
||||
└── ...
|
||||
|---- gemm_fp16_rcr_128x128x32.cpp
|
||||
|---- gemm_fp16_rcr_256x256x64.cpp
|
||||
|---- conv_fwd_fp16_2d_128x128.cpp
|
||||
+---- ...
|
||||
|
||||
Each .cpp simply includes its corresponding .hpp and forces symbol emission.
|
||||
"""
|
||||
|
||||
@@ -359,8 +359,8 @@ class ConvTraitConfig:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvKernelConfig:
|
||||
"""Complete convolution kernel configuration"""
|
||||
class GroupedConvKernelConfig:
|
||||
"""Complete grouped convolution kernel configuration"""
|
||||
|
||||
tile: ConvTileConfig = field(default_factory=ConvTileConfig)
|
||||
trait: ConvTraitConfig = field(default_factory=ConvTraitConfig)
|
||||
@@ -419,7 +419,11 @@ class ConvKernelConfig:
|
||||
|
||||
def kernel_name(self) -> str:
|
||||
"""Generate kernel name from config"""
|
||||
variant_map = {"forward": "fwd", "bwd_data": "bwdd", "bwd_weight": "bwdw"}
|
||||
variant_map = {
|
||||
"forward": "fwd",
|
||||
"bwd_data": "bwd_data",
|
||||
"bwd_weight": "bwd_weight",
|
||||
}
|
||||
var_str = variant_map.get(self.variant, self.variant)
|
||||
|
||||
name = f"conv_{var_str}_{self.dtype_input}_{self.ndim}d"
|
||||
@@ -433,11 +437,11 @@ class ConvKernelConfig:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvKernelConfigSet:
|
||||
class GroupedConvKernelConfigSet:
|
||||
"""A set of convolution kernel configurations loaded from JSON"""
|
||||
|
||||
name: str = "default"
|
||||
configs: List[ConvKernelConfig] = field(default_factory=list)
|
||||
configs: List[GroupedConvKernelConfig] = field(default_factory=list)
|
||||
|
||||
# Tile parameter ranges
|
||||
tile_m_values: List[int] = field(default_factory=lambda: [128])
|
||||
@@ -481,7 +485,7 @@ class ConvKernelConfigSet:
|
||||
layout: str = "nhwgc"
|
||||
gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"])
|
||||
|
||||
def generate_configs(self) -> Iterator[ConvKernelConfig]:
|
||||
def generate_configs(self) -> Iterator[GroupedConvKernelConfig]:
|
||||
"""Generate all kernel configurations (cartesian product)"""
|
||||
# Tile parameters
|
||||
tile_params = itertools.product(
|
||||
@@ -548,7 +552,7 @@ class ConvKernelConfigSet:
|
||||
double_smem_buffer=trait[6],
|
||||
num_groups_to_merge=trait[7],
|
||||
)
|
||||
yield ConvKernelConfig(
|
||||
yield GroupedConvKernelConfig(
|
||||
tile=tile_cfg,
|
||||
trait=trait_cfg,
|
||||
dtype_input=self.dtype_input,
|
||||
@@ -599,7 +603,9 @@ class ConvKernelConfigSet:
|
||||
return tile_count * trait_count * extra_count * len(self.gpu_targets)
|
||||
|
||||
|
||||
def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet:
|
||||
def load_grouped_conv_kernel_configs(
|
||||
json_path: str | Path,
|
||||
) -> GroupedConvKernelConfigSet:
|
||||
"""
|
||||
Load convolution kernel configurations from a JSON file.
|
||||
|
||||
@@ -607,14 +613,14 @@ def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet:
|
||||
json_path: Path to JSON configuration file
|
||||
|
||||
Returns:
|
||||
ConvKernelConfigSet with all parameter values loaded
|
||||
GroupedConvKernelConfigSet with all parameter values loaded
|
||||
"""
|
||||
json_path = Path(json_path)
|
||||
|
||||
with open(json_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
config_set = ConvKernelConfigSet()
|
||||
config_set = GroupedConvKernelConfigSet()
|
||||
|
||||
# Name
|
||||
config_set.name = data.get("kernel_set_name", json_path.stem)
|
||||
@@ -680,15 +686,15 @@ def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet:
|
||||
|
||||
|
||||
def generate_cpp_conv_kernel_set_declaration(
|
||||
config_set: ConvKernelConfigSet,
|
||||
config_set: GroupedConvKernelConfigSet,
|
||||
set_name: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate C++ DECL_CONV_KERNEL_SET code from a ConvKernelConfigSet.
|
||||
Generate C++ DECL_GROUPED_CONV_KERNEL_SET code from a GroupedConvKernelConfigSet.
|
||||
"""
|
||||
name = set_name or config_set.name
|
||||
|
||||
lines = [f"DECL_CONV_KERNEL_SET({name},"]
|
||||
lines = [f"DECL_GROUPED_CONV_KERNEL_SET({name},"]
|
||||
|
||||
for config in config_set.generate_configs():
|
||||
line = f' .add("{config.dtype_input}", "{config.variant}", {config.ndim}, '
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
Unified GEMM Code Generator - Single Source of Truth
|
||||
|
||||
This is THE unified code generator for all GEMM kernel variants:
|
||||
- Standard GEMM (C = A × B)
|
||||
- Standard GEMM (C = A x B)
|
||||
- Preshuffle GEMM (optimized weight access)
|
||||
- Multi-D GEMM (element-wise fusion)
|
||||
|
||||
@@ -25,6 +25,12 @@ from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
import concurrent.futures
|
||||
|
||||
from codegen_common import (
|
||||
TileConfig,
|
||||
TraitConfigBase,
|
||||
CommonTypeMappings as TypeMappings,
|
||||
)
|
||||
|
||||
# Import architecture filter for GPU-specific validation
|
||||
try:
|
||||
from arch_filter import ArchFilter, KernelConfig as ArchKernelConfig, OperatorType
|
||||
@@ -194,62 +200,14 @@ class GemmVariant(Enum):
|
||||
MULTI_D = "multi_d"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TileConfig:
|
||||
"""Tile configuration parameters"""
|
||||
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
warp_m: int
|
||||
warp_n: int
|
||||
warp_k: int
|
||||
warp_tile_m: int
|
||||
warp_tile_n: int
|
||||
warp_tile_k: int
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Validate tile configuration"""
|
||||
return (
|
||||
self.tile_m % (self.warp_m * self.warp_tile_m) == 0
|
||||
and self.tile_n % (self.warp_n * self.warp_tile_n) == 0
|
||||
and self.tile_k % (self.warp_k * self.warp_tile_k) == 0
|
||||
and self.tile_m > 0
|
||||
and self.tile_n > 0
|
||||
and self.tile_k > 0
|
||||
)
|
||||
# TileConfig imported from codegen_common
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraitConfig:
|
||||
"""Kernel trait configuration"""
|
||||
class TraitConfig(TraitConfigBase):
|
||||
"""GEMM-specific trait configuration extending TraitConfigBase with persistent mode."""
|
||||
|
||||
pipeline: str # mem, compv3, compv4
|
||||
epilogue: str # default, cshuffle
|
||||
scheduler: str # intrawave, interwave
|
||||
pad_m: bool
|
||||
pad_n: bool
|
||||
pad_k: bool
|
||||
persistent: bool
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if trait combination is valid"""
|
||||
# Unsupported combinations
|
||||
# Only 'mem' pipeline supports interwave scheduler.
|
||||
# All compute pipelines (compv3/v4/v5/v6/async) only support intrawave.
|
||||
unsupported = {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave"),
|
||||
("compv5", "cshuffle", "interwave"),
|
||||
("compv5", "default", "interwave"),
|
||||
("compv6", "cshuffle", "interwave"),
|
||||
("compv6", "default", "interwave"),
|
||||
("comp_async", "cshuffle", "interwave"),
|
||||
("comp_async", "default", "interwave"),
|
||||
}
|
||||
return (self.pipeline, self.epilogue, self.scheduler) not in unsupported
|
||||
persistent: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -345,89 +303,7 @@ class KernelConfig:
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TypeMappings:
|
||||
"""Centralized type mappings for code generation"""
|
||||
|
||||
DTYPE_TO_CK = {
|
||||
"fp16": "fp16_t",
|
||||
"bf16": "bf16_t",
|
||||
"fp32": "float",
|
||||
"fp8": "fp8_t",
|
||||
"bf8": "bf8_t",
|
||||
"int8": "int8_t",
|
||||
}
|
||||
|
||||
# Fully-qualified types for use outside of 'using namespace ck_tile' scope
|
||||
DTYPE_TO_CK_QUALIFIED = {
|
||||
"fp16": "ck_tile::fp16_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp32": "float", # Built-in type, no namespace
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf8": "ck_tile::bf8_t",
|
||||
"int8": "int8_t", # Built-in type
|
||||
}
|
||||
|
||||
DTYPE_TO_DISPATCHER = {
|
||||
"fp16": "DataType::FP16",
|
||||
"bf16": "DataType::BF16",
|
||||
"fp32": "DataType::FP32",
|
||||
"fp8": "DataType::FP8",
|
||||
"bf8": "DataType::BF8",
|
||||
"int8": "DataType::INT8",
|
||||
}
|
||||
|
||||
LAYOUT_TO_CK = {
|
||||
"r": "tensor_layout::gemm::RowMajor",
|
||||
"c": "tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
|
||||
LAYOUT_TO_DISPATCHER = {
|
||||
"r": "LayoutTag::RowMajor",
|
||||
"c": "LayoutTag::ColMajor",
|
||||
}
|
||||
|
||||
PIPELINE_TO_CK = {
|
||||
"mem": "GemmPipelineAgBgCrMem",
|
||||
"compv3": "GemmPipelineAgBgCrCompV3",
|
||||
"compv4": "GemmPipelineAgBgCrCompV4",
|
||||
"preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2",
|
||||
}
|
||||
|
||||
PIPELINE_TO_BASE = {
|
||||
"mem": "BaseGemmPipelineAgBgCrMem",
|
||||
"compv3": "BaseGemmPipelineAgBgCrCompV3",
|
||||
"compv4": "BaseGemmPipelineAgBgCrCompV4",
|
||||
"preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2",
|
||||
}
|
||||
|
||||
PIPELINE_TO_DISPATCHER = {
|
||||
"mem": "Pipeline::Mem",
|
||||
"compv3": "Pipeline::CompV3",
|
||||
"compv4": "Pipeline::CompV4",
|
||||
"preshufflev2": "Pipeline::PreShuffleV2",
|
||||
}
|
||||
|
||||
SCHEDULER_TO_CK = {
|
||||
"intrawave": "GemmPipelineScheduler::Intrawave",
|
||||
"interwave": "GemmPipelineScheduler::Interwave",
|
||||
"default": "GemmPipelineScheduler::Default",
|
||||
}
|
||||
|
||||
SCHEDULER_TO_DISPATCHER = {
|
||||
"intrawave": "Scheduler::Intrawave",
|
||||
"interwave": "Scheduler::Interwave",
|
||||
"default": "Scheduler::Auto",
|
||||
}
|
||||
|
||||
EPILOGUE_TO_DISPATCHER = {
|
||||
"cshuffle": "Epilogue::CShuffle",
|
||||
"default": "Epilogue::Default",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_output_dtype(dtype: str) -> str:
|
||||
"""Get output datatype (fp8/bf8 -> fp16)"""
|
||||
return "fp16" if dtype in ["fp8", "bf8"] else dtype
|
||||
# TypeMappings imported from codegen_common as CommonTypeMappings -> TypeMappings alias
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -1068,7 +944,11 @@ class UnifiedGemmCodegen:
|
||||
}
|
||||
|
||||
def generate_all(self, parallel: bool = True) -> Dict:
|
||||
"""Generate all kernels"""
|
||||
"""Generate all kernels.
|
||||
|
||||
When parallel=True, all configs across all variants are collected first,
|
||||
then generated concurrently in a single thread pool for maximum throughput.
|
||||
"""
|
||||
log.info("Generating GEMM kernels:")
|
||||
log.info(f" Datatype: {self.datatype}")
|
||||
log.info(f" Layout: {self.layout}")
|
||||
@@ -1078,49 +958,24 @@ class UnifiedGemmCodegen:
|
||||
|
||||
results = {"kernels": [], "wrappers": [], "failed": []}
|
||||
|
||||
# Get configurations
|
||||
# Collect ALL configs across all variants/preselected sets upfront
|
||||
all_configs = []
|
||||
if self.use_preselected:
|
||||
configs = self._get_preselected_configs()
|
||||
log.info(f" Total configurations: {len(configs)}")
|
||||
all_configs = self._get_preselected_configs()
|
||||
log.info(f" Total configurations: {len(all_configs)}")
|
||||
else:
|
||||
for variant in self.variants:
|
||||
log.info(f"\nGenerating {variant.value} kernels...")
|
||||
configs = self._get_configs_for_variant(variant)
|
||||
log.info(f" Configurations: {len(configs)}")
|
||||
log.info(f" {variant.value}: {len(configs)} configurations")
|
||||
all_configs.extend(configs)
|
||||
log.info(f" Total across all variants: {len(all_configs)}")
|
||||
|
||||
if parallel:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(self._generate_one, cfg) for cfg in configs
|
||||
]
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
k, w = future.result()
|
||||
results["kernels"].append(k)
|
||||
results["wrappers"].append(w)
|
||||
except Exception as e:
|
||||
results["failed"].append(str(e))
|
||||
log.error(f"Failed: {e}")
|
||||
else:
|
||||
for cfg in configs:
|
||||
try:
|
||||
k, w = self._generate_one(cfg)
|
||||
results["kernels"].append(k)
|
||||
results["wrappers"].append(w)
|
||||
except Exception as e:
|
||||
results["failed"].append(str(e))
|
||||
log.error(f"Failed: {e}")
|
||||
|
||||
# Generate registration header
|
||||
if results["wrappers"]:
|
||||
self._generate_registration_header(results["wrappers"])
|
||||
|
||||
return results
|
||||
|
||||
# Generate from preselected set
|
||||
if parallel:
|
||||
# Generate all configs in a single parallel pass
|
||||
if parallel and all_configs:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(self._generate_one, cfg) for cfg in configs]
|
||||
futures = [
|
||||
executor.submit(self._generate_one, cfg) for cfg in all_configs
|
||||
]
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
k, w = future.result()
|
||||
@@ -1130,7 +985,7 @@ class UnifiedGemmCodegen:
|
||||
results["failed"].append(str(e))
|
||||
log.error(f"Failed: {e}")
|
||||
else:
|
||||
for cfg in configs:
|
||||
for cfg in all_configs:
|
||||
try:
|
||||
k, w = self._generate_one(cfg)
|
||||
results["kernels"].append(k)
|
||||
@@ -1139,7 +994,6 @@ class UnifiedGemmCodegen:
|
||||
results["failed"].append(str(e))
|
||||
log.error(f"Failed: {e}")
|
||||
|
||||
# Generate registration header
|
||||
if results["wrappers"]:
|
||||
self._generate_registration_header(results["wrappers"])
|
||||
|
||||
@@ -1638,12 +1492,19 @@ def main():
|
||||
|
||||
# Write to temp file and use as config
|
||||
import tempfile
|
||||
import os as _os
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
_tmp_config = tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as f:
|
||||
json.dump(full_config, f)
|
||||
args.config = Path(f.name)
|
||||
)
|
||||
try:
|
||||
json.dump(full_config, _tmp_config)
|
||||
_tmp_config.close()
|
||||
args.config = Path(_tmp_config.name)
|
||||
except Exception:
|
||||
_tmp_config.close()
|
||||
_os.unlink(_tmp_config.name)
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Invalid tile-config-json: {e}")
|
||||
return 1
|
||||
@@ -1672,7 +1533,7 @@ def main():
|
||||
|
||||
results = codegen.generate_all(parallel=not args.no_parallel)
|
||||
|
||||
logging.info("\n✅ Generation complete!")
|
||||
logging.info("\nGeneration complete.")
|
||||
logging.info(f" Kernels: {len(results['kernels'])}")
|
||||
logging.info(f" Wrappers: {len(results['wrappers'])}")
|
||||
logging.info(f" Failed: {len(results['failed'])}")
|
||||
@@ -1684,7 +1545,7 @@ def main():
|
||||
|
||||
# Generate dispatcher registration if requested
|
||||
if args.register:
|
||||
logging.info("\n📝 Generating dispatcher registration code...")
|
||||
logging.info("\nGenerating dispatcher registration code...")
|
||||
try:
|
||||
from generate_dispatcher_registration import (
|
||||
scan_generated_headers,
|
||||
@@ -1701,11 +1562,20 @@ def main():
|
||||
)
|
||||
generate_registration_cpp(kernels, reg_dir / "dispatcher_registration.cpp")
|
||||
|
||||
logging.info(f"✓ Generated registration code for {len(kernels)} kernels")
|
||||
logging.info(f"Generated registration code for {len(kernels)} kernels")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to generate registration code: {e}")
|
||||
return 1
|
||||
|
||||
# Clean up temp config file if we created one
|
||||
if args.tile_config_json and args.config and args.config.exists():
|
||||
try:
|
||||
import os as _os
|
||||
|
||||
_os.unlink(args.config)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return 0 if not results["failed"] else 1
|
||||
|
||||
|
||||
|
||||
1757
dispatcher/codegen/unified_grouped_conv_codegen.py
Normal file
1757
dispatcher/codegen/unified_grouped_conv_codegen.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user