mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
[CK_TILE] Integrate CK Tile Dispatcher code generation into CK Tile Profiler (#7284) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation CK Tile is going to be delivered to hipDNN via CK Dispatcher. Currently the CK Tile Profiler using CK Builder for generating the profiled instances from the configuration files that identify the instances that old CK exposes. We need to replace this instance generation with the CK Tile Dispatcher codegen. ## Technical Details The old CK Profiler config files are converted to JSON files that the CK Tile Dispatcher can digest. The conversion script for configurations is stored to source control in case we need to update the JSON configurations later. The dispatcher generates instance libraries per conv direction (fwd, bwd data, and bwd weight) that are linked to the CK Profiler executable. I also implemented codegne for the stream-K and depthwise conv instances. The proposed solution replaces the CK Builder codegen with the CK Tile Dispatcher codegen. There are two new methods that are exposed via the dispatcher backend - `is_supported` - required to enabled the profiler workflow where we check the applicability of the kernel instance before running it. - `get_instance_string` - this mainly for verification. This provide the CK Builder instance string for verifying that the old CK Builder based profiler and the new CK Tile Dispatcher based profiler have the same instances. The rules that limit the generated instances are now collected to a single location under the dispacther. The CK Builder codegen uses these, which ensures that the two codegen pipelines are in sync. The next step (different PR) is to remove the CK Builder codegen pipeline altogether. ## Test Plan Verified that the old CK Builder based profiler and the new CK Tile Dispatcher based profiler have the same instances, that is, the Dispatcher based codgen can generate the same instances as the old CK Builder. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2824 lines
109 KiB
Python
2824 lines
109 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Unified Grouped Convolution Code Generator
|
|
|
|
This is the unified code generator for all grouped convolution kernel variants:
|
|
- Forward grouped convolution
|
|
- Backward data grouped convolution
|
|
- Backward weight grouped convolution
|
|
|
|
Generates both CK Tile kernels AND dispatcher wrappers.
|
|
Based on the GEMM codegen pattern.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import List, Optional, Tuple, Union
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
|
|
from codegen_common import (
|
|
TileConfig,
|
|
TraitConfigBase,
|
|
parallel_generate,
|
|
)
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
|
log = logging.getLogger(__name__)
|
|
|
|
# Import architecture filter for GPU-specific validation
|
|
try:
|
|
from arch_filter import ArchFilter, OperatorType
|
|
|
|
HAS_ARCH_FILTER = True
|
|
except ImportError:
|
|
HAS_ARCH_FILTER = False
|
|
ArchFilter = None
|
|
OperatorType = None
|
|
|
|
# Import tile configurations and shared validation rules from grouped_config_rules
|
|
# (single source of truth)
|
|
try:
|
|
from grouped_config_rules import (
|
|
COMMON_TILES,
|
|
TILE_TO_WAVE,
|
|
TILE_TO_WARP,
|
|
VARIANT_PIPELINES,
|
|
BWD_WEIGHT_TILES,
|
|
COMPV4_COMPATIBLE_TILES,
|
|
# Shared validation functions
|
|
check_vectors,
|
|
check_warp_coverage,
|
|
check_bwd_data_vec_coverage,
|
|
is_valid_pipeline_for_variant,
|
|
is_streamk_valid_for_variant,
|
|
)
|
|
HAS_TILE_CONFIGS = True
|
|
except ImportError:
|
|
HAS_TILE_CONFIGS = False
|
|
COMMON_TILES = []
|
|
TILE_TO_WAVE = {}
|
|
TILE_TO_WARP = {}
|
|
VARIANT_PIPELINES = {}
|
|
BWD_WEIGHT_TILES = []
|
|
COMPV4_COMPATIBLE_TILES = []
|
|
|
|
|
|
# ============================================================================
|
|
# Configuration and Data Structures
|
|
# ============================================================================
|
|
|
|
|
|
class GroupedConvVariant(Enum):
|
|
"""Grouped convolution kernel variants"""
|
|
|
|
FORWARD = "forward"
|
|
FORWARD_DEPTHWISE = "forward_depthwise"
|
|
BACKWARD_DATA = "bwd_data"
|
|
BACKWARD_WEIGHT = "bwd_weight"
|
|
|
|
|
|
class GroupedConvLayout(Enum):
|
|
"""Grouped convolution data layouts"""
|
|
|
|
# 1D
|
|
NWGC = "NWGC" # Input/Output: N W G C
|
|
GKXC = "GKXC" # Weight: G K X C
|
|
NWGK = "NWGK" # Output: N W G K
|
|
|
|
# 2D
|
|
NHWGC = "NHWGC" # Input: N H W G C
|
|
GKYXC = "GKYXC" # Weight: G K Y X C
|
|
NHWGK = "NHWGK" # Output: N H W G K
|
|
|
|
# 3D
|
|
NDHWGC = "NDHWGC" # Input: N D H W G C
|
|
GKZYXC = "GKZYXC" # Weight: G K Z Y X C
|
|
NDHWGK = "NDHWGK" # Output: N D H W G K
|
|
|
|
|
|
class StreamKReductionStrategy(Enum):
|
|
"""Strategies for stream-K reduction"""
|
|
TREE = "TREE"
|
|
LINEAR = "LINEAR"
|
|
|
|
@dataclass
|
|
class StreamKConfig:
|
|
"""Configuration for stream-K"""
|
|
|
|
streamk_enabled: bool = False
|
|
strategy: StreamKReductionStrategy = StreamKReductionStrategy.TREE
|
|
streamk_persistent: bool = False
|
|
|
|
|
|
@dataclass
|
|
class GroupedConvTraitConfig(TraitConfigBase):
|
|
"""Kernel trait configuration for grouped convolution (extends TraitConfigBase).
|
|
|
|
Conv-specific extensions beyond TraitConfigBase. These map to
|
|
GroupedConvTraits template parameters in grouped_convolution_utils.hpp:
|
|
- double_smem_buffer: ping-pong LDS for compute V4+ pipelines
|
|
- num_groups_to_merge: fuse multiple groups into one tile (NumGroupsToMerge)
|
|
- split_image: split spatial dims for large tensors (EnableSplitImage)
|
|
- explicit_gemm: use explicit GEMM path (ExplicitGemm)
|
|
- two_stage: two-stage bwd_weight with fp32 workspace + elementwise convert
|
|
|
|
Note: CK Tile already uses long_index_t (64-bit) for group strides and
|
|
batch offsets, so there is no separate "large_tensor" flag. For large
|
|
spatial dimensions, use split_image=True instead.
|
|
"""
|
|
|
|
double_smem_buffer: bool = False
|
|
num_groups_to_merge: int = 1
|
|
split_image: bool = False
|
|
explicit_gemm: bool = False
|
|
two_stage: bool = False
|
|
specialization: str = "default" # default, filter1x1_pad0, filter1x1_stride1_pad0, filter3x3
|
|
streamk_config: StreamKConfig = field(default_factory=StreamKConfig)
|
|
|
|
|
|
# Backward compatibility alias
|
|
TraitConfig = GroupedConvTraitConfig
|
|
|
|
|
|
def deduce_block_per_cu(pipeline: str, double_smem_buffer: bool) -> int:
|
|
"""Deduce the minimum blocks-per-CU hint from pipeline type and LDS buffering mode.
|
|
|
|
Rules derived from pipeline LDS allocation (see pipeline headers):
|
|
- compv4 / comp_async: mandatory double LDS (static_assert enforced).
|
|
Double LDS halves achievable occupancy by itself, so we set block_per_cu=1
|
|
to let the compiler use as many registers as it needs.
|
|
- compv1/v2, basic_v1/v2, basic_async_v1: always single LDS (hardcoded false).
|
|
Hence, so we set block_per_cu=2.
|
|
Matches the CK Tile global default (CK_TILE_MIN_BLOCK_PER_CU=2).
|
|
- mem, compv3, compv5, compv6: configurable via double_smem_buffer.
|
|
Follow the same logic: 1 when double buffering, 2 when single.
|
|
"""
|
|
# Pipelines that mandate double LDS (no user choice)
|
|
_ALWAYS_DOUBLE = {"compv4", "comp_async"}
|
|
# Pipelines that mandate single LDS (no user choice)
|
|
_ALWAYS_SINGLE = {"compv1", "compv2", "basic_v1", "basic_v2", "basic_async_v1"}
|
|
|
|
if pipeline in _ALWAYS_DOUBLE:
|
|
return 1
|
|
if pipeline in _ALWAYS_SINGLE:
|
|
return 2
|
|
# Configurable pipelines (mem, compv3, compv5, compv6, ...)
|
|
return 1 if double_smem_buffer else 2
|
|
|
|
|
|
@dataclass
|
|
class GroupedConvKernelConfig:
|
|
"""Complete grouped convolution kernel configuration"""
|
|
|
|
tile: TileConfig
|
|
trait: GroupedConvTraitConfig
|
|
variant: GroupedConvVariant = GroupedConvVariant.FORWARD
|
|
ndim_spatial: int = 2 # 1D, 2D, or 3D
|
|
arch: str = "gfx942" # Target architecture
|
|
layout: Union[str, GroupedConvLayout] = (
|
|
"nhwgc" # Data layout (e.g., "nhwgc", "ndhwgc")
|
|
)
|
|
|
|
# Vector sizes: a=4 for fp16 input (8-byte aligned global loads),
|
|
# b=8 for weight tensor, c=8 for output stores. These match the
|
|
# CK Tile default vectorization widths for fp16 on CDNA3 (gfx942).
|
|
vector_size_a: int = 4
|
|
vector_size_b: int = 8
|
|
vector_size_c: int = 8
|
|
vector_sizes: Optional[Tuple[int, int, int]] = None
|
|
|
|
# Merging multiple conv groups into a single GEMM batch.
|
|
# By default no merging. This helps when the number of channel per groups is small.
|
|
num_groups_to_merge: int = 1
|
|
|
|
# Occupancy parameters
|
|
num_wave_groups: int = 1
|
|
|
|
# Double buffering
|
|
double_smem_buffer: bool = False
|
|
|
|
def __post_init__(self):
|
|
if self.vector_sizes is not None:
|
|
self.vector_size_a, self.vector_size_b, self.vector_size_c = (
|
|
self.vector_sizes[:3]
|
|
)
|
|
# Sync trait fields with top-level fields (trait is source of truth
|
|
# when both are specified, but top-level overrides default trait values).
|
|
if self.double_smem_buffer and not self.trait.double_smem_buffer:
|
|
self.trait.double_smem_buffer = self.double_smem_buffer
|
|
elif self.trait.double_smem_buffer:
|
|
self.double_smem_buffer = self.trait.double_smem_buffer
|
|
if self.num_groups_to_merge != 1 and self.trait.num_groups_to_merge == 1:
|
|
self.trait.num_groups_to_merge = self.num_groups_to_merge
|
|
elif self.trait.num_groups_to_merge != 1:
|
|
self.num_groups_to_merge = self.trait.num_groups_to_merge
|
|
|
|
@property
|
|
def block_per_cu(self) -> int:
|
|
"""Deduce min blocks-per-CU from pipeline type and LDS buffering mode."""
|
|
return deduce_block_per_cu(self.trait.pipeline, self.double_smem_buffer)
|
|
|
|
def _layout_str(self) -> str:
|
|
"""Get layout as lowercase string for naming."""
|
|
if hasattr(self.layout, "value"):
|
|
return self.layout.value.lower()
|
|
return str(self.layout).lower()
|
|
|
|
def name(self, datatype: str) -> str:
|
|
"""
|
|
Generate kernel name that uniquely identifies the kernel configuration.
|
|
|
|
Format: grouped_conv_{variant}_{dtype}_{layout}_{ndim}d_{pipeline}_{epilogue}_{scheduler}
|
|
_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}
|
|
_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}
|
|
[_vec{a}_{b}_{c}][_bpc{n}][_wg{n}][_gm{n}][_dsb][_pad{mnk}]
|
|
|
|
All parameters that affect kernel behavior MUST be included to ensure
|
|
unique names for unique configurations:
|
|
- Variant (fwd/bwd_data/bwd_weight)
|
|
- Data type
|
|
- Layout (nhwgc, nchw, ndhwgc, etc.)
|
|
- Spatial dimensions (2d/3d)
|
|
- Pipeline, epilogue, scheduler
|
|
- Tile, warp, warp_tile dimensions
|
|
- Vector sizes, occupancy hints (if non-default)
|
|
- Double SMEM buffer, padding flags
|
|
"""
|
|
t = self.tile
|
|
tr = self.trait
|
|
layout_str = self._layout_str()
|
|
|
|
variant_str = {
|
|
GroupedConvVariant.FORWARD: "fwd",
|
|
GroupedConvVariant.BACKWARD_DATA: "bwd_data",
|
|
GroupedConvVariant.BACKWARD_WEIGHT: "bwd_weight",
|
|
}[self.variant]
|
|
|
|
# Core identity: variant, dtype, layout, dims
|
|
name = (
|
|
f"grouped_conv_{variant_str}_{datatype}_{layout_str}_{self.ndim_spatial}d"
|
|
)
|
|
|
|
# Pipeline configuration
|
|
name += f"_{tr.pipeline}_{tr.epilogue}_{tr.scheduler}"
|
|
|
|
# Block tile dimensions (M_Tile x N_Tile x K_Tile)
|
|
name += f"_{t.tile_m}x{t.tile_n}x{t.tile_k}"
|
|
|
|
# Wave distribution (M_Warp x N_Warp x K_Warp)
|
|
name += f"_{t.warp_m}x{t.warp_n}x{t.warp_k}"
|
|
|
|
# Warp tile dimensions (M_Warp_Tile x N_Warp_Tile x K_Warp_Tile)
|
|
name += f"_{t.warp_tile_m}x{t.warp_tile_n}x{t.warp_tile_k}"
|
|
|
|
# Vector sizes (only if non-default)
|
|
if (self.vector_size_a, self.vector_size_b, self.vector_size_c) != (4, 8, 8):
|
|
name += (
|
|
f"_vec{self.vector_size_a}_{self.vector_size_b}_{self.vector_size_c}"
|
|
)
|
|
|
|
if self.num_wave_groups != 1:
|
|
name += f"_wg{self.num_wave_groups}"
|
|
|
|
if self.num_groups_to_merge != 1:
|
|
name += f"_gm{self.num_groups_to_merge}"
|
|
|
|
# Double SMEM buffer (for compute V4+)
|
|
if self.double_smem_buffer or tr.double_smem_buffer:
|
|
name += "_dsb"
|
|
|
|
# Two-stage bwd_weight (fp32 workspace + elementwise convert)
|
|
if tr.two_stage:
|
|
name += "_2stage"
|
|
|
|
# Specialization suffix (only if non-default)
|
|
if hasattr(tr, "specialization") and tr.specialization != "default":
|
|
name += f"_{tr.specialization}"
|
|
|
|
if tr.explicit_gemm:
|
|
name += "_explicit_gemm"
|
|
|
|
# Stream-K suffix
|
|
sk = tr.streamk_config
|
|
if sk.streamk_enabled:
|
|
name += f"_streamk_{sk.strategy.value.lower()}"
|
|
if sk.streamk_persistent:
|
|
name += "_persistent"
|
|
|
|
# Large tensor (split image) suffix
|
|
if tr.split_image:
|
|
name += "_large_tensor"
|
|
|
|
# Padding suffix (only if not all enabled)
|
|
if not (tr.pad_m and tr.pad_n and tr.pad_k):
|
|
name += f"_pad{int(tr.pad_m)}{int(tr.pad_n)}{int(tr.pad_k)}"
|
|
|
|
return name
|
|
|
|
def is_valid_for_arch(self, arch: Optional[str] = None) -> bool:
|
|
"""Check if configuration is valid for target architecture.
|
|
|
|
Uses shared validation rules from grouped_config_rules.py.
|
|
"""
|
|
target_arch = arch if arch is not None else self.arch
|
|
|
|
# Check trait validity (pipeline+epilogue+scheduler combination)
|
|
if not self.trait.is_valid():
|
|
return False
|
|
|
|
tr = self.trait
|
|
variant_str = self.variant.value # e.g. "forward", "bwd_data", "bwd_weight"
|
|
|
|
# Stream-K is only supported for backward_weight
|
|
if tr.streamk_config.streamk_enabled and not is_streamk_valid_for_variant(variant_str):
|
|
return False
|
|
|
|
# Backward operations reject compv5
|
|
if not is_valid_pipeline_for_variant(tr.pipeline, variant_str):
|
|
return False
|
|
|
|
# Reject irregular vector sizes (AMD GPUs: 1, 2, 4, 8, 16 only)
|
|
if not check_vectors(self.vector_size_a, self.vector_size_b, self.vector_size_c):
|
|
log.warning(
|
|
f"Rejecting config: irregular vector size "
|
|
f"(vec_a={self.vector_size_a}, vec_b={self.vector_size_b}, "
|
|
f"vec_c={self.vector_size_c})"
|
|
)
|
|
return False
|
|
|
|
# Reject tile dims that exceed single-warp vector load coverage
|
|
t = self.tile
|
|
if not check_warp_coverage(
|
|
t.tile_m, t.tile_n, t.tile_k,
|
|
self.vector_size_a, self.vector_size_b,
|
|
variant=variant_str,
|
|
):
|
|
log.warning(
|
|
f"Rejecting config: tile exceeds warp coverage "
|
|
f"(tile={t.tile_m}x{t.tile_n}x{t.tile_k}, "
|
|
f"vec_a={self.vector_size_a}, vec_b={self.vector_size_b})"
|
|
)
|
|
return False
|
|
|
|
# Bwd_data only: vector width must not exceed elements per thread
|
|
if self.variant == GroupedConvVariant.BACKWARD_DATA:
|
|
if not check_bwd_data_vec_coverage(
|
|
t.tile_m, t.tile_n, t.tile_k,
|
|
t.warp_m, t.warp_n, t.warp_k,
|
|
self.vector_size_a, self.vector_size_b,
|
|
):
|
|
log.warning(
|
|
f"Rejecting bwd_data config: vec exceeds tile coverage "
|
|
f"(tile={t.tile_m}x{t.tile_n}x{t.tile_k}, "
|
|
f"vec_a={self.vector_size_a}, vec_b={self.vector_size_b})"
|
|
)
|
|
return False
|
|
|
|
# Check warp configuration (from arch_specs)
|
|
try:
|
|
from arch_specs_generated import WARP_SUPPORTED_COMBINATIONS
|
|
|
|
supported = WARP_SUPPORTED_COMBINATIONS.get(target_arch)
|
|
if supported is None:
|
|
return False # Unknown architecture
|
|
warp_cfg = [t.warp_m, t.warp_n, t.warp_k]
|
|
if warp_cfg not in supported:
|
|
return False
|
|
except ImportError:
|
|
pass # Allow if arch_specs not available
|
|
|
|
return True
|
|
|
|
|
|
@dataclass
|
|
class DepthwiseConvKernelConfig:
|
|
"""Complete depthwise convolution kernel configuration.
|
|
"""
|
|
|
|
# Depthwise tile parameters
|
|
tile_h: int = 8
|
|
tile_w: int = 8
|
|
filt: int = 3 # filter_h == filter_w (square filters)
|
|
str_h: int = 1
|
|
str_w: int = 1
|
|
pad_h: int = 1
|
|
pad_w: int = 1
|
|
nbatch: int = 1
|
|
sub_h: int = 1
|
|
sub_w: int = 1
|
|
in_vec: int = 1
|
|
out_vec: int = 1
|
|
|
|
# Fixed parameters (depthwise always uses these)
|
|
block_size: int = 64
|
|
dil_h: int = 1
|
|
dil_w: int = 1
|
|
ndim_spatial: int = 2
|
|
|
|
# Metadata
|
|
arch: str = "gfx942"
|
|
layout: str = "ngchw"
|
|
datatype: str = "fp16"
|
|
|
|
def name(self, datatype: str) -> str:
|
|
"""Generate unique kernel name for depthwise convolution."""
|
|
return (
|
|
f"grouped_conv_fwd_depthwise_{datatype}_{self.layout}_{self.ndim_spatial}d"
|
|
f"_{self.tile_h}x{self.tile_w}"
|
|
f"_f{self.filt}"
|
|
f"_s{self.str_h}x{self.str_w}"
|
|
f"_p{self.pad_h}x{self.pad_w}"
|
|
f"_nb{self.nbatch}"
|
|
f"_sub{self.sub_h}x{self.sub_w}"
|
|
f"_vec{self.in_vec}_{self.out_vec}"
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# Type Mappings
|
|
# ============================================================================
|
|
|
|
|
|
class GroupedConvTypeMappings:
|
|
"""Centralized type mappings for grouped convolution code generation"""
|
|
|
|
DTYPE_TO_CK = {
|
|
"fp16": "half_t",
|
|
"bf16": "bf16_t",
|
|
"fp32": "float",
|
|
}
|
|
|
|
# CK Tile conv pipelines (from conv_configs.hpp PipelineTypeTraits).
|
|
# basic_v1/mem/compv3 use GroupedConvUniversalPipelineAgBgCrPolicy;
|
|
# compv4/compv5/compv6/comp_async/basic_async_v1 use their own default policy.
|
|
PIPELINE_TO_CK = {
|
|
"basic_v1": "GemmPipeline::BASIC_V1",
|
|
"basic_v2": "GemmPipeline::BASIC_V2",
|
|
"compv1": "GemmPipeline::BASIC_V1", # alias used by dispatcher/converter
|
|
"compv2": "GemmPipeline::BASIC_V2", # alias used by dispatcher/converter
|
|
"mem": "GemmPipeline::MEMORY",
|
|
"compv3": "GemmPipeline::COMPUTE_V3",
|
|
"compv4": "GemmPipeline::COMPUTE_V4",
|
|
"compv5": "GemmPipeline::COMPUTE_V5",
|
|
"compv6": "GemmPipeline::COMPUTE_V6",
|
|
"comp_async": "GemmPipeline::COMPUTE_ASYNC",
|
|
"basic_async_v1": "GemmPipeline::BASIC_ASYNC_V1",
|
|
}
|
|
|
|
SCHEDULER_TO_CK = {
|
|
"intrawave": "GemmPipelineScheduler::Intrawave",
|
|
"interwave": "GemmPipelineScheduler::Interwave",
|
|
}
|
|
|
|
LAYOUT_1D = {
|
|
"in": "tensor_layout::convolution::NWGC",
|
|
"wei": "tensor_layout::convolution::GKXC",
|
|
"out": "tensor_layout::convolution::NWGK",
|
|
}
|
|
|
|
LAYOUT_2D = {
|
|
"in": "tensor_layout::convolution::NHWGC",
|
|
"wei": "tensor_layout::convolution::GKYXC",
|
|
"out": "tensor_layout::convolution::NHWGK",
|
|
}
|
|
|
|
LAYOUT_3D = {
|
|
"in": "tensor_layout::convolution::NDHWGC",
|
|
"wei": "tensor_layout::convolution::GKZYXC",
|
|
"out": "tensor_layout::convolution::NDHWGK",
|
|
}
|
|
|
|
@classmethod
|
|
def get_layouts(cls, ndim: int) -> dict:
|
|
if ndim == 1:
|
|
return cls.LAYOUT_1D
|
|
elif ndim == 2:
|
|
return cls.LAYOUT_2D
|
|
else:
|
|
return cls.LAYOUT_3D
|
|
|
|
|
|
# ============================================================================
|
|
# CK Tile Grouped Conv Kernel Generator
|
|
# ============================================================================
|
|
|
|
|
|
class CKTileGroupedConvKernelGenerator:
|
|
"""Generates CK Tile grouped convolution kernel instance code"""
|
|
|
|
def __init__(
|
|
self,
|
|
datatype: str,
|
|
variant: GroupedConvVariant = GroupedConvVariant.FORWARD,
|
|
):
|
|
self.datatype = datatype
|
|
self.variant = variant
|
|
self.tm = GroupedConvTypeMappings()
|
|
|
|
def generate(self, config: GroupedConvKernelConfig) -> str:
|
|
"""Generate complete CK Tile grouped convolution kernel"""
|
|
kernel_name = config.name(self.datatype)
|
|
return f"""{self._header(kernel_name, config)}
|
|
{self._config_struct(config, kernel_name)}
|
|
{self._kernel_instance(config, kernel_name)}
|
|
"""
|
|
|
|
def _header(self, kernel_name: str, config: GroupedConvKernelConfig) -> str:
|
|
"""Generate header includes based on variant"""
|
|
if self.variant == GroupedConvVariant.BACKWARD_DATA:
|
|
kernel_header = "grouped_convolution_backward_data_kernel.hpp"
|
|
elif self.variant == GroupedConvVariant.BACKWARD_WEIGHT:
|
|
kernel_header = "grouped_convolution_backward_weight_kernel.hpp"
|
|
else:
|
|
kernel_header = "grouped_convolution_forward_kernel.hpp"
|
|
|
|
elementwise_include = ""
|
|
if config.trait.two_stage:
|
|
elementwise_include = '\n#include "ck_tile/ops/elementwise.hpp"'
|
|
|
|
streamk_include = ""
|
|
if config.trait.streamk_config.streamk_enabled:
|
|
streamk_include = '\n#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"'
|
|
|
|
return f"""// SPDX-License-Identifier: MIT
|
|
// Auto-generated CK Tile Grouped Convolution kernel: {kernel_name}
|
|
// Variant: {self.variant.value}
|
|
#pragma once
|
|
|
|
#include <cstdint>
|
|
#include <numeric>
|
|
#include <functional>
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/host/kernel_launch.hpp"
|
|
#include "ck_tile/ops/gemm.hpp"
|
|
#include "ck_tile/ops/grouped_convolution.hpp"
|
|
#include "ck_tile/ops/epilogue.hpp"
|
|
#include "ck_tile/ops/grouped_convolution/kernel/{kernel_header}"
|
|
#include "ck_tile/ops/grouped_convolution/pipeline/grouped_conv_universal_pipeline_ag_bg_cr_policy.hpp"{elementwise_include}{streamk_include}
|
|
|
|
using namespace ck_tile;
|
|
"""
|
|
|
|
def _config_struct(self, config: GroupedConvKernelConfig, kernel_name: str) -> str:
|
|
"""Generate config struct"""
|
|
t = config.tile
|
|
tr = config.trait
|
|
layouts = self.tm.get_layouts(config.ndim_spatial)
|
|
|
|
return f"""
|
|
// Kernel configuration
|
|
struct {kernel_name}_Config {{
|
|
// Data types
|
|
using InDataType = {self.tm.DTYPE_TO_CK[self.datatype]};
|
|
using WeiDataType = {self.tm.DTYPE_TO_CK[self.datatype]};
|
|
using AccDataType = float;
|
|
using OutDataType = {self.tm.DTYPE_TO_CK[self.datatype]};
|
|
|
|
// Layouts
|
|
using InLayout = {layouts["in"]};
|
|
using WeiLayout = {layouts["wei"]};
|
|
using OutLayout = {layouts["out"]};
|
|
|
|
// Tile shape
|
|
static constexpr index_t M_Tile = {t.tile_m};
|
|
static constexpr index_t N_Tile = {t.tile_n};
|
|
static constexpr index_t K_Tile = {t.tile_k};
|
|
|
|
static constexpr index_t M_Warp = {t.warp_m};
|
|
static constexpr index_t N_Warp = {t.warp_n};
|
|
static constexpr index_t K_Warp = {t.warp_k};
|
|
|
|
static constexpr index_t M_Warp_Tile = {t.warp_tile_m};
|
|
static constexpr index_t N_Warp_Tile = {t.warp_tile_n};
|
|
static constexpr index_t K_Warp_Tile = {t.warp_tile_k};
|
|
|
|
// Vector sizes
|
|
static constexpr index_t VectorSizeA = {config.vector_size_a};
|
|
static constexpr index_t VectorSizeB = {config.vector_size_b};
|
|
static constexpr index_t VectorSizeC = {config.vector_size_c};
|
|
|
|
// Padding
|
|
static constexpr bool kPadM = {str(tr.pad_m).lower()};
|
|
static constexpr bool kPadN = {str(tr.pad_n).lower()};
|
|
static constexpr bool kPadK = {str(tr.pad_k).lower()};
|
|
|
|
// Pipeline & Epilogue
|
|
static constexpr auto Pipeline = {self.tm.PIPELINE_TO_CK[tr.pipeline]};
|
|
static constexpr auto Scheduler = {self.tm.SCHEDULER_TO_CK[tr.scheduler]};
|
|
static constexpr bool DoubleSmemBuffer = {str(tr.double_smem_buffer).lower()};
|
|
static constexpr bool UseCShuffleEpilogue = {str(tr.epilogue == "cshuffle").lower()};
|
|
|
|
// Other params
|
|
static constexpr int kBlockPerCu = {config.block_per_cu};
|
|
static constexpr index_t NumWaveGroups = {config.num_wave_groups};
|
|
static constexpr index_t NumGroupsToMerge = {tr.num_groups_to_merge};
|
|
static constexpr bool EnableSplitImage = {str(tr.split_image).lower()};
|
|
static constexpr bool ExplicitGemm = {str(tr.explicit_gemm).lower()};
|
|
static constexpr index_t NDimSpatial = {config.ndim_spatial};
|
|
|
|
// Target architecture
|
|
static constexpr const char* TargetArch = "{config.arch}";
|
|
}};
|
|
"""
|
|
|
|
def _kernel_instance(
|
|
self, config: GroupedConvKernelConfig, kernel_name: str
|
|
) -> str:
|
|
"""Generate kernel instantiation code with launch function"""
|
|
tr = config.trait
|
|
|
|
if self.variant == GroupedConvVariant.BACKWARD_WEIGHT and tr.streamk_config.streamk_enabled:
|
|
return self._kernel_instance_streamk(config, kernel_name)
|
|
|
|
if self.variant == GroupedConvVariant.BACKWARD_WEIGHT and tr.two_stage:
|
|
return self._kernel_instance_two_stage(config, kernel_name)
|
|
|
|
# Variant-specific configuration
|
|
if self.variant == GroupedConvVariant.BACKWARD_DATA:
|
|
host_args_type = "GroupedConvBwdDataHostArgs"
|
|
kernel_type = "GroupedConvolutionBackwardDataKernel"
|
|
gemm_traits = "GroupedConvImplicitGemmTraitsBwdData"
|
|
layout_suffix = "BwdData"
|
|
# For bwd_data: A=dOutput, B=Weight, C=dInput
|
|
a_dtype = "OutDataType"
|
|
b_dtype = "WeiDataType"
|
|
c_dtype = "InDataType"
|
|
gemm_k_calc = "args.K_ * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end()"
|
|
direction_prefix = "BWD_DATA"
|
|
launcher_alias = "SelectedConvBwdDataLauncher"
|
|
elif self.variant == GroupedConvVariant.BACKWARD_WEIGHT:
|
|
host_args_type = "GroupedConvBwdWeightHostArgs"
|
|
kernel_type = "GroupedConvolutionBackwardWeightKernel"
|
|
gemm_traits = "GroupedConvImplicitGemmTraitsBwdWeight"
|
|
layout_suffix = "BwdWeight"
|
|
# For bwd_weight: A=dOutput, B=Input, C=dWeight (per CK Tile invoker)
|
|
a_dtype = "OutDataType"
|
|
b_dtype = "InDataType"
|
|
c_dtype = "WeiDataType"
|
|
gemm_k_calc = "args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), args.output_spatial_lengths_.end()"
|
|
direction_prefix = "BWD_WEIGHT"
|
|
launcher_alias = "SelectedConvBwdWeightLauncher"
|
|
else: # Forward
|
|
host_args_type = "GroupedConvFwdHostArgs<>"
|
|
kernel_type = "GroupedConvolutionForwardKernel"
|
|
gemm_traits = "GroupedConvImplicitGemmTraitsFwd"
|
|
layout_suffix = "Fwd"
|
|
a_dtype = "InDataType"
|
|
b_dtype = "WeiDataType"
|
|
c_dtype = "OutDataType"
|
|
gemm_k_calc = "args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end()"
|
|
direction_prefix = "FWD"
|
|
launcher_alias = "SelectedConvKernelLauncher"
|
|
|
|
# Pipeline v1 uses 2-arg TailHandler(Run, has_hot_loop) with 1-arg Run lambda.
|
|
# All other pipelines use 3-arg TailHandler(Run, has_hot_loop, tail_num) with 2-arg Run lambda.
|
|
is_v1_pipeline = tr.pipeline in ("compv1", "basic_v1", "basic_async_v1")
|
|
run_lambda_extra_param = "" if is_v1_pipeline else ", const auto tail_number_"
|
|
tail_handler_extra_arg = "" if is_v1_pipeline else ", tail_num"
|
|
tail_num_decl = "" if is_v1_pipeline else "const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);"
|
|
|
|
# Create valid C++ namespace name
|
|
ns_name = "ns_" + kernel_name.replace("-", "_")
|
|
|
|
# compv1 / basic_v1 / basic_async_v1 inherit BaseGemmPipelineAGmemBGmemCRegV1
|
|
# whose TailHandler takes (run_func, has_hot_loop) and invokes
|
|
# run_func(bool_constant<...>) -- 1 lambda arg. Other pipelines pass
|
|
# (run_func, has_hot_loop, tail_number) and invoke 2-arg run_func.
|
|
if tr.pipeline in ("compv1", "basic_v1", "basic_async_v1"):
|
|
tail_handler_call = "BaseGemmPipeline::TailHandler(Run, has_hot_loop);"
|
|
run_lambda_signature = "[&](const auto has_hot_loop_)"
|
|
else:
|
|
tail_handler_call = (
|
|
"BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);"
|
|
)
|
|
run_lambda_signature = (
|
|
"[&](const auto has_hot_loop_, const auto tail_number_)"
|
|
)
|
|
|
|
return f"""
|
|
// Unique namespace for this kernel to avoid conflicts when including multiple kernels
|
|
namespace {ns_name} {{
|
|
|
|
// Bring Config into namespace
|
|
using Config = {kernel_name}_Config;
|
|
|
|
// Kernel name for identification
|
|
constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = "{kernel_name}";
|
|
|
|
// Selected kernel alias
|
|
using SelectedConv{direction_prefix.title()}Kernel = Config;
|
|
|
|
// =============================================================================
|
|
// Kernel Launch Implementation ({self.variant.value})
|
|
// =============================================================================
|
|
|
|
struct {kernel_name}_Launcher {{
|
|
using KernelConfig = Config; // Use the Config alias from namespace
|
|
using InDataType = typename Config::InDataType;
|
|
using WeiDataType = typename Config::WeiDataType;
|
|
using OutDataType = typename Config::OutDataType;
|
|
using AccDataType = typename Config::AccDataType;
|
|
using InLayout = typename Config::InLayout;
|
|
using WeiLayout = typename Config::WeiLayout;
|
|
using OutLayout = typename Config::OutLayout;
|
|
|
|
static constexpr index_t NDimSpatial = Config::NDimSpatial;
|
|
|
|
// Implicit GEMM shape
|
|
using GemmShape = TileGemmShape<
|
|
sequence<Config::M_Tile, Config::N_Tile, Config::K_Tile>,
|
|
sequence<Config::M_Warp, Config::N_Warp, Config::K_Warp>,
|
|
sequence<Config::M_Warp_Tile, Config::N_Warp_Tile, Config::K_Warp_Tile>>;
|
|
|
|
// Convolution traits
|
|
static constexpr auto ConvSpec = {self._get_conv_specialization(config.trait)};
|
|
using GroupedConvTraitsType = GroupedConvTraits<
|
|
NDimSpatial, ConvSpec, InLayout, WeiLayout, tuple<>, OutLayout,
|
|
Config::VectorSizeA, Config::VectorSizeB, Config::VectorSizeC,
|
|
Config::NumGroupsToMerge, Config::EnableSplitImage, Config::ExplicitGemm>;
|
|
|
|
// Tile partitioner
|
|
using TilePartitioner = GemmSpatiallyLocalTilePartitioner<
|
|
GemmShape,
|
|
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
|
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
|
|
|
// Universal traits - layout suffix changes per variant
|
|
using GemmUniversalTraits = TileGemmUniversalTraits<
|
|
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
|
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
|
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
|
Config::DoubleSmemBuffer,
|
|
typename GroupedConvTraitsType::AsLayout{layout_suffix},
|
|
typename GroupedConvTraitsType::BsLayout{layout_suffix},
|
|
typename GroupedConvTraitsType::CLayout{layout_suffix},
|
|
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
|
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
|
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
|
Config::NumWaveGroups>;
|
|
|
|
// Pipeline problem - data types change per variant
|
|
using GemmPipelineProblem = GemmPipelineProblem<
|
|
{a_dtype}, {b_dtype}, AccDataType, GemmShape,
|
|
typename GroupedConvTraitsType::template {gemm_traits}<Config::NumWaveGroups>,
|
|
{a_dtype}, {b_dtype},
|
|
element_wise::PassThrough, element_wise::PassThrough,
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
|
|
|
|
// Base pipeline for tail handling
|
|
using BaseGemmPipeline = {self._get_base_pipeline(tr.pipeline)}<GemmPipelineProblem>;
|
|
|
|
static float launch(const {host_args_type}& args, const stream_config& s) {{
|
|
const index_t gemm_k = {gemm_k_calc}, 1, std::multiplies<index_t>());
|
|
|
|
const index_t k_grain = args.k_batch * Config::K_Tile;
|
|
const index_t K_split = (gemm_k + k_grain - 1) / k_grain * Config::K_Tile;
|
|
const index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
|
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
|
{tail_num_decl}
|
|
|
|
float ave_time{{0}};
|
|
|
|
constexpr auto scheduler = Config::Scheduler;
|
|
|
|
using UniversalGemmProblem = UniversalGemmPipelineProblem<
|
|
{a_dtype}, {b_dtype}, AccDataType, GemmShape, GemmUniversalTraits,
|
|
scheduler,
|
|
element_wise::PassThrough, element_wise::PassThrough,
|
|
{a_dtype}, {b_dtype},
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
|
|
|
|
using GemmPipeline = {self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")};
|
|
|
|
using ConvEpilogue = CShuffleEpilogue<CShuffleEpilogueProblem<
|
|
{a_dtype}, {b_dtype}, tuple<>, AccDataType, {c_dtype},
|
|
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
|
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
|
element_wise::PassThrough,
|
|
TilePartitioner::MPerBlock, TilePartitioner::NPerBlock,
|
|
Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile,
|
|
Config::N_Warp_Tile, Config::K_Warp_Tile,
|
|
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
|
Config::NumWaveGroups,
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
Config::VectorSizeC, 1, Config::DoubleSmemBuffer>>;
|
|
|
|
using Kernel = {kernel_type}<
|
|
GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>;
|
|
|
|
const auto Run = {run_lambda_signature} {{
|
|
auto kargs = Kernel::MakeKernelArgs(args);
|
|
|
|
if (!Kernel::IsSupportedArgument(kargs)) {{
|
|
throw std::runtime_error("Arguments not supported for grouped conv kernel");
|
|
}}
|
|
|
|
const dim3 grids = Kernel::GridSize(kargs);
|
|
const dim3 blocks = Kernel::BlockSize();
|
|
|
|
{self._get_launch_code()}
|
|
|
|
return ave_time;
|
|
}};
|
|
|
|
{tail_handler_call}
|
|
return ave_time;
|
|
}}
|
|
|
|
{self._get_is_supported_code(config, host_args_type, kernel_type, a_dtype, b_dtype, c_dtype)}
|
|
|
|
{self._get_instance_string_code(config, kernel_type, a_dtype, b_dtype, c_dtype)}
|
|
}};
|
|
|
|
// Launcher alias for tile_engine compatibility
|
|
using {launcher_alias} = {kernel_name}_Launcher;
|
|
|
|
}} // namespace {ns_name}
|
|
|
|
// Export specific launcher to global namespace
|
|
using {kernel_name}_Launcher = {ns_name}::{kernel_name}_Launcher;
|
|
|
|
// When used with -include compiler flag, export aliases to global namespace
|
|
#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE
|
|
using {launcher_alias} = {ns_name}::{launcher_alias};
|
|
constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = {ns_name}::CONV_{direction_prefix}_KERNEL_NAME;
|
|
#endif
|
|
"""
|
|
|
|
# Pipelines that accept GroupedConvUniversalPipelineAgBgCrPolicy
|
|
# as a second template parameter for conv-specific LDS layout.
|
|
# (from conv_configs.hpp PipelineTypeTraits -- basic_v1/mem/compv3)
|
|
# CompV4/V5/V6/comp_async/basic_async_v1 use their own default policies.
|
|
_CONV_POLICY_PIPELINES = {"basic_v1", "basic_v2", "compv1", "compv2", "mem", "compv3"}
|
|
|
|
_SPECIALIZATION_TO_CK = {
|
|
"default": "ConvolutionSpecialization::Default",
|
|
"filter1x1_pad0": "ConvolutionSpecialization::Filter1x1Pad0",
|
|
"filter1x1_stride1_pad0": "ConvolutionSpecialization::Filter1x1Stride1Pad0",
|
|
"filter3x3": "ConvolutionSpecialization::Filter3x3",
|
|
}
|
|
|
|
def _get_conv_specialization(self, trait) -> str:
|
|
"""Get C++ ConvolutionSpecialization enum from trait."""
|
|
spec = getattr(trait, "specialization", "default")
|
|
return self._SPECIALIZATION_TO_CK.get(spec, "ConvolutionSpecialization::Default")
|
|
|
|
def _get_pipeline(self, pipeline: str) -> str:
|
|
"""Get pipeline class name."""
|
|
pipelines = {
|
|
"basic_v1": "GemmPipelineAGmemBGmemCRegV1",
|
|
"basic_v2": "GemmPipelineAGmemBGmemCRegV2",
|
|
"compv1": "GemmPipelineAGmemBGmemCRegV1", # alias
|
|
"compv2": "GemmPipelineAGmemBGmemCRegV2", # alias
|
|
"mem": "GemmPipelineAgBgCrMem",
|
|
"compv3": "GemmPipelineAgBgCrCompV3",
|
|
"compv4": "GemmPipelineAgBgCrCompV4",
|
|
"compv5": "GemmPipelineAgBgCrCompV5",
|
|
"compv6": "GemmPipelineAgBgCrCompV6",
|
|
"comp_async": "GemmPipelineAgBgCrCompAsync",
|
|
"basic_async_v1": "GemmPipelineAGmemBGmemCRegAsyncV1",
|
|
}
|
|
return pipelines.get(pipeline, "GemmPipelineAgBgCrCompV3")
|
|
|
|
def _get_pipeline_template_args(self, pipeline: str, problem_type: str) -> str:
|
|
"""Get full template argument list for pipeline instantiation.
|
|
|
|
For basic_v1/mem/compv3, passes GroupedConvUniversalPipelineAgBgCrPolicy
|
|
as a second template argument for conv-specific LDS banking.
|
|
"""
|
|
base = self._get_pipeline(pipeline)
|
|
if pipeline in self._CONV_POLICY_PIPELINES:
|
|
return f"{base}<{problem_type}, GroupedConvUniversalPipelineAgBgCrPolicy>"
|
|
return f"{base}<{problem_type}>"
|
|
|
|
def _get_base_pipeline(self, pipeline: str) -> str:
|
|
"""Get base pipeline class name (used for tail handling only).
|
|
|
|
Note: basic_async_v1 inherits from BaseGemmPipelineAGmemBGmemCRegV1
|
|
(there is no separate BaseGemmPipelineAGmemBGmemCRegAsyncV1).
|
|
"""
|
|
pipelines = {
|
|
"basic_v1": "BaseGemmPipelineAGmemBGmemCRegV1",
|
|
"basic_v2": "BaseGemmPipelineAGmemBGmemCRegV2",
|
|
"compv1": "BaseGemmPipelineAGmemBGmemCRegV1", # alias
|
|
"compv2": "BaseGemmPipelineAGmemBGmemCRegV2", # alias
|
|
"mem": "BaseGemmPipelineAgBgCrMem",
|
|
"compv3": "BaseGemmPipelineAgBgCrCompV3",
|
|
"compv4": "BaseGemmPipelineAgBgCrCompV4",
|
|
"compv5": "BaseGemmPipelineAgBgCrCompV5",
|
|
"compv6": "BaseGemmPipelineAgBgCrCompV6",
|
|
"comp_async": "BaseGemmPipelineAgBgCrCompAsync",
|
|
"basic_async_v1": "BaseGemmPipelineAGmemBGmemCRegV1",
|
|
}
|
|
return pipelines.get(pipeline, "BaseGemmPipelineAgBgCrCompV3")
|
|
|
|
def _get_is_supported_code(self, config, host_args_type, kernel_type, a_dtype, b_dtype, c_dtype) -> str:
|
|
"""Generate the is_supported() static method for the launcher.
|
|
|
|
Constructs the same Kernel type as launch() and calls
|
|
MakeKernelArgs + IsSupportedArgument without actually launching.
|
|
"""
|
|
tr = config.trait
|
|
pipeline_template = self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")
|
|
|
|
return f"""static bool is_supported(const ck_tile::conv::ConvParam& conv_param, int k_batch) {{
|
|
{host_args_type} args(conv_param,
|
|
nullptr, nullptr, {{}}, nullptr, k_batch);
|
|
|
|
constexpr auto scheduler = Config::Scheduler;
|
|
|
|
using UniversalGemmProblem = UniversalGemmPipelineProblem<
|
|
{a_dtype}, {b_dtype}, AccDataType, GemmShape, GemmUniversalTraits,
|
|
scheduler,
|
|
element_wise::PassThrough, element_wise::PassThrough,
|
|
{a_dtype}, {b_dtype},
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
|
|
|
|
using GemmPipeline = {pipeline_template};
|
|
|
|
using ConvEpilogue = CShuffleEpilogue<CShuffleEpilogueProblem<
|
|
{a_dtype}, {b_dtype}, tuple<>, AccDataType, {c_dtype},
|
|
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
|
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
|
element_wise::PassThrough,
|
|
TilePartitioner::MPerBlock, TilePartitioner::NPerBlock,
|
|
Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile,
|
|
Config::N_Warp_Tile, Config::K_Warp_Tile,
|
|
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
|
Config::NumWaveGroups,
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
Config::VectorSizeC, 1, Config::DoubleSmemBuffer>>;
|
|
|
|
using Kernel = {kernel_type}<
|
|
GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>;
|
|
|
|
auto kargs = Kernel::MakeKernelArgs(args);
|
|
return Kernel::IsSupportedArgument(kargs);
|
|
}}"""
|
|
|
|
def _get_instance_string_code(self, config, kernel_type, a_dtype, b_dtype, c_dtype) -> str:
|
|
"""Generate the get_instance_string() static method for the launcher.
|
|
|
|
Constructs the same Kernel type and calls Kernel{}.GetInstanceString()
|
|
(available when CK_EXPERIMENTAL_BUILDER is defined).
|
|
"""
|
|
tr = config.trait
|
|
pipeline_template = self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")
|
|
|
|
# For two-stage, the epilogue writes to fp32 workspace so VectorSizeC
|
|
# and the E data type differ. The non-two-stage path is the common case.
|
|
return f"""#ifdef CK_EXPERIMENTAL_BUILDER
|
|
static std::string get_instance_string() {{
|
|
constexpr auto scheduler = Config::Scheduler;
|
|
|
|
using UniversalGemmProblem = UniversalGemmPipelineProblem<
|
|
{a_dtype}, {b_dtype}, AccDataType, GemmShape, GemmUniversalTraits,
|
|
scheduler,
|
|
element_wise::PassThrough, element_wise::PassThrough,
|
|
{a_dtype}, {b_dtype},
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
|
|
|
|
using GemmPipeline = {pipeline_template};
|
|
|
|
using ConvEpilogue = CShuffleEpilogue<CShuffleEpilogueProblem<
|
|
{a_dtype}, {b_dtype}, tuple<>, AccDataType, {c_dtype},
|
|
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
|
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
|
element_wise::PassThrough,
|
|
TilePartitioner::MPerBlock, TilePartitioner::NPerBlock,
|
|
Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile,
|
|
Config::N_Warp_Tile, Config::K_Warp_Tile,
|
|
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
|
Config::NumWaveGroups,
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
Config::VectorSizeC, 1, Config::DoubleSmemBuffer>>;
|
|
|
|
using Kernel = {kernel_type}<
|
|
GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>;
|
|
|
|
return Kernel{{}}.GetInstanceString();
|
|
}}
|
|
#endif"""
|
|
|
|
def _get_launch_code(self) -> str:
|
|
"""Generate the kernel launch code for the non-two-stage launcher.
|
|
|
|
For bwd_weight with split-K, we need to zero the output buffer before
|
|
each kernel launch since atomic accumulation is used.
|
|
For bwd_data with split-K, we similarly zero the dX buffer.
|
|
For forward, no zeroing is needed.
|
|
"""
|
|
kernel_launch = (
|
|
"make_kernel<Config::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs)"
|
|
)
|
|
if self.variant == GroupedConvVariant.BACKWARD_WEIGHT:
|
|
return f"""// Compute zeroing size for split-K atomic accumulation
|
|
const std::size_t zeroing_size = std::accumulate(
|
|
std::begin(kargs.wei_g_k_c_xs_lengths.data),
|
|
std::end(kargs.wei_g_k_c_xs_lengths.data),
|
|
std::size_t{{1}}, std::multiplies<std::size_t>());
|
|
auto preprocess = [&]() {{
|
|
if(kargs.k_batch > 1) {{
|
|
hip_check_error(hipMemsetAsync(
|
|
kargs.wei_ptr, 0,
|
|
zeroing_size * sizeof(WeiDataType),
|
|
s.stream_id_));
|
|
}}
|
|
}};
|
|
ave_time = launch_kernel_time_mask(s, preprocess, {kernel_launch});"""
|
|
elif self.variant == GroupedConvVariant.BACKWARD_DATA:
|
|
return f"""// Compute zeroing size for split-K atomic accumulation
|
|
const std::size_t zeroing_size = std::accumulate(
|
|
std::begin(kargs.in_g_n_c_wis_lengths.data),
|
|
std::end(kargs.in_g_n_c_wis_lengths.data),
|
|
std::size_t{{1}}, std::multiplies<std::size_t>());
|
|
auto preprocess = [&]() {{
|
|
hip_check_error(hipMemsetAsync(
|
|
kargs.in_ptr, 0,
|
|
zeroing_size * sizeof(InDataType),
|
|
s.stream_id_));
|
|
}};
|
|
ave_time = launch_kernel_time_mask(s, preprocess, {kernel_launch});"""
|
|
else:
|
|
return f"ave_time = launch_kernel(s, {kernel_launch});"
|
|
|
|
|
|
def _kernel_instance_two_stage(
|
|
self, config: GroupedConvKernelConfig, kernel_name: str
|
|
) -> str:
|
|
"""Generate two-stage bwd_weight kernel: GEMM into fp32 workspace + ElementWise convert.
|
|
|
|
Mirrors grouped_convolution_backward_weight_two_stage_invoker.hpp from
|
|
example/ck_tile/20_grouped_convolution/.
|
|
"""
|
|
tr = config.trait
|
|
ns_name = "ns_" + kernel_name.replace("-", "_")
|
|
direction_prefix = "BWD_WEIGHT"
|
|
launcher_alias = "SelectedConvBwdWeightLauncher"
|
|
|
|
return f"""
|
|
namespace {ns_name} {{
|
|
|
|
using Config = {kernel_name}_Config;
|
|
constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = "{kernel_name}";
|
|
using SelectedConv{direction_prefix.title()}Kernel = Config;
|
|
|
|
struct {kernel_name}_Launcher {{
|
|
using KernelConfig = Config;
|
|
using InDataType = typename Config::InDataType;
|
|
using WeiDataType = typename Config::WeiDataType;
|
|
using OutDataType = typename Config::OutDataType;
|
|
using AccDataType = typename Config::AccDataType;
|
|
using InLayout = typename Config::InLayout;
|
|
using WeiLayout = typename Config::WeiLayout;
|
|
using OutLayout = typename Config::OutLayout;
|
|
using WorkspaceDataType = float;
|
|
|
|
static constexpr index_t NDimSpatial = Config::NDimSpatial;
|
|
|
|
using GemmShape = TileGemmShape<
|
|
sequence<Config::M_Tile, Config::N_Tile, Config::K_Tile>,
|
|
sequence<Config::M_Warp, Config::N_Warp, Config::K_Warp>,
|
|
sequence<Config::M_Warp_Tile, Config::N_Warp_Tile, Config::K_Warp_Tile>>;
|
|
|
|
static constexpr auto ConvSpec = {self._get_conv_specialization(config.trait)};
|
|
using GroupedConvTraitsType = GroupedConvTraits<
|
|
NDimSpatial, ConvSpec, InLayout, WeiLayout, tuple<>, OutLayout,
|
|
Config::VectorSizeA, Config::VectorSizeB, Config::VectorSizeC,
|
|
Config::NumGroupsToMerge, Config::EnableSplitImage, Config::ExplicitGemm>;
|
|
|
|
using TilePartitioner = GemmSpatiallyLocalTilePartitioner<
|
|
GemmShape,
|
|
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
|
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
|
|
|
using GemmUniversalTraits = TileGemmUniversalTraits<
|
|
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
|
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
|
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
|
Config::DoubleSmemBuffer,
|
|
typename GroupedConvTraitsType::AsLayoutBwdWeight,
|
|
typename GroupedConvTraitsType::BsLayoutBwdWeight,
|
|
typename GroupedConvTraitsType::CLayoutBwdWeight,
|
|
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
|
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
|
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
|
Config::NumWaveGroups>;
|
|
|
|
using GemmPipelineProblem = GemmPipelineProblem<
|
|
OutDataType, InDataType, AccDataType, GemmShape,
|
|
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<Config::NumWaveGroups>,
|
|
OutDataType, InDataType,
|
|
element_wise::PassThrough, element_wise::PassThrough,
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
|
|
|
|
using BaseGemmPipeline = {self._get_base_pipeline(tr.pipeline)}<GemmPipelineProblem>;
|
|
|
|
static float launch(const GroupedConvBwdWeightHostArgs& args, const stream_config& s) {{
|
|
float ave_time{{0}};
|
|
|
|
constexpr auto scheduler = Config::Scheduler;
|
|
|
|
using UniversalGemmProblem = UniversalGemmPipelineProblem<
|
|
OutDataType, InDataType, AccDataType, GemmShape, GemmUniversalTraits,
|
|
scheduler,
|
|
element_wise::PassThrough, element_wise::PassThrough,
|
|
OutDataType, InDataType,
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
|
|
|
|
using GemmPipeline = {self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")};
|
|
|
|
// Epilogue writes to fp32 workspace (not fp16 output)
|
|
using ConvEpilogue = CShuffleEpilogue<CShuffleEpilogueProblem<
|
|
OutDataType, InDataType, tuple<>, AccDataType, WorkspaceDataType,
|
|
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
|
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
|
element_wise::PassThrough,
|
|
TilePartitioner::MPerBlock, TilePartitioner::NPerBlock,
|
|
Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile,
|
|
Config::N_Warp_Tile, Config::K_Warp_Tile,
|
|
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
|
Config::NumWaveGroups,
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
GroupedConvTraitsType::VectorSizeC>>;
|
|
|
|
using Kernel = GroupedConvolutionBackwardWeightKernel<
|
|
GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>;
|
|
|
|
// ElementWise kernel: fp32 workspace -> fp16/bf16 output
|
|
using XElementwiseOp = element_wise::UnaryConvert;
|
|
using EwBlockTile = sequence<2048>;
|
|
using EwBlockWarps = sequence<8>;
|
|
using EwWarpTile = sequence<64>;
|
|
using EwShape = ElementWiseShape<EwBlockWarps, EwBlockTile, EwWarpTile, WorkspaceDataType>;
|
|
using EwProblem = ElementWisePipelineProblem<
|
|
WorkspaceDataType, WorkspaceDataType, WeiDataType, EwShape, XElementwiseOp>;
|
|
using EwKernel = ElementWiseKernel<EwProblem, ElementWiseDefaultPolicy>;
|
|
|
|
// Workspace: G * K * C * product(filter_spatial) elements in fp32
|
|
const index_t spatial_accum = std::accumulate(
|
|
args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end(),
|
|
1, std::multiplies<index_t>());
|
|
DeviceMem ws_buf(args.G_ * args.K_ * args.C_ * spatial_accum * sizeof(WorkspaceDataType));
|
|
|
|
GroupedConvBwdWeightHostArgs ws_args(args);
|
|
auto* c_ptr = ws_args.wei_ptr;
|
|
ws_args.wei_ptr = ws_buf.GetDeviceBuffer();
|
|
|
|
auto kargs = Kernel::MakeKernelArgs(ws_args);
|
|
|
|
if(!Kernel::IsSupportedArgument(kargs)) {{
|
|
throw std::runtime_error("Arguments not supported for two-stage bwd_weight kernel");
|
|
}}
|
|
|
|
const dim3 grids = Kernel::GridSize(kargs);
|
|
const dim3 blocks = Kernel::BlockSize();
|
|
|
|
// ElementWise kernel setup
|
|
const index_t ew_block_size = EwKernel::BlockSize();
|
|
const index_t total_elems = args.G_ * args.K_ * args.C_ * spatial_accum;
|
|
constexpr index_t elems_per_block = EwBlockTile::at(number<0>{{}});
|
|
const index_t ew_grid_size = (total_elems + elems_per_block - 1) / elems_per_block;
|
|
|
|
auto ew_shape = make_tuple(args.G_ * args.K_,
|
|
args.C_ * spatial_accum);
|
|
auto ew_inputs = make_tuple(static_cast<WorkspaceDataType*>(ws_args.wei_ptr));
|
|
|
|
if(!EwKernel::IsSupportedArgument(ew_shape)) {{
|
|
throw std::runtime_error("ElementWise arguments not supported for two-stage convert");
|
|
}}
|
|
|
|
auto preprocess = [&]() {{
|
|
if(kargs.k_batch > 1)
|
|
hip_check_error(hipMemsetAsync(
|
|
ws_args.wei_ptr, 0,
|
|
total_elems * sizeof(WorkspaceDataType),
|
|
s.stream_id_));
|
|
}};
|
|
|
|
ave_time = launch_kernel_time_mask(
|
|
s, preprocess,
|
|
make_kernel<Config::kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs),
|
|
make_kernel<Config::kBlockPerCu>(
|
|
EwKernel{{}}, ew_grid_size, ew_block_size, 0,
|
|
ew_shape,
|
|
make_tuple(args.C_ * spatial_accum, 1),
|
|
make_tuple(args.C_ * spatial_accum, 1),
|
|
ew_inputs,
|
|
static_cast<WeiDataType*>(c_ptr)));
|
|
|
|
return ave_time;
|
|
}}
|
|
|
|
static bool is_supported(const ck_tile::conv::ConvParam& conv_param, int k_batch) {{
|
|
GroupedConvBwdWeightHostArgs args(conv_param,
|
|
nullptr, nullptr, {{}}, nullptr, k_batch);
|
|
|
|
constexpr auto scheduler = Config::Scheduler;
|
|
|
|
using UniversalGemmProblem = UniversalGemmPipelineProblem<
|
|
OutDataType, InDataType, AccDataType, GemmShape, GemmUniversalTraits,
|
|
scheduler,
|
|
element_wise::PassThrough, element_wise::PassThrough,
|
|
OutDataType, InDataType,
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
|
|
|
|
using GemmPipeline = {self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")};
|
|
|
|
// Epilogue writes to fp32 workspace (not fp16 output)
|
|
using ConvEpilogue = CShuffleEpilogue<CShuffleEpilogueProblem<
|
|
OutDataType, InDataType, tuple<>, AccDataType, WorkspaceDataType,
|
|
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
|
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
|
element_wise::PassThrough,
|
|
TilePartitioner::MPerBlock, TilePartitioner::NPerBlock,
|
|
Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile,
|
|
Config::N_Warp_Tile, Config::K_Warp_Tile,
|
|
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
|
Config::NumWaveGroups,
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
GroupedConvTraitsType::VectorSizeC>>;
|
|
|
|
using Kernel = GroupedConvolutionBackwardWeightKernel<
|
|
GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>;
|
|
|
|
auto kargs = Kernel::MakeKernelArgs(args);
|
|
return Kernel::IsSupportedArgument(kargs);
|
|
}}
|
|
|
|
#ifdef CK_EXPERIMENTAL_BUILDER
|
|
static std::string get_instance_string() {{
|
|
constexpr auto scheduler = Config::Scheduler;
|
|
|
|
using UniversalGemmProblem = UniversalGemmPipelineProblem<
|
|
OutDataType, InDataType, AccDataType, GemmShape, GemmUniversalTraits,
|
|
scheduler,
|
|
element_wise::PassThrough, element_wise::PassThrough,
|
|
OutDataType, InDataType,
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
|
|
|
|
using GemmPipeline = {self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")};
|
|
|
|
// Two-stage: epilogue writes to fp32 workspace
|
|
using ConvEpilogue = CShuffleEpilogue<CShuffleEpilogueProblem<
|
|
OutDataType, InDataType, tuple<>, AccDataType, WorkspaceDataType,
|
|
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
|
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
|
element_wise::PassThrough,
|
|
TilePartitioner::MPerBlock, TilePartitioner::NPerBlock,
|
|
Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile,
|
|
Config::N_Warp_Tile, Config::K_Warp_Tile,
|
|
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
|
Config::NumWaveGroups,
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
GroupedConvTraitsType::VectorSizeC>>;
|
|
|
|
using Kernel = GroupedConvolutionBackwardWeightKernel<
|
|
GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>;
|
|
|
|
return Kernel{{}}.GetInstanceString();
|
|
}}
|
|
#endif
|
|
}};
|
|
|
|
using {launcher_alias} = {kernel_name}_Launcher;
|
|
|
|
}} // namespace {ns_name}
|
|
|
|
using {kernel_name}_Launcher = {ns_name}::{kernel_name}_Launcher;
|
|
|
|
#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE
|
|
using {launcher_alias} = {ns_name}::{launcher_alias};
|
|
constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = {ns_name}::CONV_{direction_prefix}_KERNEL_NAME;
|
|
#endif
|
|
"""
|
|
|
|
|
|
def _kernel_instance_streamk(
|
|
self, config: GroupedConvKernelConfig, kernel_name: str
|
|
) -> str:
|
|
"""Generate stream-K bwd_weight kernel: StreamKTilePartitioner, workspace-based reduction.
|
|
"""
|
|
tr = config.trait
|
|
sk = tr.streamk_config
|
|
ns_name = "ns_" + kernel_name.replace("-", "_")
|
|
direction_prefix = "BWD_WEIGHT"
|
|
launcher_alias = "SelectedConvBwdWeightLauncher"
|
|
strategy_cpp = f"StreamKReductionStrategy::{sk.strategy.value.capitalize()}"
|
|
persistent_cpp = "true" if sk.streamk_persistent else "false"
|
|
|
|
return f"""
|
|
namespace {ns_name} {{
|
|
|
|
using Config = {kernel_name}_Config;
|
|
constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = "{kernel_name}";
|
|
using SelectedConv{direction_prefix.title()}_Kernel = Config;
|
|
|
|
struct {kernel_name}_Launcher {{
|
|
using KernelConfig = Config;
|
|
using InDataType = typename Config::InDataType;
|
|
using WeiDataType = typename Config::WeiDataType;
|
|
using OutDataType = typename Config::OutDataType;
|
|
using AccDataType = typename Config::AccDataType;
|
|
using InLayout = typename Config::InLayout;
|
|
using WeiLayout = typename Config::WeiLayout;
|
|
using OutLayout = typename Config::OutLayout;
|
|
|
|
static constexpr index_t NDimSpatial = Config::NDimSpatial;
|
|
|
|
using GemmShape = TileGemmShape<
|
|
sequence<Config::M_Tile, Config::N_Tile, Config::K_Tile>,
|
|
sequence<Config::M_Warp, Config::N_Warp, Config::K_Warp>,
|
|
sequence<Config::M_Warp_Tile, Config::N_Warp_Tile, Config::K_Warp_Tile>>;
|
|
|
|
static constexpr auto ConvSpec = {self._get_conv_specialization(config.trait)};
|
|
using GroupedConvTraitsType = GroupedConvTraits<
|
|
NDimSpatial,
|
|
ConvSpec,
|
|
InLayout,
|
|
WeiLayout,
|
|
tuple<>,
|
|
OutLayout,
|
|
Config::VectorSizeA,
|
|
Config::VectorSizeB,
|
|
Config::VectorSizeC,
|
|
Config::NumGroupsToMerge,
|
|
Config::EnableSplitImage,
|
|
Config::ExplicitGemm>;
|
|
|
|
using TilePartitioner = StreamKTilePartitioner<GemmShape, {strategy_cpp}, {persistent_cpp}>;
|
|
|
|
using GemmUniversalTraits = TileGemmUniversalTraits<
|
|
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
|
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
|
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
|
Config::DoubleSmemBuffer,
|
|
typename GroupedConvTraitsType::AsLayoutBwdWeight,
|
|
typename GroupedConvTraitsType::BsLayoutBwdWeight,
|
|
typename GroupedConvTraitsType::CLayoutBwdWeight,
|
|
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
|
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
|
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
|
Config::NumWaveGroups>;
|
|
|
|
using UniversalGemmProblem = UniversalGemmPipelineProblem<
|
|
OutDataType,
|
|
InDataType,
|
|
AccDataType,
|
|
GemmShape,
|
|
GemmUniversalTraits,
|
|
Config::Scheduler,
|
|
element_wise::PassThrough,
|
|
element_wise::PassThrough,
|
|
OutDataType,
|
|
InDataType,
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
GroupedConvTraitsType::VectorSizeA,
|
|
GroupedConvTraitsType::VectorSizeB>;
|
|
|
|
using EpilogueProblem = CShuffleEpilogueProblem<
|
|
OutDataType,
|
|
InDataType,
|
|
tuple<>,
|
|
AccDataType,
|
|
WeiDataType,
|
|
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
|
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
|
element_wise::PassThrough,
|
|
TilePartitioner::MPerBlock,
|
|
TilePartitioner::NPerBlock,
|
|
Config::M_Warp,
|
|
Config::N_Warp,
|
|
Config::M_Warp_Tile,
|
|
Config::N_Warp_Tile,
|
|
Config::K_Warp_Tile,
|
|
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
|
Config::NumWaveGroups,
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
Config::VectorSizeC>;
|
|
|
|
using ConvEpilogue =
|
|
std::conditional_t<Config::Pipeline == ck_tile::GemmPipeline::COMPUTE_TDM_V1 ||
|
|
Config::Pipeline == ck_tile::GemmPipeline::COMPUTE_TDM_V2,
|
|
ck_tile::TdmEpilogue<EpilogueProblem>,
|
|
ck_tile::CShuffleEpilogue<EpilogueProblem>>;
|
|
|
|
using GemmPipeline = {self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")};
|
|
|
|
using Kernel = GroupedConvolutionBackwardWeightKernel<
|
|
GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>;
|
|
|
|
static float launch(const GroupedConvBwdWeightHostArgs& args, const stream_config& s) {{
|
|
float ave_time{{0}};
|
|
|
|
auto kargs = Kernel::MakeKernelArgs(args);
|
|
|
|
if (!Kernel::IsSupportedArgument(kargs)) {{
|
|
throw std::runtime_error("Arguments not supported for stream-K bwd_weight kernel");
|
|
}}
|
|
|
|
// Stream-K workspace allocation
|
|
auto ws_size = Kernel::GetWorkSpaceSize(kargs);
|
|
DeviceMem workspace_dev(ws_size);
|
|
Kernel::SetWorkSpacePointer(kargs, workspace_dev.GetDeviceBuffer());
|
|
|
|
const dim3 grids = Kernel::GridSize(kargs);
|
|
const dim3 blocks = Kernel::BlockSize();
|
|
|
|
auto preprocess = [&]() {{
|
|
// Stream-K: zero workspace flags before each kernel launch
|
|
if(ws_size > 0) {{
|
|
hip_check_error(
|
|
hipMemsetAsync(workspace_dev.GetDeviceBuffer(), 0, ws_size, s.stream_id_));
|
|
}}
|
|
}};
|
|
|
|
ave_time = launch_kernel_time_mask(
|
|
s, preprocess,
|
|
make_kernel<Config::kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
|
|
|
|
return ave_time;
|
|
}}
|
|
|
|
static bool is_supported(const ck_tile::conv::ConvParam& conv_param, int k_batch) {{
|
|
GroupedConvBwdWeightHostArgs args(conv_param, nullptr, nullptr, {{}}, nullptr, k_batch);
|
|
|
|
auto kargs = Kernel::MakeKernelArgs(args);
|
|
return Kernel::IsSupportedArgument(kargs);
|
|
}}
|
|
|
|
#ifdef CK_EXPERIMENTAL_BUILDER
|
|
static std::string get_instance_string() {{
|
|
return Kernel{{}}.GetInstanceString();
|
|
}}
|
|
#endif
|
|
}};
|
|
|
|
using {launcher_alias} = {kernel_name}_Launcher;
|
|
|
|
}} // namespace {ns_name}
|
|
|
|
using {kernel_name}_Launcher = {ns_name}::{kernel_name}_Launcher;
|
|
|
|
#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE
|
|
using {launcher_alias} = {ns_name}::{launcher_alias};
|
|
constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = {ns_name}::CONV_{direction_prefix}_KERNEL_NAME;
|
|
#endif
|
|
"""
|
|
|
|
|
|
# ============================================================================
|
|
# CK Tile Depthwise Conv Kernel Generator
|
|
# ============================================================================
|
|
|
|
|
|
class CKTileDepthwiseConvKernelGenerator:
|
|
"""Generates CK Tile depthwise convolution kernel instance code.
|
|
"""
|
|
|
|
DTYPE_TO_CK = {
|
|
"fp16": "half_t",
|
|
"bf16": "bf16_t",
|
|
"fp32": "float",
|
|
}
|
|
|
|
def __init__(self, datatype: str):
|
|
self.datatype = datatype
|
|
|
|
def generate(self, config: DepthwiseConvKernelConfig) -> str:
|
|
"""Generate complete depthwise convolution kernel header."""
|
|
kernel_name = config.name(self.datatype)
|
|
return f"""{self._header(kernel_name)}
|
|
{self._config_and_types(config, kernel_name)}
|
|
{self._launcher(config, kernel_name)}
|
|
"""
|
|
|
|
def _header(self, kernel_name: str) -> str:
|
|
return f"""// SPDX-License-Identifier: MIT
|
|
// Auto-generated CK Tile Depthwise Convolution kernel: {kernel_name}
|
|
// Variant: forward_depthwise
|
|
#pragma once
|
|
|
|
#include <cstdint>
|
|
#include <numeric>
|
|
#include <functional>
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/host/kernel_launch.hpp"
|
|
#include "ck_tile/ops/gemm.hpp"
|
|
#include "ck_tile/ops/grouped_convolution.hpp"
|
|
#include "ck_tile/ops/grouped_convolution/pipeline/grouped_convolution_forward_depthwise_pipeline.hpp"
|
|
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
|
|
|
using namespace ck_tile;
|
|
"""
|
|
|
|
def _config_and_types(self, config: DepthwiseConvKernelConfig, kernel_name: str) -> str:
|
|
ck_dtype = self.DTYPE_TO_CK[self.datatype]
|
|
c = config
|
|
return f"""
|
|
// Kernel configuration and type definitions
|
|
namespace ns_{kernel_name} {{
|
|
|
|
using InDataType = {ck_dtype};
|
|
using WeiDataType = {ck_dtype};
|
|
using AccDataType = float;
|
|
using OutDataType = {ck_dtype};
|
|
|
|
// Depthwise convolution traits
|
|
using DwTraits = DepthwiseConvFwdTraits<
|
|
InDataType, WeiDataType, AccDataType, OutDataType,
|
|
{c.block_size}, // BlockSize
|
|
{c.tile_h}, // TileH
|
|
{c.tile_w}, // TileW
|
|
{c.filt}, // FilterH
|
|
{c.filt}, // FilterW
|
|
{c.str_h}, // StrideH
|
|
{c.str_w}, // StrideW
|
|
{c.dil_h}, // DilationH
|
|
{c.dil_w}, // DilationW
|
|
{c.pad_h}, // PadH
|
|
{c.pad_w}, // PadW
|
|
{c.nbatch}, // NBatch
|
|
{c.sub_h}, // SubTileH
|
|
{c.sub_w}, // SubTileW
|
|
{c.in_vec}, // InVec
|
|
{c.out_vec}>; // OutVec
|
|
|
|
// Depthwise pipeline
|
|
using DwPipeline = DepthwiseConvFwdPipeline<DwTraits>;
|
|
|
|
// Grouped convolution traits (depthwise specialization)
|
|
using ConvTraitsType = GroupedConvTraits<
|
|
{c.ndim_spatial}, // NDimSpatial
|
|
ConvolutionSpecialization::Default, // ConvSpec
|
|
void, // InLayout (unused for depthwise)
|
|
void, // WeiLayout (unused for depthwise)
|
|
tuple<>, // DsLayout
|
|
void, // OutLayout (unused for depthwise)
|
|
{c.in_vec}, // VectorSizeA
|
|
{c.in_vec}, // VectorSizeB
|
|
{c.out_vec}, // VectorSizeC
|
|
1, // NumGroupsToMerge
|
|
false, // EnableSplitImage
|
|
false, // ExplicitGemm
|
|
DwTraits>; // DepthwiseTraits
|
|
|
|
// Null epilogue for depthwise (no shuffle needed)
|
|
struct DepthwiseNullEpilogue {{
|
|
using DsLayout = tuple<>;
|
|
using DsDataType = tuple<>;
|
|
using ODataType = OutDataType;
|
|
using AccDataType = float;
|
|
using CDElementwise = element_wise::PassThrough;
|
|
}};
|
|
|
|
// Complete kernel type
|
|
using Kernel = GroupedConvolutionForwardKernel<
|
|
ConvTraitsType, void, DwPipeline, DepthwiseNullEpilogue>;
|
|
"""
|
|
|
|
def _launcher(self, config: DepthwiseConvKernelConfig, kernel_name: str) -> str:
|
|
ns_name = f"ns_{kernel_name}"
|
|
return f"""
|
|
constexpr const char* CONV_FWD_KERNEL_NAME = "{kernel_name}";
|
|
|
|
struct {kernel_name}_Launcher {{
|
|
using KernelConfig = DwTraits;
|
|
using InDataType = {ns_name}::InDataType;
|
|
using WeiDataType = {ns_name}::WeiDataType;
|
|
using OutDataType = {ns_name}::OutDataType;
|
|
using AccDataType = {ns_name}::AccDataType;
|
|
|
|
static constexpr index_t NDimSpatial = {config.ndim_spatial};
|
|
|
|
static float launch(const GroupedConvFwdHostArgs<>& args, const stream_config& s) {{
|
|
auto kargs = Kernel::MakeKernelArgs(args);
|
|
|
|
if (!Kernel::IsSupportedArgument(kargs)) {{
|
|
throw std::runtime_error("Arguments not supported for depthwise conv kernel");
|
|
}}
|
|
|
|
const dim3 grids = Kernel::GridSize(kargs);
|
|
const dim3 blocks = Kernel::BlockSize();
|
|
|
|
float ave_time = launch_kernel(s, make_kernel(Kernel{{}}, grids, blocks, 0, kargs));
|
|
return ave_time;
|
|
}}
|
|
|
|
static bool is_supported(const ck_tile::conv::ConvParam& conv_param, int k_batch) {{
|
|
GroupedConvFwdHostArgs<> args(conv_param,
|
|
nullptr, nullptr, {{}}, nullptr, k_batch);
|
|
|
|
auto kargs = Kernel::MakeKernelArgs(args);
|
|
return Kernel::IsSupportedArgument(kargs);
|
|
}}
|
|
|
|
#ifdef CK_EXPERIMENTAL_BUILDER
|
|
static std::string get_instance_string() {{
|
|
return Kernel{{}}.GetInstanceString();
|
|
}}
|
|
#endif
|
|
}};
|
|
|
|
using SelectedConvKernelLauncher = {kernel_name}_Launcher;
|
|
|
|
}} // namespace {ns_name}
|
|
|
|
using {kernel_name}_Launcher = {ns_name}::{kernel_name}_Launcher;
|
|
|
|
#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE
|
|
using SelectedConvKernelLauncher = {ns_name}::SelectedConvKernelLauncher;
|
|
constexpr const char* CONV_FWD_KERNEL_NAME = {ns_name}::CONV_FWD_KERNEL_NAME;
|
|
#endif
|
|
"""
|
|
|
|
|
|
# ============================================================================
|
|
# Dispatcher Wrapper Generator
|
|
# ============================================================================
|
|
|
|
|
|
class GroupedConvDispatcherWrapperGenerator:
|
|
"""Generates dispatcher integration wrapper following GEMM pattern"""
|
|
|
|
# Static mappings for pipeline and scheduler enum names (matches kernel_key.hpp)
|
|
PIPELINE_TO_DISPATCHER = {
|
|
"mem": "Pipeline::Mem",
|
|
"compv1": "Pipeline::CompV1",
|
|
"compv2": "Pipeline::CompV2",
|
|
"basic_v1": "Pipeline::CompV1",
|
|
"basic_v2": "Pipeline::CompV2",
|
|
"compv3": "Pipeline::CompV3",
|
|
"compv4": "Pipeline::CompV4",
|
|
"compv5": "Pipeline::CompV5",
|
|
"compv6": "Pipeline::CompV6",
|
|
"preshufflev1": "Pipeline::PreShuffleV1",
|
|
"preshufflev2": "Pipeline::PreShuffleV2",
|
|
}
|
|
|
|
SCHEDULER_TO_DISPATCHER = {
|
|
"default": "Scheduler::Default",
|
|
"intrawave": "Scheduler::Intrawave",
|
|
"interwave": "Scheduler::Interwave",
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
datatype: str,
|
|
variant: GroupedConvVariant = GroupedConvVariant.FORWARD,
|
|
):
|
|
self.datatype = datatype
|
|
self.variant = variant
|
|
|
|
def _pipeline_to_dispatcher(self, pipeline: str) -> str:
|
|
"""Convert pipeline string to dispatcher enum value"""
|
|
return self.PIPELINE_TO_DISPATCHER.get(
|
|
pipeline.lower(), f"Pipeline::{pipeline.capitalize()}"
|
|
)
|
|
|
|
def _scheduler_to_dispatcher(self, scheduler: str) -> str:
|
|
"""Convert scheduler string to dispatcher enum value"""
|
|
return self.SCHEDULER_TO_DISPATCHER.get(
|
|
scheduler.lower(), f"Scheduler::{scheduler.capitalize()}"
|
|
)
|
|
|
|
# Map datatype string to dispatcher DataType enum
|
|
DTYPE_TO_DISPATCHER = {
|
|
"fp16": "DataType::FP16",
|
|
"bf16": "DataType::BF16",
|
|
"fp32": "DataType::FP32",
|
|
}
|
|
|
|
def generate(
|
|
self,
|
|
config: Union[GroupedConvKernelConfig, DepthwiseConvKernelConfig],
|
|
kernel_path: Path,
|
|
output_dir: Path,
|
|
) -> str:
|
|
"""Generate dispatcher wrapper with factory function for registry."""
|
|
kernel_name = config.name(self.datatype)
|
|
rel_path = kernel_path.relative_to(output_dir)
|
|
is_depthwise = isinstance(config, DepthwiseConvKernelConfig)
|
|
|
|
dtype_enum = self.DTYPE_TO_DISPATCHER.get(self.datatype, "DataType::FP16")
|
|
|
|
# Determine variant-specific fields
|
|
if is_depthwise or self.variant == GroupedConvVariant.FORWARD:
|
|
launcher_alias = "SelectedConvKernelLauncher"
|
|
host_args_type = "GroupedConvFwdHostArgs<>"
|
|
conv_type_str = "forward"
|
|
elif self.variant == GroupedConvVariant.BACKWARD_DATA:
|
|
launcher_alias = "SelectedConvBwdDataLauncher"
|
|
host_args_type = "GroupedConvBwdDataHostArgs"
|
|
conv_type_str = "bwd_data"
|
|
else: # BACKWARD_WEIGHT
|
|
launcher_alias = "SelectedConvBwdWeightLauncher"
|
|
host_args_type = "GroupedConvBwdWeightHostArgs"
|
|
conv_type_str = "bwd_weight"
|
|
|
|
layout = config.layout if is_depthwise else "nhwgc"
|
|
|
|
# Algorithm key fields differ between implicit GEMM and depthwise algorithms
|
|
if is_depthwise:
|
|
algorithm_spec = """ // Depthwise kernels have no GEMM tile parameters
|
|
key.algorithm.tile_shape = {0, 0, 0};
|
|
key.algorithm.wave_shape = {0, 0, 0};
|
|
key.algorithm.warp_tile_shape = {0, 0, 0};
|
|
key.algorithm.epilogue = Epilogue::None;"""
|
|
else:
|
|
algorithm_spec = f""" key.algorithm.tile_shape = {{{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k}}};
|
|
key.algorithm.wave_shape = {{{config.tile.warp_m}, {config.tile.warp_n}, 1}};
|
|
key.algorithm.warp_tile_shape = {{{config.tile.warp_tile_m}, {config.tile.warp_tile_n}, {config.tile.warp_tile_k}}};
|
|
key.algorithm.pipeline = {self._pipeline_to_dispatcher(config.trait.pipeline)};
|
|
key.algorithm.scheduler = {self._scheduler_to_dispatcher(config.trait.scheduler)};
|
|
key.algorithm.epilogue = Epilogue::CShuffle;"""
|
|
|
|
return f"""// SPDX-License-Identifier: MIT
|
|
// Auto-generated dispatcher wrapper for: {kernel_name}
|
|
#pragma once
|
|
|
|
#include "ck_tile/dispatcher.hpp"
|
|
#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
|
|
#include "../{rel_path}"
|
|
|
|
namespace ck_tile {{
|
|
namespace dispatcher {{
|
|
namespace generated {{
|
|
|
|
using ::ck_tile::dispatcher::GroupedConvKernelInstancePtr;
|
|
using ::ck_tile::dispatcher::GroupedConvKernelKey;
|
|
using ::ck_tile::dispatcher::DataType;
|
|
using ::ck_tile::dispatcher::LayoutTag;
|
|
using ::ck_tile::dispatcher::Pipeline;
|
|
using ::ck_tile::dispatcher::Scheduler;
|
|
using ::ck_tile::dispatcher::Epilogue;
|
|
using Priority = ::ck_tile::dispatcher::GroupedConvRegistry::Priority;
|
|
|
|
// Factory function to create kernel instance for registry
|
|
inline GroupedConvKernelInstancePtr make_{kernel_name}(const std::string& gfx_arch = "gfx942") {{
|
|
GroupedConvKernelKey key;
|
|
key.signature.dtype_in = {dtype_enum};
|
|
key.signature.dtype_wei = {dtype_enum};
|
|
key.signature.dtype_out = {dtype_enum};
|
|
key.signature.dtype_acc = DataType::FP32;
|
|
key.signature.layout = "{layout}";
|
|
key.signature.conv_type = "{conv_type_str}";
|
|
key.signature.num_dims = {config.ndim_spatial};
|
|
key.signature.groups = 1;
|
|
|
|
{algorithm_spec}
|
|
key.gfx_arch = gfx_arch;
|
|
|
|
// Create kernel instance that wraps the launcher
|
|
return std::make_shared<GroupedConvKernelInstance>(
|
|
key,
|
|
"{kernel_name}",
|
|
[]({host_args_type}& args, const stream_config& cfg) -> float {{
|
|
return {kernel_name}_Launcher::launch(args, cfg);
|
|
}}
|
|
);
|
|
}}
|
|
|
|
}} // namespace generated
|
|
}} // namespace dispatcher
|
|
}} // namespace ck_tile
|
|
|
|
// Export launcher alias to global namespace for direct use
|
|
using {launcher_alias} = {kernel_name}_Launcher;
|
|
"""
|
|
|
|
|
|
# ============================================================================
|
|
# Configuration Parser
|
|
# ============================================================================
|
|
|
|
|
|
def load_depthwise_configs_from_json(
|
|
data: dict,
|
|
arch: str = "gfx942",
|
|
instance_id: Optional[int] = None,
|
|
) -> List[DepthwiseConvKernelConfig]:
|
|
"""Load depthwise convolution configs from parsed JSON data.
|
|
|
|
Args:
|
|
data: Parsed JSON config data
|
|
arch: Target GPU architecture
|
|
instance_id: If specified, load only the instance with this ID
|
|
|
|
Returns:
|
|
List of DepthwiseConvKernelConfig objects
|
|
"""
|
|
ndim_spatial = data["ndim_spatial"]
|
|
layout = data["layout"]
|
|
datatype = data["datatype"]
|
|
|
|
instances = data["instances"]
|
|
if instance_id is not None:
|
|
instances = [inst for inst in instances if inst["id"] == instance_id]
|
|
if not instances:
|
|
raise ValueError(f"Instance ID {instance_id} not found in depthwise config")
|
|
|
|
configs = []
|
|
for inst in instances:
|
|
config = DepthwiseConvKernelConfig(
|
|
tile_h=inst["tile_h"],
|
|
tile_w=inst["tile_w"],
|
|
filt=inst["filt"],
|
|
str_h=inst["str_h"],
|
|
str_w=inst["str_w"],
|
|
pad_h=inst["pad_h"],
|
|
pad_w=inst["pad_w"],
|
|
nbatch=inst["nbatch"],
|
|
sub_h=inst["sub_h"],
|
|
sub_w=inst["sub_w"],
|
|
in_vec=inst["in_vec"],
|
|
out_vec=inst["out_vec"],
|
|
ndim_spatial=ndim_spatial,
|
|
arch=arch,
|
|
layout=layout,
|
|
datatype=datatype,
|
|
)
|
|
configs.append(config)
|
|
|
|
log.info(
|
|
f"Loaded {len(configs)} depthwise configs "
|
|
f"(layout={layout}, dtype={datatype})"
|
|
)
|
|
return configs
|
|
|
|
|
|
def load_configs_from_json(
|
|
config_path: Path,
|
|
arch: str = "gfx942",
|
|
instance_id: Optional[int] = None,
|
|
) -> List[Union[GroupedConvKernelConfig, DepthwiseConvKernelConfig]]:
|
|
"""Load kernel configurations from a JSON config file.
|
|
|
|
Args:
|
|
config_path: Path to JSON config file
|
|
arch: Target GPU architecture
|
|
instance_id: If specified, load only the instance with this ID
|
|
|
|
Returns:
|
|
List of GroupedConvKernelConfig objects
|
|
"""
|
|
with open(config_path, "r") as f:
|
|
data = json.load(f)
|
|
|
|
variant_map = {
|
|
"forward": GroupedConvVariant.FORWARD,
|
|
"fwd": GroupedConvVariant.FORWARD,
|
|
"forward_depthwise": GroupedConvVariant.FORWARD_DEPTHWISE,
|
|
"bwd_data": GroupedConvVariant.BACKWARD_DATA,
|
|
"bwd_weight": GroupedConvVariant.BACKWARD_WEIGHT,
|
|
}
|
|
variant = variant_map.get(data["variant"])
|
|
if variant is None:
|
|
raise ValueError(f"Unknown variant: {data['variant']}")
|
|
|
|
if variant == GroupedConvVariant.FORWARD_DEPTHWISE:
|
|
return load_depthwise_configs_from_json(data, arch, instance_id)
|
|
|
|
ndim_spatial = data["ndim_spatial"]
|
|
layout = data["layout"]
|
|
datatype = data["datatype"]
|
|
|
|
instances = data["instances"]
|
|
if instance_id is not None:
|
|
instances = [inst for inst in instances if inst["id"] == instance_id]
|
|
if not instances:
|
|
raise ValueError(f"Instance ID {instance_id} not found in {config_path}")
|
|
|
|
configs = []
|
|
for inst in instances:
|
|
# Map specialization to pipeline constraints
|
|
# Specializations like filter1x1_stride1_pad0 don't change the pipeline config
|
|
# but are tracked in the trait for kernel naming and runtime checks
|
|
|
|
trait = GroupedConvTraitConfig(
|
|
pipeline=inst["pipeline"],
|
|
scheduler=inst["scheduler"],
|
|
epilogue=inst["epilogue"],
|
|
pad_m=True,
|
|
pad_n=True,
|
|
pad_k=True,
|
|
double_smem_buffer=inst.get("double_smem_buffer", False),
|
|
num_groups_to_merge=inst.get("num_groups_to_merge", 1),
|
|
split_image=inst.get("split_image", False),
|
|
explicit_gemm=inst.get("explicit_gemm", False),
|
|
two_stage=inst.get("two_stage", False),
|
|
specialization=inst.get("specialization", "default"),
|
|
streamk_config=StreamKConfig(
|
|
streamk_enabled=inst.get("streamk_enabled", False),
|
|
strategy=StreamKReductionStrategy(inst.get("streamk_reduction_strategy", "TREE")),
|
|
streamk_persistent=inst.get("streamk_persistent", False)
|
|
) if inst.get("streamk_enabled", False) else StreamKConfig()
|
|
)
|
|
|
|
# compv2/basic_v2 (GemmPipelineAGmemBGmemCRegV2) is not compatible with
|
|
# CK Tile's GroupedConvolutionBackwardWeightKernel. The builder maps
|
|
# PipelineVersion::V2 to GemmPipelineAgBgCrMem (i.e. "mem"), not to
|
|
# GemmPipelineAGmemBGmemCRegV2. Skip if any config somehow has compv2.
|
|
if variant == GroupedConvVariant.BACKWARD_WEIGHT and trait.pipeline in ("compv2", "basic_v2"):
|
|
log.info(f"Skipping instance {inst['id']}: compv2/basic_v2 pipeline not compatible with CK Tile bwd_weight")
|
|
continue
|
|
|
|
config = GroupedConvKernelConfig(
|
|
tile=TileConfig(
|
|
tile_m=inst["tile_m"],
|
|
tile_n=inst["tile_n"],
|
|
tile_k=inst["tile_k"],
|
|
warp_m=inst["warp_m"],
|
|
warp_n=inst["warp_n"],
|
|
warp_k=inst["warp_k"],
|
|
warp_tile_m=inst["warp_tile_m"],
|
|
warp_tile_n=inst["warp_tile_n"],
|
|
warp_tile_k=inst["warp_tile_k"],
|
|
),
|
|
trait=trait,
|
|
variant=variant,
|
|
ndim_spatial=ndim_spatial,
|
|
arch=arch,
|
|
layout=layout,
|
|
vector_size_a=inst["vector_size_a"],
|
|
vector_size_b=inst["vector_size_b"],
|
|
vector_size_c=inst["vector_size_c"],
|
|
num_wave_groups=inst.get("num_wave_groups", 1),
|
|
)
|
|
configs.append(config)
|
|
|
|
log.info(
|
|
f"Loaded {len(configs)} configs from {config_path} "
|
|
f"(variant={data['variant']}, layout={layout}, dtype={datatype})"
|
|
)
|
|
return configs
|
|
|
|
|
|
def get_default_configs(
|
|
arch: str = "gfx942",
|
|
variants: Optional[List[GroupedConvVariant]] = None,
|
|
ndims: Optional[List[int]] = None,
|
|
) -> List[GroupedConvKernelConfig]:
|
|
"""Get default grouped convolution configurations for target architecture.
|
|
|
|
Uses tile configurations from grouped_conv_instance_builder.py as single source of truth.
|
|
"""
|
|
configs = []
|
|
|
|
if variants is None:
|
|
variants = [GroupedConvVariant.FORWARD]
|
|
if ndims is None:
|
|
ndims = [2]
|
|
|
|
# Import tile configs from instance builder (single source of truth)
|
|
if not HAS_TILE_CONFIGS or not COMMON_TILES:
|
|
log.warning("grouped_config_rules not available, using fallback tile configs")
|
|
# Fallback to minimal set if grouped_config_rules unavailable
|
|
fwd_bwd_data_tiles = [
|
|
(128, 128, 32, 2, 2, 32, 32, 16),
|
|
(64, 64, 32, 1, 4, 16, 16, 16),
|
|
(16, 64, 64, 1, 4, 16, 16, 32),
|
|
]
|
|
bwd_weight_tiles = [(16, 64, 64, 1, 4, 16, 16, 32)]
|
|
else:
|
|
# Build tile list from COMMON_TILES with wave/warp mappings
|
|
fwd_bwd_data_tiles = []
|
|
for tile_m, tile_n, tile_k in COMMON_TILES:
|
|
tile_key = (tile_m, tile_n, tile_k)
|
|
if tile_key in TILE_TO_WAVE and tile_key in TILE_TO_WARP:
|
|
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
|
|
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
|
|
fwd_bwd_data_tiles.append(
|
|
(tile_m, tile_n, tile_k, wave_m, wave_n, warp_m, warp_n, warp_k)
|
|
)
|
|
|
|
# Backward weight: use BWD_WEIGHT_TILES from config rules
|
|
bwd_weight_tiles = []
|
|
for tile_m, tile_n, tile_k in BWD_WEIGHT_TILES:
|
|
tile_key = (tile_m, tile_n, tile_k)
|
|
if tile_key in TILE_TO_WAVE and tile_key in TILE_TO_WARP:
|
|
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
|
|
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
|
|
bwd_weight_tiles.append(
|
|
(tile_m, tile_n, tile_k, wave_m, wave_n, warp_m, warp_n, warp_k)
|
|
)
|
|
|
|
for variant in variants:
|
|
# Select tile configs based on variant
|
|
if variant == GroupedConvVariant.BACKWARD_WEIGHT:
|
|
tile_configs = bwd_weight_tiles
|
|
# Backward weight supports compv3 and mem pipelines
|
|
# (compv4/compv5 have transpose_tile2d issues)
|
|
pipelines = [("compv3", "cshuffle"), ("mem", "default")]
|
|
# Also generate two-stage variants (fp32 workspace + elementwise convert)
|
|
two_stage_flags = [False, True]
|
|
elif variant == GroupedConvVariant.BACKWARD_DATA:
|
|
tile_configs = fwd_bwd_data_tiles
|
|
# Backward data supports compv3 and mem pipelines
|
|
# (compv4/compv5 have get_length issues in bwd_data kernel)
|
|
pipelines = [("compv3", "cshuffle"), ("mem", "default")]
|
|
two_stage_flags = [False]
|
|
else:
|
|
tile_configs = fwd_bwd_data_tiles
|
|
# Only forward grouped convolution supports both compv3 and compv4
|
|
pipelines = [("compv3", "cshuffle"), ("compv4", "cshuffle")]
|
|
two_stage_flags = [False]
|
|
for ndim in ndims:
|
|
for pipeline, epilogue in pipelines:
|
|
for (
|
|
tile_m,
|
|
tile_n,
|
|
tile_k,
|
|
warp_m,
|
|
warp_n,
|
|
warp_tile_m,
|
|
warp_tile_n,
|
|
warp_tile_k,
|
|
) in tile_configs:
|
|
# Skip tiles incompatible with compv4
|
|
if pipeline == "compv4" and HAS_TILE_CONFIGS:
|
|
tile_key = (tile_m, tile_n, tile_k)
|
|
if tile_key not in COMPV4_COMPATIBLE_TILES:
|
|
continue # Skip this tile for compv4
|
|
|
|
for two_stage in two_stage_flags:
|
|
adj_tile_k = tile_k * 2 if pipeline == "compv4" else tile_k
|
|
|
|
trait = GroupedConvTraitConfig(
|
|
pipeline=pipeline,
|
|
scheduler="intrawave",
|
|
epilogue=epilogue,
|
|
double_smem_buffer=(pipeline == "compv4"),
|
|
pad_m=True,
|
|
pad_n=True,
|
|
pad_k=True,
|
|
two_stage=two_stage,
|
|
)
|
|
|
|
if not trait.is_valid():
|
|
continue
|
|
|
|
config = GroupedConvKernelConfig(
|
|
tile=TileConfig(
|
|
tile_m=tile_m,
|
|
tile_n=tile_n,
|
|
tile_k=adj_tile_k,
|
|
warp_m=warp_m,
|
|
warp_n=warp_n,
|
|
warp_k=1,
|
|
warp_tile_m=warp_tile_m,
|
|
warp_tile_n=warp_tile_n,
|
|
warp_tile_k=warp_tile_k,
|
|
),
|
|
trait=trait,
|
|
variant=variant,
|
|
ndim_spatial=ndim,
|
|
arch=arch,
|
|
)
|
|
|
|
if config.is_valid_for_arch():
|
|
configs.append(config)
|
|
|
|
return configs
|
|
|
|
|
|
def get_arch_filter():
|
|
"""Get arch filter if available"""
|
|
try:
|
|
from arch_filter import ArchFilter
|
|
|
|
return ArchFilter
|
|
except ImportError:
|
|
return None
|
|
|
|
|
|
# ============================================================================
|
|
# Main Generator
|
|
# ============================================================================
|
|
|
|
|
|
class _GenItem:
|
|
"""Item for parallel generation with progress logging."""
|
|
|
|
def __init__(
|
|
self,
|
|
idx: int,
|
|
total: int,
|
|
config: Union[GroupedConvKernelConfig, DepthwiseConvKernelConfig],
|
|
datatype: str,
|
|
variant: GroupedConvVariant,
|
|
):
|
|
self.idx = idx
|
|
self.total = total
|
|
self.config = config
|
|
self.datatype = datatype
|
|
self.variant = variant
|
|
|
|
def __str__(self) -> str:
|
|
return f"kernel {self.idx}/{self.total}: {self.config.name(self.datatype)}"
|
|
|
|
|
|
class UnifiedGroupedConvCodegen:
|
|
"""Main grouped convolution code generator"""
|
|
|
|
def __init__(
|
|
self,
|
|
output_dir: Path,
|
|
gpu_target: str = "gfx942",
|
|
datatype: str = "fp16",
|
|
ndim_spatial: int = 2,
|
|
enable_arch_filter: bool = True,
|
|
):
|
|
self.output_dir = output_dir
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Create wrapper directory for dispatcher integration
|
|
self.wrapper_dir = self.output_dir / "dispatcher_wrappers"
|
|
self.wrapper_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
self.generated_files: List[Path] = []
|
|
self.generated_wrappers: List[Path] = []
|
|
self.gpu_target = gpu_target
|
|
self.datatype = datatype
|
|
self.ndim_spatial = ndim_spatial
|
|
|
|
# Initialize architecture filter for GPU-specific validation
|
|
self.arch_filter = None
|
|
if enable_arch_filter and HAS_ARCH_FILTER:
|
|
try:
|
|
self.arch_filter = ArchFilter(gpu_target, strict_mode=False)
|
|
log.info(f"Architecture filter enabled for {gpu_target}")
|
|
except ValueError as e:
|
|
log.warning(f"Could not create arch filter: {e}")
|
|
|
|
def _get_configs(self) -> List[GroupedConvKernelConfig]:
|
|
"""Get configurations for this codegen's datatype and ndim_spatial."""
|
|
return get_default_configs(
|
|
arch=self.gpu_target,
|
|
variants=[
|
|
GroupedConvVariant.FORWARD,
|
|
GroupedConvVariant.BACKWARD_DATA,
|
|
GroupedConvVariant.BACKWARD_WEIGHT,
|
|
],
|
|
ndims=[self.ndim_spatial],
|
|
)
|
|
|
|
def _get_operator_type(
|
|
self, variant: GroupedConvVariant
|
|
) -> Optional["OperatorType"]:
|
|
"""Map GroupedConvVariant to OperatorType for arch validation"""
|
|
if OperatorType is None:
|
|
return None
|
|
|
|
variant_to_operator = {
|
|
GroupedConvVariant.FORWARD: OperatorType.CONV_FWD,
|
|
GroupedConvVariant.BACKWARD_DATA: OperatorType.CONV_BWD_DATA,
|
|
GroupedConvVariant.BACKWARD_WEIGHT: OperatorType.CONV_BWD_WEIGHT,
|
|
}
|
|
return variant_to_operator.get(variant, OperatorType.CONV_FWD)
|
|
|
|
def is_config_valid(
|
|
self, config: GroupedConvKernelConfig, datatype: str = "fp16"
|
|
) -> bool:
|
|
"""Validate configuration against architecture constraints"""
|
|
if not self.arch_filter or not HAS_ARCH_FILTER:
|
|
return True
|
|
|
|
operator = self._get_operator_type(config.variant)
|
|
|
|
return self.arch_filter.is_kernel_valid(
|
|
datatype_a=datatype,
|
|
datatype_b=datatype,
|
|
datatype_c=datatype,
|
|
tile_m=config.tile.tile_m,
|
|
tile_n=config.tile.tile_n,
|
|
tile_k=config.tile.tile_k,
|
|
warp_m=config.tile.warp_m,
|
|
warp_n=config.tile.warp_n,
|
|
warp_k=1, # Grouped conv typically uses warp_k=1
|
|
warp_tile_m=config.tile.warp_tile_m,
|
|
warp_tile_n=config.tile.warp_tile_n,
|
|
warp_tile_k=config.tile.warp_tile_k,
|
|
pipeline=config.trait.pipeline,
|
|
epilogue=config.trait.epilogue,
|
|
scheduler=config.trait.scheduler,
|
|
operator=operator,
|
|
)
|
|
|
|
def generate_kernel(
|
|
self,
|
|
config: Union[GroupedConvKernelConfig, DepthwiseConvKernelConfig],
|
|
datatype: str,
|
|
variant: GroupedConvVariant = GroupedConvVariant.FORWARD,
|
|
) -> Tuple[Path, Path]:
|
|
"""Generate a single kernel file and dispatcher wrapper. Returns (kernel_path, wrapper_path)."""
|
|
if isinstance(config, DepthwiseConvKernelConfig):
|
|
kernel_gen = CKTileDepthwiseConvKernelGenerator(datatype)
|
|
# Depthwise kernels are forward-only, use the forward wrapper generator
|
|
wrapper_gen = GroupedConvDispatcherWrapperGenerator(datatype, GroupedConvVariant.FORWARD)
|
|
else:
|
|
kernel_gen = CKTileGroupedConvKernelGenerator(datatype, variant)
|
|
wrapper_gen = GroupedConvDispatcherWrapperGenerator(datatype, variant)
|
|
|
|
kernel_name = config.name(datatype)
|
|
filename = f"{kernel_name}.hpp"
|
|
filepath = self.output_dir / filename
|
|
|
|
# Generate kernel header
|
|
content = kernel_gen.generate(config)
|
|
filepath.write_text(content)
|
|
self.generated_files.append(filepath)
|
|
|
|
wrapper_content = wrapper_gen.generate(config, filepath, self.output_dir)
|
|
wrapper_path = self.wrapper_dir / f"dispatcher_wrapper_{kernel_name}.hpp"
|
|
wrapper_path.write_text(wrapper_content)
|
|
self.generated_wrappers.append(wrapper_path)
|
|
|
|
# Generate .cpp compilation unit for per-kernel parallel builds
|
|
cpp_filename = f"{kernel_name}.cpp"
|
|
cpp_filepath = self.output_dir / cpp_filename
|
|
cpp_content = f"""// SPDX-License-Identifier: MIT
|
|
// Auto-generated compilation unit for: {kernel_name}
|
|
// Enables per-kernel parallel compilation with make -j
|
|
|
|
#include "{filename}"
|
|
|
|
namespace ck_tile {{ namespace generated {{
|
|
volatile bool _{kernel_name.replace("-", "_")}_loaded = true;
|
|
}} }}
|
|
"""
|
|
cpp_filepath.write_text(cpp_content)
|
|
|
|
return filepath, wrapper_path
|
|
|
|
def _generate_single_kernel(self, item: _GenItem):
|
|
"""Generate one kernel (used by parallel_generate). Returns (kernel_path, wrapper_path) or raises."""
|
|
kernel_path, wrapper_path = self.generate_kernel(
|
|
item.config, item.datatype, item.variant
|
|
)
|
|
log.info(
|
|
"Generated kernel %d/%d: %s",
|
|
item.idx,
|
|
item.total,
|
|
item.config.name(item.datatype),
|
|
)
|
|
return (kernel_path, wrapper_path)
|
|
|
|
def generate_all(
|
|
self,
|
|
configs: Optional[List[Union[GroupedConvKernelConfig, DepthwiseConvKernelConfig]]] = None,
|
|
datatypes: Optional[List[str]] = None,
|
|
parallel: bool = True,
|
|
) -> dict:
|
|
"""Generate all kernel files (optionally in parallel).
|
|
|
|
Configs are filtered using architecture validation before generation.
|
|
Returns dict with keys: kernels, wrappers, failed.
|
|
"""
|
|
if configs is None:
|
|
configs = self._get_configs()
|
|
if datatypes is None:
|
|
datatypes = [self.datatype]
|
|
|
|
results = {"kernels": [], "wrappers": [], "failed": []}
|
|
|
|
# Filter configs using arch validation
|
|
valid_tasks = []
|
|
rejected_count = 0
|
|
|
|
for datatype in datatypes:
|
|
for config in configs:
|
|
if isinstance(config, DepthwiseConvKernelConfig):
|
|
# Depthwise configs skip arch filter validation (not applicable)
|
|
valid_tasks.append((config, datatype, GroupedConvVariant.FORWARD_DEPTHWISE))
|
|
elif self.is_config_valid(config, datatype):
|
|
valid_tasks.append((config, datatype, config.variant))
|
|
else:
|
|
rejected_count += 1
|
|
log.debug(
|
|
f"Rejected config for {self.gpu_target}: "
|
|
f"{config.tile.tile_m}x{config.tile.tile_n}x{config.tile.tile_k} "
|
|
f"variant={config.variant.value}"
|
|
)
|
|
|
|
if rejected_count > 0:
|
|
log.info(
|
|
f"Filtered {rejected_count} configs for {self.gpu_target}, "
|
|
f"{len(valid_tasks)} remaining"
|
|
)
|
|
|
|
total = len(valid_tasks)
|
|
items = [
|
|
_GenItem(i, total, config, datatype, variant)
|
|
for i, (config, datatype, variant) in enumerate(valid_tasks)
|
|
]
|
|
|
|
def _safe_generate(item: _GenItem):
|
|
"""Wrapper that catches exceptions for failure tracking."""
|
|
try:
|
|
k, w = self._generate_single_kernel(item)
|
|
return ("ok", k, w, None)
|
|
except Exception as e:
|
|
return ("fail", None, None, str(e))
|
|
|
|
raw = parallel_generate(
|
|
_safe_generate, items, parallel=parallel and len(items) > 1
|
|
)
|
|
for r in raw:
|
|
if r[0] == "ok":
|
|
results["kernels"].append(r[1])
|
|
results["wrappers"].append(r[2])
|
|
else:
|
|
results["failed"].append(r[3])
|
|
log.error("Failed: %s", r[3])
|
|
|
|
# Generate include_all_*.hpp headers for Python ctypes libraries
|
|
if results["wrappers"]:
|
|
self._generate_include_all_headers()
|
|
|
|
return results
|
|
|
|
def _generate_include_all_headers(self):
|
|
"""Generate include_all_grouped_conv_*.hpp headers and registration header"""
|
|
# Scan output directory for ALL kernel files (not just this run's generated_files)
|
|
# This handles the case where fwd and bwd kernels are generated in separate make targets
|
|
fwd_headers = []
|
|
bwd_data_headers = []
|
|
bwd_weight_headers = []
|
|
fwd_kernels = []
|
|
bwd_data_kernels = []
|
|
bwd_weight_kernels = []
|
|
|
|
for filepath in self.output_dir.glob("grouped_conv_*.hpp"):
|
|
name = filepath.name
|
|
kernel_name = name[:-4]
|
|
if name.startswith("grouped_conv_fwd_"):
|
|
fwd_headers.append(name)
|
|
fwd_kernels.append(kernel_name)
|
|
elif name.startswith(("grouped_conv_bwd_data_", "grouped_conv_bwdd_")):
|
|
bwd_data_headers.append(name)
|
|
bwd_data_kernels.append(kernel_name)
|
|
elif name.startswith(("grouped_conv_bwd_weight_", "grouped_conv_bwdw_")):
|
|
bwd_weight_headers.append(name)
|
|
bwd_weight_kernels.append(kernel_name)
|
|
|
|
headers_to_generate = [
|
|
("include_all_grouped_conv_fwd_kernels.hpp", fwd_headers, "forward"),
|
|
(
|
|
"include_all_grouped_conv_bwd_data_kernels.hpp",
|
|
bwd_data_headers,
|
|
"backward data",
|
|
),
|
|
(
|
|
"include_all_grouped_conv_bwd_weight_kernels.hpp",
|
|
bwd_weight_headers,
|
|
"backward weight",
|
|
),
|
|
]
|
|
|
|
for header_name, kernel_headers, variant_desc in headers_to_generate:
|
|
header_path = self.output_dir / header_name
|
|
includes = "\n".join(f'#include "{h}"' for h in sorted(kernel_headers))
|
|
|
|
# Pick the first kernel as the default Selected*Launcher
|
|
if kernel_headers:
|
|
first_kernel = sorted(kernel_headers)[0][:-4] # Remove .hpp
|
|
if variant_desc == "forward":
|
|
launcher_alias = (
|
|
f"using SelectedConvKernelLauncher = {first_kernel}_Launcher;"
|
|
)
|
|
elif variant_desc == "backward data":
|
|
launcher_alias = (
|
|
f"using SelectedConvBwdDataLauncher = {first_kernel}_Launcher;"
|
|
)
|
|
else: # backward weight
|
|
launcher_alias = f"using SelectedConvBwdWeightLauncher = {first_kernel}_Launcher;"
|
|
else:
|
|
launcher_alias = "// No kernels generated for this variant"
|
|
|
|
content = f"""// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
// Auto-generated header for grouped conv {variant_desc} kernels
|
|
#pragma once
|
|
|
|
{includes}
|
|
|
|
// Default launcher alias (uses first kernel)
|
|
{launcher_alias}
|
|
"""
|
|
header_path.write_text(content)
|
|
if kernel_headers:
|
|
log.info(f"Generated: {header_name} ({len(kernel_headers)} kernels)")
|
|
|
|
# Generate registration header (following GEMM pattern)
|
|
self._generate_registration_header(
|
|
fwd_kernels, bwd_data_kernels, bwd_weight_kernels
|
|
)
|
|
|
|
def _generate_registration_header(
|
|
self,
|
|
fwd_kernels: List[str],
|
|
bwd_data_kernels: List[str],
|
|
bwd_weight_kernels: List[str],
|
|
):
|
|
"""Generate master registration header for all grouped conv kernels"""
|
|
# Scan wrapper directory for ALL wrapper files
|
|
all_wrappers = []
|
|
for wrapper_path in self.wrapper_dir.glob(
|
|
"dispatcher_wrapper_grouped_conv_*.hpp"
|
|
):
|
|
all_wrappers.append(wrapper_path.name)
|
|
|
|
wrapper_includes = "\n".join(f'#include "{w}"' for w in sorted(all_wrappers))
|
|
|
|
# Generate registration calls
|
|
fwd_registrations = "\n ".join(
|
|
f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);"
|
|
for k in sorted(fwd_kernels)
|
|
)
|
|
bwd_data_registrations = "\n ".join(
|
|
f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);"
|
|
for k in sorted(bwd_data_kernels)
|
|
)
|
|
bwd_weight_registrations = "\n ".join(
|
|
f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);"
|
|
for k in sorted(bwd_weight_kernels)
|
|
)
|
|
|
|
content = f"""// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
// Auto-generated master registration header for grouped conv kernels
|
|
#pragma once
|
|
|
|
#include "ck_tile/dispatcher.hpp"
|
|
#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
|
|
|
|
{wrapper_includes}
|
|
|
|
namespace ck_tile {{
|
|
namespace dispatcher {{
|
|
|
|
using Priority = GroupedConvRegistry::Priority;
|
|
|
|
inline void register_all_grouped_conv_fwd_kernels(
|
|
const std::string& gfx_arch = "gfx942",
|
|
Priority priority = Priority::Normal)
|
|
{{
|
|
auto& registry = GroupedConvRegistry::instance();
|
|
{fwd_registrations if fwd_registrations else "// No forward kernels"}
|
|
}}
|
|
|
|
inline void register_all_grouped_conv_bwd_data_kernels(
|
|
const std::string& gfx_arch = "gfx942",
|
|
Priority priority = Priority::Normal)
|
|
{{
|
|
auto& registry = GroupedConvRegistry::instance();
|
|
{bwd_data_registrations if bwd_data_registrations else "// No backward data kernels"}
|
|
}}
|
|
|
|
inline void register_all_grouped_conv_bwd_weight_kernels(
|
|
const std::string& gfx_arch = "gfx942",
|
|
Priority priority = Priority::Normal)
|
|
{{
|
|
auto& registry = GroupedConvRegistry::instance();
|
|
{bwd_weight_registrations if bwd_weight_registrations else "// No backward weight kernels"}
|
|
}}
|
|
|
|
inline void register_all_grouped_conv_kernels(
|
|
const std::string& gfx_arch = "gfx942",
|
|
Priority priority = Priority::Normal)
|
|
{{
|
|
register_all_grouped_conv_fwd_kernels(gfx_arch, priority);
|
|
register_all_grouped_conv_bwd_data_kernels(gfx_arch, priority);
|
|
register_all_grouped_conv_bwd_weight_kernels(gfx_arch, priority);
|
|
}}
|
|
|
|
inline std::size_t get_grouped_conv_fwd_kernel_count() {{ return {len(fwd_kernels)}; }}
|
|
inline std::size_t get_grouped_conv_bwd_data_kernel_count() {{ return {len(bwd_data_kernels)}; }}
|
|
inline std::size_t get_grouped_conv_bwd_weight_kernel_count() {{ return {len(bwd_weight_kernels)}; }}
|
|
inline std::size_t get_grouped_conv_kernel_count() {{ return {len(fwd_kernels) + len(bwd_data_kernels) + len(bwd_weight_kernels)}; }}
|
|
|
|
}} // namespace dispatcher
|
|
}} // namespace ck_tile
|
|
"""
|
|
reg_path = self.wrapper_dir / "register_all_grouped_conv_kernels.hpp"
|
|
reg_path.write_text(content)
|
|
log.info(f"Generated registration header: {reg_path}")
|
|
|
|
|
|
# ============================================================================
|
|
# CLI
|
|
# ============================================================================
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Unified Grouped Convolution Code Generator"
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
"-o",
|
|
type=Path,
|
|
default=Path("build/generated_kernels"),
|
|
help="Output directory",
|
|
)
|
|
parser.add_argument(
|
|
"--datatype",
|
|
"-d",
|
|
type=str,
|
|
nargs="+",
|
|
default=["fp16"],
|
|
choices=["fp16", "bf16", "fp32"],
|
|
help="Data types to generate",
|
|
)
|
|
parser.add_argument(
|
|
"--variant",
|
|
"-v",
|
|
type=str,
|
|
nargs="+",
|
|
default=["forward"],
|
|
choices=["forward", "bwd_data", "bwd_weight"],
|
|
help="Grouped convolution variants",
|
|
)
|
|
parser.add_argument(
|
|
"--ndim",
|
|
"-n",
|
|
type=int,
|
|
nargs="+",
|
|
default=[2],
|
|
choices=[1, 2, 3],
|
|
help="Spatial dimensions",
|
|
)
|
|
parser.add_argument(
|
|
"--arch",
|
|
"-a",
|
|
type=str,
|
|
default="gfx942",
|
|
choices=["gfx90a", "gfx942", "gfx950", "gfx1201"],
|
|
help="Target GPU architecture",
|
|
)
|
|
parser.add_argument("--verbose", action="store_true", help="Verbose output")
|
|
parser.add_argument(
|
|
"--list-configs",
|
|
action="store_true",
|
|
help="List configurations without generating",
|
|
)
|
|
|
|
# JSON config file
|
|
parser.add_argument(
|
|
"--config-file",
|
|
type=Path,
|
|
default=None,
|
|
help="Path to JSON config file. "
|
|
"Overrides --variant, --ndim, and individual tile/pipeline args.",
|
|
)
|
|
parser.add_argument(
|
|
"--instance-id",
|
|
type=int,
|
|
default=None,
|
|
help="Generate only the instance with this ID from the config file. "
|
|
"Requires --config-file.",
|
|
)
|
|
|
|
# Individual kernel configuration (when not using predefined configs)
|
|
parser.add_argument("--tile-m", type=int, help="Block tile M dimension")
|
|
parser.add_argument("--tile-n", type=int, help="Block tile N dimension")
|
|
parser.add_argument("--tile-k", type=int, help="Block tile K dimension")
|
|
parser.add_argument("--warp-m", type=int, help="Wave distribution M")
|
|
parser.add_argument("--warp-n", type=int, help="Wave distribution N")
|
|
parser.add_argument("--warp-k", type=int, default=1, help="Wave distribution K")
|
|
parser.add_argument("--warp-tile-m", type=int, help="Warp tile M")
|
|
parser.add_argument("--warp-tile-n", type=int, help="Warp tile N")
|
|
parser.add_argument("--warp-tile-k", type=int, default=16, help="Warp tile K")
|
|
parser.add_argument(
|
|
"--pipeline",
|
|
type=str,
|
|
choices=[
|
|
"basic_v1",
|
|
"basic_async_v1",
|
|
"mem",
|
|
"compv3",
|
|
"compv4",
|
|
"compv5",
|
|
"compv6",
|
|
"comp_async",
|
|
],
|
|
help="Pipeline type",
|
|
)
|
|
parser.add_argument(
|
|
"--scheduler",
|
|
type=str,
|
|
choices=["intrawave", "interwave"],
|
|
help="Scheduler type",
|
|
)
|
|
parser.add_argument(
|
|
"--epilogue",
|
|
type=str,
|
|
default="cshuffle",
|
|
choices=["cshuffle", "default"],
|
|
help="Epilogue type",
|
|
)
|
|
parser.add_argument("--pad-m", type=bool, default=True, help="Pad M dimension")
|
|
parser.add_argument("--pad-n", type=bool, default=True, help="Pad N dimension")
|
|
parser.add_argument("--pad-k", type=bool, default=True, help="Pad K dimension")
|
|
parser.add_argument("--vector-a", type=int, default=4, help="Vector size A")
|
|
parser.add_argument("--vector-b", type=int, default=8, help="Vector size B")
|
|
parser.add_argument("--vector-c", type=int, default=8, help="Vector size C")
|
|
parser.add_argument("--num-wave-groups", type=int, default=1, help="Wave groups")
|
|
parser.add_argument("--num-groups-to-merge", type=int, default=1, help="Groups to merge")
|
|
parser.add_argument(
|
|
"--double-smem-buffer",
|
|
type=str,
|
|
default=None,
|
|
help="Double SMEM buffer (true/false)",
|
|
)
|
|
parser.add_argument(
|
|
"--split-image",
|
|
action="store_true",
|
|
help="Enable split-image (EnableSplitImage) for large spatial tensors",
|
|
)
|
|
parser.add_argument(
|
|
"--two-stage",
|
|
action="store_true",
|
|
help="Enable two-stage bwd_weight (fp32 workspace + elementwise convert)",
|
|
)
|
|
parser.add_argument(
|
|
"--explicit-gemm",
|
|
action="store_true",
|
|
help="Enable explicit GEMM",
|
|
)
|
|
parser.add_argument(
|
|
"--streamk-enabled",
|
|
action="store_true",
|
|
help="Use StreamK for grouped convolution kernels",
|
|
)
|
|
parser.add_argument(
|
|
"--streamk-reduction-strategy",
|
|
type=str,
|
|
choices=["TREE", "LINEAR"],
|
|
default=None,
|
|
help="Reduction strategy for Stream-K",
|
|
)
|
|
parser.add_argument(
|
|
"--streamk-persistent",
|
|
action="store_true",
|
|
help="Use persistent threads for Stream-K",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.verbose:
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
|
|
# Map variant strings to enums
|
|
variant_map = {
|
|
"forward": GroupedConvVariant.FORWARD,
|
|
"bwd_data": GroupedConvVariant.BACKWARD_DATA,
|
|
"bwd_weight": GroupedConvVariant.BACKWARD_WEIGHT,
|
|
}
|
|
requested_variants = [variant_map[v] for v in args.variant]
|
|
|
|
# Validate --instance-id requires --config-file
|
|
if args.instance_id is not None and args.config_file is None:
|
|
parser.error("--instance-id requires --config-file")
|
|
|
|
# Check if user specified a JSON config file
|
|
if args.config_file is not None:
|
|
filtered_configs = load_configs_from_json(
|
|
args.config_file, arch=args.arch, instance_id=args.instance_id
|
|
)
|
|
# Extract datatype from JSON config for code generation
|
|
with open(args.config_file, "r") as f:
|
|
config_data = json.load(f)
|
|
args.datatype = [config_data["datatype"]]
|
|
elif args.tile_m is not None or args.tile_n is not None or args.pipeline is not None:
|
|
# Build custom config from CLI arguments
|
|
tile = TileConfig(
|
|
tile_m=args.tile_m or 128,
|
|
tile_n=args.tile_n or 128,
|
|
tile_k=args.tile_k or 64,
|
|
warp_m=args.warp_m or 2,
|
|
warp_n=args.warp_n or 2,
|
|
warp_k=args.warp_k or 1,
|
|
warp_tile_m=args.warp_tile_m or 32,
|
|
warp_tile_n=args.warp_tile_n or 32,
|
|
warp_tile_k=args.warp_tile_k or 16,
|
|
)
|
|
pipeline = args.pipeline or "compv4"
|
|
# Determine double_smem_buffer: use CLI arg if given, else default based on pipeline
|
|
if args.double_smem_buffer is not None:
|
|
dsb = args.double_smem_buffer.lower() == "true"
|
|
else:
|
|
# Historical default: only compv4 auto-defaults to dsb=true.
|
|
# Other pipelines that also require DoubleSmemBuffer (e.g. comp_async)
|
|
# must be told explicitly via --double-smem-buffer true; otherwise
|
|
# they will fail loudly at the pipeline header static_assert. This
|
|
# is intentional -- silent fallback to a different config would
|
|
# mask the user's input.
|
|
dsb = pipeline == "compv4"
|
|
|
|
trait = GroupedConvTraitConfig(
|
|
pipeline=pipeline,
|
|
scheduler=args.scheduler or "intrawave",
|
|
epilogue=args.epilogue or "cshuffle",
|
|
pad_m=args.pad_m,
|
|
pad_n=args.pad_n,
|
|
pad_k=args.pad_k,
|
|
double_smem_buffer=dsb,
|
|
num_groups_to_merge=args.num_groups_to_merge,
|
|
split_image=args.split_image,
|
|
two_stage=args.two_stage,
|
|
explicit_gemm=args.explicit_gemm,
|
|
streamk_config=StreamKConfig(
|
|
streamk_enabled=args.streamk_enabled,
|
|
strategy=StreamKReductionStrategy(args.streamk_reduction_strategy or "TREE"),
|
|
streamk_persistent=args.streamk_persistent,
|
|
) if args.streamk_enabled else StreamKConfig()
|
|
)
|
|
config = GroupedConvKernelConfig(
|
|
tile=tile,
|
|
trait=trait,
|
|
variant=requested_variants[0]
|
|
if requested_variants
|
|
else GroupedConvVariant.FORWARD,
|
|
ndim_spatial=args.ndim[0] if args.ndim else 2,
|
|
arch=args.arch,
|
|
vector_size_a=args.vector_a,
|
|
vector_size_b=args.vector_b,
|
|
vector_size_c=args.vector_c,
|
|
num_wave_groups=args.num_wave_groups,
|
|
)
|
|
filtered_configs = [config]
|
|
else:
|
|
# Get predefined configurations for target arch with requested variants and ndims
|
|
filtered_configs = get_default_configs(
|
|
arch=args.arch, variants=requested_variants, ndims=args.ndim
|
|
)
|
|
|
|
if args.list_configs:
|
|
print(f"Grouped convolution configurations for {args.arch}:")
|
|
print(f" Datatypes: {args.datatype}")
|
|
print(f" Variants: {args.variant}")
|
|
print(f" Spatial dims: {args.ndim}")
|
|
print(f"\nConfigurations ({len(filtered_configs)}):")
|
|
for cfg in filtered_configs:
|
|
# List configs for each requested datatype (fixes bf16 -> fp16 bug)
|
|
for dt in args.datatype:
|
|
print(f" - {cfg.name(dt)}")
|
|
if isinstance(cfg, DepthwiseConvKernelConfig):
|
|
print(f" Depthwise: tile={cfg.tile_h}x{cfg.tile_w}, filter={cfg.filt}")
|
|
print(f" Stride: {cfg.str_h}x{cfg.str_w}, Pad: {cfg.pad_h}x{cfg.pad_w}")
|
|
print(f" NBatch: {cfg.nbatch}, Sub: {cfg.sub_h}x{cfg.sub_w}")
|
|
print(f" Vec: in={cfg.in_vec}, out={cfg.out_vec}")
|
|
else:
|
|
print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}")
|
|
print(f" Warp: {cfg.tile.warp_m}x{cfg.tile.warp_n}x{cfg.tile.warp_k}")
|
|
print(
|
|
f" WarpTile: {cfg.tile.warp_tile_m}x{cfg.tile.warp_tile_n}x{cfg.tile.warp_tile_k}"
|
|
)
|
|
print(
|
|
f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}, Scheduler: {cfg.trait.scheduler}"
|
|
)
|
|
print(
|
|
f" Padding: M={cfg.trait.pad_m}, N={cfg.trait.pad_n}, K={cfg.trait.pad_k}"
|
|
)
|
|
return
|
|
|
|
# Generate (disable arch filter when using pre-validated JSON configs)
|
|
codegen = UnifiedGroupedConvCodegen(
|
|
output_dir=args.output,
|
|
gpu_target=args.arch,
|
|
enable_arch_filter=(args.config_file is None),
|
|
)
|
|
results = codegen.generate_all(
|
|
configs=filtered_configs, datatypes=args.datatype, parallel=True
|
|
)
|
|
|
|
print(
|
|
f"\nGenerated {len(results['kernels'])} grouped convolution kernel files "
|
|
f"for {args.arch} in {args.output}"
|
|
)
|
|
if results["failed"]:
|
|
print(f" Failed: {len(results['failed'])}")
|
|
for err in results["failed"][:5]:
|
|
print(f" - {err}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|