* Add placeholder test.
* Initial conv bwd weight factory.
* Conv builder test refactoring.
* Add missing pieces to bwd weight factory.
* Improve compile time erros message when no matching factory is found.
* Use amcro to ensure automatic macthing between concepts are their string representations.
* Improve compile time diagnostics.
* Small improvements.
* Improve missing member/wrong type compile-time errors.
* Improve compile time diagnostics.
* Concept bug fixes.
* Remove debug assert.
* Update algorithm signature diagnostics.
* Factory bug fixes.
* First functional version of bwd weight conv factory.
* Refactor handing of GEMM-K batch template parameter in conv bwd weight factory.
* Concept improvements.
* Improve concept diagnostics.
* Introduve a common size type for concepts.
* Update compiletime diagnostics to use the size type.
* Update conv specialization enum.
* Fix fwd conv builder tests.
* Fix smoke tests.
* Separate bwd weigth and bwd data tests into separate targets.
* Clean-up CK Tile builder tests.
* Add bwd weight XDL CShuffle V3 factory.
* Build conv bwd weigth v3 instances successfully.
* Add instance traits for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3.
* Test fix.
* Add instance traits for bwd weight algorithms.
* Add unit tests for instance strings.
* Build new instance traits unit tests but exclude WMMA for now.
* Added factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle.
* Conv bwd weight DL factory.
* Final implementation for bwd weight DL factory.
* Add test for creating DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance.
* Add factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
* Treat ref algorithm the same way as real algorithms in the dispatcher.
* Refactor large tensor support and WMMA configuration.
* Add factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffleV3.
* Update Readme.
* Fix WMMA bwd weight tests.
* Added factory and tests for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3.
* Factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffle.
* Dispatching for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle.
* Add factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
* Fix DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 factory and compute types for input and output tensor in bwd weigth convs.
* Fix fwd factories after refactoring.
* clang-format
* Move compile-time diagnostics to a separate branch.
* Fix ref algorithm dispatching.
* Fix smoke tests.
* clang-format
* Fix factory for regular WMMA conv bwd weight.
* Clarify builder Readme.
* Remove obsolete test file.
* Fix test after merge.
* clang-format
* Remove the C++26 extensions.
* Unify conv elementwise ops and layout definitions for fwd and bwd directions.
* Remove old layout and elementwise ops.
* Unify handling of conv tensor types between fwd and bwd directions.
* Unify block transfer for fwd and bwd directions. Rename ThreadSliceDim to ThreadClusterRank.
* Make BlockTransferDescriptor concept parametrized. Introduce a common TileTransferParameters concept for conv algorithms.
* clang-format
---------
Co-authored-by: Ville Pietilä <>
[ROCm/composable_kernel commit: 9908a87c31]
Composable Kernel Builder Design Documentation
This directory contains the builder framework for Composable Kernel, which provides a compile-time, type-safe interface for constructing convolution operations with various configurations.
Table of Contents
Convolution Signature
Overview
The convolution signature system provides a compile-time description of grouped convolution operations. A signature is a collection of properties that fully characterize a convolution kernel's mathematical and operational behavior, enabling:
- Compile-time validation: Ensures type safety and correctness before kernel instantiation
- Kernel selection: Matches user requirements to optimized implementations
- Specialization: Enables optimized code paths for specific configurations
- Composability: Supports building complex operations from simpler components
The signature leverages modern C++20 features, particularly concepts, to provide expressive, self-documenting interfaces with compile-time guarantees.
Architecture
The signature system is organized into a hierarchical structure:
┌─────────────────────────────────────────────────────────┐
│ ConvSignature │
├─────────────────────────────────────────────────────────┤
│ Properties: │
│ • spatial_dim: int (1D, 2D, or 3D) │
│ • direction: ConvDirection (Fwd/BwdData/BwdWeight) │
│ • data_type: DataType (default data type) │
│ • accumulation_data_type: DataType │
│ • input: ConvTensor ──┐ │
│ • weight: ConvTensor ──│ │
│ • output: ConvTensor ──│ │
└──────────────────────────────────┼──────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ ConvTensor │
├─────────────────────────────────────────┤
│ ╔═════════════════════════════════════╗ │
│ ║ TensorConfig (required) ║ │
│ ╠═════════════════════════════════════╣ │
│ ║ • layout: ConvLayout ║ │
│ ║ • data_type: DataType (optional) ║ │
│ ║ • compute_type: DataType (optional)║ │
│ ╚═════════════════════════════════════╝ │
│ │
│ ┌─────────────────────────────────────┐ │
│ │ TensorOperation (optional) │ │
│ ├─────────────────────────────────────┤ │
│ │ • elementwise_operation │ │
│ │ • auxiliary_operand_configs[] │ │
│ │ (each is also ConvTensor) ◄───────┼─┐
│ └─────────────────────────────────────┘ │ │
└─────────────────────────────────────────┘ │
│
Recursive ───────────────┘
Key Design Points:
- ConvSignature contains three ConvTensor instances (input, weight, output)
- All tensors share the same ConvTensor structure
- Each ConvTensor has:
- TensorConfig (required): Defines layout as well as optional data and compute type overrides
- TensorOperation (optional): Defines fused elementwise operations
- Auxiliary operands (e.g., bias) in TensorOperation also use the ConvTensor type
Core Components
1. Signature Level
The top-level signature contains global properties that apply to the entire convolution operation:
template <typename T>
concept ConvSignatureDescriptor = requires(T t) {
{ t.spatial_dim } -> std::convertible_to<unsigned int>; // 1, 2, or 3
{ t.input } -> ConvTensorDescriptor;
{ t.weight } -> ConvTensorDescriptor;
{ t.output } -> ConvTensorDescriptor;
requires ConvolutionDirectionWellDefinedIfProvided<T>; // Optional direction
requires detail::DataTypeWellDefinedIfProvided<T>; // Optional default data type
requires detail::ElementwiseOpWellDefinedIfProvided<T>; // Optional default elementwise operation
};
Properties:
spatial_dim: Dimensionality of the convolution (1D, 2D, or 3D)direction: Operation type (Optional, defaults to FORWARD)FORWARD: Standard forward convolutionBACKWARD_DATA: Gradient computation w.r.t. inputBACKWARD_WEIGHT: Gradient computation w.r.t. weights
data_type: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8). (Optional, defaults to UNDEFINED_DATA_TYPE, may be overridden by tensors)operation: Default Operation (Optional, defaults to PASS_THROUGH, may be overridden by tensors)accumulation_data_type: Type used for internal accumulation
2. Tensor Level
Each tensor (input, weight, output) has its own descriptor:
template <typename T>
concept ConvTensorDescriptor = requires(T t) {
{ t.config } -> TensorConfigDescriptor;
requires ElementwiseOpWellDefinedIfProvided<T>;
};
A tensor descriptor encapsulates:
- Configuration: Layout and data type information
- operation Fused elementwise operations on this tensor (Optional, default provided by ConvSignatureDescriptor)
3. Tensor Configuration
Describes the memory layout and data types:
template <typename T>
concept TensorConfigDescriptor = requires(T t) {
{ t.layout } -> std::convertible_to<ConvLayout>;
requires detail::DataTypeWellDefinedIfProvided<T>; // Override data type (Optional, default provided by ConvSignatureDescriptor)
};
Layout Types (dimension-specific):
-
1D Convolution:
- Input:
GNCW,GNWC,NWGC,NGCW,G_NW_C_strided - Weight:
GKXC,GKCX,KXGC,G_K_X_C_strided - Output:
GNKW,GNWK,NWGK,NGKW,G_NW_K_strided
- Input:
-
2D Convolution:
- Input:
GNCHW,GNHWC,NHWGC,NGCHW,G_NHW_C_strided - Weight:
GKYXC,GKCYX,KYXGC,G_K_YX_C_strided - Output:
GNKHW,GNHWK,NHWGK,NGKHW,G_NHW_K_strided
- Input:
-
3D Convolution:
- Input:
GNCDHW,GNDHWC,NDHWGC,NGCDHW,G_NDHW_C_strided - Weight:
GKZYXC,GKCZYX,KZYXGC,G_K_ZYX_C_strided - Output:
GNKDHW,GNDHWK,NDHWGK,NGKDHW,G_NDHW_K_strided
- Input:
Where:
G= GroupsN= Batch sizeC= Input channelsK= Output channels (filters)W,H,D= Width, Height, Depth (spatial dimensions)X,Y,Z= Filter dimensions
4. Tensor Operations
Describes fused elementwise operations applied to a tensor:
template <typename T>
concept TensorOperatorDescriptor = requires(T t) {
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
requires AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
};
Supported Operations:
PASS_THROUGH: No operation (identity)SCALE: Multiply by a scalarCLAMP: Clamp values to a rangeBIAS_BNORM_CLAMP: Bias addition + batch normalization + clampSCALEADD_SCALEADD_RELU: Fused scale-add operations + ReLU activation
Auxiliary Operands:
Some operations require additional tensor inputs (e.g., bias tensors, scaling factors). These are specified through auxiliary_operand_configs, which is an array of TensorConfigDescriptor objects describing the layout and data type of each auxiliary input.
Concepts and Validation
The signature system uses C++20 concepts for compile-time validation at multiple levels:
Constraint Concepts
// Spatial dimension must be 1, 2, or 3
template <auto N>
concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 || N == 3);
// Valid data types for convolution
template <DataType T>
concept ValidConvDataType =
(T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
Validation Concept
// Validates a complete signature
template <auto Sig>
concept ValidConvSignature = requires {
requires ConvSpatialDim<Sig.spatial_dim>;
requires ValidConvDataType<Sig.data_type>;
};
Tensor Descriptors
The layout/data type/elementwise operation are described per tensor. This multi-level hierarchy allows:
- Flexibility: Each tensor can have independent layout and data type
- Reusability: Common configurations can be shared across different signatures
- Extensibility: New properties can be added to specific levels without affecting others
- Clarity: Separates concerns (global properties vs. tensor-specific properties)
Optional Signature Fields
Several fields in the signature are optional:
direction: Defaults toFORWARDif not specified, reducing boilerplate for the common case- Tensor
data_type: Falls back to signature's default, allowing mixed-precision with minimal specification - Tensor
operation: Defaults toPASS_THROUGH, supporting both fused and non-fused operations with the same interface
This design follows the principle of "make the common case simple, the complex case possible."
Convolution Algorithm
Convolution Factory
Convolution factory builds the instance based on the convolution signature and convolution algorithm. The signature and the algorithm descriptions are dispatched to the relevant algorithm specific factory for instance creation. The convolution factory design is described in a separate Readme.