mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#5038 (commit 6e74de7)
[CK_BUILDER] Update developer notes in the CK Builder source directories (#5038) ## Motivation This PR updates the developer notes for the CK Tile builder. It captures the current state of the implementation in more detail, and frames the description around the need to have true facade. There is no functional change, only better alignment of developer notes with the current code. This doc clearly explains the current technical debt: that we have created many facades that expose the implementation details. There is an expanded section on reflection that explains how unified reflection will help clarify the unified builder design. Additional changes are just better accounting for the current state of the code, including previously undocumented operations. A few typos and cosmetic issues are cleaned up, too.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
f1746955fd
commit
54861f1f49
@@ -10,11 +10,26 @@ The builder provides a high-level, semantically-clear interface for constructing
|
||||
|
||||
This project is a prototype for a more general builder pattern for all of composable_kernel (CK) and CK Tile, but is currently limited to formalizing the interface between MIOpen and CK.
|
||||
|
||||
## Design Direction
|
||||
|
||||
The builder's primary goal is transparent dispatch across two backend implementations: old CK (template-heavy device operations) and CK Tile (modern tile-based API). MIOpen, the consumer library, should construct kernels through the builder without needing to know which backend provides the implementation.
|
||||
|
||||
**Current state:** The builder dispatches correctly, but each kernel variant (forward XDL, forward WMMA, backward weight XDL V3, etc.) has its own factory and its own algorithm descriptor shape. The result is 16+ per-variant facades rather than one unified facade. Unification across three axes — CK vs CK Tile backend, MFMA vs WMMA instruction set, and forward vs backward direction — is the central design challenge.
|
||||
|
||||
Three principles guide the design toward that unification:
|
||||
|
||||
1. **Unified vocabulary through reflection.** The reflection system (`reflect/`) extracts kernel traits from both backends into a common `ConvTraits` representation. This shared vocabulary is the mechanism for discovering what algorithm parameters are truly variant-specific versus what can be expressed once and mapped to multiple backends.
|
||||
|
||||
2. **Expert overrides.** Power users can pin to a specific backend or device operation when needed, bypassing automatic dispatch.
|
||||
|
||||
3. **Versioned API evolution.** The builder uses semantic version strings (`"0.0.0"`, `"0.1.0"`) to manage API changes predictably. The `ConvBuilder` template defaults to the latest version but accepts explicit version pinning.
|
||||
|
||||
## Design descriptions
|
||||
|
||||
- [CK Builder design description](include/ck_tile/builder/README.md)
|
||||
- [CK Builder factory design](include/ck_tile/builder/factory/README.md)
|
||||
- [CK Builder testing design](include/ck_tile/builder/testing/README.md)
|
||||
- [CK Builder reflection design](include/ck_tile/builder/reflect/README.md)
|
||||
|
||||
## Directory Structure
|
||||
|
||||
@@ -23,7 +38,7 @@ This project is a prototype for a more general builder pattern for all of compos
|
||||
- `include/ck_tile/builder/reflect`
|
||||
Reflection mechanism.
|
||||
- `include/ck_tile/builder/factory`
|
||||
Compile-time dispatch from builder descriptors to our exisitng specialized convolution kernel implementations.
|
||||
Compile-time dispatch from builder descriptors to our existing specialized convolution kernel implementations.
|
||||
- `test/`
|
||||
Unit tests and example usage of the builder pattern.
|
||||
- `CMakeLists.txt`
|
||||
@@ -62,8 +77,7 @@ ninja smoke-builder
|
||||
```
|
||||
|
||||
### Regression Tests (Integration Tests)
|
||||
Integration tests that compile actual GPU kernels to verify that the builder generates valid, compilable code. These are more expensive than smoke tests (can take minutes to compile) but cover more fuctionality.
|
||||
)
|
||||
Integration tests that compile actual GPU kernels to verify that the builder generates valid, compilable code. These are more expensive than smoke tests (can take minutes to compile) but cover more functionality.
|
||||
|
||||
```sh
|
||||
ninja regression-builder
|
||||
|
||||
@@ -51,9 +51,8 @@ The signature system is organized into a hierarchical structure:
|
||||
│ ╔═════════════════════════════════════╗ │
|
||||
│ ║ TensorConfig (required) ║ │
|
||||
│ ╠═════════════════════════════════════╣ │
|
||||
│ ║ • layout: ConvLayout ║ │
|
||||
│ ║ • layout: TensorLayout ║ │
|
||||
│ ║ • data_type: DataType (optional) ║ │
|
||||
│ ║ • compute_type: DataType (optional)║ │
|
||||
│ ╚═════════════════════════════════════╝ │
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────┐ │
|
||||
@@ -127,7 +126,7 @@ Describes the memory layout and data types:
|
||||
```cpp
|
||||
template <typename T>
|
||||
concept TensorConfigDescriptor = requires(T t) {
|
||||
{ t.layout } -> std::convertible_to<ConvLayout>;
|
||||
{ t.layout } -> std::convertible_to<TensorLayout>;
|
||||
requires detail::DataTypeWellDefinedIfProvided<T>; // Override data type (Optional, default provided by ConvSignatureDescriptor)
|
||||
};
|
||||
```
|
||||
@@ -175,11 +174,15 @@ concept TensorOperatorDescriptor = requires(T t) {
|
||||
```
|
||||
|
||||
**Supported Operations:**
|
||||
- `PASS_THROUGH`: No operation (identity)
|
||||
- `SCALE`: Multiply by a scalar
|
||||
- `CLAMP`: Clamp values to a range
|
||||
- `BIAS_BNORM_CLAMP`: Bias addition + batch normalization + clamp
|
||||
- `SCALEADD_SCALEADD_RELU`: Fused scale-add operations + ReLU activation
|
||||
|
||||
The `ElementwiseOperation` enum in `types.hpp` defines 35 operations:
|
||||
|
||||
- **Identity**: `PASS_THROUGH`
|
||||
- **Scaling and arithmetic**: `SCALE`, `SCALE_ADD`, `CLAMP`, `ADD_CLAMP`, `BILINEAR`
|
||||
- **Convolution-specific scaling**: `CONV_SCALE`, `CONV_SCALE_ADD`, `CONV_SCALE_RELU`, `CONV_INVSCALE`
|
||||
- **Activations**: `RELU`, `LEAKY_RELU`, `CLIPPED_RELU`, `SOFT_RELU`, `GELU`, `SILU`, `SIGMOID`, `TANH`, `ELU`, `SWISH`, `LOGISTIC`, `POWER`, `UNARY_ABS`
|
||||
- **Composite fused operations**: `BIAS_BNORM_CLAMP`, `SCALEADD_SCALEADD_RELU`, `ADD_RELU_ADD`, `ACTIVATION_MUL_CLAMP`, `ACTIVATION_MUL2_CLAMP`, `ADD_ACTIVATION_MUL_CLAMP`, `ADD_ACTIVATION_MUL2_CLAMP`, `ADD_MUL_ACTIVATION_MUL_CLAMP`, `ADD_MUL2_ACTIVATION_MUL_CLAMP`
|
||||
- **Dynamic and generic**: `DYNAMIC_UNARY_OP`, `UNARY_COMBINED_OP`, `UNARY_CONVERT`
|
||||
|
||||
**Auxiliary Operands:**
|
||||
Some operations require additional tensor inputs (e.g., bias tensors, scaling factors). These are specified through `auxiliary_operand_configs`, which is an array of `TensorConfigDescriptor` objects describing the layout and data type of each auxiliary input.
|
||||
@@ -232,7 +235,43 @@ This design follows the principle of "make the common case simple, the complex c
|
||||
|
||||
## Convolution Algorithm
|
||||
|
||||
The algorithm descriptor specifies **how** a convolution is computed — the implementation strategy including tile sizes, hardware instruction variant, pipeline scheduling, and memory access patterns. It is the complement to the signature, which specifies **what** is computed.
|
||||
|
||||
### Algorithm Descriptor Concept
|
||||
|
||||
An algorithm descriptor is any struct satisfying the `ConvAlgorithmDescriptor` concept (`conv_algorithm_concepts.hpp`). The required fields depend on the target kernel variant. The dispatcher (`conv_dispatcher.hpp`) uses predicate concepts to classify each algorithm descriptor into one of the supported variants:
|
||||
|
||||
- **ReferenceAlgorithm**: Requires only a `specialization` field set to `REFERENCE`. Used for correctness validation.
|
||||
- **TileAlgorithm**: CK Tile backend. Requires tile-level configuration: block shape, warp tile, block GEMM pipeline, transfer vectorization, and optimizations.
|
||||
- **Forward-specific** (old CK): XDL V3, XDL, WMMA, DL, Large Tensor. Each requires progressively different fields (thread block, GEMM config, transfer, scheduling, prefetch stages).
|
||||
- **Backward weight-specific** (old CK): XDL, XDL V3, Two-Stage XDL, DL, Multi-D XDL, WMMA V3, Two-Stage WMMA V3, WMMA, Multi-D WMMA V3.
|
||||
|
||||
The `ConvAlgorithmSpecialization` enum provides broad algorithm classes (`REFERENCE`, `LARGE_TENSOR`, `TWO_STAGE`, `MULTIPLE_D`) for requesting a category of algorithm without specifying the full descriptor.
|
||||
|
||||
### Algorithm Descriptor Fragmentation
|
||||
|
||||
The builder currently requires a different algorithm descriptor shape for each kernel variant. This fragmentation exists along three axes:
|
||||
|
||||
1. **Backend** (CK vs CK Tile): The old CK backend flattens ~49 template parameters into a single device operation type (explicit thread block dimensions, block transfer descriptors with LDS configurations, thread cluster arrangements, per-tensor access orders). The CK Tile backend composes higher-level objects — tile partitioner, GEMM pipeline, epilogue pipeline — with ~31 parameters distributed across four composed types.
|
||||
|
||||
2. **Instruction set** (MFMA vs WMMA): Within the old CK backend, XDL (MFMA) and WMMA variants require different algorithm descriptor fields. The dispatcher uses separate predicate concepts (`FwdXdlAlgorithm` vs `FwdWmmaAlgorithm`) to classify them, and separate factories to instantiate them.
|
||||
|
||||
3. **Direction** (forward vs backward weight vs backward data): Each direction has its own set of factories. Backward weight alone has 9 old CK factory variants (XDL, XDL V3, two-stage XDL, DL, multi-D XDL, WMMA V3, two-stage WMMA V3, WMMA, multi-D WMMA V3).
|
||||
|
||||
The result is 16+ per-variant factories, each accepting a different algorithm descriptor shape. MIOpen must currently know which variant it wants and construct the matching descriptor — the builder dispatches but does not unify.
|
||||
|
||||
The path toward a single algorithm descriptor runs through the reflection system. `ConvTraits` already provides a common representation for old CK instances across both MFMA and WMMA. Extending this to CK Tile instances will reveal which parameters are genuinely variant-specific versus which can be expressed in a single descriptor and mapped to multiple backends by the dispatcher. See the [reflection documentation](reflect/README.md) for the current state of this bridge.
|
||||
|
||||
## Convolution Factory
|
||||
|
||||
Convolution factory builds the instance based on the convolution signature and convolution algorithm.
|
||||
The signature and the algorithm descriptions are dispatched to the relevant algorithm specific factory for instance creation. The convolution factory design is described in a separate [Readme](factory/README.md).
|
||||
The factory system translates a (signature, algorithm) pair into a concrete kernel instance. The entry point is `make_conv_instance<SIGNATURE, ALGORITHM, VERSION>()` in `conv_dispatcher.hpp`.
|
||||
|
||||
The dispatch proceeds in two phases:
|
||||
|
||||
1. **Algorithm classification**: Predicate concepts (`ReferenceAlgorithm`, `TileAlgorithm`, `FwdXdlV3Algorithm`, etc.) inspect the algorithm descriptor's structure to determine which kernel variant it satisfies.
|
||||
|
||||
2. **Direction routing**: An `if constexpr` chain routes to the appropriate factory based on convolution direction (forward, backward data, backward weight) and classified algorithm type.
|
||||
|
||||
Each factory (e.g., `ConvFwdXdlV3Factory`, `ConvBwdWeightWmmaV3Factory`) transforms builder descriptors into the underlying device operation's template parameters and instantiates the kernel.
|
||||
|
||||
The factory design is described in detail in the [factory README](factory/README.md).
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Convolution Builder Factory Directory
|
||||
|
||||
This directory implements compile-time dispatch from high-level signature algorithm descriptors to our exisitng specialized convolution kernel implementations.
|
||||
This directory implements compile-time dispatch from high-level signature and algorithm descriptors to our existing specialized convolution kernel implementations.
|
||||
|
||||
See the [main builder documentation](../README.md) for an overview.
|
||||
|
||||
@@ -8,24 +8,61 @@ See the [main builder documentation](../README.md) for an overview.
|
||||
|
||||
The factory system operates in two phases:
|
||||
|
||||
1. **Algorithm Classification**: The function `make_conv_instance` in `conv_dispatcher.hpp` inspects the signature and algorithm descriptors to determine which kernel variant they satisfy (XDL V3, XDL, WMMA, DL, or Large Tensor)
|
||||
1. **Algorithm Classification**: Predicate concepts in `conv_dispatcher.hpp` inspect the algorithm descriptor to determine which kernel variant it satisfies. The predicates are evaluated in a specific order using `if constexpr`:
|
||||
|
||||
2. **Factory Instantiation**: Each factory (`conv_fwd_*_factory.hpp`) transforms builder descriptors into CK device operation template parameters and instantiates the corresponding kernel device operation.
|
||||
- **Cross-direction** (checked first, supports all convolution directions):
|
||||
- `ReferenceAlgorithm` — simple reference implementation for validation
|
||||
- `TileAlgorithm` — CK Tile backend, dispatches via `ConvTileFactory`
|
||||
|
||||
- **Forward direction** (old CK):
|
||||
- `FwdXdlV3Algorithm` — newer XDL pipeline using block GEMM structure
|
||||
- `FwdXdlAlgorithm` — standard XDL using AMD XDLops instructions
|
||||
- `FwdWmmaAlgorithm` — WMMA variant for gfx11/gfx12 hardware
|
||||
- `FwdDlAlgorithm` — vectorized dot-product kernel (non-XDLops)
|
||||
- `LargeTensorAlgorithm` — XDL with extended tensor support
|
||||
|
||||
- **Backward weight direction** (old CK):
|
||||
- `BwdXdlAlgorithm`, `BwdXdlV3Algorithm`, `BwdTwoStageXdlAlgorithm`, `BwdDlAlgorithm`, `BwdMultiDXdlAlgorithm`, `BwdWmmaV3Algorithm`, `BwdTwoStageWmmaV3Algorithm`, `BwdWmmaAlgorithm`, `BwdMultiDWmmaV3Algorithm`
|
||||
|
||||
- **Backward data direction**: Currently supports only Reference and Tile algorithms. Optimized old CK kernels are not yet implemented.
|
||||
|
||||
2. **Factory Instantiation**: Each factory transforms builder descriptors into backend-specific template parameters and instantiates the corresponding kernel.
|
||||
|
||||
## Key Files
|
||||
|
||||
- **`conv_dispatcher.hpp`**: Entry point with `make_conv_instance()` function. Contains dispatch logic and algorithm classification predicates. **Start here** to understand the overall flow.
|
||||
|
||||
- **`conv_fwd_*_factory.hpp`**: Individual factories for each kernel variant. Each extracts configuration from descriptors, validates parameters, and instantiates the underlying CK device operation.
|
||||
- **Forward factories** (old CK):
|
||||
`conv_fwd_v3_factory.hpp`, `conv_fwd_xdl_factory.hpp`, `conv_fwd_wmma_factory.hpp`, `conv_fwd_dl_factory.hpp`, `conv_fwd_large_tensor_factory.hpp`
|
||||
|
||||
- **`helpers/`**: Transformation utilities that map builder types to CK device operation parameters (layouts, data types, elementwise ops, block configurations, etc.)
|
||||
- **Backward weight factories** (old CK):
|
||||
`conv_bwd_weight_xdl_factory.hpp`, `conv_bwd_weight_xdl_v3_factory.hpp`, `conv_bwd_weight_two_stage_xdl_factory.hpp`, `conv_bwd_weight_dl_factory.hpp`, `conv_bwd_weight_multi_d_xdl_factory.hpp`, `conv_bwd_weight_wmma_v3_factory.hpp`, `conv_bwd_weight_two_stage_wmma_v3_factory.hpp`, `conv_bwd_weight_wmma_factory.hpp`, `conv_bwd_weight_multi_d_wmma_v3_factory.hpp`
|
||||
|
||||
- **Cross-direction factories**:
|
||||
`reference_factory.hpp` (reference implementation), `conv_tile_factory.hpp` (CK Tile backend)
|
||||
|
||||
- **`helpers/`**: Transformation utilities that map builder types to backend-specific parameters. Organized into `helpers/ck/` (old CK mappings) and `helpers/ck_tile/` (CK Tile mappings).
|
||||
|
||||
## Usage
|
||||
|
||||
```cpp
|
||||
#include "ck_tile/builder/factory/conv_dispatcher.hpp"
|
||||
|
||||
using Factory = decltype(make_conv_instance<signature, algorithm, "v1">());
|
||||
// Uses latest version by default (currently "0.1.0")
|
||||
auto kernel = make_conv_instance<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Or pin to a specific version
|
||||
auto kernel_v0 = make_conv_instance<SIGNATURE, ALGORITHM, "0.0.0">();
|
||||
```
|
||||
|
||||
The dispatcher automatically selects the appropriate factory following explicit logic.
|
||||
The dispatcher automatically selects the appropriate factory at compile time.
|
||||
|
||||
## Factory Architecture and the Unification Gap
|
||||
|
||||
Each factory is a self-contained facade: it accepts builder descriptors and produces a kernel instance, but it does so with its own algorithm descriptor shape and its own parameter mapping logic. The 16+ factories share no common infrastructure for parameter transformation.
|
||||
|
||||
**Old CK factories** (e.g., `ConvFwdXdlV3Factory`) flatten all algorithm parameters into a single device operation template instantiation with approximately 49 template arguments. The factory's primary job is mapping builder enum values (layouts, data types, elementwise ops) to CK's internal types. Within old CK, the XDL and WMMA factories duplicate much of this mapping logic despite sharing the same underlying parameter concepts.
|
||||
|
||||
**The CK Tile factory** (`ConvTileFactory`) composes modern objects — a traits type, a tile partitioner, a GEMM pipeline, and an epilogue pipeline — each with its own configuration. This results in approximately 31 parameters distributed across four composed types rather than one flat template.
|
||||
|
||||
Both factory paths produce a kernel `Instance` type that satisfies the same usage interface (construction, argument setup, invocation). The dispatcher abstracts this difference from the caller. However, the algorithm descriptor accepted by each factory is different — the unification burden currently falls on the caller (MIOpen), not the dispatcher. Collapsing these per-variant descriptors into a single algorithm format that the dispatcher decomposes internally is the key step toward making the builder a true unified facade.
|
||||
|
||||
@@ -9,7 +9,7 @@ See the [main builder documentation](../README.md) for an overview.
|
||||
The reflection system works by extracting properties from a convolution kernel *type* and formatting them into a string. This is useful for debugging, performance tuning, and generating documentation.
|
||||
|
||||
1. **Trait Extraction**: The `ConvTraits` template (in `conv_traits.hpp`) is specialized for each kernel instance. It extracts low-level details like tile sizes, data layouts, and pipeline versions from the kernel's type definition.
|
||||
This template is common for xld and wmma, fwd and backwards weight kernels. std::optional is used for parameters that are only used by some kernels
|
||||
This template is common for XDL and WMMA, forward and backward weight kernels. `std::optional` is used for parameters that are only used by some kernels.
|
||||
|
||||
2. **Description Generation**: The `describe<Instance>()` function (in `conv_description.hpp`) uses `ConvTraits` to populate a `ConvDescription` (`Description`) object.
|
||||
|
||||
@@ -49,7 +49,7 @@ The reflection system (`ckr::describe`) currently supports the following convolu
|
||||
- **Standard XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle`)
|
||||
- **Large Tensor XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor`)
|
||||
- **V3 XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3`)
|
||||
- **V3 WMMA Forward Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3`)
|
||||
- **WMMA Forward Convolution** (`DeviceGroupedConvFwdMultipleD_Wmma_CShuffle`)
|
||||
- **XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffle`)
|
||||
- **V3 XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffleV3`)
|
||||
- **XDL Multiple D Backward Weight Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle`)
|
||||
@@ -61,21 +61,39 @@ The reflection system (`ckr::describe`) currently supports the following convolu
|
||||
|
||||
These variants all share similar template parameter structures and are compatible with the current `ConvTraits` implementation.
|
||||
|
||||
### Unsupported Instance Types
|
||||
#### CK Tile Instance Types
|
||||
|
||||
The following instance types are **not yet supported** by the reflection system:
|
||||
The reflection system also provides `InstanceTraits` specializations for CK Tile kernel instances:
|
||||
|
||||
- **DL (pre-XDL) Variants** (`DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK`)
|
||||
- Uses different internal structure with parameters like `K0PerBlock`, `K1`, `M1PerThread`, etc.
|
||||
- Missing standard members like `kKPerBlock`, `kMPerXDL`, `kAK1`
|
||||
- **Tile Forward Convolution** (`GroupedConvolutionForwardKernel`)
|
||||
- **Tile Backward Weight Convolution** (`GroupedConvolutionBackwardWeightKernel`)
|
||||
- **Tile Backward Data Convolution** (`GroupedConvolutionBackwardDataKernel`)
|
||||
- **Reference Convolution** (reference implementation)
|
||||
|
||||
#### Unsupported Instance Types
|
||||
|
||||
- **DL (non-XDLops) Forward** (`DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK`) has `InstanceTraits` but uses a different internal parameter structure (`K0PerBlock`, `K1`, `M1PerThread` instead of standard block/warp parameters). It can use `GetInstanceString()` through the base class pointer but cannot use `describe()`.
|
||||
|
||||
### Reflection Coverage: ConvTraits Bridge
|
||||
|
||||
The reflection system operates at two levels:
|
||||
|
||||
1. **`InstanceTraits`** (compile-time): Extracts raw template parameters from a kernel type. Specializations exist for both old CK and CK Tile instances.
|
||||
|
||||
2. **`ConvTraits`** (runtime): A unified, type-erased data structure representing kernel configuration in convolution-specific terms. Populated by `instance_to_conv_traits<Instance>()` specializations.
|
||||
|
||||
`ConvTraits` captures the common ground shared by both backends: spatial dimensions, tensor layouts, data types, elementwise operations, tile dimensions, pipeline version/scheduler, and memory access patterns. Within old CK, `ConvTraits` already unifies across the MFMA/WMMA instruction set boundary — XDL and WMMA forward instances both produce the same `ConvTraits` representation, demonstrating that instruction-set differences can be abstracted at this level.
|
||||
|
||||
Currently, `instance_to_conv_traits()` specializations exist only for old CK instances (forward XDL, XDL V3, WMMA, large tensor, and 8 backward weight variants). CK Tile instances have `InstanceTraits` but lack `instance_to_conv_traits()` specializations — there is no bridge from CK Tile's `InstanceTraits` to the unified `ConvTraits` representation.
|
||||
|
||||
This is the critical gap in the reflection system. Today the builder has 16+ per-variant factories, each with its own algorithm descriptor shape. `ConvTraits` is the mechanism for discovering which parameters are genuinely variant-specific versus which can be expressed in a single unified algorithm descriptor. Closing the CK Tile bridge means writing `instance_to_conv_traits()` specializations for the CK Tile kernel types that map their `InstanceTraits` fields to the `ConvTraits` struct. Once this bridge exists, both backends produce the same `ConvTraits` output — making it possible to define a single algorithm descriptor format that the dispatcher decomposes into variant-specific parameters internally.
|
||||
|
||||
### Future Work
|
||||
|
||||
To support these additional instance types, the reflection system would need:
|
||||
The priorities for the reflection system are:
|
||||
|
||||
1. Specialized `ConvTraits` templates for each variant type
|
||||
2. Updated `conv_layout`, `conv_data_type`, and other helper functions to handle different parameter structures
|
||||
3. Conditional compilation or SFINAE techniques to select the appropriate trait extraction logic based on instance type
|
||||
4. Customize `ConvDescription` methods for more general kernels.
|
||||
1. **CK Tile ConvTraits bridge.** Write `instance_to_conv_traits()` specializations for `GroupedConvolutionForwardKernel`, `GroupedConvolutionBackwardWeightKernel`, and `GroupedConvolutionBackwardDataKernel`. This is the prerequisite for unified algorithm descriptors.
|
||||
|
||||
For now, these unsupported types can still use `GetInstanceString()` through the base class pointer, but cannot use the `ckr::describe` reflection API.
|
||||
2. **DL variant support.** The DL forward kernel needs a specialized `ConvTraits` mapping due to its different internal parameter structure.
|
||||
|
||||
3. **Generalization beyond convolution.** `ConvTraits` is designed to evolve toward a more general `KernelTraits` covering GEMM, flash attention, and other operations.
|
||||
|
||||
@@ -41,34 +41,27 @@ The "signature" defines the **mathematical contract** that the kernel must satis
|
||||
- Data types (FP32, FP16, BF16, etc.)
|
||||
- Fused element-wise operations (e.g., Bias, ReLU)
|
||||
|
||||
The format of the signature struct is enforced at compile time using C++20 concepts by the CK-Builder API, ensuring type safety and enabling compile-time optimizations. The design of these concepts and the required constraints are discussed in the [CK Builder design description](../include/ck_tile/builder/README.md).
|
||||
The format of the signature struct is enforced at compile time using C++20 concepts by the CK-Builder API, ensuring type safety and enabling compile-time optimizations. The design of these concepts and the required constraints are discussed in the [CK Builder design description](../README.md).
|
||||
|
||||
```cpp
|
||||
// Define our custom signature struct.
|
||||
struct ConvSignature {
|
||||
int spatial_dim = 2;
|
||||
ck_tile::builder::ConvDirection direction =
|
||||
ck_tile::builder::ConvDirection::FORWARD;
|
||||
ck_tile::builder::GroupConvLayout2D layout =
|
||||
ck_tile::builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK;
|
||||
ck_tile::builder::DataType data_type =
|
||||
ck_tile::builder::DataType::FP16;
|
||||
ck_tile::builder::ElementwiseOperation elementwise_operation =
|
||||
ck_tile::builder::ElementwiseOperation::PASS_THROUGH;
|
||||
};
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
|
||||
// Double-check that out structure is well-defined according to the CK-Builder API.
|
||||
static_assert(ck_tile::builder::ConvSignatureDescriptor<ConvSignature>);
|
||||
// A signature specifies per-tensor layouts via ConvTensorDescriptor fields.
|
||||
// Each tensor has a config (with layout and optional data type override)
|
||||
// and an optional operation (with elementwise op and auxiliary operands).
|
||||
// See test/impl/conv_signature_types.hpp for a reusable ConvSignature template.
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::FORWARD,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
// Instantiate the signature with a configuration. These values are again checked
|
||||
// by the CK-Builder API when a device operation is built.
|
||||
constexpr auto SIGNATURE = ConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ck_tile::builder::ConvDirection::FORWARD,
|
||||
.layout = ck_tile::builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = ck_tile::builder::DataType::FP16,
|
||||
.elementwise_operation = ck_tile::builder::ElementwiseOperation::PASS_THROUGH,
|
||||
};
|
||||
// The ConvSignatureDescriptor concept validates the structure at compile time.
|
||||
static_assert(ckb::ConvSignatureDescriptor<decltype(SIGNATURE)>);
|
||||
```
|
||||
|
||||
#### Run-time Arguments
|
||||
@@ -122,7 +115,7 @@ The "algorithm" defines the **implementation strategy** for the kernel. It speci
|
||||
- Data transfer vectorization
|
||||
- Pipeline scheduling
|
||||
|
||||
As with the signature struct, the format of the algorithm struct is enforced at compile time using C++20 concepts by the CK-Builder API. The design of these concepts and the required constraints are discussed in the [CK Builder factory design description](../include/ck_tile/builder/factory/README.md).
|
||||
As with the signature struct, the format of the algorithm struct is enforced at compile time using C++20 concepts by the CK-Builder API. The design of these concepts and the required constraints are discussed in the [CK Builder factory design description](../factory/README.md).
|
||||
|
||||
|
||||
```cpp
|
||||
@@ -198,7 +191,7 @@ This instance can then be invoked using `ck_tile::builder::test::run()`, the sam
|
||||
|
||||
```cpp
|
||||
auto reference_outputs = ck_tile::builder::test::allocate_outputs(args);
|
||||
ck_tile::builder::test::run(conv, args, inputs.get(), reference_outputs.get());
|
||||
ck_tile::builder::test::run(reference_conv, args, inputs.get(), reference_outputs.get());
|
||||
```
|
||||
|
||||
#### Validating Results
|
||||
@@ -236,50 +229,33 @@ Here's a complete test that demonstrates the Given-When-Then pattern:
|
||||
#include "ck_tile/testing/validator.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
// Define the convolution signature
|
||||
struct ConvSignature {
|
||||
int spatial_dim = 2;
|
||||
ck_tile::builder::ConvDirection direction =
|
||||
ck_tile::builder::ConvDirection::FORWARD;
|
||||
ck_tile::builder::GroupConvLayout2D layout =
|
||||
ck_tile::builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK;
|
||||
ck_tile::builder::DataType data_type =
|
||||
ck_tile::builder::DataType::FP16;
|
||||
ck_tile::builder::ElementwiseOperation elementwise_operation =
|
||||
ck_tile::builder::ElementwiseOperation::PASS_THROUGH;
|
||||
};
|
||||
static_assert(ck_tile::builder::ConvSignatureDescriptor<ConvSignature>);
|
||||
constexpr auto SIGNATURE = ConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ck_tile::builder::ConvDirection::FORWARD,
|
||||
.layout = ck_tile::builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = ck_tile::builder::DataType::FP16,
|
||||
.elementwise_operation = ck_tile::builder::ElementwiseOperation::PASS_THROUGH,
|
||||
};
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
|
||||
// Define the convolution algorithm
|
||||
struct ConvAlgorithm {
|
||||
// Algorithm configuration details...
|
||||
// (Omitted for brevity)
|
||||
};
|
||||
static_assert(ck_tile::builder::ConvAlgorithmDescriptor<ConvAlgorithm>);
|
||||
constexpr auto ALGORITHM = ConvAlgorithm{/* ... */};
|
||||
// Define the convolution signature with per-tensor layout specification
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::FORWARD,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
static_assert(ckb::ConvSignatureDescriptor<decltype(SIGNATURE)>);
|
||||
|
||||
// Define the reference convolution algorithm
|
||||
struct ReferenceAlgorithm {
|
||||
ck_tile::builder::ConvAlgorithmSpecialization specialization;
|
||||
};
|
||||
static_assert(ck_tile::builder::ConvAlgorithmDescriptor<ReferenceAlgorithm>);
|
||||
constexpr auto REFERENCE_ALGORITHM = ReferenceAlgorithm{
|
||||
.specialization = ck_tile::builder::ConvAlgorithmSpecialization::REFERENCE;
|
||||
};
|
||||
// Define the convolution algorithm (omitted for brevity — see conv_algorithm_concepts.hpp
|
||||
// for the required fields and test/impl/conv_algorithm_types.hpp for examples)
|
||||
constexpr auto ALGORITHM = /* ... */;
|
||||
|
||||
// Define the reference algorithm
|
||||
constexpr auto REFERENCE_ALGORITHM = ckt::ConvAlgorithm_Reference{};
|
||||
|
||||
// The actual test
|
||||
TEST(ConvolutionTest, Forward2D_FP16) {
|
||||
// ===== GIVEN: Set up the test case =====
|
||||
|
||||
// Define runtime parameters
|
||||
ck_tile::builder::test::Args<ConvSignature> args = {
|
||||
ckt::Args<SIGNATURE> args = {
|
||||
.lengths = {
|
||||
.batch_size = 128,
|
||||
.groups = 1,
|
||||
@@ -295,30 +271,30 @@ TEST(ConvolutionTest, Forward2D_FP16) {
|
||||
};
|
||||
|
||||
// Allocate GPU memory
|
||||
auto inputs = ck_tile::builder::test::allocate_inputs(args);
|
||||
auto outputs = ck_tile::builder::test::allocate_outputs(args);
|
||||
auto reference_outputs = ck_tile::builder::test::allocate_outputs(args);
|
||||
auto inputs = ckt::allocate_inputs(args);
|
||||
auto outputs = ckt::allocate_outputs(args);
|
||||
auto reference_outputs = ckt::allocate_outputs(args);
|
||||
|
||||
// Initialize inputs
|
||||
ck_tile::builder::test::init_inputs(args, inputs);
|
||||
ckt::init_inputs(args, inputs);
|
||||
|
||||
// ===== WHEN: Execute the kernel =====
|
||||
|
||||
// Build the kernel
|
||||
using Conv = ck_tile::builder::ConvBuilder<SIGNATURE, ALGORITHM>::Instance;
|
||||
using Conv = ckb::ConvBuilder<SIGNATURE, ALGORITHM>::Instance;
|
||||
auto conv = Conv{};
|
||||
|
||||
// Compute actual results
|
||||
ck_tile::builder::test::run(conv, args, inputs.get(), outputs.get());
|
||||
ckt::run(conv, args, inputs.get(), outputs.get());
|
||||
|
||||
// ===== THEN: Verify the results =====
|
||||
|
||||
// Build the reference kernel
|
||||
using ReferenceConv = ck_tile::builder::ConvBuilder<SIGNATURE, REFERENCE_ALGORITHM>::Instance;
|
||||
using ReferenceConv = ckb::ConvBuilder<SIGNATURE, REFERENCE_ALGORITHM>::Instance;
|
||||
auto reference_conv = ReferenceConv{};
|
||||
|
||||
// Compute reference results
|
||||
ck_tile::builder::test::run(conv, args, inputs.get(), reference_outputs.get());
|
||||
ckt::run(reference_conv, args, inputs.get(), reference_outputs.get());
|
||||
|
||||
// Check the results
|
||||
EXPECT_THAT(outputs.get(), ck_tile::test::MatchesReference(args, reference_outputs.get()));
|
||||
|
||||
Reference in New Issue
Block a user