Files
composable_kernel/dispatcher/codegen/unified_grouped_conv_codegen.py
Yaswanth Raparti 6989cf800c [rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for
 grouped conv (#6327)

## Motivation
The ML heuristic in dispatcher does not support grouped-conv operator
yet. In this PR, the support for fwd, bdw-data, and bwd-weight
grouped-conv kernels have been added. A tile_engine utility has also
been added to compile and run any selected kernel configuration through
dispatcher infrastructure.

## Technical Details

1. Tile engine utility is added to benchmark each shape with all the
possible kernel+tile_size combinations here -
[https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url)
2. New LGBM regressor models for grouped conv are added to models
directory. We have 3 separate models for fwd, bwd-data, and bwd-weights
[https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url)
3. Implemented lazy GPU initialization (dispatcher/python)
- **Issue**: ProcessPoolExecutor fork() + GPU context caused memory
access faults
- **Solution**: Mirror FMHA pattern - defer GPU initialization until
first run()
  - **Changes**:
- setup_multiple_grouped_conv_dispatchers() returns List[Path], not
loaded libs
    - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL
    - Added _ensure_initialized() method for lazy GPU loading
    - GPU context created only on first run() call
  - **Benefit**: Parallel compilation now works without GPU conflicts
4. Addressed few miscellaneous issues such as:
  - Fixed BF16->FP16 naming bug in the dispatcher wrapper
- Added new tile sizes, and comp_v5 pipeline to the arch spec to expand
the kernel selection
- Added automatic padding support for unsupported shapes in dispatcher
runner
- Created a single source of truth between tile_engine and dispatcher
about the architecture and tile_size details
- Build a validation scripts to compare oracle_best vs ml_heuristic
comparison

## Test Plan

1. Validated fwd, bwd-data, and bwd-weight kernels with both known and
unseen data sets with up to 300 problems.
2. Ensured that test cases are added in both dispatcher and tile_engine
to validate the heuristic.

## Test Result
Results on Unseen shapes validated on gfx950
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.05%**
  - Median Efficiency: **96.8%**
  - P10 Efficiency: **79.9%**

#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.8%**
  - Median Efficiency: **96.5%**
  - P10 Efficiency: **82.9%**

#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **96.1%**
  - Median Efficiency: **99.2%**
  - P10 Efficiency: **89.4%**

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-08 20:48:42 +00:00

1845 lines
71 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Unified Grouped Convolution Code Generator
This is the unified code generator for all grouped convolution kernel variants:
- Forward grouped convolution
- Backward data grouped convolution
- Backward weight grouped convolution
Generates both CK Tile kernels AND dispatcher wrappers.
Based on the GEMM codegen pattern.
"""
import argparse
import logging
from pathlib import Path
from typing import List, Optional, Tuple, Union
from dataclasses import dataclass
from enum import Enum
from codegen_common import (
TileConfig,
TraitConfigBase,
parallel_generate,
)
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
log = logging.getLogger(__name__)
# Import architecture filter for GPU-specific validation
try:
from arch_filter import ArchFilter, OperatorType
HAS_ARCH_FILTER = True
except ImportError:
HAS_ARCH_FILTER = False
ArchFilter = None
OperatorType = None
# Import tile configurations from grouped_config_rules (single source of truth)
try:
from grouped_config_rules import (
COMMON_TILES,
TILE_TO_WAVE,
TILE_TO_WARP,
VARIANT_PIPELINES,
BWD_WEIGHT_TILES,
COMPV4_COMPATIBLE_TILES,
)
HAS_TILE_CONFIGS = True
except ImportError:
HAS_TILE_CONFIGS = False
COMMON_TILES = []
TILE_TO_WAVE = {}
TILE_TO_WARP = {}
VARIANT_PIPELINES = {}
BWD_WEIGHT_TILES = []
COMPV4_COMPATIBLE_TILES = []
# ============================================================================
# Configuration and Data Structures
# ============================================================================
class GroupedConvVariant(Enum):
"""Grouped convolution kernel variants"""
FORWARD = "forward"
BACKWARD_DATA = "bwd_data"
BACKWARD_WEIGHT = "bwd_weight"
class GroupedConvLayout(Enum):
"""Grouped convolution data layouts"""
# 1D
NWGC = "NWGC" # Input/Output: N W G C
GKXC = "GKXC" # Weight: G K X C
NWGK = "NWGK" # Output: N W G K
# 2D
NHWGC = "NHWGC" # Input: N H W G C
GKYXC = "GKYXC" # Weight: G K Y X C
NHWGK = "NHWGK" # Output: N H W G K
# 3D
NDHWGC = "NDHWGC" # Input: N D H W G C
GKZYXC = "GKZYXC" # Weight: G K Z Y X C
NDHWGK = "NDHWGK" # Output: N D H W G K
@dataclass
class GroupedConvTraitConfig(TraitConfigBase):
"""Kernel trait configuration for grouped convolution (extends TraitConfigBase).
Conv-specific extensions beyond TraitConfigBase. These map to
GroupedConvTraits template parameters in grouped_convolution_utils.hpp:
- double_smem_buffer: ping-pong LDS for compute V4+ pipelines
- num_groups_to_merge: fuse multiple groups into one tile (NumGroupsToMerge)
- split_image: split spatial dims for large tensors (EnableSplitImage)
- explicit_gemm: use explicit GEMM path (ExplicitGemm)
- two_stage: two-stage bwd_weight with fp32 workspace + elementwise convert
Note: CK Tile already uses long_index_t (64-bit) for group strides and
batch offsets, so there is no separate "large_tensor" flag. For large
spatial dimensions, use split_image=True instead.
"""
double_smem_buffer: bool = False
num_groups_to_merge: int = 1
split_image: bool = False
explicit_gemm: bool = False
two_stage: bool = False
# Backward compatibility alias
TraitConfig = GroupedConvTraitConfig
@dataclass
class GroupedConvKernelConfig:
"""Complete grouped convolution kernel configuration"""
tile: TileConfig
trait: GroupedConvTraitConfig
variant: GroupedConvVariant = GroupedConvVariant.FORWARD
ndim_spatial: int = 2 # 1D, 2D, or 3D
arch: str = "gfx942" # Target architecture
layout: Union[str, GroupedConvLayout] = (
"nhwgc" # Data layout (e.g., "nhwgc", "ndhwgc")
)
# Vector sizes: a=4 for fp16 input (8-byte aligned global loads),
# b=8 for weight tensor, c=8 for output stores. These match the
# CK Tile default vectorization widths for fp16 on CDNA3 (gfx942).
vector_size_a: int = 4
vector_size_b: int = 8
vector_size_c: int = 8
vector_sizes: Optional[Tuple[int, int, int]] = None
# Occupancy parameters
block_per_cu: int = 1
num_wave_groups: int = 1
num_groups_to_merge: int = 1
# Double buffering
double_smem_buffer: bool = False
def __post_init__(self):
if self.vector_sizes is not None:
self.vector_size_a, self.vector_size_b, self.vector_size_c = (
self.vector_sizes[:3]
)
# Sync trait fields with top-level fields (trait is source of truth
# when both are specified, but top-level overrides default trait values).
if self.double_smem_buffer and not self.trait.double_smem_buffer:
self.trait.double_smem_buffer = self.double_smem_buffer
elif self.trait.double_smem_buffer:
self.double_smem_buffer = self.trait.double_smem_buffer
if self.num_groups_to_merge != 1 and self.trait.num_groups_to_merge == 1:
self.trait.num_groups_to_merge = self.num_groups_to_merge
elif self.trait.num_groups_to_merge != 1:
self.num_groups_to_merge = self.trait.num_groups_to_merge
def _layout_str(self) -> str:
"""Get layout as lowercase string for naming."""
if hasattr(self.layout, "value"):
return self.layout.value.lower()
return str(self.layout).lower()
def name(self, datatype: str) -> str:
"""
Generate kernel name that uniquely identifies the kernel configuration.
Format: grouped_conv_{variant}_{dtype}_{layout}_{ndim}d_{pipeline}_{epilogue}_{scheduler}
_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}
_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}
[_vec{a}_{b}_{c}][_bpc{n}][_wg{n}][_gm{n}][_dsb][_pad{mnk}]
All parameters that affect kernel behavior MUST be included to ensure
unique names for unique configurations:
- Variant (fwd/bwd_data/bwd_weight)
- Data type
- Layout (nhwgc, nchw, ndhwgc, etc.)
- Spatial dimensions (2d/3d)
- Pipeline, epilogue, scheduler
- Tile, warp, warp_tile dimensions
- Vector sizes, occupancy hints (if non-default)
- Double SMEM buffer, padding flags
"""
t = self.tile
tr = self.trait
layout_str = self._layout_str()
variant_str = {
GroupedConvVariant.FORWARD: "fwd",
GroupedConvVariant.BACKWARD_DATA: "bwd_data",
GroupedConvVariant.BACKWARD_WEIGHT: "bwd_weight",
}[self.variant]
# Core identity: variant, dtype, layout, dims
name = (
f"grouped_conv_{variant_str}_{datatype}_{layout_str}_{self.ndim_spatial}d"
)
# Pipeline configuration
name += f"_{tr.pipeline}_{tr.epilogue}_{tr.scheduler}"
# Block tile dimensions (M_Tile x N_Tile x K_Tile)
name += f"_{t.tile_m}x{t.tile_n}x{t.tile_k}"
# Wave distribution (M_Warp x N_Warp x K_Warp)
name += f"_{t.warp_m}x{t.warp_n}x{t.warp_k}"
# Warp tile dimensions (M_Warp_Tile x N_Warp_Tile x K_Warp_Tile)
name += f"_{t.warp_tile_m}x{t.warp_tile_n}x{t.warp_tile_k}"
# Vector sizes (only if non-default)
if (self.vector_size_a, self.vector_size_b, self.vector_size_c) != (4, 8, 8):
name += (
f"_vec{self.vector_size_a}_{self.vector_size_b}_{self.vector_size_c}"
)
# Occupancy hints (only if non-default)
if self.block_per_cu != 1:
name += f"_bpc{self.block_per_cu}"
if self.num_wave_groups != 1:
name += f"_wg{self.num_wave_groups}"
if self.num_groups_to_merge != 1:
name += f"_gm{self.num_groups_to_merge}"
# Double SMEM buffer (for compute V4+)
if self.double_smem_buffer or tr.double_smem_buffer:
name += "_dsb"
# Two-stage bwd_weight (fp32 workspace + elementwise convert)
if tr.two_stage:
name += "_2stage"
# Padding suffix (only if not all enabled)
if not (tr.pad_m and tr.pad_n and tr.pad_k):
name += f"_pad{int(tr.pad_m)}{int(tr.pad_n)}{int(tr.pad_k)}"
return name
def is_valid_for_arch(self, arch: Optional[str] = None) -> bool:
"""Check if configuration is valid for target architecture"""
target_arch = arch if arch is not None else self.arch
# Check trait validity
if not self.trait.is_valid():
return False
# Backward operations have stricter pipeline requirements:
# - Backward weight: compv4/compv5 have transpose_tile2d issues
# - Backward data: compv4 has get_length issues in bwd_data kernel
# Both backward operations ONLY support compv3 and mem pipelines
if self.variant in (
GroupedConvVariant.BACKWARD_WEIGHT,
GroupedConvVariant.BACKWARD_DATA,
):
if self.trait.pipeline not in ("compv3", "mem"):
return False
# Check warp configuration (from arch_specs)
try:
from arch_specs_generated import WARP_SUPPORTED_COMBINATIONS
supported = WARP_SUPPORTED_COMBINATIONS.get(target_arch)
if supported is None:
return False # Unknown architecture
warp_cfg = [self.tile.warp_m, self.tile.warp_n, self.tile.warp_k]
if warp_cfg not in supported:
return False
except ImportError:
pass # Allow if arch_specs not available
return True
# ============================================================================
# Type Mappings
# ============================================================================
class GroupedConvTypeMappings:
"""Centralized type mappings for grouped convolution code generation"""
DTYPE_TO_CK = {
"fp16": "half_t",
"bf16": "bf16_t",
"fp32": "float",
}
# CK Tile conv pipelines (from conv_configs.hpp PipelineTypeTraits).
# basic_v1/mem/compv3 use GroupedConvUniversalPipelineAgBgCrPolicy;
# compv4/compv5/compv6/comp_async/basic_async_v1 use their own default policy.
PIPELINE_TO_CK = {
"basic_v1": "GemmPipeline::BASIC_V1",
"mem": "GemmPipeline::MEMORY",
"compv3": "GemmPipeline::COMPUTE_V3",
"compv4": "GemmPipeline::COMPUTE_V4",
"compv5": "GemmPipeline::COMPUTE_V5",
"compv6": "GemmPipeline::COMPUTE_V6",
"comp_async": "GemmPipeline::COMPUTE_ASYNC",
"basic_async_v1": "GemmPipeline::BASIC_ASYNC_V1",
}
SCHEDULER_TO_CK = {
"intrawave": "GemmPipelineScheduler::Intrawave",
"interwave": "GemmPipelineScheduler::Interwave",
}
LAYOUT_1D = {
"in": "tensor_layout::convolution::NWGC",
"wei": "tensor_layout::convolution::GKXC",
"out": "tensor_layout::convolution::NWGK",
}
LAYOUT_2D = {
"in": "tensor_layout::convolution::NHWGC",
"wei": "tensor_layout::convolution::GKYXC",
"out": "tensor_layout::convolution::NHWGK",
}
LAYOUT_3D = {
"in": "tensor_layout::convolution::NDHWGC",
"wei": "tensor_layout::convolution::GKZYXC",
"out": "tensor_layout::convolution::NDHWGK",
}
@classmethod
def get_layouts(cls, ndim: int) -> dict:
if ndim == 1:
return cls.LAYOUT_1D
elif ndim == 2:
return cls.LAYOUT_2D
else:
return cls.LAYOUT_3D
# ============================================================================
# CK Tile Grouped Conv Kernel Generator
# ============================================================================
class CKTileGroupedConvKernelGenerator:
"""Generates CK Tile grouped convolution kernel instance code"""
def __init__(
self,
datatype: str,
variant: GroupedConvVariant = GroupedConvVariant.FORWARD,
):
self.datatype = datatype
self.variant = variant
self.tm = GroupedConvTypeMappings()
def generate(self, config: GroupedConvKernelConfig) -> str:
"""Generate complete CK Tile grouped convolution kernel"""
kernel_name = config.name(self.datatype)
return f"""{self._header(kernel_name, config)}
{self._config_struct(config, kernel_name)}
{self._kernel_instance(config, kernel_name)}
"""
def _header(self, kernel_name: str, config: GroupedConvKernelConfig) -> str:
"""Generate header includes based on variant"""
if self.variant == GroupedConvVariant.BACKWARD_DATA:
kernel_header = "grouped_convolution_backward_data_kernel.hpp"
elif self.variant == GroupedConvVariant.BACKWARD_WEIGHT:
kernel_header = "grouped_convolution_backward_weight_kernel.hpp"
else:
kernel_header = "grouped_convolution_forward_kernel.hpp"
elementwise_include = ""
if config.trait.two_stage:
elementwise_include = '\n#include "ck_tile/ops/elementwise.hpp"'
return f"""// SPDX-License-Identifier: MIT
// Auto-generated CK Tile Grouped Convolution kernel: {kernel_name}
// Variant: {self.variant.value}
#pragma once
#include <cstdint>
#include <numeric>
#include <functional>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/grouped_convolution/kernel/{kernel_header}"
#include "ck_tile/ops/grouped_convolution/pipeline/grouped_conv_universal_pipeline_ag_bg_cr_policy.hpp"{elementwise_include}
using namespace ck_tile;
"""
def _config_struct(self, config: GroupedConvKernelConfig, kernel_name: str) -> str:
"""Generate config struct"""
t = config.tile
tr = config.trait
layouts = self.tm.get_layouts(config.ndim_spatial)
return f"""
// Kernel configuration
struct {kernel_name}_Config {{
// Data types
using InDataType = {self.tm.DTYPE_TO_CK[self.datatype]};
using WeiDataType = {self.tm.DTYPE_TO_CK[self.datatype]};
using AccDataType = float;
using OutDataType = {self.tm.DTYPE_TO_CK[self.datatype]};
// Layouts
using InLayout = {layouts["in"]};
using WeiLayout = {layouts["wei"]};
using OutLayout = {layouts["out"]};
// Tile shape
static constexpr index_t M_Tile = {t.tile_m};
static constexpr index_t N_Tile = {t.tile_n};
static constexpr index_t K_Tile = {t.tile_k};
static constexpr index_t M_Warp = {t.warp_m};
static constexpr index_t N_Warp = {t.warp_n};
static constexpr index_t K_Warp = {t.warp_k};
static constexpr index_t M_Warp_Tile = {t.warp_tile_m};
static constexpr index_t N_Warp_Tile = {t.warp_tile_n};
static constexpr index_t K_Warp_Tile = {t.warp_tile_k};
// Vector sizes
static constexpr index_t VectorSizeA = {config.vector_size_a};
static constexpr index_t VectorSizeB = {config.vector_size_b};
static constexpr index_t VectorSizeC = {config.vector_size_c};
// Padding
static constexpr bool kPadM = {str(tr.pad_m).lower()};
static constexpr bool kPadN = {str(tr.pad_n).lower()};
static constexpr bool kPadK = {str(tr.pad_k).lower()};
// Pipeline & Epilogue
static constexpr auto Pipeline = {self.tm.PIPELINE_TO_CK[tr.pipeline]};
static constexpr auto Scheduler = {self.tm.SCHEDULER_TO_CK[tr.scheduler]};
static constexpr bool DoubleSmemBuffer = {str(tr.double_smem_buffer).lower()};
static constexpr bool UseCShuffleEpilogue = {str(tr.epilogue == "cshuffle").lower()};
// Other params
static constexpr int kBlockPerCu = {config.block_per_cu};
static constexpr index_t NumWaveGroups = {config.num_wave_groups};
static constexpr index_t NumGroupsToMerge = {tr.num_groups_to_merge};
static constexpr bool EnableSplitImage = {str(tr.split_image).lower()};
static constexpr bool ExplicitGemm = {str(tr.explicit_gemm).lower()};
static constexpr index_t NDimSpatial = {config.ndim_spatial};
// Target architecture
static constexpr const char* TargetArch = "{config.arch}";
}};
"""
def _kernel_instance(
self, config: GroupedConvKernelConfig, kernel_name: str
) -> str:
"""Generate kernel instantiation code with launch function"""
tr = config.trait
if self.variant == GroupedConvVariant.BACKWARD_WEIGHT and tr.two_stage:
return self._kernel_instance_two_stage(config, kernel_name)
# Variant-specific configuration
if self.variant == GroupedConvVariant.BACKWARD_DATA:
host_args_type = "GroupedConvBwdDataHostArgs"
kernel_type = "GroupedConvolutionBackwardDataKernel"
gemm_traits = "GroupedConvImplicitGemmTraitsBwdData"
layout_suffix = "BwdData"
# For bwd_data: A=dOutput, B=Weight, C=dInput
a_dtype = "OutDataType"
b_dtype = "WeiDataType"
c_dtype = "InDataType"
gemm_k_calc = "args.K_ * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end()"
direction_prefix = "BWD_DATA"
launcher_alias = "SelectedConvBwdDataLauncher"
elif self.variant == GroupedConvVariant.BACKWARD_WEIGHT:
host_args_type = "GroupedConvBwdWeightHostArgs"
kernel_type = "GroupedConvolutionBackwardWeightKernel"
gemm_traits = "GroupedConvImplicitGemmTraitsBwdWeight"
layout_suffix = "BwdWeight"
# For bwd_weight: A=dOutput, B=Input, C=dWeight (per CK Tile invoker)
a_dtype = "OutDataType"
b_dtype = "InDataType"
c_dtype = "WeiDataType"
gemm_k_calc = "args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), args.output_spatial_lengths_.end()"
direction_prefix = "BWD_WEIGHT"
launcher_alias = "SelectedConvBwdWeightLauncher"
else: # Forward
host_args_type = "GroupedConvFwdHostArgs<>"
kernel_type = "GroupedConvolutionForwardKernel"
gemm_traits = "GroupedConvImplicitGemmTraitsFwd"
layout_suffix = "Fwd"
a_dtype = "InDataType"
b_dtype = "WeiDataType"
c_dtype = "OutDataType"
gemm_k_calc = "args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end()"
direction_prefix = "FWD"
launcher_alias = "SelectedConvKernelLauncher"
# Create valid C++ namespace name
ns_name = "ns_" + kernel_name.replace("-", "_")
# basic_v1 / basic_async_v1 inherit BaseGemmPipelineAGmemBGmemCRegV1
# whose TailHandler takes (run_func, has_hot_loop) and invokes
# run_func(bool_constant<...>) -- 1 lambda arg. Other pipelines pass
# (run_func, has_hot_loop, tail_number) and invoke 2-arg run_func.
if tr.pipeline in ("basic_v1", "basic_async_v1"):
tail_handler_call = "BaseGemmPipeline::TailHandler(Run, has_hot_loop);"
run_lambda_signature = "[&](const auto has_hot_loop_)"
else:
tail_handler_call = (
"BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);"
)
run_lambda_signature = (
"[&](const auto has_hot_loop_, const auto tail_number_)"
)
return f"""
// Unique namespace for this kernel to avoid conflicts when including multiple kernels
namespace {ns_name} {{
// Bring Config into namespace
using Config = {kernel_name}_Config;
// Kernel name for identification
constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = "{kernel_name}";
// Selected kernel alias
using SelectedConv{direction_prefix.title()}Kernel = Config;
// =============================================================================
// Kernel Launch Implementation ({self.variant.value})
// =============================================================================
struct {kernel_name}_Launcher {{
using KernelConfig = Config; // Use the Config alias from namespace
using InDataType = typename Config::InDataType;
using WeiDataType = typename Config::WeiDataType;
using OutDataType = typename Config::OutDataType;
using AccDataType = typename Config::AccDataType;
using InLayout = typename Config::InLayout;
using WeiLayout = typename Config::WeiLayout;
using OutLayout = typename Config::OutLayout;
static constexpr index_t NDimSpatial = Config::NDimSpatial;
// Implicit GEMM shape
using GemmShape = TileGemmShape<
sequence<Config::M_Tile, Config::N_Tile, Config::K_Tile>,
sequence<Config::M_Warp, Config::N_Warp, Config::K_Warp>,
sequence<Config::M_Warp_Tile, Config::N_Warp_Tile, Config::K_Warp_Tile>>;
// Convolution traits
static constexpr auto ConvSpec = ConvolutionSpecialization::Default;
using GroupedConvTraitsType = GroupedConvTraits<
NDimSpatial, ConvSpec, InLayout, WeiLayout, tuple<>, OutLayout,
Config::VectorSizeA, Config::VectorSizeB, Config::VectorSizeC,
Config::NumGroupsToMerge, Config::EnableSplitImage, Config::ExplicitGemm>;
// Tile partitioner
using TilePartitioner = GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
// Universal traits - layout suffix changes per variant
using GemmUniversalTraits = TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
Config::DoubleSmemBuffer,
typename GroupedConvTraitsType::AsLayout{layout_suffix},
typename GroupedConvTraitsType::BsLayout{layout_suffix},
typename GroupedConvTraitsType::CLayout{layout_suffix},
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
Config::NumWaveGroups>;
// Pipeline problem - data types change per variant
using GemmPipelineProblem = GemmPipelineProblem<
{a_dtype}, {b_dtype}, AccDataType, GemmShape,
typename GroupedConvTraitsType::template {gemm_traits}<Config::NumWaveGroups>,
element_wise::PassThrough, element_wise::PassThrough, {c_dtype},
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
// Base pipeline for tail handling
using BaseGemmPipeline = {self._get_base_pipeline(tr.pipeline)}<GemmPipelineProblem>;
static float launch(const {host_args_type}& args, const stream_config& s) {{
const index_t gemm_k = {gemm_k_calc}, 1, std::multiplies<index_t>());
const index_t k_grain = args.k_batch * Config::K_Tile;
const index_t K_split = (gemm_k + k_grain - 1) / k_grain * Config::K_Tile;
const index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{{0}};
constexpr auto scheduler = Config::Scheduler;
using UniversalGemmProblem = UniversalGemmPipelineProblem<
{a_dtype}, {b_dtype}, AccDataType, GemmShape, GemmUniversalTraits,
scheduler,
element_wise::PassThrough, element_wise::PassThrough, {c_dtype},
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = {self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")};
using ConvEpilogue = CShuffleEpilogue<CShuffleEpilogueProblem<
{a_dtype}, {b_dtype}, tuple<>, AccDataType, {c_dtype},
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
element_wise::PassThrough,
TilePartitioner::MPerBlock, TilePartitioner::NPerBlock,
Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile,
Config::N_Warp_Tile, Config::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
Config::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
Config::VectorSizeC, false, 1, Config::DoubleSmemBuffer>>;
using Kernel = {kernel_type}<
GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>;
const auto Run = {run_lambda_signature} {{
auto kargs = Kernel::MakeKernelArgs(args);
if (!Kernel::IsSupportedArgument(kargs)) {{
throw std::runtime_error("Arguments not supported for grouped conv kernel");
}}
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
ave_time = launch_kernel(s, make_kernel<Config::kBlockPerCu>(
Kernel{{}}, grids, blocks, 0, kargs));
return ave_time;
}};
{tail_handler_call}
return ave_time;
}}
}};
// Launcher alias for tile_engine compatibility
using {launcher_alias} = {kernel_name}_Launcher;
}} // namespace {ns_name}
// Export specific launcher to global namespace
using {kernel_name}_Launcher = {ns_name}::{kernel_name}_Launcher;
// When used with -include compiler flag, export aliases to global namespace
#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE
using {launcher_alias} = {ns_name}::{launcher_alias};
constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = {ns_name}::CONV_{direction_prefix}_KERNEL_NAME;
#endif
"""
# Pipelines that accept GroupedConvUniversalPipelineAgBgCrPolicy
# as a second template parameter for conv-specific LDS layout.
# (from conv_configs.hpp PipelineTypeTraits -- basic_v1/mem/compv3)
# CompV4/V5/V6/comp_async/basic_async_v1 use their own default policies.
_CONV_POLICY_PIPELINES = {"basic_v1", "mem", "compv3"}
def _get_pipeline(self, pipeline: str) -> str:
"""Get pipeline class name."""
pipelines = {
"basic_v1": "GemmPipelineAGmemBGmemCRegV1",
"mem": "GemmPipelineAgBgCrMem",
"compv3": "GemmPipelineAgBgCrCompV3",
"compv4": "GemmPipelineAgBgCrCompV4",
"compv5": "GemmPipelineAgBgCrCompV5",
"compv6": "GemmPipelineAgBgCrCompV6",
"comp_async": "GemmPipelineAgBgCrCompAsync",
"basic_async_v1": "GemmPipelineAGmemBGmemCRegAsyncV1",
}
return pipelines.get(pipeline, "GemmPipelineAgBgCrCompV3")
def _get_pipeline_template_args(self, pipeline: str, problem_type: str) -> str:
"""Get full template argument list for pipeline instantiation.
For basic_v1/mem/compv3, passes GroupedConvUniversalPipelineAgBgCrPolicy
as a second template argument for conv-specific LDS banking.
"""
base = self._get_pipeline(pipeline)
if pipeline in self._CONV_POLICY_PIPELINES:
return f"{base}<{problem_type}, GroupedConvUniversalPipelineAgBgCrPolicy>"
return f"{base}<{problem_type}>"
def _get_base_pipeline(self, pipeline: str) -> str:
"""Get base pipeline class name (used for tail handling only).
Note: basic_async_v1 inherits from BaseGemmPipelineAGmemBGmemCRegV1
(there is no separate BaseGemmPipelineAGmemBGmemCRegAsyncV1).
"""
pipelines = {
"basic_v1": "BaseGemmPipelineAGmemBGmemCRegV1",
"mem": "BaseGemmPipelineAgBgCrMem",
"compv3": "BaseGemmPipelineAgBgCrCompV3",
"compv4": "BaseGemmPipelineAgBgCrCompV4",
"compv5": "BaseGemmPipelineAgBgCrCompV5",
"compv6": "BaseGemmPipelineAgBgCrCompV6",
"comp_async": "BaseGemmPipelineAgBgCrCompAsync",
"basic_async_v1": "BaseGemmPipelineAGmemBGmemCRegV1",
}
return pipelines.get(pipeline, "BaseGemmPipelineAgBgCrCompV3")
def _kernel_instance_two_stage(
self, config: GroupedConvKernelConfig, kernel_name: str
) -> str:
"""Generate two-stage bwd_weight kernel: GEMM into fp32 workspace + ElementWise convert.
Mirrors grouped_convolution_backward_weight_two_stage_invoker.hpp from
example/ck_tile/20_grouped_convolution/.
"""
tr = config.trait
ns_name = "ns_" + kernel_name.replace("-", "_")
direction_prefix = "BWD_WEIGHT"
launcher_alias = "SelectedConvBwdWeightLauncher"
return f"""
namespace {ns_name} {{
using Config = {kernel_name}_Config;
constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = "{kernel_name}";
using SelectedConv{direction_prefix.title()}Kernel = Config;
struct {kernel_name}_Launcher {{
using KernelConfig = Config;
using InDataType = typename Config::InDataType;
using WeiDataType = typename Config::WeiDataType;
using OutDataType = typename Config::OutDataType;
using AccDataType = typename Config::AccDataType;
using InLayout = typename Config::InLayout;
using WeiLayout = typename Config::WeiLayout;
using OutLayout = typename Config::OutLayout;
using WorkspaceDataType = float;
static constexpr index_t NDimSpatial = Config::NDimSpatial;
// Two-stage forces VectorSizeC = 1 for workspace writes
static constexpr index_t VectorSizeC_TwoStage = 1;
using GemmShape = TileGemmShape<
sequence<Config::M_Tile, Config::N_Tile, Config::K_Tile>,
sequence<Config::M_Warp, Config::N_Warp, Config::K_Warp>,
sequence<Config::M_Warp_Tile, Config::N_Warp_Tile, Config::K_Warp_Tile>>;
static constexpr auto ConvSpec = ConvolutionSpecialization::Default;
using GroupedConvTraitsType = GroupedConvTraits<
NDimSpatial, ConvSpec, InLayout, WeiLayout, tuple<>, OutLayout,
Config::VectorSizeA, Config::VectorSizeB, VectorSizeC_TwoStage,
Config::NumGroupsToMerge, Config::EnableSplitImage, Config::ExplicitGemm>;
using TilePartitioner = GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
Config::DoubleSmemBuffer,
typename GroupedConvTraitsType::AsLayoutBwdWeight,
typename GroupedConvTraitsType::BsLayoutBwdWeight,
typename GroupedConvTraitsType::CLayoutBwdWeight,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
Config::NumWaveGroups>;
using GemmPipelineProblem = GemmPipelineProblem<
OutDataType, InDataType, AccDataType, GemmShape,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<Config::NumWaveGroups>,
element_wise::PassThrough, element_wise::PassThrough, WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = {self._get_base_pipeline(tr.pipeline)}<GemmPipelineProblem>;
static float launch(const GroupedConvBwdWeightHostArgs& args, const stream_config& s) {{
const index_t gemm_k = args.N_ * std::accumulate(
args.output_spatial_lengths_.begin(), args.output_spatial_lengths_.end(),
1, std::multiplies<index_t>());
const index_t k_grain = args.k_batch * Config::K_Tile;
const index_t K_split = (gemm_k + k_grain - 1) / k_grain * Config::K_Tile;
const index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{{0}};
constexpr auto scheduler = Config::Scheduler;
using UniversalGemmProblem = UniversalGemmPipelineProblem<
OutDataType, InDataType, AccDataType, GemmShape, GemmUniversalTraits,
scheduler,
element_wise::PassThrough, element_wise::PassThrough, WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = {self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")};
// Epilogue writes to fp32 workspace (not fp16 output)
using ConvEpilogue = CShuffleEpilogue<CShuffleEpilogueProblem<
OutDataType, InDataType, tuple<>, AccDataType, WorkspaceDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
element_wise::PassThrough,
TilePartitioner::MPerBlock, TilePartitioner::NPerBlock,
Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile,
Config::N_Warp_Tile, Config::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
Config::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = GroupedConvolutionBackwardWeightKernel<
GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>;
// ElementWise kernel: fp32 workspace -> fp16/bf16 output
using XElementwiseOp = element_wise::UnaryConvert;
using EwBlockTile = sequence<2048>;
using EwBlockWarps = sequence<8>;
using EwWarpTile = sequence<64>;
using EwShape = ElementWiseShape<EwBlockWarps, EwBlockTile, EwWarpTile, WorkspaceDataType>;
using EwProblem = ElementWisePipelineProblem<
WorkspaceDataType, WorkspaceDataType, WeiDataType, EwShape, XElementwiseOp>;
using EwKernel = ElementWiseKernel<EwProblem, ElementWiseDefaultPolicy>;
// Workspace: G * K * C * product(filter_spatial) elements in fp32
const index_t spatial_accum = std::accumulate(
args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end(),
1, std::multiplies<index_t>());
DeviceMem ws_buf(args.G_ * args.K_ * args.C_ * spatial_accum * sizeof(WorkspaceDataType));
GroupedConvBwdWeightHostArgs ws_args(args);
auto* c_ptr = ws_args.wei_ptr;
ws_args.wei_ptr = ws_buf.GetDeviceBuffer();
auto kargs = Kernel::MakeKernelArgs(ws_args);
if(!Kernel::IsSupportedArgument(kargs)) {{
throw std::runtime_error("Arguments not supported for two-stage bwd_weight kernel");
}}
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
// ElementWise kernel setup
const index_t ew_block_size = EwKernel::BlockSize();
const index_t total_elems = args.G_ * args.K_ * args.C_ * spatial_accum;
constexpr index_t elems_per_block = EwBlockTile::at(number<0>{{}});
const index_t ew_grid_size = (total_elems + elems_per_block - 1) / elems_per_block;
auto ew_shape = make_tuple(args.G_ * args.K_,
args.C_ * spatial_accum);
auto ew_inputs = make_tuple(static_cast<WorkspaceDataType*>(ws_args.wei_ptr));
if(!EwKernel::IsSupportedArgument(ew_shape)) {{
throw std::runtime_error("ElementWise arguments not supported for two-stage convert");
}}
auto preprocess = [&]() {{
if(kargs.k_batch > 1)
hip_check_error(hipMemsetAsync(
ws_args.wei_ptr, 0,
total_elems * sizeof(WorkspaceDataType),
s.stream_id_));
}};
ave_time = launch_kernel_time_mask(
s, preprocess,
make_kernel<Config::kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs),
make_kernel<Config::kBlockPerCu>(
EwKernel{{}}, ew_grid_size, ew_block_size, 0,
ew_shape,
make_tuple(args.C_ * spatial_accum, 1),
make_tuple(args.C_ * spatial_accum, 1),
ew_inputs,
static_cast<WeiDataType*>(c_ptr)));
return ave_time;
}}
}};
using {launcher_alias} = {kernel_name}_Launcher;
}} // namespace {ns_name}
using {kernel_name}_Launcher = {ns_name}::{kernel_name}_Launcher;
#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE
using {launcher_alias} = {ns_name}::{launcher_alias};
constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = {ns_name}::CONV_{direction_prefix}_KERNEL_NAME;
#endif
"""
# ============================================================================
# Dispatcher Wrapper Generator
# ============================================================================
class GroupedConvDispatcherWrapperGenerator:
"""Generates dispatcher integration wrapper following GEMM pattern"""
# Static mappings for pipeline and scheduler enum names (matches kernel_key.hpp)
PIPELINE_TO_DISPATCHER = {
"mem": "Pipeline::Mem",
"compv3": "Pipeline::CompV3",
"compv4": "Pipeline::CompV4",
"compv5": "Pipeline::CompV5",
"preshufflev1": "Pipeline::PreShuffleV1",
"preshufflev2": "Pipeline::PreShuffleV2",
}
SCHEDULER_TO_DISPATCHER = {
"default": "Scheduler::Default",
"intrawave": "Scheduler::Intrawave",
"interwave": "Scheduler::Interwave",
}
def __init__(
self,
datatype: str,
variant: GroupedConvVariant = GroupedConvVariant.FORWARD,
):
self.datatype = datatype
self.variant = variant
def _pipeline_to_dispatcher(self, pipeline: str) -> str:
"""Convert pipeline string to dispatcher enum value"""
return self.PIPELINE_TO_DISPATCHER.get(
pipeline.lower(), f"Pipeline::{pipeline.capitalize()}"
)
def _scheduler_to_dispatcher(self, scheduler: str) -> str:
"""Convert scheduler string to dispatcher enum value"""
return self.SCHEDULER_TO_DISPATCHER.get(
scheduler.lower(), f"Scheduler::{scheduler.capitalize()}"
)
def generate(
self,
config: GroupedConvKernelConfig,
kernel_path: Path,
output_dir: Path,
) -> str:
"""Generate dispatcher wrapper with factory function for registry"""
kernel_name = config.name(self.datatype)
rel_path = kernel_path.relative_to(output_dir)
# Determine launcher type based on variant
if self.variant == GroupedConvVariant.FORWARD:
launcher_alias = "SelectedConvKernelLauncher"
host_args_type = "GroupedConvFwdHostArgs<>"
conv_type_str = "forward"
elif self.variant == GroupedConvVariant.BACKWARD_DATA:
launcher_alias = "SelectedConvBwdDataLauncher"
host_args_type = "GroupedConvBwdDataHostArgs"
conv_type_str = "bwd_data"
else: # BACKWARD_WEIGHT
launcher_alias = "SelectedConvBwdWeightLauncher"
host_args_type = "GroupedConvBwdWeightHostArgs"
conv_type_str = "bwd_weight"
return f"""// SPDX-License-Identifier: MIT
// Auto-generated dispatcher wrapper for: {kernel_name}
#pragma once
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
#include "../{rel_path}"
namespace ck_tile {{
namespace dispatcher {{
namespace generated {{
using ::ck_tile::dispatcher::GroupedConvKernelInstancePtr;
using ::ck_tile::dispatcher::GroupedConvKernelKey;
using ::ck_tile::dispatcher::DataType;
using ::ck_tile::dispatcher::LayoutTag;
using ::ck_tile::dispatcher::Pipeline;
using ::ck_tile::dispatcher::Scheduler;
using ::ck_tile::dispatcher::Epilogue;
using Priority = ::ck_tile::dispatcher::GroupedConvRegistry::Priority;
// Factory function to create kernel instance for registry
inline GroupedConvKernelInstancePtr make_{kernel_name}(const std::string& gfx_arch = "gfx942") {{
GroupedConvKernelKey key;
key.signature.dtype_in = DataType::FP16;
key.signature.dtype_wei = DataType::FP16;
key.signature.dtype_out = DataType::FP16;
key.signature.dtype_acc = DataType::FP32;
key.signature.layout = "nhwgc";
key.signature.conv_type = "{conv_type_str}";
key.signature.num_dims = {config.ndim_spatial};
key.signature.groups = 1;
key.algorithm.tile_shape = {{{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k}}};
key.algorithm.wave_shape = {{{config.tile.warp_m}, {config.tile.warp_n}, 1}};
key.algorithm.warp_tile_shape = {{{config.tile.warp_tile_m}, {config.tile.warp_tile_n}, {config.tile.warp_tile_k}}};
key.algorithm.pipeline = {self._pipeline_to_dispatcher(config.trait.pipeline)};
key.algorithm.scheduler = {self._scheduler_to_dispatcher(config.trait.scheduler)};
key.algorithm.epilogue = Epilogue::CShuffle;
key.gfx_arch = gfx_arch;
// Create kernel instance that wraps the launcher
return std::make_shared<GroupedConvKernelInstance>(
key,
"{kernel_name}",
[]({host_args_type}& args, const stream_config& cfg) -> float {{
return {kernel_name}_Launcher::launch(args, cfg);
}}
);
}}
}} // namespace generated
}} // namespace dispatcher
}} // namespace ck_tile
// Export launcher alias to global namespace for direct use
using {launcher_alias} = {kernel_name}_Launcher;
"""
# ============================================================================
# Configuration Parser
# ============================================================================
def get_default_configs(
arch: str = "gfx942",
variants: Optional[List[GroupedConvVariant]] = None,
ndims: Optional[List[int]] = None,
) -> List[GroupedConvKernelConfig]:
"""Get default grouped convolution configurations for target architecture.
Uses tile configurations from grouped_conv_instance_builder.py as single source of truth.
"""
configs = []
if variants is None:
variants = [GroupedConvVariant.FORWARD]
if ndims is None:
ndims = [2]
# Import tile configs from instance builder (single source of truth)
if not HAS_TILE_CONFIGS or not COMMON_TILES:
log.warning("grouped_config_rules not available, using fallback tile configs")
# Fallback to minimal set if grouped_config_rules unavailable
fwd_bwd_data_tiles = [
(128, 128, 32, 2, 2, 32, 32, 16),
(64, 64, 32, 1, 4, 16, 16, 16),
(16, 64, 64, 1, 4, 16, 16, 32),
]
bwd_weight_tiles = [(16, 64, 64, 1, 4, 16, 16, 32)]
else:
# Build tile list from COMMON_TILES with wave/warp mappings
fwd_bwd_data_tiles = []
for tile_m, tile_n, tile_k in COMMON_TILES:
tile_key = (tile_m, tile_n, tile_k)
if tile_key in TILE_TO_WAVE and tile_key in TILE_TO_WARP:
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
fwd_bwd_data_tiles.append(
(tile_m, tile_n, tile_k, wave_m, wave_n, warp_m, warp_n, warp_k)
)
# Backward weight: use BWD_WEIGHT_TILES from config rules
bwd_weight_tiles = []
for tile_m, tile_n, tile_k in BWD_WEIGHT_TILES:
tile_key = (tile_m, tile_n, tile_k)
if tile_key in TILE_TO_WAVE and tile_key in TILE_TO_WARP:
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
bwd_weight_tiles.append(
(tile_m, tile_n, tile_k, wave_m, wave_n, warp_m, warp_n, warp_k)
)
for variant in variants:
# Select tile configs based on variant
if variant == GroupedConvVariant.BACKWARD_WEIGHT:
tile_configs = bwd_weight_tiles
# Backward weight supports compv3 and mem pipelines
# (compv4/compv5 have transpose_tile2d issues)
pipelines = [("compv3", "cshuffle"), ("mem", "default")]
# Also generate two-stage variants (fp32 workspace + elementwise convert)
two_stage_flags = [False, True]
elif variant == GroupedConvVariant.BACKWARD_DATA:
tile_configs = fwd_bwd_data_tiles
# Backward data supports compv3 and mem pipelines
# (compv4/compv5 have get_length issues in bwd_data kernel)
pipelines = [("compv3", "cshuffle"), ("mem", "default")]
two_stage_flags = [False]
else:
tile_configs = fwd_bwd_data_tiles
# Only forward grouped convolution supports both compv3 and compv4
pipelines = [("compv3", "cshuffle"), ("compv4", "cshuffle")]
two_stage_flags = [False]
for ndim in ndims:
for pipeline, epilogue in pipelines:
for (
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_tile_m,
warp_tile_n,
warp_tile_k,
) in tile_configs:
# Skip tiles incompatible with compv4
if pipeline == "compv4" and HAS_TILE_CONFIGS:
tile_key = (tile_m, tile_n, tile_k)
if tile_key not in COMPV4_COMPATIBLE_TILES:
continue # Skip this tile for compv4
for two_stage in two_stage_flags:
adj_tile_k = tile_k * 2 if pipeline == "compv4" else tile_k
trait = GroupedConvTraitConfig(
pipeline=pipeline,
scheduler="intrawave",
epilogue=epilogue,
double_smem_buffer=(pipeline == "compv4"),
pad_m=True,
pad_n=True,
pad_k=True,
two_stage=two_stage,
)
if not trait.is_valid():
continue
config = GroupedConvKernelConfig(
tile=TileConfig(
tile_m=tile_m,
tile_n=tile_n,
tile_k=adj_tile_k,
warp_m=warp_m,
warp_n=warp_n,
warp_k=1,
warp_tile_m=warp_tile_m,
warp_tile_n=warp_tile_n,
warp_tile_k=warp_tile_k,
),
trait=trait,
variant=variant,
ndim_spatial=ndim,
arch=arch,
)
if config.is_valid_for_arch():
configs.append(config)
return configs
def get_arch_filter():
"""Get arch filter if available"""
try:
from arch_filter import ArchFilter
return ArchFilter
except ImportError:
return None
# ============================================================================
# Main Generator
# ============================================================================
class _GenItem:
"""Item for parallel generation with progress logging."""
def __init__(
self,
idx: int,
total: int,
config: GroupedConvKernelConfig,
datatype: str,
variant: GroupedConvVariant,
):
self.idx = idx
self.total = total
self.config = config
self.datatype = datatype
self.variant = variant
def __str__(self) -> str:
return f"kernel {self.idx}/{self.total}: {self.config.name(self.datatype)}"
class UnifiedGroupedConvCodegen:
"""Main grouped convolution code generator"""
def __init__(
self,
output_dir: Path,
gpu_target: str = "gfx942",
datatype: str = "fp16",
ndim_spatial: int = 2,
enable_arch_filter: bool = True,
):
self.output_dir = output_dir
self.output_dir.mkdir(parents=True, exist_ok=True)
# Create wrapper directory for dispatcher integration
self.wrapper_dir = self.output_dir / "dispatcher_wrappers"
self.wrapper_dir.mkdir(parents=True, exist_ok=True)
self.generated_files: List[Path] = []
self.generated_wrappers: List[Path] = []
self.gpu_target = gpu_target
self.datatype = datatype
self.ndim_spatial = ndim_spatial
# Initialize architecture filter for GPU-specific validation
self.arch_filter = None
if enable_arch_filter and HAS_ARCH_FILTER:
try:
self.arch_filter = ArchFilter(gpu_target, strict_mode=False)
log.info(f"Architecture filter enabled for {gpu_target}")
except ValueError as e:
log.warning(f"Could not create arch filter: {e}")
def _get_configs(self) -> List[GroupedConvKernelConfig]:
"""Get configurations for this codegen's datatype and ndim_spatial."""
return get_default_configs(
arch=self.gpu_target,
variants=[
GroupedConvVariant.FORWARD,
GroupedConvVariant.BACKWARD_DATA,
GroupedConvVariant.BACKWARD_WEIGHT,
],
ndims=[self.ndim_spatial],
)
def _get_operator_type(
self, variant: GroupedConvVariant
) -> Optional["OperatorType"]:
"""Map GroupedConvVariant to OperatorType for arch validation"""
if OperatorType is None:
return None
variant_to_operator = {
GroupedConvVariant.FORWARD: OperatorType.CONV_FWD,
GroupedConvVariant.BACKWARD_DATA: OperatorType.CONV_BWD_DATA,
GroupedConvVariant.BACKWARD_WEIGHT: OperatorType.CONV_BWD_WEIGHT,
}
return variant_to_operator.get(variant, OperatorType.CONV_FWD)
def is_config_valid(
self, config: GroupedConvKernelConfig, datatype: str = "fp16"
) -> bool:
"""Validate configuration against architecture constraints"""
if not self.arch_filter or not HAS_ARCH_FILTER:
return True
operator = self._get_operator_type(config.variant)
return self.arch_filter.is_kernel_valid(
datatype_a=datatype,
datatype_b=datatype,
datatype_c=datatype,
tile_m=config.tile.tile_m,
tile_n=config.tile.tile_n,
tile_k=config.tile.tile_k,
warp_m=config.tile.warp_m,
warp_n=config.tile.warp_n,
warp_k=1, # Grouped conv typically uses warp_k=1
warp_tile_m=config.tile.warp_tile_m,
warp_tile_n=config.tile.warp_tile_n,
warp_tile_k=config.tile.warp_tile_k,
pipeline=config.trait.pipeline,
epilogue=config.trait.epilogue,
scheduler=config.trait.scheduler,
operator=operator,
)
def generate_kernel(
self,
config: GroupedConvKernelConfig,
datatype: str,
variant: GroupedConvVariant = GroupedConvVariant.FORWARD,
) -> Tuple[Path, Path]:
"""Generate a single kernel file and dispatcher wrapper. Returns (kernel_path, wrapper_path)."""
kernel_gen = CKTileGroupedConvKernelGenerator(datatype, variant)
wrapper_gen = GroupedConvDispatcherWrapperGenerator(datatype, variant)
kernel_name = config.name(datatype)
filename = f"{kernel_name}.hpp"
filepath = self.output_dir / filename
# Generate kernel header
content = kernel_gen.generate(config)
filepath.write_text(content)
self.generated_files.append(filepath)
# Generate dispatcher wrapper
wrapper_content = wrapper_gen.generate(config, filepath, self.output_dir)
wrapper_path = self.wrapper_dir / f"dispatcher_wrapper_{kernel_name}.hpp"
wrapper_path.write_text(wrapper_content)
self.generated_wrappers.append(wrapper_path)
# Generate .cpp compilation unit for per-kernel parallel builds
cpp_filename = f"{kernel_name}.cpp"
cpp_filepath = self.output_dir / cpp_filename
cpp_content = f"""// SPDX-License-Identifier: MIT
// Auto-generated compilation unit for: {kernel_name}
// Enables per-kernel parallel compilation with make -j
#include "{filename}"
namespace ck_tile {{ namespace generated {{
volatile bool _{kernel_name.replace("-", "_")}_loaded = true;
}} }}
"""
cpp_filepath.write_text(cpp_content)
return filepath, wrapper_path
def _generate_single_kernel(self, item: _GenItem):
"""Generate one kernel (used by parallel_generate). Returns (kernel_path, wrapper_path) or raises."""
kernel_path, wrapper_path = self.generate_kernel(
item.config, item.datatype, item.variant
)
log.info(
"Generated kernel %d/%d: %s",
item.idx,
item.total,
item.config.name(item.datatype),
)
return (kernel_path, wrapper_path)
def generate_all(
self,
configs: Optional[List[GroupedConvKernelConfig]] = None,
datatypes: Optional[List[str]] = None,
parallel: bool = True,
) -> dict:
"""Generate all kernel files (optionally in parallel).
Configs are filtered using architecture validation before generation.
Returns dict with keys: kernels, wrappers, failed.
"""
if configs is None:
configs = self._get_configs()
if datatypes is None:
datatypes = [self.datatype]
results = {"kernels": [], "wrappers": [], "failed": []}
# Filter configs using arch validation
valid_tasks = []
rejected_count = 0
for datatype in datatypes:
for config in configs:
if self.is_config_valid(config, datatype):
valid_tasks.append((config, datatype, config.variant))
else:
rejected_count += 1
log.debug(
f"Rejected config for {self.gpu_target}: "
f"{config.tile.tile_m}x{config.tile.tile_n}x{config.tile.tile_k} "
f"variant={config.variant.value}"
)
if rejected_count > 0:
log.info(
f"Filtered {rejected_count} configs for {self.gpu_target}, "
f"{len(valid_tasks)} remaining"
)
total = len(valid_tasks)
items = [
_GenItem(i, total, config, datatype, variant)
for i, (config, datatype, variant) in enumerate(valid_tasks)
]
def _safe_generate(item: _GenItem):
"""Wrapper that catches exceptions for failure tracking."""
try:
k, w = self._generate_single_kernel(item)
return ("ok", k, w, None)
except Exception as e:
return ("fail", None, None, str(e))
raw = parallel_generate(
_safe_generate, items, parallel=parallel and len(items) > 1
)
for r in raw:
if r[0] == "ok":
results["kernels"].append(r[1])
results["wrappers"].append(r[2])
else:
results["failed"].append(r[3])
log.error("Failed: %s", r[3])
# Generate include_all_*.hpp headers for Python ctypes libraries
if results["wrappers"]:
self._generate_include_all_headers()
return results
def _generate_include_all_headers(self):
"""Generate include_all_grouped_conv_*.hpp headers and registration header"""
# Scan output directory for ALL kernel files (not just this run's generated_files)
# This handles the case where fwd and bwd kernels are generated in separate make targets
fwd_headers = []
bwd_data_headers = []
bwd_weight_headers = []
fwd_kernels = []
bwd_data_kernels = []
bwd_weight_kernels = []
for filepath in self.output_dir.glob("grouped_conv_*.hpp"):
name = filepath.name
kernel_name = name[:-4]
if name.startswith("grouped_conv_fwd_"):
fwd_headers.append(name)
fwd_kernels.append(kernel_name)
elif name.startswith(("grouped_conv_bwd_data_", "grouped_conv_bwdd_")):
bwd_data_headers.append(name)
bwd_data_kernels.append(kernel_name)
elif name.startswith(("grouped_conv_bwd_weight_", "grouped_conv_bwdw_")):
bwd_weight_headers.append(name)
bwd_weight_kernels.append(kernel_name)
headers_to_generate = [
("include_all_grouped_conv_fwd_kernels.hpp", fwd_headers, "forward"),
(
"include_all_grouped_conv_bwd_data_kernels.hpp",
bwd_data_headers,
"backward data",
),
(
"include_all_grouped_conv_bwd_weight_kernels.hpp",
bwd_weight_headers,
"backward weight",
),
]
for header_name, kernel_headers, variant_desc in headers_to_generate:
header_path = self.output_dir / header_name
includes = "\n".join(f'#include "{h}"' for h in sorted(kernel_headers))
# Pick the first kernel as the default Selected*Launcher
if kernel_headers:
first_kernel = sorted(kernel_headers)[0][:-4] # Remove .hpp
if variant_desc == "forward":
launcher_alias = (
f"using SelectedConvKernelLauncher = {first_kernel}_Launcher;"
)
elif variant_desc == "backward data":
launcher_alias = (
f"using SelectedConvBwdDataLauncher = {first_kernel}_Launcher;"
)
else: # backward weight
launcher_alias = f"using SelectedConvBwdWeightLauncher = {first_kernel}_Launcher;"
else:
launcher_alias = "// No kernels generated for this variant"
content = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Auto-generated header for grouped conv {variant_desc} kernels
#pragma once
{includes}
// Default launcher alias (uses first kernel)
{launcher_alias}
"""
header_path.write_text(content)
if kernel_headers:
log.info(f"Generated: {header_name} ({len(kernel_headers)} kernels)")
# Generate registration header (following GEMM pattern)
self._generate_registration_header(
fwd_kernels, bwd_data_kernels, bwd_weight_kernels
)
def _generate_registration_header(
self,
fwd_kernels: List[str],
bwd_data_kernels: List[str],
bwd_weight_kernels: List[str],
):
"""Generate master registration header for all grouped conv kernels"""
# Scan wrapper directory for ALL wrapper files
all_wrappers = []
for wrapper_path in self.wrapper_dir.glob(
"dispatcher_wrapper_grouped_conv_*.hpp"
):
all_wrappers.append(wrapper_path.name)
wrapper_includes = "\n".join(f'#include "{w}"' for w in sorted(all_wrappers))
# Generate registration calls
fwd_registrations = "\n ".join(
f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);"
for k in sorted(fwd_kernels)
)
bwd_data_registrations = "\n ".join(
f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);"
for k in sorted(bwd_data_kernels)
)
bwd_weight_registrations = "\n ".join(
f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);"
for k in sorted(bwd_weight_kernels)
)
content = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Auto-generated master registration header for grouped conv kernels
#pragma once
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
{wrapper_includes}
namespace ck_tile {{
namespace dispatcher {{
using Priority = GroupedConvRegistry::Priority;
inline void register_all_grouped_conv_fwd_kernels(
const std::string& gfx_arch = "gfx942",
Priority priority = Priority::Normal)
{{
auto& registry = GroupedConvRegistry::instance();
{fwd_registrations if fwd_registrations else "// No forward kernels"}
}}
inline void register_all_grouped_conv_bwd_data_kernels(
const std::string& gfx_arch = "gfx942",
Priority priority = Priority::Normal)
{{
auto& registry = GroupedConvRegistry::instance();
{bwd_data_registrations if bwd_data_registrations else "// No backward data kernels"}
}}
inline void register_all_grouped_conv_bwd_weight_kernels(
const std::string& gfx_arch = "gfx942",
Priority priority = Priority::Normal)
{{
auto& registry = GroupedConvRegistry::instance();
{bwd_weight_registrations if bwd_weight_registrations else "// No backward weight kernels"}
}}
inline void register_all_grouped_conv_kernels(
const std::string& gfx_arch = "gfx942",
Priority priority = Priority::Normal)
{{
register_all_grouped_conv_fwd_kernels(gfx_arch, priority);
register_all_grouped_conv_bwd_data_kernels(gfx_arch, priority);
register_all_grouped_conv_bwd_weight_kernels(gfx_arch, priority);
}}
inline std::size_t get_grouped_conv_fwd_kernel_count() {{ return {len(fwd_kernels)}; }}
inline std::size_t get_grouped_conv_bwd_data_kernel_count() {{ return {len(bwd_data_kernels)}; }}
inline std::size_t get_grouped_conv_bwd_weight_kernel_count() {{ return {len(bwd_weight_kernels)}; }}
inline std::size_t get_grouped_conv_kernel_count() {{ return {len(fwd_kernels) + len(bwd_data_kernels) + len(bwd_weight_kernels)}; }}
}} // namespace dispatcher
}} // namespace ck_tile
"""
reg_path = self.wrapper_dir / "register_all_grouped_conv_kernels.hpp"
reg_path.write_text(content)
log.info(f"Generated registration header: {reg_path}")
# ============================================================================
# CLI
# ============================================================================
def main():
parser = argparse.ArgumentParser(
description="Unified Grouped Convolution Code Generator"
)
parser.add_argument(
"--output",
"-o",
type=Path,
default=Path("build/generated_kernels"),
help="Output directory",
)
parser.add_argument(
"--datatype",
"-d",
type=str,
nargs="+",
default=["fp16"],
choices=["fp16", "bf16", "fp32"],
help="Data types to generate",
)
parser.add_argument(
"--variant",
"-v",
type=str,
nargs="+",
default=["forward"],
choices=["forward", "bwd_data", "bwd_weight"],
help="Grouped convolution variants",
)
parser.add_argument(
"--ndim",
"-n",
type=int,
nargs="+",
default=[2],
choices=[1, 2, 3],
help="Spatial dimensions",
)
parser.add_argument(
"--arch",
"-a",
type=str,
default="gfx942",
choices=["gfx90a", "gfx942", "gfx950", "gfx1201"],
help="Target GPU architecture",
)
parser.add_argument("--verbose", action="store_true", help="Verbose output")
parser.add_argument(
"--list-configs",
action="store_true",
help="List configurations without generating",
)
# Individual kernel configuration (when not using predefined configs)
parser.add_argument("--tile-m", type=int, help="Block tile M dimension")
parser.add_argument("--tile-n", type=int, help="Block tile N dimension")
parser.add_argument("--tile-k", type=int, help="Block tile K dimension")
parser.add_argument("--warp-m", type=int, help="Wave distribution M")
parser.add_argument("--warp-n", type=int, help="Wave distribution N")
parser.add_argument("--warp-k", type=int, default=1, help="Wave distribution K")
parser.add_argument("--warp-tile-m", type=int, help="Warp tile M")
parser.add_argument("--warp-tile-n", type=int, help="Warp tile N")
parser.add_argument("--warp-tile-k", type=int, default=16, help="Warp tile K")
parser.add_argument(
"--pipeline",
type=str,
choices=[
"basic_v1",
"basic_async_v1",
"mem",
"compv3",
"compv4",
"compv5",
"compv6",
"comp_async",
],
help="Pipeline type",
)
parser.add_argument(
"--scheduler",
type=str,
choices=["intrawave", "interwave"],
help="Scheduler type",
)
parser.add_argument(
"--epilogue",
type=str,
default="cshuffle",
choices=["cshuffle", "default"],
help="Epilogue type",
)
parser.add_argument("--pad-m", type=bool, default=True, help="Pad M dimension")
parser.add_argument("--pad-n", type=bool, default=True, help="Pad N dimension")
parser.add_argument("--pad-k", type=bool, default=True, help="Pad K dimension")
parser.add_argument("--vector-a", type=int, default=4, help="Vector size A")
parser.add_argument("--vector-b", type=int, default=8, help="Vector size B")
parser.add_argument("--vector-c", type=int, default=8, help="Vector size C")
parser.add_argument("--block-per-cu", type=int, default=1, help="Blocks per CU")
parser.add_argument("--num-wave-groups", type=int, default=1, help="Wave groups")
parser.add_argument(
"--num-groups-to-merge", type=int, default=1, help="Groups to merge"
)
parser.add_argument(
"--double-smem-buffer",
type=str,
default=None,
help="Double SMEM buffer (true/false)",
)
parser.add_argument(
"--split-image",
action="store_true",
help="Enable split-image (EnableSplitImage) for large spatial tensors",
)
parser.add_argument(
"--two-stage",
action="store_true",
help="Enable two-stage bwd_weight (fp32 workspace + elementwise convert)",
)
args = parser.parse_args()
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
# Map variant strings to enums
variant_map = {
"forward": GroupedConvVariant.FORWARD,
"bwd_data": GroupedConvVariant.BACKWARD_DATA,
"bwd_weight": GroupedConvVariant.BACKWARD_WEIGHT,
}
requested_variants = [variant_map[v] for v in args.variant]
# Check if user specified custom configuration
custom_config = (
args.tile_m is not None or args.tile_n is not None or args.pipeline is not None
)
if custom_config:
# Build custom config from CLI arguments
tile = TileConfig(
tile_m=args.tile_m or 128,
tile_n=args.tile_n or 128,
tile_k=args.tile_k or 64,
warp_m=args.warp_m or 2,
warp_n=args.warp_n or 2,
warp_k=args.warp_k or 1,
warp_tile_m=args.warp_tile_m or 32,
warp_tile_n=args.warp_tile_n or 32,
warp_tile_k=args.warp_tile_k or 16,
)
pipeline = args.pipeline or "compv4"
# Determine double_smem_buffer: use CLI arg if given, else default based on pipeline
if args.double_smem_buffer is not None:
dsb = args.double_smem_buffer.lower() == "true"
else:
# Historical default: only compv4 auto-defaults to dsb=true.
# Other pipelines that also require DoubleSmemBuffer (e.g. comp_async)
# must be told explicitly via --double-smem-buffer true; otherwise
# they will fail loudly at the pipeline header static_assert. This
# is intentional -- silent fallback to a different config would
# mask the user's input.
dsb = pipeline == "compv4"
trait = GroupedConvTraitConfig(
pipeline=pipeline,
scheduler=args.scheduler or "intrawave",
epilogue=args.epilogue or "cshuffle",
pad_m=args.pad_m,
pad_n=args.pad_n,
pad_k=args.pad_k,
double_smem_buffer=dsb,
num_groups_to_merge=args.num_groups_to_merge,
split_image=args.split_image,
two_stage=args.two_stage,
)
config = GroupedConvKernelConfig(
tile=tile,
trait=trait,
variant=requested_variants[0]
if requested_variants
else GroupedConvVariant.FORWARD,
ndim_spatial=args.ndim[0] if args.ndim else 2,
arch=args.arch,
vector_size_a=args.vector_a,
vector_size_b=args.vector_b,
vector_size_c=args.vector_c,
block_per_cu=args.block_per_cu,
num_wave_groups=args.num_wave_groups,
)
filtered_configs = [config]
else:
# Get predefined configurations for target arch with requested variants and ndims
filtered_configs = get_default_configs(
arch=args.arch, variants=requested_variants, ndims=args.ndim
)
if args.list_configs:
print(f"Grouped convolution configurations for {args.arch}:")
print(f" Datatypes: {args.datatype}")
print(f" Variants: {args.variant}")
print(f" Spatial dims: {args.ndim}")
print(f"\nConfigurations ({len(filtered_configs)}):")
for cfg in filtered_configs:
# List configs for each requested datatype (fixes bf16 -> fp16 bug)
for dt in args.datatype:
print(f" - {cfg.name(dt)}")
print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}")
print(f" Warp: {cfg.tile.warp_m}x{cfg.tile.warp_n}x{cfg.tile.warp_k}")
print(
f" WarpTile: {cfg.tile.warp_tile_m}x{cfg.tile.warp_tile_n}x{cfg.tile.warp_tile_k}"
)
print(
f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}, Scheduler: {cfg.trait.scheduler}"
)
print(
f" Padding: M={cfg.trait.pad_m}, N={cfg.trait.pad_n}, K={cfg.trait.pad_k}"
)
return
# Generate
codegen = UnifiedGroupedConvCodegen(
output_dir=args.output,
gpu_target=args.arch,
enable_arch_filter=True,
)
results = codegen.generate_all(
configs=filtered_configs, datatypes=args.datatype, parallel=True
)
print(
f"\nGenerated {len(results['kernels'])} grouped convolution kernel files "
f"for {args.arch} in {args.output}"
)
if results["failed"]:
print(f" Failed: {len(results['failed'])}")
for err in results["failed"][:5]:
print(f" - {err}")
if __name__ == "__main__":
main()