mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[CK_BUILDER] Refactor convolution signature to provide data type/layout/elementwise op per tensor (#3331)
* Separate layouts into separate entities for input, weight, and output tensors. * Add test for handling bias tensor layouts. * Use instance string in builder tests. * Add handling of output bias data types and layouts. * Generalize handling of the elementwise ops. * Test fix. * Create builder for layouts. * Layout builder improvements. * Improve layout builder. * Simplify bias layout handling. * Code clean-up. * Move layout utils into separate file. * Remove hard-coded layout combinations. * Small code clean-up. * Move data type utils into a separate file. * Add data types, layouts, and elementwise ops per conv tensor. * Builder bug fixes after refactoring. * Working baseline. * Make signature definition look nice in the test code. * Move TensorConfig into test implementations. * Fix all fwd conv builder tests. * Fix conv traits and descriptors tests. * More factory assets under a separate directory. * Fix building conv traits. * Fix clang-format. * Add Readme doc to describe the design. * Add link to main Readme. Fix links in the builder design doc. * Clean-up data type/layout/elementwise op conversions. * Switch from dimension and tensor type specific layouts to a flat list of tensor layouts. * Fix clang-formatting. * Fix clang-format for test code. * Simplify fwd conv signature definitions in the test code. * Remove accidental edits. * Fix comment string. * Fix instance factory after rebase. * Fix tests after rebase. * Unify layout handling. * Add more conv layout unit tests. * Clang-format. * Fix merge conflicts. * Improve elementwise op handling. --------- Co-authored-by: Ville Pietilä <>
This commit is contained in:
244
experimental/builder/include/ck_tile/builder/README.md
Normal file
244
experimental/builder/include/ck_tile/builder/README.md
Normal file
@@ -0,0 +1,244 @@
|
||||
# Composable Kernel Builder Design Documentation
|
||||
|
||||
This directory contains the builder framework for Composable Kernel, which provides a compile-time, type-safe interface for constructing convolution operations with various configurations.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Convolution Signature Design](#convolution-signature-design)
|
||||
- [Overview](#overview)
|
||||
- [Architecture](#architecture)
|
||||
- [Core Components](#core-components)
|
||||
- [Concepts and Validation](#concepts-and-validation)
|
||||
---
|
||||
|
||||
## Convolution Signature Design
|
||||
|
||||
### Overview
|
||||
|
||||
The convolution signature system provides a **compile-time description** of grouped convolution operations. A signature is a collection of properties that fully characterize a convolution kernel's mathematical and operational behavior, enabling:
|
||||
|
||||
- **Compile-time validation**: Ensures type safety and correctness before kernel instantiation
|
||||
- **Kernel selection**: Matches user requirements to optimized implementations
|
||||
- **Specialization**: Enables optimized code paths for specific configurations
|
||||
- **Composability**: Supports building complex operations from simpler components
|
||||
|
||||
The signature leverages modern C++20 features, particularly **concepts**, to provide expressive, self-documenting interfaces with compile-time guarantees.
|
||||
|
||||
### Architecture
|
||||
|
||||
The signature system is organized into a hierarchical structure:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ ConvSignature │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ Properties: │
|
||||
│ • spatial_dim: int (1D, 2D, or 3D) │
|
||||
│ • direction: ConvDirection (Fwd/BwdData/BwdWeight) │
|
||||
│ • data_type: DataType (default data type) │
|
||||
│ • accumulation_data_type: DataType │
|
||||
│ • input: ConvTensor ──┐ │
|
||||
│ • weight: ConvTensor ──│ │
|
||||
│ • output: ConvTensor ──│ │
|
||||
└──────────────────────────────────┼──────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────┐
|
||||
│ ConvTensor │
|
||||
├─────────────────────────────────────────┤
|
||||
│ ╔═════════════════════════════════════╗ │
|
||||
│ ║ TensorConfig (required) ║ │
|
||||
│ ╠═════════════════════════════════════╣ │
|
||||
│ ║ • layout: ConvLayout ║ │
|
||||
│ ║ • data_type: DataType (optional) ║ │
|
||||
│ ║ • compute_type: DataType (optional)║ │
|
||||
│ ╚═════════════════════════════════════╝ │
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────┐ │
|
||||
│ │ TensorOperation (optional) │ │
|
||||
│ ├─────────────────────────────────────┤ │
|
||||
│ │ • elementwise_operation │ │
|
||||
│ │ • auxiliary_operand_configs[] │ │
|
||||
│ │ (each is also ConvTensor) ◄───────┼─┐
|
||||
│ └─────────────────────────────────────┘ │ │
|
||||
└─────────────────────────────────────────┘ │
|
||||
│
|
||||
Recursive ───────────────┘
|
||||
```
|
||||
Key Design Points:
|
||||
- ConvSignature contains three ConvTensor instances (input, weight, output)
|
||||
- All tensors share the same ConvTensor structure
|
||||
- Each ConvTensor has:
|
||||
- TensorConfig (required): Defines layout as well as optional data and compute type overrides
|
||||
- TensorOperation (optional): Defines fused elementwise operations
|
||||
- Auxiliary operands (e.g., bias) in TensorOperation also use the ConvTensor type
|
||||
|
||||
### Core Components
|
||||
|
||||
#### 1. Signature Level
|
||||
|
||||
The top-level signature contains global properties that apply to the entire convolution operation:
|
||||
|
||||
```cpp
|
||||
template <typename T>
|
||||
concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.spatial_dim } -> std::convertible_to<unsigned int>; // 1, 2, or 3
|
||||
{ t.data_type } -> std::convertible_to<DataType>; // Default data type
|
||||
{ t.input } -> ConvTensorDescriptor;
|
||||
{ t.weight } -> ConvTensorDescriptor;
|
||||
{ t.output } -> ConvTensorDescriptor;
|
||||
requires ConvolutionDirectionWellDefinedIfProvided<T>; // Optional direction
|
||||
};
|
||||
```
|
||||
|
||||
**Properties:**
|
||||
- **`spatial_dim`**: Dimensionality of the convolution (1D, 2D, or 3D)
|
||||
- **`direction`**: Operation type (optional, defaults to FORWARD)
|
||||
- `FORWARD`: Standard forward convolution
|
||||
- `BACKWARD_DATA`: Gradient computation w.r.t. input
|
||||
- `BACKWARD_WEIGHT`: Gradient computation w.r.t. weights
|
||||
- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8)
|
||||
- **`accumulation_data_type`**: Type used for internal accumulation
|
||||
|
||||
#### 2. Tensor Level
|
||||
|
||||
Each tensor (input, weight, output) has its own descriptor:
|
||||
|
||||
```cpp
|
||||
template <typename T>
|
||||
concept ConvTensorDescriptor = requires(T t) {
|
||||
{ t.config } -> TensorConfigDescriptor;
|
||||
requires ElementwiseOpWellDefinedIfProvided<T>;
|
||||
};
|
||||
```
|
||||
|
||||
A tensor descriptor encapsulates:
|
||||
- **Configuration**: Layout and data type information
|
||||
- **Operation** (optional): Fused elementwise operations on this tensor
|
||||
|
||||
#### 3. Tensor Configuration
|
||||
|
||||
Describes the memory layout and data types:
|
||||
|
||||
```cpp
|
||||
template <typename T>
|
||||
concept TensorConfigDescriptor = requires(T t) {
|
||||
{ t.layout } -> std::convertible_to<ConvLayout>;
|
||||
{ t.data_type } -> std::convertible_to<DataType>; // Optional override
|
||||
};
|
||||
```
|
||||
|
||||
**Layout Types** (dimension-specific):
|
||||
- **1D Convolution**:
|
||||
- Input: `GNCW`, `GNWC`, `NWGC`, `NGCW`, `G_NW_C_strided`
|
||||
- Weight: `GKXC`, `GKCX`, `KXGC`, `G_K_X_C_strided`
|
||||
- Output: `GNKW`, `GNWK`, `NWGK`, `NGKW`, `G_NW_K_strided`
|
||||
|
||||
- **2D Convolution**:
|
||||
- Input: `GNCHW`, `GNHWC`, `NHWGC`, `NGCHW`, `G_NHW_C_strided`
|
||||
- Weight: `GKYXC`, `GKCYX`, `KYXGC`, `G_K_YX_C_strided`
|
||||
- Output: `GNKHW`, `GNHWK`, `NHWGK`, `NGKHW`, `G_NHW_K_strided`
|
||||
|
||||
- **3D Convolution**:
|
||||
- Input: `GNCDHW`, `GNDHWC`, `NDHWGC`, `NGCDHW`, `G_NDHW_C_strided`
|
||||
- Weight: `GKZYXC`, `GKCZYX`, `KZYXGC`, `G_K_ZYX_C_strided`
|
||||
- Output: `GNKDHW`, `GNDHWK`, `NDHWGK`, `NGKDHW`, `G_NDHW_K_strided`
|
||||
|
||||
Where:
|
||||
- `G` = Groups
|
||||
- `N` = Batch size
|
||||
- `C` = Input channels
|
||||
- `K` = Output channels (filters)
|
||||
- `W`, `H`, `D` = Width, Height, Depth (spatial dimensions)
|
||||
- `X`, `Y`, `Z` = Filter dimensions
|
||||
|
||||
#### 4. Tensor Operations
|
||||
|
||||
Describes fused elementwise operations applied to a tensor:
|
||||
|
||||
```cpp
|
||||
template <typename T>
|
||||
concept TensorOperatorDescriptor = requires(T t) {
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
requires AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
|
||||
};
|
||||
```
|
||||
|
||||
**Supported Operations:**
|
||||
- `PASS_THROUGH`: No operation (identity)
|
||||
- `SCALE`: Multiply by a scalar
|
||||
- `CLAMP`: Clamp values to a range
|
||||
- `BIAS_BNORM_CLAMP`: Bias addition + batch normalization + clamp
|
||||
- `SCALEADD_SCALEADD_RELU`: Fused scale-add operations + ReLU activation
|
||||
|
||||
**Auxiliary Operands:**
|
||||
Some operations require additional tensor inputs (e.g., bias tensors, scaling factors). These are specified through `auxiliary_operand_configs`, which is an array of `TensorConfigDescriptor` objects describing the layout and data type of each auxiliary input.
|
||||
|
||||
### Concepts and Validation
|
||||
|
||||
The signature system uses C++20 concepts for compile-time validation at multiple levels:
|
||||
|
||||
#### Constraint Concepts
|
||||
|
||||
```cpp
|
||||
// Spatial dimension must be 1, 2, or 3
|
||||
template <auto N>
|
||||
concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 || N == 3);
|
||||
|
||||
// Valid data types for convolution
|
||||
template <DataType T>
|
||||
concept ValidConvDataType =
|
||||
(T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
|
||||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
|
||||
```
|
||||
|
||||
#### Validation Concept
|
||||
|
||||
```cpp
|
||||
// Validates a complete signature
|
||||
template <auto Sig>
|
||||
concept ValidConvSignature = requires {
|
||||
requires ConvSpatialDim<Sig.spatial_dim>;
|
||||
requires ValidConvDataType<Sig.data_type>;
|
||||
};
|
||||
```
|
||||
|
||||
#### Tensor Descriptors
|
||||
|
||||
The layout/data type/elementwise operation are described per tensor. This multi-level hierarchy allows:
|
||||
- **Flexibility**: Each tensor can have independent layout and data type
|
||||
- **Reusability**: Common configurations can be shared across different signatures
|
||||
- **Extensibility**: New properties can be added to specific levels without affecting others
|
||||
- **Clarity**: Separates concerns (global properties vs. tensor-specific properties)
|
||||
|
||||
#### Optional Signature Fields
|
||||
|
||||
Several fields in the signature are optional:
|
||||
- **`direction`**: Defaults to `FORWARD` if not specified, reducing boilerplate for the common case
|
||||
- **Tensor `data_type`**: Falls back to signature's default, allowing mixed-precision with minimal specification
|
||||
- **Tensor `operation`**: Defaults to `PASS_THROUGH`, supporting both fused and non-fused operations with the same interface
|
||||
|
||||
This design follows the principle of "make the common case simple, the complex case possible."
|
||||
|
||||
#### Union-Based Layout Representation
|
||||
|
||||
The `ConvLayout` type uses unions to support dimension-agnostic code:
|
||||
|
||||
```cpp
|
||||
struct ConvLayout {
|
||||
union {
|
||||
ConvInputLayout _input_layout;
|
||||
ConvWeightLayout _weight_layout;
|
||||
ConvOutputLayout _output_layout;
|
||||
ConvAuxiliaryTensorLayout _aux_tensor_layout;
|
||||
};
|
||||
// ... constructors for each type
|
||||
};
|
||||
```
|
||||
|
||||
This allows:
|
||||
- Single type to represent all layout variants
|
||||
- Type-safe construction through overloaded constructors
|
||||
- Compile-time enforcement of valid combinations through concepts
|
||||
|
||||
---
|
||||
@@ -28,24 +28,104 @@ namespace ck_tile::builder {
|
||||
template <auto N>
|
||||
concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 || N == 3);
|
||||
|
||||
// Constraints for forward convolution layouts.
|
||||
template <auto LayoutValue, size_t SpatialDim>
|
||||
concept ValidConvLayoutForSpatialDim =
|
||||
(SpatialDim == 1 && std::same_as<decltype(LayoutValue), GroupConvLayout1D>) ||
|
||||
(SpatialDim == 2 && std::same_as<decltype(LayoutValue), GroupConvLayout2D>) ||
|
||||
(SpatialDim == 3 && std::same_as<decltype(LayoutValue), GroupConvLayout3D>);
|
||||
|
||||
// Constrains convolution data types to common floating-point types.
|
||||
template <DataType T>
|
||||
concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
|
||||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
|
||||
concept ValidConvDataType =
|
||||
(T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
|
||||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept BiasTensorLayout =
|
||||
(L == TensorLayout::GC) || (L == TensorLayout::G_C_strided) || (L == TensorLayout::G_K_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvInputLayout1D =
|
||||
(L == TensorLayout::GNCW) || (L == TensorLayout::GNWC) || (L == TensorLayout::NWGC) ||
|
||||
(L == TensorLayout::NGCW) || (L == TensorLayout::G_NW_C_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvInputLayout2D =
|
||||
(L == TensorLayout::GNCHW) || (L == TensorLayout::GNHWC) || (L == TensorLayout::NHWGC) ||
|
||||
(L == TensorLayout::NGCHW) || (L == TensorLayout::G_NHW_C_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvInputLayout3D =
|
||||
(L == TensorLayout::GNCDHW) || (L == TensorLayout::GNDHWC) || (L == TensorLayout::NDHWGC) ||
|
||||
(L == TensorLayout::NGCDHW) || (L == TensorLayout::G_NDHW_C_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvWeightLayout1D = (L == TensorLayout::GKXC) || (L == TensorLayout::GKCX) ||
|
||||
(L == TensorLayout::KXGC) || (L == TensorLayout::G_K_X_C_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvWeightLayout2D = (L == TensorLayout::GKYXC) || (L == TensorLayout::GKCYX) ||
|
||||
(L == TensorLayout::KYXGC) || (L == TensorLayout::G_K_YX_C_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvWeightLayout3D = (L == TensorLayout::GKZYXC) || (L == TensorLayout::GKCZYX) ||
|
||||
(L == TensorLayout::KZYXGC) || (L == TensorLayout::G_K_ZYX_C_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvOutputLayout1D =
|
||||
(L == TensorLayout::GNKW) || (L == TensorLayout::GNWK) || (L == TensorLayout::NWGK) ||
|
||||
(L == TensorLayout::NGKW) || (L == TensorLayout::G_NW_K_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvOutputLayout2D =
|
||||
(L == TensorLayout::GNKHW) || (L == TensorLayout::GNHWK) || (L == TensorLayout::NHWGK) ||
|
||||
(L == TensorLayout::NGKHW) || (L == TensorLayout::G_NHW_K_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvOutputLayout3D =
|
||||
(L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) ||
|
||||
(L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided);
|
||||
|
||||
template <typename T>
|
||||
concept ConvLayout = std::same_as<std::remove_cvref_t<T>, GroupConvLayout>;
|
||||
concept TensorConfigDescriptor = requires(T t) {
|
||||
{ t.layout } -> std::convertible_to<TensorLayout>;
|
||||
// Only require that data type is defined. It might be set to undefined value, in which case the
|
||||
// signature's data type is used.
|
||||
{ t.data_type } -> std::convertible_to<DataType>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept HasElementwiseOp = requires(T t) {
|
||||
{ t.elementwise_operation };
|
||||
concept HasAuxiliaryOperandConfigs = requires(T t) {
|
||||
{ t.auxiliary_operand_configs };
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
struct IsArrayOfTensorConfigDescriptors : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T, std::size_t N>
|
||||
requires TensorConfigDescriptor<T>
|
||||
struct IsArrayOfTensorConfigDescriptors<std::array<T, N>> : std::true_type
|
||||
{
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
concept ConvertibleToArrayOfTensorConfigs =
|
||||
detail::IsArrayOfTensorConfigDescriptors<std::remove_cvref_t<T>>::value;
|
||||
|
||||
template <typename T>
|
||||
concept AuxiliaryOperandConfigsWellDefinedIfProvided = requires(T t) {
|
||||
requires !HasAuxiliaryOperandConfigs<T> || requires {
|
||||
{ t.auxiliary_operand_configs } -> ConvertibleToArrayOfTensorConfigs;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept TensorOperatorDescriptor = requires(T t) {
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
requires AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept HasTensorOp = requires(T t) {
|
||||
{ t.operation };
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -56,11 +136,8 @@ concept HasConvolutionDirection = requires(T t) {
|
||||
// Note: it is not required to provide an ElementwiseOp, but if one is provided, check if well
|
||||
// defined
|
||||
template <typename T>
|
||||
concept ElementwiseOpWellDefinedIfProvided = requires(T t) {
|
||||
requires !HasElementwiseOp<T> || requires {
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
};
|
||||
};
|
||||
concept ElementwiseOpWellDefinedIfProvided =
|
||||
!HasTensorOp<T> || requires(T t) { requires TensorOperatorDescriptor<decltype(t.operation)>; };
|
||||
|
||||
// Note: it is not required to provide a convolution, but if one is provided, check if well defined
|
||||
template <typename T>
|
||||
@@ -70,13 +147,27 @@ concept ConvolutionDirectionWellDefinedIfProvided = requires(T t) {
|
||||
};
|
||||
};
|
||||
|
||||
// Concept for the convolution tensor
|
||||
template <typename T>
|
||||
concept ConvTensorDescriptor = requires(T t) {
|
||||
{ t.config } -> TensorConfigDescriptor;
|
||||
requires ElementwiseOpWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept HasElementwiseOpWithAuxiliaryOperands = requires(T t) {
|
||||
requires HasTensorOp<T>;
|
||||
requires HasAuxiliaryOperandConfigs<decltype(t.operation)>;
|
||||
};
|
||||
|
||||
// Concept for a type that defines a convolution's operational signature.
|
||||
template <typename T>
|
||||
concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
|
||||
{ t.layout } -> ConvLayout;
|
||||
{ t.data_type } -> std::convertible_to<DataType>;
|
||||
requires ElementwiseOpWellDefinedIfProvided<T>;
|
||||
{ t.input } -> ConvTensorDescriptor;
|
||||
{ t.weight } -> ConvTensorDescriptor;
|
||||
{ t.output } -> ConvTensorDescriptor;
|
||||
requires ConvolutionDirectionWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
@@ -84,7 +175,7 @@ concept ConvSignatureDescriptor = requires(T t) {
|
||||
template <auto Sig>
|
||||
concept ValidConvSignature = requires {
|
||||
requires ConvSpatialDim<Sig.spatial_dim>;
|
||||
requires ConvDataType<Sig.data_type>;
|
||||
requires ValidConvDataType<Sig.data_type>;
|
||||
};
|
||||
|
||||
// Predicate for forward convolution (default if direction is not included).
|
||||
@@ -100,4 +191,22 @@ concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
|
||||
|
||||
// Constraints for forward convolution input layouts.
|
||||
template <TensorLayout L, size_t SpatialDim>
|
||||
concept ValidConvInputLayoutForSpatialDim =
|
||||
(SpatialDim == 1 && ConvInputLayout1D<L>) || (SpatialDim == 2 && ConvInputLayout2D<L>) ||
|
||||
(SpatialDim == 3 && ConvInputLayout3D<L>);
|
||||
|
||||
// Constraints for forward convolution output layouts.
|
||||
template <TensorLayout L, size_t SpatialDim>
|
||||
concept ValidConvOutputLayoutForSpatialDim =
|
||||
(SpatialDim == 1 && ConvOutputLayout1D<L>) || (SpatialDim == 2 && ConvOutputLayout2D<L>) ||
|
||||
(SpatialDim == 3 && ConvOutputLayout3D<L>);
|
||||
|
||||
// Constraints for forward convolution weight layouts.
|
||||
template <TensorLayout L, size_t SpatialDim>
|
||||
concept ValidConvWeightLayoutForSpatialDim =
|
||||
(SpatialDim == 1 && ConvWeightLayout1D<L>) || (SpatialDim == 2 && ConvWeightLayout2D<L>) ||
|
||||
(SpatialDim == 3 && ConvWeightLayout3D<L>);
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
/**********************************************
|
||||
* constexpr helper functions for optional parameters
|
||||
**********************************************/
|
||||
|
||||
template <auto Sig>
|
||||
concept ProvidesElementwiseOperation = requires { Sig.elementwiseOperation; };
|
||||
|
||||
template <auto Sig>
|
||||
concept ProvidesConvolutionDirection = requires { Sig.direction; };
|
||||
|
||||
template <auto Sig>
|
||||
constexpr auto get_elementwise_operation()
|
||||
{
|
||||
if constexpr(ProvidesElementwiseOperation<Sig>)
|
||||
{
|
||||
return Sig.elementwise_operation;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ElementwiseOperation::PASS_THROUGH;
|
||||
}
|
||||
}
|
||||
|
||||
template <auto Sig>
|
||||
constexpr auto get_conv_direction()
|
||||
{
|
||||
if constexpr(ProvidesConvolutionDirection<Sig>)
|
||||
{
|
||||
return Sig.direction;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ConvDirection::FORWARD;
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile::builder
|
||||
@@ -7,7 +7,6 @@
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
@@ -25,11 +24,9 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdDlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
@@ -27,11 +26,9 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdLargeTensorFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm;
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
@@ -27,11 +26,9 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdXdlV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load ==
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
@@ -27,11 +26,9 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdWmmaFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
@@ -27,11 +26,9 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
|
||||
@@ -6,32 +6,70 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
template <ElementwiseOperation T>
|
||||
template <ElementwiseOperation Op>
|
||||
struct ElementwiseOpToCK
|
||||
{
|
||||
static_assert(sizeof(UnsupportedEnumValue<Op>) == 0,
|
||||
"Unsupported elementwise operation conversion to CK.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCK<ElementwiseOperation::PASS_THROUGH>
|
||||
{
|
||||
using Op = ck::tensor_operation::element_wise::PassThrough;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCK<ElementwiseOperation::SCALE>
|
||||
{
|
||||
using Op = ck::tensor_operation::element_wise::Scale;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCK<ElementwiseOperation::CLAMP>
|
||||
{
|
||||
using Op = ck::tensor_operation::element_wise::Clamp;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCK<ElementwiseOperation::SCALEADD_SCALEADD_RELU>
|
||||
{
|
||||
using Op = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCK<ElementwiseOperation::BIAS_BNORM_CLAMP>
|
||||
{
|
||||
using Op = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
};
|
||||
|
||||
template <auto TensorDesc>
|
||||
consteval auto GetElementwiseOp()
|
||||
{
|
||||
if constexpr(HasTensorOp<decltype(TensorDesc)>)
|
||||
{
|
||||
constexpr auto op = TensorDesc.operation.elementwise_operation;
|
||||
return ElementwiseOpToCK<op>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ElementwiseOpToCK<ElementwiseOperation::PASS_THROUGH>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <auto Sig>
|
||||
struct ElementwiseOps
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported elementwise operation for convolution factory.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOps<ElementwiseOperation::PASS_THROUGH>
|
||||
{
|
||||
using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOps<ElementwiseOperation::SCALE>
|
||||
{
|
||||
using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementwiseOp = ck::tensor_operation::element_wise::Scale;
|
||||
static constexpr auto input_op = GetElementwiseOp<Sig.input>();
|
||||
static constexpr auto weight_op = GetElementwiseOp<Sig.weight>();
|
||||
static constexpr auto output_op = GetElementwiseOp<Sig.output>();
|
||||
using AElementwiseOp = typename decltype(input_op)::Op;
|
||||
using BElementwiseOp = typename decltype(weight_op)::Op;
|
||||
using CDEElementwiseOp = typename decltype(output_op)::Op;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -6,141 +6,228 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types.
|
||||
template <auto LayoutValue, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> && ValidConvLayoutForSpatialDim<LayoutValue, SPATIAL_DIM>)
|
||||
struct ConvTensorLayouts
|
||||
template <TensorLayout Layout>
|
||||
struct LayoutToCK
|
||||
{
|
||||
// This will trigger if a specialization for the given layout is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
using Layout = decltype(LayoutValue);
|
||||
static_assert(sizeof(Layout) == 0,
|
||||
"Internal error. Unsupported layout for convolution factory.");
|
||||
static_assert(sizeof(UnsupportedEnumValue<Layout>) == 0,
|
||||
"Unsupported layout conversion to CK.");
|
||||
};
|
||||
|
||||
// 1D Forward Convolution Layout Specializations
|
||||
// Bias layouts
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NWGC_GKXC_NWGK, 1, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::G_K_strided>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NWGK;
|
||||
using type = ck::tensor_layout::convolution::G_K;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKXC_NGKW, 1, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::GC>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKW;
|
||||
using type = ck::tensor_layout::convolution::GC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::GNWC_GKXC_GNWK, 1, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::G_C_strided>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNWK;
|
||||
using type = ck::tensor_layout::convolution::G_C;
|
||||
};
|
||||
|
||||
// Input 1D
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKCX_NGKW, 1, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::NWGC>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKW;
|
||||
using type = ck::tensor_layout::convolution::NWGC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKYXC_NGKHW, 2, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::NGCW>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKHW;
|
||||
using type = ck::tensor_layout::convolution::NGCW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NHWGC_GKYXC_NHWGK, 2, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::GNWC>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NHWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NHWGK;
|
||||
using type = ck::tensor_layout::convolution::GNWC;
|
||||
};
|
||||
|
||||
// Input 2D
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::GNHWC_GKYXC_GNHWK, 2, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::NGCHW>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNHWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNHWK;
|
||||
using type = ck::tensor_layout::convolution::NGCHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKCYX_NGKHW, 2, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::NHWGC>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCYX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKHW;
|
||||
using type = ck::tensor_layout::convolution::NHWGC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, 3, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::GNHWC>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCDHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCZYX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKDHW;
|
||||
using type = ck::tensor_layout::convolution::GNHWC;
|
||||
};
|
||||
|
||||
// Input 3D
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, 3, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::NGCDHW>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
using type = ck::tensor_layout::convolution::NGCDHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, 3, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::NDHWGC>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNDHWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNDHWK;
|
||||
using type = ck::tensor_layout::convolution::NDHWGC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GNDHWC>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GNDHWC;
|
||||
};
|
||||
|
||||
template <GroupConvLayout Layout, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
consteval auto GetTensorLayout()
|
||||
// Weight 1D
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GKXC>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GKXC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GKCX>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GKCX;
|
||||
};
|
||||
|
||||
if constexpr(SPATIAL_DIM == 1)
|
||||
{
|
||||
return internal::ConvTensorLayouts<Layout._1d, 1, DIR>{};
|
||||
}
|
||||
else if constexpr(SPATIAL_DIM == 2)
|
||||
{
|
||||
return internal::ConvTensorLayouts<Layout._2d, 2, DIR>{};
|
||||
}
|
||||
else if constexpr(SPATIAL_DIM == 3)
|
||||
{
|
||||
return internal::ConvTensorLayouts<Layout._3d, 3, DIR>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported spatial dimension for convolution layout.");
|
||||
}
|
||||
// Weight 2D
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GKYXC>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GKYXC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GKCYX>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GKCYX;
|
||||
};
|
||||
|
||||
// Weight 3D
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GKCZYX>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GKCZYX;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GKZYXC>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GKZYXC;
|
||||
};
|
||||
|
||||
// Output 1D
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::NWGK>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::NWGK;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::NGKW>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::NGKW;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GNWK>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GNWK;
|
||||
};
|
||||
|
||||
// Output 2D
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::NGKHW>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::NGKHW;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::NHWGK>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::NHWGK;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GNHWK>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GNHWK;
|
||||
};
|
||||
|
||||
// Output 3D
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::NGKDHW>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::NGKDHW;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::NDHWGK>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::NDHWGK;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GNDHWK>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GNDHWK;
|
||||
};
|
||||
|
||||
template <TensorLayout Layout>
|
||||
consteval auto TensorLayoutToCK()
|
||||
{
|
||||
return typename LayoutToCK<Layout>::type{};
|
||||
}
|
||||
|
||||
struct EmptyAuxiliaryTensorLayout
|
||||
{
|
||||
using type = ck::Tuple<>;
|
||||
};
|
||||
|
||||
template <auto AuxiliaryTensorConfigsArray, size_t... Indices>
|
||||
consteval auto GetAuxiliaryTensorLayoutTuple(std::index_sequence<Indices...>)
|
||||
{
|
||||
return ck::Tuple<
|
||||
decltype(TensorLayoutToCK<AuxiliaryTensorConfigsArray[Indices].layout>())...>{};
|
||||
}
|
||||
|
||||
template <auto AuxiliaryTensorConfigsValue, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM>)
|
||||
struct AuxiliaryTensorLayouts
|
||||
{
|
||||
static constexpr auto Size = AuxiliaryTensorConfigsValue.size();
|
||||
using type = decltype(GetAuxiliaryTensorLayoutTuple<AuxiliaryTensorConfigsValue>(
|
||||
std::make_index_sequence<Size>{}));
|
||||
};
|
||||
|
||||
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
|
||||
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTensorLayouts()
|
||||
{
|
||||
return AuxiliaryTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
|
||||
SPATIAL_DIM,
|
||||
DIR>{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTensorLayouts()
|
||||
{
|
||||
return EmptyAuxiliaryTensorLayout{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> &&
|
||||
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
|
||||
struct ConvTensorLayouts
|
||||
{
|
||||
static_assert(DIR == ConvDirection::FORWARD, "Only Forward convolution is supported.");
|
||||
using ALayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
|
||||
using BLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
|
||||
using ELayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
|
||||
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM, DIR>())::type;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -6,82 +6,172 @@
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Type mappings from builder convolution data type to CK tensor types.
|
||||
template <DataType T>
|
||||
struct ConvTensorTypes
|
||||
template <DataType DT>
|
||||
struct DataTypeToCK
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported data type for convolution factory.");
|
||||
// Catch unsupported data types at compile time
|
||||
static_assert(sizeof(UnsupportedEnumValue<DT>) == 0, "Unsupported data type conversion to CK.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP16>
|
||||
struct DataTypeToCK<DataType::FP16>
|
||||
{
|
||||
using ADataType = ck::half_t;
|
||||
using AComputeType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using BComputeType = ck::half_t;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::half_t;
|
||||
using type = ck::half_t;
|
||||
};
|
||||
template <>
|
||||
struct DataTypeToCK<DataType::BF16>
|
||||
{
|
||||
using type = ck::bhalf_t;
|
||||
};
|
||||
template <>
|
||||
struct DataTypeToCK<DataType::FP32>
|
||||
{
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct DataTypeToCK<DataType::INT32>
|
||||
{
|
||||
using type = int32_t;
|
||||
};
|
||||
template <>
|
||||
struct DataTypeToCK<DataType::I8>
|
||||
{
|
||||
using type = int8_t;
|
||||
};
|
||||
template <>
|
||||
struct DataTypeToCK<DataType::FP8>
|
||||
{
|
||||
using type = ck::f8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::BF16>
|
||||
struct CK_empty_tuple
|
||||
{
|
||||
using ADataType = ck::bhalf_t;
|
||||
using AComputeType = ck::bhalf_t;
|
||||
using BDataType = ck::bhalf_t;
|
||||
using BComputeType = ck::bhalf_t;
|
||||
using CShuffleDataType = ck::bhalf_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::bhalf_t;
|
||||
using type = ck::Tuple<>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP32>
|
||||
template <DataType dt>
|
||||
consteval auto ConvertDataTypeToCK()
|
||||
{
|
||||
using ADataType = float;
|
||||
using AComputeType = float;
|
||||
using BDataType = float;
|
||||
using BComputeType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = float;
|
||||
return DataTypeToCK<dt>{};
|
||||
}
|
||||
|
||||
template <auto Config, DataType SignatureDataType>
|
||||
consteval auto GetTensorDataAndComputeTypes()
|
||||
{
|
||||
constexpr auto data_type = Config.data_type;
|
||||
constexpr auto compute_type = Config.compute_type;
|
||||
|
||||
if constexpr(data_type == DataType::UNDEFINDED && compute_type == DataType::UNDEFINDED)
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<SignatureDataType>(),
|
||||
ConvertDataTypeToCK<SignatureDataType>());
|
||||
}
|
||||
else if constexpr(data_type == DataType::UNDEFINDED)
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<SignatureDataType>(),
|
||||
ConvertDataTypeToCK<compute_type>());
|
||||
}
|
||||
else if constexpr(compute_type == DataType::UNDEFINDED)
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<data_type>(),
|
||||
ConvertDataTypeToCK<SignatureDataType>());
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<data_type>(),
|
||||
ConvertDataTypeToCK<compute_type>());
|
||||
}
|
||||
}
|
||||
|
||||
template <DataType SignatureAccDataType, DataType SignatureDataType>
|
||||
consteval auto GetTensorAccumulationType()
|
||||
{
|
||||
constexpr auto data_type = SignatureAccDataType;
|
||||
if constexpr(data_type == DataType::UNDEFINDED)
|
||||
{
|
||||
return ConvertDataTypeToCK<SignatureDataType>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return ConvertDataTypeToCK<data_type>();
|
||||
}
|
||||
}
|
||||
|
||||
template <auto Config, DataType SignatureDataType>
|
||||
consteval auto GetAuxiliaryTensorDataTypeValue()
|
||||
{
|
||||
constexpr auto data_type = Config.data_type;
|
||||
if constexpr(data_type == DataType::UNDEFINDED)
|
||||
{
|
||||
return ConvertDataTypeToCK<SignatureDataType>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return ConvertDataTypeToCK<data_type>();
|
||||
}
|
||||
}
|
||||
|
||||
template <auto AuxiliaryTensorConfigsArray, DataType SignatureDataType, size_t... Indices>
|
||||
consteval auto GetAuxiliaryTensorDataTypeTuple(std::index_sequence<Indices...>)
|
||||
{
|
||||
return ck::Tuple<
|
||||
typename decltype(GetAuxiliaryTensorDataTypeValue<AuxiliaryTensorConfigsArray[Indices],
|
||||
SignatureDataType>())::type...>{};
|
||||
}
|
||||
|
||||
template <auto AuxiliaryTensorConfigsValue, DataType SignatureDataType>
|
||||
struct AuxiliaryTensorDataTypes
|
||||
{
|
||||
static constexpr auto Size = AuxiliaryTensorConfigsValue.size();
|
||||
using type =
|
||||
decltype(GetAuxiliaryTensorDataTypeTuple<AuxiliaryTensorConfigsValue, SignatureDataType>(
|
||||
std::make_index_sequence<Size>{}));
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::I8>
|
||||
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
|
||||
template <auto Signature>
|
||||
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTensorDataTypes()
|
||||
{
|
||||
using ADataType = int8_t;
|
||||
using AComputeType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using BComputeType = int8_t;
|
||||
using CShuffleDataType = int8_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = int32_t;
|
||||
using EDataType = int8_t;
|
||||
};
|
||||
return AuxiliaryTensorDataTypes<Signature.output.operation.auxiliary_operand_configs,
|
||||
Signature.data_type>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP8>
|
||||
template <auto Signature>
|
||||
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTensorDataTypes()
|
||||
{
|
||||
using ADataType = ck::f8_t;
|
||||
using AComputeType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using BComputeType = ck::f8_t;
|
||||
using CShuffleDataType = ck::f8_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::f8_t;
|
||||
return CK_empty_tuple{};
|
||||
}
|
||||
|
||||
template <auto Signature>
|
||||
struct FwdConvTensorDataTypes
|
||||
{
|
||||
static constexpr auto input_types =
|
||||
GetTensorDataAndComputeTypes<Signature.input.config, Signature.data_type>();
|
||||
static constexpr auto weight_types =
|
||||
GetTensorDataAndComputeTypes<Signature.weight.config, Signature.data_type>();
|
||||
static constexpr auto output_types =
|
||||
GetTensorDataAndComputeTypes<Signature.output.config, Signature.data_type>();
|
||||
|
||||
using ADataType = typename decltype(input_types.first)::type;
|
||||
using AComputeType = typename decltype(input_types.second)::type;
|
||||
using BDataType = typename decltype(weight_types.first)::type;
|
||||
using BComputeType = typename decltype(weight_types.second)::type;
|
||||
using AccDataType =
|
||||
typename decltype(GetTensorAccumulationType<Signature.accumulation_data_type,
|
||||
Signature.data_type>())::type;
|
||||
using EDataType = typename decltype(output_types.first)::type;
|
||||
|
||||
// This is the "compute" type for output.
|
||||
using CShuffleDataType = typename decltype(output_types.second)::type;
|
||||
|
||||
// Data types for the auxiliary tensors (e.g., bias).
|
||||
using DsDataTypes = typename decltype(GetAuxiliaryTensorDataTypes<Signature>())::type;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -41,8 +41,9 @@ struct ConvSignatureInfo
|
||||
{
|
||||
int spatial_dim;
|
||||
builder::ConvDirection direction;
|
||||
std::variant<builder::GroupConvLayout1D, builder::GroupConvLayout2D, builder::GroupConvLayout3D>
|
||||
layout;
|
||||
builder::TensorLayout input_layout;
|
||||
builder::TensorLayout weight_layout;
|
||||
builder::TensorLayout output_layout;
|
||||
builder::DataType data_type;
|
||||
builder::ElementwiseOperation input_element_op;
|
||||
builder::ElementwiseOperation weight_element_op;
|
||||
@@ -106,7 +107,9 @@ class ConvDescription : public Description
|
||||
f.writeLine(0, signature_.spatial_dim, "D ", signature_.direction, " Convolution Kernel");
|
||||
f.writeLine(1, "Signature");
|
||||
f.writeLine(2, "Tensor Type: ", signature_.data_type);
|
||||
f.writeLine(2, "Memory Layout: ", signature_.layout);
|
||||
f.writeLine(2, "Input Layout: ", signature_.input_layout);
|
||||
f.writeLine(2, "Weight Layout: ", signature_.weight_layout);
|
||||
f.writeLine(2, "Output Layout: ", signature_.output_layout);
|
||||
f.writeLine(2, "Input elementwise operation: ", signature_.input_element_op);
|
||||
f.writeLine(2, "Weights elementwise operation: ", signature_.weight_element_op);
|
||||
f.writeLast(2, "Output elementwise operation: ", signature_.output_element_op);
|
||||
@@ -264,7 +267,9 @@ conv::ConvDescription describe()
|
||||
conv::ConvSignatureInfo{
|
||||
.spatial_dim = Traits::spatial_dim,
|
||||
.direction = Traits::direction,
|
||||
.layout = Traits::layout,
|
||||
.input_layout = Traits::layout[0],
|
||||
.weight_layout = Traits::layout[1],
|
||||
.output_layout = Traits::layout[2],
|
||||
.data_type = Traits::data_type,
|
||||
.input_element_op = Traits::input_element_op,
|
||||
.weight_element_op = Traits::weight_element_op,
|
||||
|
||||
@@ -298,7 +298,10 @@ constexpr auto conv_spec()
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::GroupConvLayout{1D|2D|3D}` enum value corresponding to the tensor layouts.
|
||||
/// @return An std::array corresponding to the tensor layouts:
|
||||
/// index 0 -> Input layout
|
||||
/// index 1 -> Weight layout
|
||||
/// index 2 -> Output layout
|
||||
template <typename Instance>
|
||||
constexpr auto conv_layout()
|
||||
{
|
||||
@@ -314,22 +317,30 @@ constexpr auto conv_layout()
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNWC> && std::is_same_v<BLayout, ctc::GKXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNWK>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::GNWC_GKXC_GNWK;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNWC,
|
||||
builder::TensorLayout::GKXC,
|
||||
builder::TensorLayout::GNWK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NWGK>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::NWGC_GKXC_NWGK;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NWGC,
|
||||
builder::TensorLayout::GKXC,
|
||||
builder::TensorLayout::NWGK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
|
||||
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NGKW>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::NGCW_GKXC_NGKW;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCW,
|
||||
builder::TensorLayout::GKXC,
|
||||
builder::TensorLayout::NGKW};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCX> && std::is_same_v<ELayout, ctc::NGKW>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::NGCW_GKCX_NGKW;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCW,
|
||||
builder::TensorLayout::GKCX,
|
||||
builder::TensorLayout::NGKW};
|
||||
}
|
||||
}
|
||||
else if constexpr(InstTraits::kSpatialDim == 2)
|
||||
@@ -337,25 +348,33 @@ constexpr auto conv_layout()
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNHWC> && std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNHWK>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNHWC,
|
||||
builder::TensorLayout::GKYXC,
|
||||
builder::TensorLayout::GNHWK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NHWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NHWGK>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NHWGC,
|
||||
builder::TensorLayout::GKYXC,
|
||||
builder::TensorLayout::NHWGK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NGKHW>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::NGCHW_GKYXC_NGKHW;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCHW,
|
||||
builder::TensorLayout::GKYXC,
|
||||
builder::TensorLayout::NGKHW};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCYX> &&
|
||||
std::is_same_v<ELayout, ctc::NGKHW>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::NGCHW_GKCYX_NGKHW;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCHW,
|
||||
builder::TensorLayout::GKCYX,
|
||||
builder::TensorLayout::NGKHW};
|
||||
}
|
||||
}
|
||||
else if constexpr(InstTraits::kSpatialDim == 3)
|
||||
@@ -363,25 +382,33 @@ constexpr auto conv_layout()
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNDHWC> && std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNDHWK>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNDHWC,
|
||||
builder::TensorLayout::GKZYXC,
|
||||
builder::TensorLayout::GNDHWK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NDHWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NDHWGK>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NDHWGC,
|
||||
builder::TensorLayout::GKZYXC,
|
||||
builder::TensorLayout::NDHWGK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCDHW,
|
||||
builder::TensorLayout::GKZYXC,
|
||||
builder::TensorLayout::NGKDHW};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCZYX> &&
|
||||
std::is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCDHW,
|
||||
builder::TensorLayout::GKCZYX,
|
||||
builder::TensorLayout::NGKDHW};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -433,22 +460,10 @@ template <typename ElementwiseOp>
|
||||
constexpr builder::ElementwiseOperation elementwise_op()
|
||||
{
|
||||
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
|
||||
if constexpr(detail::case_insensitive_equal(name, "Bias"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BIAS;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasClamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BIAS_CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
|
||||
if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BIAS_BNORM_CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BILINEAR;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::CLAMP;
|
||||
@@ -461,6 +476,10 @@ constexpr builder::ElementwiseOperation elementwise_op()
|
||||
{
|
||||
return builder::ElementwiseOperation::PASS_THROUGH;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
|
||||
{
|
||||
return builder::ElementwiseOperation::SCALEADD_SCALEADD_RELU;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives a gemm padding from a kernel instance type.
|
||||
|
||||
@@ -6,64 +6,91 @@
|
||||
#include <ostream>
|
||||
#include <string_view>
|
||||
#include <variant>
|
||||
#include <bit>
|
||||
#include <array>
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
enum class DataType
|
||||
{
|
||||
UNDEFINDED = 0,
|
||||
FP32,
|
||||
FP16,
|
||||
BF16,
|
||||
FP8,
|
||||
INT32,
|
||||
I8,
|
||||
U8
|
||||
};
|
||||
|
||||
// Memory layouts for 1D convolution tensors.
|
||||
// G: Group, N: Batch, K: Output Channel, C: Input Channel, W: Width
|
||||
// Enum defines Input, Weight, and Output tensor layouts respectively.
|
||||
enum class GroupConvLayout1D
|
||||
enum class TensorLayout
|
||||
{
|
||||
GNWC_GKXC_GNWK,
|
||||
NWGC_GKXC_NWGK,
|
||||
NGCW_GKXC_NGKW,
|
||||
NGCW_GKCX_NGKW
|
||||
};
|
||||
UNDEFINED,
|
||||
|
||||
// Memory layouts for 2D convolution tensors.
|
||||
// G: Group, N: Batch, K: Output Channel, C: Input Channel, Y: Height, X: Width, H: Height
|
||||
// Enum defines Input, Weight, and Output tensor layouts respectively.
|
||||
enum class GroupConvLayout2D
|
||||
{
|
||||
GNHWC_GKYXC_GNHWK,
|
||||
NHWGC_GKYXC_NHWGK,
|
||||
NGCHW_GKYXC_NGKHW,
|
||||
NGCHW_GKCYX_NGKHW
|
||||
};
|
||||
// Bias tensors
|
||||
GC,
|
||||
G_C_strided,
|
||||
G_K_strided,
|
||||
|
||||
// Memory layouts for 3D convolution tensors.
|
||||
// G: Group, N: Batch, K: Output Channel, C: Input Channel, Z: Depth, Y: Height, X: Width, D: Depth,
|
||||
// H: Height Enum defines Input, Weight, and Output tensor layouts respectively.
|
||||
enum class GroupConvLayout3D
|
||||
{
|
||||
GNDHWC_GKZYXC_GNDHWK,
|
||||
NDHWGC_GKZYXC_NDHWGK,
|
||||
NGCDHW_GKZYXC_NGKDHW,
|
||||
NGCDHW_GKCZYX_NGKDHW,
|
||||
};
|
||||
// 1D conv input tensor
|
||||
GNCW,
|
||||
GNWC,
|
||||
NWGC,
|
||||
NGCW,
|
||||
G_NW_C_strided,
|
||||
|
||||
struct GroupConvLayout
|
||||
{
|
||||
union
|
||||
{
|
||||
GroupConvLayout1D _1d;
|
||||
GroupConvLayout2D _2d;
|
||||
GroupConvLayout3D _3d;
|
||||
};
|
||||
// 2D conv input tensor
|
||||
GNCHW,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
NGCHW,
|
||||
G_NHW_C_strided,
|
||||
|
||||
constexpr GroupConvLayout(GroupConvLayout1D layout) : _1d(layout) {}
|
||||
constexpr GroupConvLayout(GroupConvLayout2D layout) : _2d(layout) {}
|
||||
constexpr GroupConvLayout(GroupConvLayout3D layout) : _3d(layout) {}
|
||||
// 3D conv input tensor
|
||||
GNCDHW,
|
||||
GNDHWC,
|
||||
NDHWGC,
|
||||
NGCDHW,
|
||||
G_NDHW_C_strided,
|
||||
|
||||
// 1D conv weight tensor
|
||||
GKXC,
|
||||
GKCX,
|
||||
KXGC,
|
||||
G_K_X_C_strided,
|
||||
|
||||
// 2D conv weight tensor
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
KYXGC,
|
||||
G_K_YX_C_strided,
|
||||
|
||||
// 3D conv weight tensor
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
KZYXGC,
|
||||
G_K_ZYX_C_strided,
|
||||
|
||||
// 1D conv output tensor
|
||||
GNKW,
|
||||
GNWK,
|
||||
NWGK,
|
||||
NGKW,
|
||||
G_NW_K_strided,
|
||||
|
||||
// 2D conv output tensor
|
||||
GNKHW,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
NGKHW,
|
||||
G_NHW_K_strided,
|
||||
|
||||
// 3D conv output tensor
|
||||
GNKDHW,
|
||||
GNDHWK,
|
||||
NDHWGK,
|
||||
NGKDHW,
|
||||
G_NDHW_K_strided
|
||||
};
|
||||
|
||||
// Direction of the convolution operation.
|
||||
@@ -77,13 +104,11 @@ enum class ConvDirection
|
||||
// Fused element-wise operations.
|
||||
enum class ElementwiseOperation
|
||||
{
|
||||
BIAS,
|
||||
BIAS_CLAMP,
|
||||
BIAS_BNORM_CLAMP,
|
||||
BILINEAR,
|
||||
CLAMP,
|
||||
SCALE,
|
||||
PASS_THROUGH
|
||||
CLAMP,
|
||||
PASS_THROUGH,
|
||||
SCALEADD_SCALEADD_RELU
|
||||
};
|
||||
|
||||
// Enums for pipeline versions & schedulers
|
||||
@@ -188,8 +213,10 @@ inline std::ostream& operator<<(std::ostream& os, DataType dt)
|
||||
case FP32: return os << "FP32";
|
||||
case BF16: return os << "BF16";
|
||||
case FP8: return os << "FP8";
|
||||
case INT32: return os << "INT32";
|
||||
case I8: return os << "I8";
|
||||
case U8: return os << "U8";
|
||||
case UNDEFINDED: return os << "UNDEFINDED";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
@@ -206,57 +233,16 @@ inline std::ostream& operator<<(std::ostream& os, ConvDirection dir)
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout1D layout)
|
||||
{
|
||||
using enum GroupConvLayout1D;
|
||||
switch(layout)
|
||||
{
|
||||
case GNWC_GKXC_GNWK: return os << "GNWC_GKXC_GNWK";
|
||||
case NWGC_GKXC_NWGK: return os << "NWGC_GKXC_NWGK";
|
||||
case NGCW_GKXC_NGKW: return os << "NGCW_GKXC_NGKW";
|
||||
case NGCW_GKCX_NGKW: return os << "NGCW_GKCX_NGKW";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout2D layout)
|
||||
{
|
||||
using enum GroupConvLayout2D;
|
||||
switch(layout)
|
||||
{
|
||||
case GNHWC_GKYXC_GNHWK: return os << "GNHWC_GKYXC_GNHWK";
|
||||
case NHWGC_GKYXC_NHWGK: return os << "NHWGC_GKYXC_NHWGK";
|
||||
case NGCHW_GKYXC_NGKHW: return os << "NGCHW_GKYXC_NGKHW";
|
||||
case NGCHW_GKCYX_NGKHW: return os << "NGCHW_GKCYX_NGKHW";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout)
|
||||
{
|
||||
using enum GroupConvLayout3D;
|
||||
switch(layout)
|
||||
{
|
||||
case GNDHWC_GKZYXC_GNDHWK: return os << "GNDHWC_GKZYXC_GNDHWK";
|
||||
case NDHWGC_GKZYXC_NDHWGK: return os << "NDHWGC_GKZYXC_NDHWGK";
|
||||
case NGCDHW_GKZYXC_NGKDHW: return os << "NGCDHW_GKZYXC_NGKDHW";
|
||||
case NGCDHW_GKCZYX_NGKDHW: return os << "NGCDHW_GKCZYX_NGKDHW";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
switch(op)
|
||||
{
|
||||
case BIAS: return os << "BIAS";
|
||||
case BIAS_CLAMP: return os << "BIAS_CLAMP";
|
||||
case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP";
|
||||
case BILINEAR: return os << "BILINEAR";
|
||||
case CLAMP: return os << "CLAMP";
|
||||
case SCALE: return os << "SCALE";
|
||||
case PASS_THROUGH: return os << "PASS_THROUGH";
|
||||
case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP";
|
||||
case SCALEADD_SCALEADD_RELU: return os << "SCALEADD_SCALEADD_RELU";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
@@ -375,13 +361,59 @@ inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
|
||||
}
|
||||
}
|
||||
|
||||
// ostream operator overload for std::variant of layout types
|
||||
inline std::ostream&
|
||||
operator<<(std::ostream& os,
|
||||
const std::variant<GroupConvLayout1D, GroupConvLayout2D, GroupConvLayout3D>& layout)
|
||||
inline std::ostream& operator<<(std::ostream& os, TensorLayout layout)
|
||||
{
|
||||
std::visit([&os](const auto& l) { os << l; }, layout);
|
||||
return os;
|
||||
using enum TensorLayout;
|
||||
switch(layout)
|
||||
{
|
||||
case GNCW: return os << "GNCW";
|
||||
case GNWC: return os << "GNWC";
|
||||
case NWGC: return os << "NWGC";
|
||||
case NGCW: return os << "NGCW";
|
||||
case G_NW_C_strided: return os << "G_NW_C_strided";
|
||||
case GNCHW: return os << "GNCHW";
|
||||
case GNHWC: return os << "GNHWC";
|
||||
case NHWGC: return os << "NHWGC";
|
||||
case NGCHW: return os << "NGCHW";
|
||||
case G_NHW_C_strided: return os << "G_NHW_C_strided";
|
||||
case GNCDHW: return os << "GNCDHW";
|
||||
case GNDHWC: return os << "GNDHWC";
|
||||
case NDHWGC: return os << "NDHWGC";
|
||||
case NGCDHW: return os << "NGCDHW";
|
||||
case G_NDHW_C_strided: return os << "G_NDHW_C_strided";
|
||||
case GKXC: return os << "GKXC";
|
||||
case GKCX: return os << "GKCX";
|
||||
case KXGC: return os << "KXGC";
|
||||
case G_K_X_C_strided: return os << "G_K_X_C_strided";
|
||||
case GKYXC: return os << "GKYXC";
|
||||
case GKCYX: return os << "GKCYX";
|
||||
case KYXGC: return os << "KYXGC";
|
||||
case G_K_YX_C_strided: return os << "G_K_YX_C_strided";
|
||||
case GKZYXC: return os << "GKZYXC";
|
||||
case GKCZYX: return os << "GKCZYX";
|
||||
case KZYXGC: return os << "KZYXGC";
|
||||
case G_K_ZYX_C_strided: return os << "G_K_ZYX_C_strided";
|
||||
case GNKW: return os << "GNKW";
|
||||
case GNWK: return os << "GNWK";
|
||||
case NWGK: return os << "NWGK";
|
||||
case NGKW: return os << "NGKW";
|
||||
case G_NW_K_strided: return os << "G_NW_K_strided";
|
||||
case GNKHW: return os << "GNKHW";
|
||||
case GNHWK: return os << "GNHWK";
|
||||
case NHWGK: return os << "NHWGK";
|
||||
case NGKHW: return os << "NGKHW";
|
||||
case G_NHW_K_strided: return os << "G_NHW_K_strided";
|
||||
case GNKDHW: return os << "GNKDHW";
|
||||
case GNDHWK: return os << "GNDHWK";
|
||||
case NDHWGK: return os << "NDHWGK";
|
||||
case NGKDHW: return os << "NGKDHW";
|
||||
case G_NDHW_K_strided: return os << "G_NDHW_K_strided";
|
||||
case GC: return os << "GC";
|
||||
case G_C_strided: return os << "G_C_strided";
|
||||
case G_K_strided: return os << "G_K_strided";
|
||||
case UNDEFINED: return os << "UNDEFINED";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
// ostream operator overload for std::variant of convolution specializations
|
||||
|
||||
Reference in New Issue
Block a user