[CK_BUILDER] Update developer notes in the CK Builder source directories (#5038) ## Motivation This PR updates the developer notes for the CK Tile builder. It captures the current state of the implementation in more detail, and frames the description around the need to have true facade. There is no functional change, only better alignment of developer notes with the current code. This doc clearly explains the current technical debt: that we have created many facades that expose the implementation details. There is an expanded section on reflection that explains how unified reflection will help clarify the unified builder design. Additional changes are just better accounting for the current state of the code, including previously undocumented operations. A few typos and cosmetic issues are cleaned up, too.
16 KiB
Composable Kernel Builder Design Documentation
This directory contains the builder framework for Composable Kernel, which provides a compile-time, type-safe interface for constructing convolution operations with various configurations.
Table of Contents
Convolution Signature
Overview
The convolution signature system provides a compile-time description of grouped convolution operations. A signature is a collection of properties that fully characterize a convolution kernel's mathematical and operational behavior, enabling:
- Compile-time validation: Ensures type safety and correctness before kernel instantiation
- Kernel selection: Matches user requirements to optimized implementations
- Specialization: Enables optimized code paths for specific configurations
- Composability: Supports building complex operations from simpler components
The signature leverages modern C++20 features, particularly concepts, to provide expressive, self-documenting interfaces with compile-time guarantees.
Architecture
The signature system is organized into a hierarchical structure:
┌─────────────────────────────────────────────────────────┐
│ ConvSignature │
├─────────────────────────────────────────────────────────┤
│ Properties: │
│ • spatial_dim: int (1D, 2D, or 3D) │
│ • direction: ConvDirection (Fwd/BwdData/BwdWeight) │
│ • data_type: DataType (default data type) │
│ • accumulation_data_type: DataType │
│ • input: ConvTensor ──┐ │
│ • weight: ConvTensor ──│ │
│ • output: ConvTensor ──│ │
└──────────────────────────────────┼──────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ ConvTensor │
├─────────────────────────────────────────┤
│ ╔═════════════════════════════════════╗ │
│ ║ TensorConfig (required) ║ │
│ ╠═════════════════════════════════════╣ │
│ ║ • layout: TensorLayout ║ │
│ ║ • data_type: DataType (optional) ║ │
│ ╚═════════════════════════════════════╝ │
│ │
│ ┌─────────────────────────────────────┐ │
│ │ TensorOperation (optional) │ │
│ ├─────────────────────────────────────┤ │
│ │ • elementwise_operation │ │
│ │ • auxiliary_operand_configs[] │ │
│ │ (each is also ConvTensor) ◄───────┼─┐
│ └─────────────────────────────────────┘ │ │
└─────────────────────────────────────────┘ │
│
Recursive ───────────────┘
Key Design Points:
- ConvSignature contains three ConvTensor instances (input, weight, output)
- All tensors share the same ConvTensor structure
- Each ConvTensor has:
- TensorConfig (required): Defines layout as well as optional data and compute type overrides
- TensorOperation (optional): Defines fused elementwise operations
- Auxiliary operands (e.g., bias) in TensorOperation also use the ConvTensor type
Core Components
1. Signature Level
The top-level signature contains global properties that apply to the entire convolution operation:
template <typename T>
concept ConvSignatureDescriptor = requires(T t) {
{ t.spatial_dim } -> std::convertible_to<unsigned int>; // 1, 2, or 3
{ t.input } -> ConvTensorDescriptor;
{ t.weight } -> ConvTensorDescriptor;
{ t.output } -> ConvTensorDescriptor;
requires ConvolutionDirectionWellDefinedIfProvided<T>; // Optional direction
requires detail::DataTypeWellDefinedIfProvided<T>; // Optional default data type
requires detail::ElementwiseOpWellDefinedIfProvided<T>; // Optional default elementwise operation
};
Properties:
spatial_dim: Dimensionality of the convolution (1D, 2D, or 3D)direction: Operation type (Optional, defaults to FORWARD)FORWARD: Standard forward convolutionBACKWARD_DATA: Gradient computation w.r.t. inputBACKWARD_WEIGHT: Gradient computation w.r.t. weights
data_type: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8). (Optional, defaults to UNDEFINED_DATA_TYPE which indicates the type should be inferred or specified per-tensor, may be overridden by individual tensors)elementwise_operation: Default elementwise operation for all tensors (Optional, defaults to PASS_THROUGH, may be overridden by individual tensors via theiroperationfield)accumulation_data_type: Type used for internal accumulation
2. Tensor Level
Each tensor (input, weight, output) has its own descriptor:
template <typename T>
concept ConvTensorDescriptor = requires(T t) {
{ t.config } -> TensorConfigDescriptor;
requires ElementwiseOpWellDefinedIfProvided<T>;
};
A tensor descriptor encapsulates:
- Configuration: Layout and data type information
- operation Fused elementwise operations on this tensor (Optional, default provided by ConvSignatureDescriptor)
3. Tensor Configuration
Describes the memory layout and data types:
template <typename T>
concept TensorConfigDescriptor = requires(T t) {
{ t.layout } -> std::convertible_to<TensorLayout>;
requires detail::DataTypeWellDefinedIfProvided<T>; // Override data type (Optional, default provided by ConvSignatureDescriptor)
};
Layout Types (dimension-specific):
-
Special Values:
UNDEFINED_TENSOR_LAYOUT: Placeholder value indicating layout is not yet specified or should be inferred
-
1D Convolution:
- Input:
GNCW,GNWC,NWGC,NGCW,G_NW_C_strided - Weight:
GKXC,GKCX,KXGC,G_K_X_C_strided - Output:
GNKW,GNWK,NWGK,NGKW,G_NW_K_strided
- Input:
-
2D Convolution:
- Input:
GNCHW,GNHWC,NHWGC,NGCHW,G_NHW_C_strided - Weight:
GKYXC,GKCYX,KYXGC,G_K_YX_C_strided - Output:
GNKHW,GNHWK,NHWGK,NGKHW,G_NHW_K_strided
- Input:
-
3D Convolution:
- Input:
GNCDHW,GNDHWC,NDHWGC,NGCDHW,G_NDHW_C_strided - Weight:
GKZYXC,GKCZYX,KZYXGC,G_K_ZYX_C_strided - Output:
GNKDHW,GNDHWK,NDHWGK,NGKDHW,G_NDHW_K_strided
- Input:
-
Bias Tensors:
GC,G_C_strided,G_K_strided
Where:
G= GroupsN= Batch sizeC= Input channelsK= Output channels (filters)W,H,D= Width, Height, Depth (spatial dimensions)X,Y,Z= Filter dimensions
4. Tensor Operations
Describes fused elementwise operations applied to a tensor:
template <typename T>
concept TensorOperatorDescriptor = requires(T t) {
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
requires AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
};
Supported Operations:
The ElementwiseOperation enum in types.hpp defines 35 operations:
- Identity:
PASS_THROUGH - Scaling and arithmetic:
SCALE,SCALE_ADD,CLAMP,ADD_CLAMP,BILINEAR - Convolution-specific scaling:
CONV_SCALE,CONV_SCALE_ADD,CONV_SCALE_RELU,CONV_INVSCALE - Activations:
RELU,LEAKY_RELU,CLIPPED_RELU,SOFT_RELU,GELU,SILU,SIGMOID,TANH,ELU,SWISH,LOGISTIC,POWER,UNARY_ABS - Composite fused operations:
BIAS_BNORM_CLAMP,SCALEADD_SCALEADD_RELU,ADD_RELU_ADD,ACTIVATION_MUL_CLAMP,ACTIVATION_MUL2_CLAMP,ADD_ACTIVATION_MUL_CLAMP,ADD_ACTIVATION_MUL2_CLAMP,ADD_MUL_ACTIVATION_MUL_CLAMP,ADD_MUL2_ACTIVATION_MUL_CLAMP - Dynamic and generic:
DYNAMIC_UNARY_OP,UNARY_COMBINED_OP,UNARY_CONVERT
Auxiliary Operands:
Some operations require additional tensor inputs (e.g., bias tensors, scaling factors). These are specified through auxiliary_operand_configs, which is an array of TensorConfigDescriptor objects describing the layout and data type of each auxiliary input.
Concepts and Validation
The signature system uses C++20 concepts for compile-time validation at multiple levels:
Constraint Concepts
// Spatial dimension must be 1, 2, or 3
template <auto N>
concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 || N == 3);
// Valid data types for convolution
template <DataType T>
concept ValidConvDataType =
(T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
Validation Concept
// Validates a complete signature
template <auto Sig>
concept ValidConvSignature = requires {
requires ConvSpatialDim<Sig.spatial_dim>;
requires ValidConvDataType<Sig.data_type>;
};
Tensor Descriptors
The layout/data type/elementwise operation are described per tensor. This multi-level hierarchy allows:
- Flexibility: Each tensor can have independent layout and data type
- Reusability: Common configurations can be shared across different signatures
- Extensibility: New properties can be added to specific levels without affecting others
- Clarity: Separates concerns (global properties vs. tensor-specific properties)
Optional Signature Fields
Several fields in the signature are optional:
direction: Defaults toFORWARDif not specified, reducing boilerplate for the common case- Tensor
data_type: Falls back to signature's default, allowing mixed-precision with minimal specification - Tensor
operation: Defaults toPASS_THROUGH, supporting both fused and non-fused operations with the same interface
This design follows the principle of "make the common case simple, the complex case possible."
Convolution Algorithm
The algorithm descriptor specifies how a convolution is computed — the implementation strategy including tile sizes, hardware instruction variant, pipeline scheduling, and memory access patterns. It is the complement to the signature, which specifies what is computed.
Algorithm Descriptor Concept
An algorithm descriptor is any struct satisfying the ConvAlgorithmDescriptor concept (conv_algorithm_concepts.hpp). The required fields depend on the target kernel variant. The dispatcher (conv_dispatcher.hpp) uses predicate concepts to classify each algorithm descriptor into one of the supported variants:
- ReferenceAlgorithm: Requires only a
specializationfield set toREFERENCE. Used for correctness validation. - TileAlgorithm: CK Tile backend. Requires tile-level configuration: block shape, warp tile, block GEMM pipeline, transfer vectorization, and optimizations.
- Forward-specific (old CK): XDL V3, XDL, WMMA, DL, Large Tensor. Each requires progressively different fields (thread block, GEMM config, transfer, scheduling, prefetch stages).
- Backward weight-specific (old CK): XDL, XDL V3, Two-Stage XDL, DL, Multi-D XDL, WMMA V3, Two-Stage WMMA V3, WMMA, Multi-D WMMA V3.
The ConvAlgorithmSpecialization enum provides broad algorithm classes (REFERENCE, LARGE_TENSOR, TWO_STAGE, MULTIPLE_D) for requesting a category of algorithm without specifying the full descriptor.
Algorithm Descriptor Fragmentation
The builder currently requires a different algorithm descriptor shape for each kernel variant. This fragmentation exists along three axes:
-
Backend (CK vs CK Tile): The old CK backend flattens ~49 template parameters into a single device operation type (explicit thread block dimensions, block transfer descriptors with LDS configurations, thread cluster arrangements, per-tensor access orders). The CK Tile backend composes higher-level objects — tile partitioner, GEMM pipeline, epilogue pipeline — with ~31 parameters distributed across four composed types.
-
Instruction set (MFMA vs WMMA): Within the old CK backend, XDL (MFMA) and WMMA variants require different algorithm descriptor fields. The dispatcher uses separate predicate concepts (
FwdXdlAlgorithmvsFwdWmmaAlgorithm) to classify them, and separate factories to instantiate them. -
Direction (forward vs backward weight vs backward data): Each direction has its own set of factories. Backward weight alone has 9 old CK factory variants (XDL, XDL V3, two-stage XDL, DL, multi-D XDL, WMMA V3, two-stage WMMA V3, WMMA, multi-D WMMA V3).
The result is 16+ per-variant factories, each accepting a different algorithm descriptor shape. MIOpen must currently know which variant it wants and construct the matching descriptor — the builder dispatches but does not unify.
The path toward a single algorithm descriptor runs through the reflection system. ConvTraits already provides a common representation for old CK instances across both MFMA and WMMA. Extending this to CK Tile instances will reveal which parameters are genuinely variant-specific versus which can be expressed in a single descriptor and mapped to multiple backends by the dispatcher. See the reflection documentation for the current state of this bridge.
Convolution Factory
The factory system translates a (signature, algorithm) pair into a concrete kernel instance. The entry point is make_conv_instance<SIGNATURE, ALGORITHM, VERSION>() in conv_dispatcher.hpp.
The dispatch proceeds in two phases:
-
Algorithm classification: Predicate concepts (
ReferenceAlgorithm,TileAlgorithm,FwdXdlV3Algorithm, etc.) inspect the algorithm descriptor's structure to determine which kernel variant it satisfies. -
Direction routing: An
if constexprchain routes to the appropriate factory based on convolution direction (forward, backward data, backward weight) and classified algorithm type.
Each factory (e.g., ConvFwdXdlV3Factory, ConvBwdWeightWmmaV3Factory) transforms builder descriptors into the underlying device operation's template parameters and instantiates the kernel.
The factory design is described in detail in the factory README.