mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_BUILDER] Refactor convolution signature to provide data type/layout/elementwise op per tensor (#3331)
* Separate layouts into separate entities for input, weight, and output tensors. * Add test for handling bias tensor layouts. * Use instance string in builder tests. * Add handling of output bias data types and layouts. * Generalize handling of the elementwise ops. * Test fix. * Create builder for layouts. * Layout builder improvements. * Improve layout builder. * Simplify bias layout handling. * Code clean-up. * Move layout utils into separate file. * Remove hard-coded layout combinations. * Small code clean-up. * Move data type utils into a separate file. * Add data types, layouts, and elementwise ops per conv tensor. * Builder bug fixes after refactoring. * Working baseline. * Make signature definition look nice in the test code. * Move TensorConfig into test implementations. * Fix all fwd conv builder tests. * Fix conv traits and descriptors tests. * More factory assets under a separate directory. * Fix building conv traits. * Fix clang-format. * Add Readme doc to describe the design. * Add link to main Readme. Fix links in the builder design doc. * Clean-up data type/layout/elementwise op conversions. * Switch from dimension and tensor type specific layouts to a flat list of tensor layouts. * Fix clang-formatting. * Fix clang-format for test code. * Simplify fwd conv signature definitions in the test code. * Remove accidental edits. * Fix comment string. * Fix instance factory after rebase. * Fix tests after rebase. * Unify layout handling. * Add more conv layout unit tests. * Clang-format. * Fix merge conflicts. * Improve elementwise op handling. --------- Co-authored-by: Ville Pietilä <>
This commit is contained in:
@@ -10,6 +10,10 @@ The builder provides a high-level, semantically-clear interface for constructing
|
||||
|
||||
This project is a prototype for a more general builder pattern for all of composable_kernel (CK) and CKTile, but is currently limited to formalizing the interface between MIOpen and CK.
|
||||
|
||||
## Design descriptions
|
||||
|
||||
- [CK Builder design description](include/ck_tile/builder/README.md)
|
||||
|
||||
## Directory Structure
|
||||
|
||||
- `include/ck_tile/builder/`
|
||||
|
||||
244
experimental/builder/include/ck_tile/builder/README.md
Normal file
244
experimental/builder/include/ck_tile/builder/README.md
Normal file
@@ -0,0 +1,244 @@
|
||||
# Composable Kernel Builder Design Documentation
|
||||
|
||||
This directory contains the builder framework for Composable Kernel, which provides a compile-time, type-safe interface for constructing convolution operations with various configurations.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Convolution Signature Design](#convolution-signature-design)
|
||||
- [Overview](#overview)
|
||||
- [Architecture](#architecture)
|
||||
- [Core Components](#core-components)
|
||||
- [Concepts and Validation](#concepts-and-validation)
|
||||
---
|
||||
|
||||
## Convolution Signature Design
|
||||
|
||||
### Overview
|
||||
|
||||
The convolution signature system provides a **compile-time description** of grouped convolution operations. A signature is a collection of properties that fully characterize a convolution kernel's mathematical and operational behavior, enabling:
|
||||
|
||||
- **Compile-time validation**: Ensures type safety and correctness before kernel instantiation
|
||||
- **Kernel selection**: Matches user requirements to optimized implementations
|
||||
- **Specialization**: Enables optimized code paths for specific configurations
|
||||
- **Composability**: Supports building complex operations from simpler components
|
||||
|
||||
The signature leverages modern C++20 features, particularly **concepts**, to provide expressive, self-documenting interfaces with compile-time guarantees.
|
||||
|
||||
### Architecture
|
||||
|
||||
The signature system is organized into a hierarchical structure:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ ConvSignature │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ Properties: │
|
||||
│ • spatial_dim: int (1D, 2D, or 3D) │
|
||||
│ • direction: ConvDirection (Fwd/BwdData/BwdWeight) │
|
||||
│ • data_type: DataType (default data type) │
|
||||
│ • accumulation_data_type: DataType │
|
||||
│ • input: ConvTensor ──┐ │
|
||||
│ • weight: ConvTensor ──│ │
|
||||
│ • output: ConvTensor ──│ │
|
||||
└──────────────────────────────────┼──────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────┐
|
||||
│ ConvTensor │
|
||||
├─────────────────────────────────────────┤
|
||||
│ ╔═════════════════════════════════════╗ │
|
||||
│ ║ TensorConfig (required) ║ │
|
||||
│ ╠═════════════════════════════════════╣ │
|
||||
│ ║ • layout: ConvLayout ║ │
|
||||
│ ║ • data_type: DataType (optional) ║ │
|
||||
│ ║ • compute_type: DataType (optional)║ │
|
||||
│ ╚═════════════════════════════════════╝ │
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────┐ │
|
||||
│ │ TensorOperation (optional) │ │
|
||||
│ ├─────────────────────────────────────┤ │
|
||||
│ │ • elementwise_operation │ │
|
||||
│ │ • auxiliary_operand_configs[] │ │
|
||||
│ │ (each is also ConvTensor) ◄───────┼─┐
|
||||
│ └─────────────────────────────────────┘ │ │
|
||||
└─────────────────────────────────────────┘ │
|
||||
│
|
||||
Recursive ───────────────┘
|
||||
```
|
||||
Key Design Points:
|
||||
- ConvSignature contains three ConvTensor instances (input, weight, output)
|
||||
- All tensors share the same ConvTensor structure
|
||||
- Each ConvTensor has:
|
||||
- TensorConfig (required): Defines layout as well as optional data and compute type overrides
|
||||
- TensorOperation (optional): Defines fused elementwise operations
|
||||
- Auxiliary operands (e.g., bias) in TensorOperation also use the ConvTensor type
|
||||
|
||||
### Core Components
|
||||
|
||||
#### 1. Signature Level
|
||||
|
||||
The top-level signature contains global properties that apply to the entire convolution operation:
|
||||
|
||||
```cpp
|
||||
template <typename T>
|
||||
concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.spatial_dim } -> std::convertible_to<unsigned int>; // 1, 2, or 3
|
||||
{ t.data_type } -> std::convertible_to<DataType>; // Default data type
|
||||
{ t.input } -> ConvTensorDescriptor;
|
||||
{ t.weight } -> ConvTensorDescriptor;
|
||||
{ t.output } -> ConvTensorDescriptor;
|
||||
requires ConvolutionDirectionWellDefinedIfProvided<T>; // Optional direction
|
||||
};
|
||||
```
|
||||
|
||||
**Properties:**
|
||||
- **`spatial_dim`**: Dimensionality of the convolution (1D, 2D, or 3D)
|
||||
- **`direction`**: Operation type (optional, defaults to FORWARD)
|
||||
- `FORWARD`: Standard forward convolution
|
||||
- `BACKWARD_DATA`: Gradient computation w.r.t. input
|
||||
- `BACKWARD_WEIGHT`: Gradient computation w.r.t. weights
|
||||
- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8)
|
||||
- **`accumulation_data_type`**: Type used for internal accumulation
|
||||
|
||||
#### 2. Tensor Level
|
||||
|
||||
Each tensor (input, weight, output) has its own descriptor:
|
||||
|
||||
```cpp
|
||||
template <typename T>
|
||||
concept ConvTensorDescriptor = requires(T t) {
|
||||
{ t.config } -> TensorConfigDescriptor;
|
||||
requires ElementwiseOpWellDefinedIfProvided<T>;
|
||||
};
|
||||
```
|
||||
|
||||
A tensor descriptor encapsulates:
|
||||
- **Configuration**: Layout and data type information
|
||||
- **Operation** (optional): Fused elementwise operations on this tensor
|
||||
|
||||
#### 3. Tensor Configuration
|
||||
|
||||
Describes the memory layout and data types:
|
||||
|
||||
```cpp
|
||||
template <typename T>
|
||||
concept TensorConfigDescriptor = requires(T t) {
|
||||
{ t.layout } -> std::convertible_to<ConvLayout>;
|
||||
{ t.data_type } -> std::convertible_to<DataType>; // Optional override
|
||||
};
|
||||
```
|
||||
|
||||
**Layout Types** (dimension-specific):
|
||||
- **1D Convolution**:
|
||||
- Input: `GNCW`, `GNWC`, `NWGC`, `NGCW`, `G_NW_C_strided`
|
||||
- Weight: `GKXC`, `GKCX`, `KXGC`, `G_K_X_C_strided`
|
||||
- Output: `GNKW`, `GNWK`, `NWGK`, `NGKW`, `G_NW_K_strided`
|
||||
|
||||
- **2D Convolution**:
|
||||
- Input: `GNCHW`, `GNHWC`, `NHWGC`, `NGCHW`, `G_NHW_C_strided`
|
||||
- Weight: `GKYXC`, `GKCYX`, `KYXGC`, `G_K_YX_C_strided`
|
||||
- Output: `GNKHW`, `GNHWK`, `NHWGK`, `NGKHW`, `G_NHW_K_strided`
|
||||
|
||||
- **3D Convolution**:
|
||||
- Input: `GNCDHW`, `GNDHWC`, `NDHWGC`, `NGCDHW`, `G_NDHW_C_strided`
|
||||
- Weight: `GKZYXC`, `GKCZYX`, `KZYXGC`, `G_K_ZYX_C_strided`
|
||||
- Output: `GNKDHW`, `GNDHWK`, `NDHWGK`, `NGKDHW`, `G_NDHW_K_strided`
|
||||
|
||||
Where:
|
||||
- `G` = Groups
|
||||
- `N` = Batch size
|
||||
- `C` = Input channels
|
||||
- `K` = Output channels (filters)
|
||||
- `W`, `H`, `D` = Width, Height, Depth (spatial dimensions)
|
||||
- `X`, `Y`, `Z` = Filter dimensions
|
||||
|
||||
#### 4. Tensor Operations
|
||||
|
||||
Describes fused elementwise operations applied to a tensor:
|
||||
|
||||
```cpp
|
||||
template <typename T>
|
||||
concept TensorOperatorDescriptor = requires(T t) {
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
requires AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
|
||||
};
|
||||
```
|
||||
|
||||
**Supported Operations:**
|
||||
- `PASS_THROUGH`: No operation (identity)
|
||||
- `SCALE`: Multiply by a scalar
|
||||
- `CLAMP`: Clamp values to a range
|
||||
- `BIAS_BNORM_CLAMP`: Bias addition + batch normalization + clamp
|
||||
- `SCALEADD_SCALEADD_RELU`: Fused scale-add operations + ReLU activation
|
||||
|
||||
**Auxiliary Operands:**
|
||||
Some operations require additional tensor inputs (e.g., bias tensors, scaling factors). These are specified through `auxiliary_operand_configs`, which is an array of `TensorConfigDescriptor` objects describing the layout and data type of each auxiliary input.
|
||||
|
||||
### Concepts and Validation
|
||||
|
||||
The signature system uses C++20 concepts for compile-time validation at multiple levels:
|
||||
|
||||
#### Constraint Concepts
|
||||
|
||||
```cpp
|
||||
// Spatial dimension must be 1, 2, or 3
|
||||
template <auto N>
|
||||
concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 || N == 3);
|
||||
|
||||
// Valid data types for convolution
|
||||
template <DataType T>
|
||||
concept ValidConvDataType =
|
||||
(T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
|
||||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
|
||||
```
|
||||
|
||||
#### Validation Concept
|
||||
|
||||
```cpp
|
||||
// Validates a complete signature
|
||||
template <auto Sig>
|
||||
concept ValidConvSignature = requires {
|
||||
requires ConvSpatialDim<Sig.spatial_dim>;
|
||||
requires ValidConvDataType<Sig.data_type>;
|
||||
};
|
||||
```
|
||||
|
||||
#### Tensor Descriptors
|
||||
|
||||
The layout/data type/elementwise operation are described per tensor. This multi-level hierarchy allows:
|
||||
- **Flexibility**: Each tensor can have independent layout and data type
|
||||
- **Reusability**: Common configurations can be shared across different signatures
|
||||
- **Extensibility**: New properties can be added to specific levels without affecting others
|
||||
- **Clarity**: Separates concerns (global properties vs. tensor-specific properties)
|
||||
|
||||
#### Optional Signature Fields
|
||||
|
||||
Several fields in the signature are optional:
|
||||
- **`direction`**: Defaults to `FORWARD` if not specified, reducing boilerplate for the common case
|
||||
- **Tensor `data_type`**: Falls back to signature's default, allowing mixed-precision with minimal specification
|
||||
- **Tensor `operation`**: Defaults to `PASS_THROUGH`, supporting both fused and non-fused operations with the same interface
|
||||
|
||||
This design follows the principle of "make the common case simple, the complex case possible."
|
||||
|
||||
#### Union-Based Layout Representation
|
||||
|
||||
The `ConvLayout` type uses unions to support dimension-agnostic code:
|
||||
|
||||
```cpp
|
||||
struct ConvLayout {
|
||||
union {
|
||||
ConvInputLayout _input_layout;
|
||||
ConvWeightLayout _weight_layout;
|
||||
ConvOutputLayout _output_layout;
|
||||
ConvAuxiliaryTensorLayout _aux_tensor_layout;
|
||||
};
|
||||
// ... constructors for each type
|
||||
};
|
||||
```
|
||||
|
||||
This allows:
|
||||
- Single type to represent all layout variants
|
||||
- Type-safe construction through overloaded constructors
|
||||
- Compile-time enforcement of valid combinations through concepts
|
||||
|
||||
---
|
||||
@@ -28,24 +28,104 @@ namespace ck_tile::builder {
|
||||
template <auto N>
|
||||
concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 || N == 3);
|
||||
|
||||
// Constraints for forward convolution layouts.
|
||||
template <auto LayoutValue, size_t SpatialDim>
|
||||
concept ValidConvLayoutForSpatialDim =
|
||||
(SpatialDim == 1 && std::same_as<decltype(LayoutValue), GroupConvLayout1D>) ||
|
||||
(SpatialDim == 2 && std::same_as<decltype(LayoutValue), GroupConvLayout2D>) ||
|
||||
(SpatialDim == 3 && std::same_as<decltype(LayoutValue), GroupConvLayout3D>);
|
||||
|
||||
// Constrains convolution data types to common floating-point types.
|
||||
template <DataType T>
|
||||
concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
|
||||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
|
||||
concept ValidConvDataType =
|
||||
(T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
|
||||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept BiasTensorLayout =
|
||||
(L == TensorLayout::GC) || (L == TensorLayout::G_C_strided) || (L == TensorLayout::G_K_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvInputLayout1D =
|
||||
(L == TensorLayout::GNCW) || (L == TensorLayout::GNWC) || (L == TensorLayout::NWGC) ||
|
||||
(L == TensorLayout::NGCW) || (L == TensorLayout::G_NW_C_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvInputLayout2D =
|
||||
(L == TensorLayout::GNCHW) || (L == TensorLayout::GNHWC) || (L == TensorLayout::NHWGC) ||
|
||||
(L == TensorLayout::NGCHW) || (L == TensorLayout::G_NHW_C_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvInputLayout3D =
|
||||
(L == TensorLayout::GNCDHW) || (L == TensorLayout::GNDHWC) || (L == TensorLayout::NDHWGC) ||
|
||||
(L == TensorLayout::NGCDHW) || (L == TensorLayout::G_NDHW_C_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvWeightLayout1D = (L == TensorLayout::GKXC) || (L == TensorLayout::GKCX) ||
|
||||
(L == TensorLayout::KXGC) || (L == TensorLayout::G_K_X_C_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvWeightLayout2D = (L == TensorLayout::GKYXC) || (L == TensorLayout::GKCYX) ||
|
||||
(L == TensorLayout::KYXGC) || (L == TensorLayout::G_K_YX_C_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvWeightLayout3D = (L == TensorLayout::GKZYXC) || (L == TensorLayout::GKCZYX) ||
|
||||
(L == TensorLayout::KZYXGC) || (L == TensorLayout::G_K_ZYX_C_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvOutputLayout1D =
|
||||
(L == TensorLayout::GNKW) || (L == TensorLayout::GNWK) || (L == TensorLayout::NWGK) ||
|
||||
(L == TensorLayout::NGKW) || (L == TensorLayout::G_NW_K_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvOutputLayout2D =
|
||||
(L == TensorLayout::GNKHW) || (L == TensorLayout::GNHWK) || (L == TensorLayout::NHWGK) ||
|
||||
(L == TensorLayout::NGKHW) || (L == TensorLayout::G_NHW_K_strided);
|
||||
|
||||
template <TensorLayout L>
|
||||
concept ConvOutputLayout3D =
|
||||
(L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) ||
|
||||
(L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided);
|
||||
|
||||
template <typename T>
|
||||
concept ConvLayout = std::same_as<std::remove_cvref_t<T>, GroupConvLayout>;
|
||||
concept TensorConfigDescriptor = requires(T t) {
|
||||
{ t.layout } -> std::convertible_to<TensorLayout>;
|
||||
// Only require that data type is defined. It might be set to undefined value, in which case the
|
||||
// signature's data type is used.
|
||||
{ t.data_type } -> std::convertible_to<DataType>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept HasElementwiseOp = requires(T t) {
|
||||
{ t.elementwise_operation };
|
||||
concept HasAuxiliaryOperandConfigs = requires(T t) {
|
||||
{ t.auxiliary_operand_configs };
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
struct IsArrayOfTensorConfigDescriptors : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T, std::size_t N>
|
||||
requires TensorConfigDescriptor<T>
|
||||
struct IsArrayOfTensorConfigDescriptors<std::array<T, N>> : std::true_type
|
||||
{
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
concept ConvertibleToArrayOfTensorConfigs =
|
||||
detail::IsArrayOfTensorConfigDescriptors<std::remove_cvref_t<T>>::value;
|
||||
|
||||
template <typename T>
|
||||
concept AuxiliaryOperandConfigsWellDefinedIfProvided = requires(T t) {
|
||||
requires !HasAuxiliaryOperandConfigs<T> || requires {
|
||||
{ t.auxiliary_operand_configs } -> ConvertibleToArrayOfTensorConfigs;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept TensorOperatorDescriptor = requires(T t) {
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
requires AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept HasTensorOp = requires(T t) {
|
||||
{ t.operation };
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -56,11 +136,8 @@ concept HasConvolutionDirection = requires(T t) {
|
||||
// Note: it is not required to provide an ElementwiseOp, but if one is provided, check if well
|
||||
// defined
|
||||
template <typename T>
|
||||
concept ElementwiseOpWellDefinedIfProvided = requires(T t) {
|
||||
requires !HasElementwiseOp<T> || requires {
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
};
|
||||
};
|
||||
concept ElementwiseOpWellDefinedIfProvided =
|
||||
!HasTensorOp<T> || requires(T t) { requires TensorOperatorDescriptor<decltype(t.operation)>; };
|
||||
|
||||
// Note: it is not required to provide a convolution, but if one is provided, check if well defined
|
||||
template <typename T>
|
||||
@@ -70,13 +147,27 @@ concept ConvolutionDirectionWellDefinedIfProvided = requires(T t) {
|
||||
};
|
||||
};
|
||||
|
||||
// Concept for the convolution tensor
|
||||
template <typename T>
|
||||
concept ConvTensorDescriptor = requires(T t) {
|
||||
{ t.config } -> TensorConfigDescriptor;
|
||||
requires ElementwiseOpWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept HasElementwiseOpWithAuxiliaryOperands = requires(T t) {
|
||||
requires HasTensorOp<T>;
|
||||
requires HasAuxiliaryOperandConfigs<decltype(t.operation)>;
|
||||
};
|
||||
|
||||
// Concept for a type that defines a convolution's operational signature.
|
||||
template <typename T>
|
||||
concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
|
||||
{ t.layout } -> ConvLayout;
|
||||
{ t.data_type } -> std::convertible_to<DataType>;
|
||||
requires ElementwiseOpWellDefinedIfProvided<T>;
|
||||
{ t.input } -> ConvTensorDescriptor;
|
||||
{ t.weight } -> ConvTensorDescriptor;
|
||||
{ t.output } -> ConvTensorDescriptor;
|
||||
requires ConvolutionDirectionWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
@@ -84,7 +175,7 @@ concept ConvSignatureDescriptor = requires(T t) {
|
||||
template <auto Sig>
|
||||
concept ValidConvSignature = requires {
|
||||
requires ConvSpatialDim<Sig.spatial_dim>;
|
||||
requires ConvDataType<Sig.data_type>;
|
||||
requires ValidConvDataType<Sig.data_type>;
|
||||
};
|
||||
|
||||
// Predicate for forward convolution (default if direction is not included).
|
||||
@@ -100,4 +191,22 @@ concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
|
||||
|
||||
// Constraints for forward convolution input layouts.
|
||||
template <TensorLayout L, size_t SpatialDim>
|
||||
concept ValidConvInputLayoutForSpatialDim =
|
||||
(SpatialDim == 1 && ConvInputLayout1D<L>) || (SpatialDim == 2 && ConvInputLayout2D<L>) ||
|
||||
(SpatialDim == 3 && ConvInputLayout3D<L>);
|
||||
|
||||
// Constraints for forward convolution output layouts.
|
||||
template <TensorLayout L, size_t SpatialDim>
|
||||
concept ValidConvOutputLayoutForSpatialDim =
|
||||
(SpatialDim == 1 && ConvOutputLayout1D<L>) || (SpatialDim == 2 && ConvOutputLayout2D<L>) ||
|
||||
(SpatialDim == 3 && ConvOutputLayout3D<L>);
|
||||
|
||||
// Constraints for forward convolution weight layouts.
|
||||
template <TensorLayout L, size_t SpatialDim>
|
||||
concept ValidConvWeightLayoutForSpatialDim =
|
||||
(SpatialDim == 1 && ConvWeightLayout1D<L>) || (SpatialDim == 2 && ConvWeightLayout2D<L>) ||
|
||||
(SpatialDim == 3 && ConvWeightLayout3D<L>);
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
/**********************************************
|
||||
* constexpr helper functions for optional parameters
|
||||
**********************************************/
|
||||
|
||||
template <auto Sig>
|
||||
concept ProvidesElementwiseOperation = requires { Sig.elementwiseOperation; };
|
||||
|
||||
template <auto Sig>
|
||||
concept ProvidesConvolutionDirection = requires { Sig.direction; };
|
||||
|
||||
template <auto Sig>
|
||||
constexpr auto get_elementwise_operation()
|
||||
{
|
||||
if constexpr(ProvidesElementwiseOperation<Sig>)
|
||||
{
|
||||
return Sig.elementwise_operation;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ElementwiseOperation::PASS_THROUGH;
|
||||
}
|
||||
}
|
||||
|
||||
template <auto Sig>
|
||||
constexpr auto get_conv_direction()
|
||||
{
|
||||
if constexpr(ProvidesConvolutionDirection<Sig>)
|
||||
{
|
||||
return Sig.direction;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ConvDirection::FORWARD;
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile::builder
|
||||
@@ -7,7 +7,6 @@
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
@@ -25,11 +24,9 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdDlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
@@ -27,11 +26,9 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdLargeTensorFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm;
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
@@ -27,11 +26,9 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdXdlV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load ==
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
@@ -27,11 +26,9 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdWmmaFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
@@ -27,11 +26,9 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
|
||||
@@ -6,32 +6,70 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
template <ElementwiseOperation T>
|
||||
template <ElementwiseOperation Op>
|
||||
struct ElementwiseOpToCK
|
||||
{
|
||||
static_assert(sizeof(UnsupportedEnumValue<Op>) == 0,
|
||||
"Unsupported elementwise operation conversion to CK.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCK<ElementwiseOperation::PASS_THROUGH>
|
||||
{
|
||||
using Op = ck::tensor_operation::element_wise::PassThrough;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCK<ElementwiseOperation::SCALE>
|
||||
{
|
||||
using Op = ck::tensor_operation::element_wise::Scale;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCK<ElementwiseOperation::CLAMP>
|
||||
{
|
||||
using Op = ck::tensor_operation::element_wise::Clamp;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCK<ElementwiseOperation::SCALEADD_SCALEADD_RELU>
|
||||
{
|
||||
using Op = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCK<ElementwiseOperation::BIAS_BNORM_CLAMP>
|
||||
{
|
||||
using Op = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
};
|
||||
|
||||
template <auto TensorDesc>
|
||||
consteval auto GetElementwiseOp()
|
||||
{
|
||||
if constexpr(HasTensorOp<decltype(TensorDesc)>)
|
||||
{
|
||||
constexpr auto op = TensorDesc.operation.elementwise_operation;
|
||||
return ElementwiseOpToCK<op>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ElementwiseOpToCK<ElementwiseOperation::PASS_THROUGH>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <auto Sig>
|
||||
struct ElementwiseOps
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported elementwise operation for convolution factory.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOps<ElementwiseOperation::PASS_THROUGH>
|
||||
{
|
||||
using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOps<ElementwiseOperation::SCALE>
|
||||
{
|
||||
using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementwiseOp = ck::tensor_operation::element_wise::Scale;
|
||||
static constexpr auto input_op = GetElementwiseOp<Sig.input>();
|
||||
static constexpr auto weight_op = GetElementwiseOp<Sig.weight>();
|
||||
static constexpr auto output_op = GetElementwiseOp<Sig.output>();
|
||||
using AElementwiseOp = typename decltype(input_op)::Op;
|
||||
using BElementwiseOp = typename decltype(weight_op)::Op;
|
||||
using CDEElementwiseOp = typename decltype(output_op)::Op;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -6,141 +6,228 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types.
|
||||
template <auto LayoutValue, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> && ValidConvLayoutForSpatialDim<LayoutValue, SPATIAL_DIM>)
|
||||
struct ConvTensorLayouts
|
||||
template <TensorLayout Layout>
|
||||
struct LayoutToCK
|
||||
{
|
||||
// This will trigger if a specialization for the given layout is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
using Layout = decltype(LayoutValue);
|
||||
static_assert(sizeof(Layout) == 0,
|
||||
"Internal error. Unsupported layout for convolution factory.");
|
||||
static_assert(sizeof(UnsupportedEnumValue<Layout>) == 0,
|
||||
"Unsupported layout conversion to CK.");
|
||||
};
|
||||
|
||||
// 1D Forward Convolution Layout Specializations
|
||||
// Bias layouts
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NWGC_GKXC_NWGK, 1, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::G_K_strided>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NWGK;
|
||||
using type = ck::tensor_layout::convolution::G_K;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKXC_NGKW, 1, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::GC>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKW;
|
||||
using type = ck::tensor_layout::convolution::GC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::GNWC_GKXC_GNWK, 1, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::G_C_strided>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNWK;
|
||||
using type = ck::tensor_layout::convolution::G_C;
|
||||
};
|
||||
|
||||
// Input 1D
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKCX_NGKW, 1, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::NWGC>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKW;
|
||||
using type = ck::tensor_layout::convolution::NWGC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKYXC_NGKHW, 2, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::NGCW>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKHW;
|
||||
using type = ck::tensor_layout::convolution::NGCW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NHWGC_GKYXC_NHWGK, 2, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::GNWC>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NHWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NHWGK;
|
||||
using type = ck::tensor_layout::convolution::GNWC;
|
||||
};
|
||||
|
||||
// Input 2D
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::GNHWC_GKYXC_GNHWK, 2, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::NGCHW>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNHWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNHWK;
|
||||
using type = ck::tensor_layout::convolution::NGCHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKCYX_NGKHW, 2, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::NHWGC>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCYX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKHW;
|
||||
using type = ck::tensor_layout::convolution::NHWGC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, 3, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::GNHWC>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCDHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCZYX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKDHW;
|
||||
using type = ck::tensor_layout::convolution::GNHWC;
|
||||
};
|
||||
|
||||
// Input 3D
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, 3, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::NGCDHW>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
using type = ck::tensor_layout::convolution::NGCDHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, 3, ConvDirection::FORWARD>
|
||||
struct LayoutToCK<TensorLayout::NDHWGC>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNDHWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNDHWK;
|
||||
using type = ck::tensor_layout::convolution::NDHWGC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GNDHWC>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GNDHWC;
|
||||
};
|
||||
|
||||
template <GroupConvLayout Layout, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
consteval auto GetTensorLayout()
|
||||
// Weight 1D
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GKXC>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GKXC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GKCX>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GKCX;
|
||||
};
|
||||
|
||||
if constexpr(SPATIAL_DIM == 1)
|
||||
{
|
||||
return internal::ConvTensorLayouts<Layout._1d, 1, DIR>{};
|
||||
}
|
||||
else if constexpr(SPATIAL_DIM == 2)
|
||||
{
|
||||
return internal::ConvTensorLayouts<Layout._2d, 2, DIR>{};
|
||||
}
|
||||
else if constexpr(SPATIAL_DIM == 3)
|
||||
{
|
||||
return internal::ConvTensorLayouts<Layout._3d, 3, DIR>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported spatial dimension for convolution layout.");
|
||||
}
|
||||
// Weight 2D
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GKYXC>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GKYXC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GKCYX>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GKCYX;
|
||||
};
|
||||
|
||||
// Weight 3D
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GKCZYX>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GKCZYX;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GKZYXC>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GKZYXC;
|
||||
};
|
||||
|
||||
// Output 1D
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::NWGK>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::NWGK;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::NGKW>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::NGKW;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GNWK>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GNWK;
|
||||
};
|
||||
|
||||
// Output 2D
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::NGKHW>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::NGKHW;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::NHWGK>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::NHWGK;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GNHWK>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GNHWK;
|
||||
};
|
||||
|
||||
// Output 3D
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::NGKDHW>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::NGKDHW;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::NDHWGK>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::NDHWGK;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCK<TensorLayout::GNDHWK>
|
||||
{
|
||||
using type = ck::tensor_layout::convolution::GNDHWK;
|
||||
};
|
||||
|
||||
template <TensorLayout Layout>
|
||||
consteval auto TensorLayoutToCK()
|
||||
{
|
||||
return typename LayoutToCK<Layout>::type{};
|
||||
}
|
||||
|
||||
struct EmptyAuxiliaryTensorLayout
|
||||
{
|
||||
using type = ck::Tuple<>;
|
||||
};
|
||||
|
||||
template <auto AuxiliaryTensorConfigsArray, size_t... Indices>
|
||||
consteval auto GetAuxiliaryTensorLayoutTuple(std::index_sequence<Indices...>)
|
||||
{
|
||||
return ck::Tuple<
|
||||
decltype(TensorLayoutToCK<AuxiliaryTensorConfigsArray[Indices].layout>())...>{};
|
||||
}
|
||||
|
||||
template <auto AuxiliaryTensorConfigsValue, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM>)
|
||||
struct AuxiliaryTensorLayouts
|
||||
{
|
||||
static constexpr auto Size = AuxiliaryTensorConfigsValue.size();
|
||||
using type = decltype(GetAuxiliaryTensorLayoutTuple<AuxiliaryTensorConfigsValue>(
|
||||
std::make_index_sequence<Size>{}));
|
||||
};
|
||||
|
||||
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
|
||||
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTensorLayouts()
|
||||
{
|
||||
return AuxiliaryTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
|
||||
SPATIAL_DIM,
|
||||
DIR>{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTensorLayouts()
|
||||
{
|
||||
return EmptyAuxiliaryTensorLayout{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> &&
|
||||
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
|
||||
struct ConvTensorLayouts
|
||||
{
|
||||
static_assert(DIR == ConvDirection::FORWARD, "Only Forward convolution is supported.");
|
||||
using ALayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
|
||||
using BLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
|
||||
using ELayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
|
||||
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM, DIR>())::type;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -6,82 +6,172 @@
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Type mappings from builder convolution data type to CK tensor types.
|
||||
template <DataType T>
|
||||
struct ConvTensorTypes
|
||||
template <DataType DT>
|
||||
struct DataTypeToCK
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported data type for convolution factory.");
|
||||
// Catch unsupported data types at compile time
|
||||
static_assert(sizeof(UnsupportedEnumValue<DT>) == 0, "Unsupported data type conversion to CK.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP16>
|
||||
struct DataTypeToCK<DataType::FP16>
|
||||
{
|
||||
using ADataType = ck::half_t;
|
||||
using AComputeType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using BComputeType = ck::half_t;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::half_t;
|
||||
using type = ck::half_t;
|
||||
};
|
||||
template <>
|
||||
struct DataTypeToCK<DataType::BF16>
|
||||
{
|
||||
using type = ck::bhalf_t;
|
||||
};
|
||||
template <>
|
||||
struct DataTypeToCK<DataType::FP32>
|
||||
{
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct DataTypeToCK<DataType::INT32>
|
||||
{
|
||||
using type = int32_t;
|
||||
};
|
||||
template <>
|
||||
struct DataTypeToCK<DataType::I8>
|
||||
{
|
||||
using type = int8_t;
|
||||
};
|
||||
template <>
|
||||
struct DataTypeToCK<DataType::FP8>
|
||||
{
|
||||
using type = ck::f8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::BF16>
|
||||
struct CK_empty_tuple
|
||||
{
|
||||
using ADataType = ck::bhalf_t;
|
||||
using AComputeType = ck::bhalf_t;
|
||||
using BDataType = ck::bhalf_t;
|
||||
using BComputeType = ck::bhalf_t;
|
||||
using CShuffleDataType = ck::bhalf_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::bhalf_t;
|
||||
using type = ck::Tuple<>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP32>
|
||||
template <DataType dt>
|
||||
consteval auto ConvertDataTypeToCK()
|
||||
{
|
||||
using ADataType = float;
|
||||
using AComputeType = float;
|
||||
using BDataType = float;
|
||||
using BComputeType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = float;
|
||||
return DataTypeToCK<dt>{};
|
||||
}
|
||||
|
||||
template <auto Config, DataType SignatureDataType>
|
||||
consteval auto GetTensorDataAndComputeTypes()
|
||||
{
|
||||
constexpr auto data_type = Config.data_type;
|
||||
constexpr auto compute_type = Config.compute_type;
|
||||
|
||||
if constexpr(data_type == DataType::UNDEFINDED && compute_type == DataType::UNDEFINDED)
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<SignatureDataType>(),
|
||||
ConvertDataTypeToCK<SignatureDataType>());
|
||||
}
|
||||
else if constexpr(data_type == DataType::UNDEFINDED)
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<SignatureDataType>(),
|
||||
ConvertDataTypeToCK<compute_type>());
|
||||
}
|
||||
else if constexpr(compute_type == DataType::UNDEFINDED)
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<data_type>(),
|
||||
ConvertDataTypeToCK<SignatureDataType>());
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<data_type>(),
|
||||
ConvertDataTypeToCK<compute_type>());
|
||||
}
|
||||
}
|
||||
|
||||
template <DataType SignatureAccDataType, DataType SignatureDataType>
|
||||
consteval auto GetTensorAccumulationType()
|
||||
{
|
||||
constexpr auto data_type = SignatureAccDataType;
|
||||
if constexpr(data_type == DataType::UNDEFINDED)
|
||||
{
|
||||
return ConvertDataTypeToCK<SignatureDataType>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return ConvertDataTypeToCK<data_type>();
|
||||
}
|
||||
}
|
||||
|
||||
template <auto Config, DataType SignatureDataType>
|
||||
consteval auto GetAuxiliaryTensorDataTypeValue()
|
||||
{
|
||||
constexpr auto data_type = Config.data_type;
|
||||
if constexpr(data_type == DataType::UNDEFINDED)
|
||||
{
|
||||
return ConvertDataTypeToCK<SignatureDataType>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return ConvertDataTypeToCK<data_type>();
|
||||
}
|
||||
}
|
||||
|
||||
template <auto AuxiliaryTensorConfigsArray, DataType SignatureDataType, size_t... Indices>
|
||||
consteval auto GetAuxiliaryTensorDataTypeTuple(std::index_sequence<Indices...>)
|
||||
{
|
||||
return ck::Tuple<
|
||||
typename decltype(GetAuxiliaryTensorDataTypeValue<AuxiliaryTensorConfigsArray[Indices],
|
||||
SignatureDataType>())::type...>{};
|
||||
}
|
||||
|
||||
template <auto AuxiliaryTensorConfigsValue, DataType SignatureDataType>
|
||||
struct AuxiliaryTensorDataTypes
|
||||
{
|
||||
static constexpr auto Size = AuxiliaryTensorConfigsValue.size();
|
||||
using type =
|
||||
decltype(GetAuxiliaryTensorDataTypeTuple<AuxiliaryTensorConfigsValue, SignatureDataType>(
|
||||
std::make_index_sequence<Size>{}));
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::I8>
|
||||
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
|
||||
template <auto Signature>
|
||||
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTensorDataTypes()
|
||||
{
|
||||
using ADataType = int8_t;
|
||||
using AComputeType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using BComputeType = int8_t;
|
||||
using CShuffleDataType = int8_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = int32_t;
|
||||
using EDataType = int8_t;
|
||||
};
|
||||
return AuxiliaryTensorDataTypes<Signature.output.operation.auxiliary_operand_configs,
|
||||
Signature.data_type>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP8>
|
||||
template <auto Signature>
|
||||
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTensorDataTypes()
|
||||
{
|
||||
using ADataType = ck::f8_t;
|
||||
using AComputeType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using BComputeType = ck::f8_t;
|
||||
using CShuffleDataType = ck::f8_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::f8_t;
|
||||
return CK_empty_tuple{};
|
||||
}
|
||||
|
||||
template <auto Signature>
|
||||
struct FwdConvTensorDataTypes
|
||||
{
|
||||
static constexpr auto input_types =
|
||||
GetTensorDataAndComputeTypes<Signature.input.config, Signature.data_type>();
|
||||
static constexpr auto weight_types =
|
||||
GetTensorDataAndComputeTypes<Signature.weight.config, Signature.data_type>();
|
||||
static constexpr auto output_types =
|
||||
GetTensorDataAndComputeTypes<Signature.output.config, Signature.data_type>();
|
||||
|
||||
using ADataType = typename decltype(input_types.first)::type;
|
||||
using AComputeType = typename decltype(input_types.second)::type;
|
||||
using BDataType = typename decltype(weight_types.first)::type;
|
||||
using BComputeType = typename decltype(weight_types.second)::type;
|
||||
using AccDataType =
|
||||
typename decltype(GetTensorAccumulationType<Signature.accumulation_data_type,
|
||||
Signature.data_type>())::type;
|
||||
using EDataType = typename decltype(output_types.first)::type;
|
||||
|
||||
// This is the "compute" type for output.
|
||||
using CShuffleDataType = typename decltype(output_types.second)::type;
|
||||
|
||||
// Data types for the auxiliary tensors (e.g., bias).
|
||||
using DsDataTypes = typename decltype(GetAuxiliaryTensorDataTypes<Signature>())::type;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -41,8 +41,9 @@ struct ConvSignatureInfo
|
||||
{
|
||||
int spatial_dim;
|
||||
builder::ConvDirection direction;
|
||||
std::variant<builder::GroupConvLayout1D, builder::GroupConvLayout2D, builder::GroupConvLayout3D>
|
||||
layout;
|
||||
builder::TensorLayout input_layout;
|
||||
builder::TensorLayout weight_layout;
|
||||
builder::TensorLayout output_layout;
|
||||
builder::DataType data_type;
|
||||
builder::ElementwiseOperation input_element_op;
|
||||
builder::ElementwiseOperation weight_element_op;
|
||||
@@ -106,7 +107,9 @@ class ConvDescription : public Description
|
||||
f.writeLine(0, signature_.spatial_dim, "D ", signature_.direction, " Convolution Kernel");
|
||||
f.writeLine(1, "Signature");
|
||||
f.writeLine(2, "Tensor Type: ", signature_.data_type);
|
||||
f.writeLine(2, "Memory Layout: ", signature_.layout);
|
||||
f.writeLine(2, "Input Layout: ", signature_.input_layout);
|
||||
f.writeLine(2, "Weight Layout: ", signature_.weight_layout);
|
||||
f.writeLine(2, "Output Layout: ", signature_.output_layout);
|
||||
f.writeLine(2, "Input elementwise operation: ", signature_.input_element_op);
|
||||
f.writeLine(2, "Weights elementwise operation: ", signature_.weight_element_op);
|
||||
f.writeLast(2, "Output elementwise operation: ", signature_.output_element_op);
|
||||
@@ -264,7 +267,9 @@ conv::ConvDescription describe()
|
||||
conv::ConvSignatureInfo{
|
||||
.spatial_dim = Traits::spatial_dim,
|
||||
.direction = Traits::direction,
|
||||
.layout = Traits::layout,
|
||||
.input_layout = Traits::layout[0],
|
||||
.weight_layout = Traits::layout[1],
|
||||
.output_layout = Traits::layout[2],
|
||||
.data_type = Traits::data_type,
|
||||
.input_element_op = Traits::input_element_op,
|
||||
.weight_element_op = Traits::weight_element_op,
|
||||
|
||||
@@ -298,7 +298,10 @@ constexpr auto conv_spec()
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::GroupConvLayout{1D|2D|3D}` enum value corresponding to the tensor layouts.
|
||||
/// @return An std::array corresponding to the tensor layouts:
|
||||
/// index 0 -> Input layout
|
||||
/// index 1 -> Weight layout
|
||||
/// index 2 -> Output layout
|
||||
template <typename Instance>
|
||||
constexpr auto conv_layout()
|
||||
{
|
||||
@@ -314,22 +317,30 @@ constexpr auto conv_layout()
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNWC> && std::is_same_v<BLayout, ctc::GKXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNWK>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::GNWC_GKXC_GNWK;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNWC,
|
||||
builder::TensorLayout::GKXC,
|
||||
builder::TensorLayout::GNWK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NWGK>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::NWGC_GKXC_NWGK;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NWGC,
|
||||
builder::TensorLayout::GKXC,
|
||||
builder::TensorLayout::NWGK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
|
||||
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NGKW>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::NGCW_GKXC_NGKW;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCW,
|
||||
builder::TensorLayout::GKXC,
|
||||
builder::TensorLayout::NGKW};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCX> && std::is_same_v<ELayout, ctc::NGKW>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::NGCW_GKCX_NGKW;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCW,
|
||||
builder::TensorLayout::GKCX,
|
||||
builder::TensorLayout::NGKW};
|
||||
}
|
||||
}
|
||||
else if constexpr(InstTraits::kSpatialDim == 2)
|
||||
@@ -337,25 +348,33 @@ constexpr auto conv_layout()
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNHWC> && std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNHWK>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNHWC,
|
||||
builder::TensorLayout::GKYXC,
|
||||
builder::TensorLayout::GNHWK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NHWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NHWGK>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NHWGC,
|
||||
builder::TensorLayout::GKYXC,
|
||||
builder::TensorLayout::NHWGK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NGKHW>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::NGCHW_GKYXC_NGKHW;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCHW,
|
||||
builder::TensorLayout::GKYXC,
|
||||
builder::TensorLayout::NGKHW};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCYX> &&
|
||||
std::is_same_v<ELayout, ctc::NGKHW>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::NGCHW_GKCYX_NGKHW;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCHW,
|
||||
builder::TensorLayout::GKCYX,
|
||||
builder::TensorLayout::NGKHW};
|
||||
}
|
||||
}
|
||||
else if constexpr(InstTraits::kSpatialDim == 3)
|
||||
@@ -363,25 +382,33 @@ constexpr auto conv_layout()
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNDHWC> && std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNDHWK>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNDHWC,
|
||||
builder::TensorLayout::GKZYXC,
|
||||
builder::TensorLayout::GNDHWK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NDHWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NDHWGK>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NDHWGC,
|
||||
builder::TensorLayout::GKZYXC,
|
||||
builder::TensorLayout::NDHWGK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCDHW,
|
||||
builder::TensorLayout::GKZYXC,
|
||||
builder::TensorLayout::NGKDHW};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCZYX> &&
|
||||
std::is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW;
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCDHW,
|
||||
builder::TensorLayout::GKCZYX,
|
||||
builder::TensorLayout::NGKDHW};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -433,22 +460,10 @@ template <typename ElementwiseOp>
|
||||
constexpr builder::ElementwiseOperation elementwise_op()
|
||||
{
|
||||
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
|
||||
if constexpr(detail::case_insensitive_equal(name, "Bias"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BIAS;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasClamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BIAS_CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
|
||||
if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BIAS_BNORM_CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BILINEAR;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::CLAMP;
|
||||
@@ -461,6 +476,10 @@ constexpr builder::ElementwiseOperation elementwise_op()
|
||||
{
|
||||
return builder::ElementwiseOperation::PASS_THROUGH;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
|
||||
{
|
||||
return builder::ElementwiseOperation::SCALEADD_SCALEADD_RELU;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives a gemm padding from a kernel instance type.
|
||||
|
||||
@@ -6,64 +6,91 @@
|
||||
#include <ostream>
|
||||
#include <string_view>
|
||||
#include <variant>
|
||||
#include <bit>
|
||||
#include <array>
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
enum class DataType
|
||||
{
|
||||
UNDEFINDED = 0,
|
||||
FP32,
|
||||
FP16,
|
||||
BF16,
|
||||
FP8,
|
||||
INT32,
|
||||
I8,
|
||||
U8
|
||||
};
|
||||
|
||||
// Memory layouts for 1D convolution tensors.
|
||||
// G: Group, N: Batch, K: Output Channel, C: Input Channel, W: Width
|
||||
// Enum defines Input, Weight, and Output tensor layouts respectively.
|
||||
enum class GroupConvLayout1D
|
||||
enum class TensorLayout
|
||||
{
|
||||
GNWC_GKXC_GNWK,
|
||||
NWGC_GKXC_NWGK,
|
||||
NGCW_GKXC_NGKW,
|
||||
NGCW_GKCX_NGKW
|
||||
};
|
||||
UNDEFINED,
|
||||
|
||||
// Memory layouts for 2D convolution tensors.
|
||||
// G: Group, N: Batch, K: Output Channel, C: Input Channel, Y: Height, X: Width, H: Height
|
||||
// Enum defines Input, Weight, and Output tensor layouts respectively.
|
||||
enum class GroupConvLayout2D
|
||||
{
|
||||
GNHWC_GKYXC_GNHWK,
|
||||
NHWGC_GKYXC_NHWGK,
|
||||
NGCHW_GKYXC_NGKHW,
|
||||
NGCHW_GKCYX_NGKHW
|
||||
};
|
||||
// Bias tensors
|
||||
GC,
|
||||
G_C_strided,
|
||||
G_K_strided,
|
||||
|
||||
// Memory layouts for 3D convolution tensors.
|
||||
// G: Group, N: Batch, K: Output Channel, C: Input Channel, Z: Depth, Y: Height, X: Width, D: Depth,
|
||||
// H: Height Enum defines Input, Weight, and Output tensor layouts respectively.
|
||||
enum class GroupConvLayout3D
|
||||
{
|
||||
GNDHWC_GKZYXC_GNDHWK,
|
||||
NDHWGC_GKZYXC_NDHWGK,
|
||||
NGCDHW_GKZYXC_NGKDHW,
|
||||
NGCDHW_GKCZYX_NGKDHW,
|
||||
};
|
||||
// 1D conv input tensor
|
||||
GNCW,
|
||||
GNWC,
|
||||
NWGC,
|
||||
NGCW,
|
||||
G_NW_C_strided,
|
||||
|
||||
struct GroupConvLayout
|
||||
{
|
||||
union
|
||||
{
|
||||
GroupConvLayout1D _1d;
|
||||
GroupConvLayout2D _2d;
|
||||
GroupConvLayout3D _3d;
|
||||
};
|
||||
// 2D conv input tensor
|
||||
GNCHW,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
NGCHW,
|
||||
G_NHW_C_strided,
|
||||
|
||||
constexpr GroupConvLayout(GroupConvLayout1D layout) : _1d(layout) {}
|
||||
constexpr GroupConvLayout(GroupConvLayout2D layout) : _2d(layout) {}
|
||||
constexpr GroupConvLayout(GroupConvLayout3D layout) : _3d(layout) {}
|
||||
// 3D conv input tensor
|
||||
GNCDHW,
|
||||
GNDHWC,
|
||||
NDHWGC,
|
||||
NGCDHW,
|
||||
G_NDHW_C_strided,
|
||||
|
||||
// 1D conv weight tensor
|
||||
GKXC,
|
||||
GKCX,
|
||||
KXGC,
|
||||
G_K_X_C_strided,
|
||||
|
||||
// 2D conv weight tensor
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
KYXGC,
|
||||
G_K_YX_C_strided,
|
||||
|
||||
// 3D conv weight tensor
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
KZYXGC,
|
||||
G_K_ZYX_C_strided,
|
||||
|
||||
// 1D conv output tensor
|
||||
GNKW,
|
||||
GNWK,
|
||||
NWGK,
|
||||
NGKW,
|
||||
G_NW_K_strided,
|
||||
|
||||
// 2D conv output tensor
|
||||
GNKHW,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
NGKHW,
|
||||
G_NHW_K_strided,
|
||||
|
||||
// 3D conv output tensor
|
||||
GNKDHW,
|
||||
GNDHWK,
|
||||
NDHWGK,
|
||||
NGKDHW,
|
||||
G_NDHW_K_strided
|
||||
};
|
||||
|
||||
// Direction of the convolution operation.
|
||||
@@ -77,13 +104,11 @@ enum class ConvDirection
|
||||
// Fused element-wise operations.
|
||||
enum class ElementwiseOperation
|
||||
{
|
||||
BIAS,
|
||||
BIAS_CLAMP,
|
||||
BIAS_BNORM_CLAMP,
|
||||
BILINEAR,
|
||||
CLAMP,
|
||||
SCALE,
|
||||
PASS_THROUGH
|
||||
CLAMP,
|
||||
PASS_THROUGH,
|
||||
SCALEADD_SCALEADD_RELU
|
||||
};
|
||||
|
||||
// Enums for pipeline versions & schedulers
|
||||
@@ -188,8 +213,10 @@ inline std::ostream& operator<<(std::ostream& os, DataType dt)
|
||||
case FP32: return os << "FP32";
|
||||
case BF16: return os << "BF16";
|
||||
case FP8: return os << "FP8";
|
||||
case INT32: return os << "INT32";
|
||||
case I8: return os << "I8";
|
||||
case U8: return os << "U8";
|
||||
case UNDEFINDED: return os << "UNDEFINDED";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
@@ -206,57 +233,16 @@ inline std::ostream& operator<<(std::ostream& os, ConvDirection dir)
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout1D layout)
|
||||
{
|
||||
using enum GroupConvLayout1D;
|
||||
switch(layout)
|
||||
{
|
||||
case GNWC_GKXC_GNWK: return os << "GNWC_GKXC_GNWK";
|
||||
case NWGC_GKXC_NWGK: return os << "NWGC_GKXC_NWGK";
|
||||
case NGCW_GKXC_NGKW: return os << "NGCW_GKXC_NGKW";
|
||||
case NGCW_GKCX_NGKW: return os << "NGCW_GKCX_NGKW";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout2D layout)
|
||||
{
|
||||
using enum GroupConvLayout2D;
|
||||
switch(layout)
|
||||
{
|
||||
case GNHWC_GKYXC_GNHWK: return os << "GNHWC_GKYXC_GNHWK";
|
||||
case NHWGC_GKYXC_NHWGK: return os << "NHWGC_GKYXC_NHWGK";
|
||||
case NGCHW_GKYXC_NGKHW: return os << "NGCHW_GKYXC_NGKHW";
|
||||
case NGCHW_GKCYX_NGKHW: return os << "NGCHW_GKCYX_NGKHW";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout)
|
||||
{
|
||||
using enum GroupConvLayout3D;
|
||||
switch(layout)
|
||||
{
|
||||
case GNDHWC_GKZYXC_GNDHWK: return os << "GNDHWC_GKZYXC_GNDHWK";
|
||||
case NDHWGC_GKZYXC_NDHWGK: return os << "NDHWGC_GKZYXC_NDHWGK";
|
||||
case NGCDHW_GKZYXC_NGKDHW: return os << "NGCDHW_GKZYXC_NGKDHW";
|
||||
case NGCDHW_GKCZYX_NGKDHW: return os << "NGCDHW_GKCZYX_NGKDHW";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
switch(op)
|
||||
{
|
||||
case BIAS: return os << "BIAS";
|
||||
case BIAS_CLAMP: return os << "BIAS_CLAMP";
|
||||
case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP";
|
||||
case BILINEAR: return os << "BILINEAR";
|
||||
case CLAMP: return os << "CLAMP";
|
||||
case SCALE: return os << "SCALE";
|
||||
case PASS_THROUGH: return os << "PASS_THROUGH";
|
||||
case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP";
|
||||
case SCALEADD_SCALEADD_RELU: return os << "SCALEADD_SCALEADD_RELU";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
@@ -375,13 +361,59 @@ inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
|
||||
}
|
||||
}
|
||||
|
||||
// ostream operator overload for std::variant of layout types
|
||||
inline std::ostream&
|
||||
operator<<(std::ostream& os,
|
||||
const std::variant<GroupConvLayout1D, GroupConvLayout2D, GroupConvLayout3D>& layout)
|
||||
inline std::ostream& operator<<(std::ostream& os, TensorLayout layout)
|
||||
{
|
||||
std::visit([&os](const auto& l) { os << l; }, layout);
|
||||
return os;
|
||||
using enum TensorLayout;
|
||||
switch(layout)
|
||||
{
|
||||
case GNCW: return os << "GNCW";
|
||||
case GNWC: return os << "GNWC";
|
||||
case NWGC: return os << "NWGC";
|
||||
case NGCW: return os << "NGCW";
|
||||
case G_NW_C_strided: return os << "G_NW_C_strided";
|
||||
case GNCHW: return os << "GNCHW";
|
||||
case GNHWC: return os << "GNHWC";
|
||||
case NHWGC: return os << "NHWGC";
|
||||
case NGCHW: return os << "NGCHW";
|
||||
case G_NHW_C_strided: return os << "G_NHW_C_strided";
|
||||
case GNCDHW: return os << "GNCDHW";
|
||||
case GNDHWC: return os << "GNDHWC";
|
||||
case NDHWGC: return os << "NDHWGC";
|
||||
case NGCDHW: return os << "NGCDHW";
|
||||
case G_NDHW_C_strided: return os << "G_NDHW_C_strided";
|
||||
case GKXC: return os << "GKXC";
|
||||
case GKCX: return os << "GKCX";
|
||||
case KXGC: return os << "KXGC";
|
||||
case G_K_X_C_strided: return os << "G_K_X_C_strided";
|
||||
case GKYXC: return os << "GKYXC";
|
||||
case GKCYX: return os << "GKCYX";
|
||||
case KYXGC: return os << "KYXGC";
|
||||
case G_K_YX_C_strided: return os << "G_K_YX_C_strided";
|
||||
case GKZYXC: return os << "GKZYXC";
|
||||
case GKCZYX: return os << "GKCZYX";
|
||||
case KZYXGC: return os << "KZYXGC";
|
||||
case G_K_ZYX_C_strided: return os << "G_K_ZYX_C_strided";
|
||||
case GNKW: return os << "GNKW";
|
||||
case GNWK: return os << "GNWK";
|
||||
case NWGK: return os << "NWGK";
|
||||
case NGKW: return os << "NGKW";
|
||||
case G_NW_K_strided: return os << "G_NW_K_strided";
|
||||
case GNKHW: return os << "GNKHW";
|
||||
case GNHWK: return os << "GNHWK";
|
||||
case NHWGK: return os << "NHWGK";
|
||||
case NGKHW: return os << "NGKHW";
|
||||
case G_NHW_K_strided: return os << "G_NHW_K_strided";
|
||||
case GNKDHW: return os << "GNKDHW";
|
||||
case GNDHWK: return os << "GNDHWK";
|
||||
case NDHWGK: return os << "NDHWGK";
|
||||
case NGKDHW: return os << "NGKDHW";
|
||||
case G_NDHW_K_strided: return os << "G_NDHW_K_strided";
|
||||
case GC: return os << "GC";
|
||||
case G_C_strided: return os << "G_C_strided";
|
||||
case G_K_strided: return os << "G_K_strided";
|
||||
case UNDEFINED: return os << "UNDEFINED";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
// ostream operator overload for std::variant of convolution specializations
|
||||
|
||||
@@ -119,6 +119,7 @@ add_ck_builder_test(test_ckb_instance_string
|
||||
# Tests the forward convolution builder across multiple data types and dimensions.
|
||||
# Individual tests are split into separate files to enable parallel compilation.
|
||||
add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp
|
||||
conv/test_ckb_conv_fwd_1d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_1d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_1d_i8.cpp
|
||||
|
||||
@@ -13,11 +13,15 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::NGCW_GKXC_NGKW,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::SCALE};
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::BF16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKW},
|
||||
.operation = {.elementwise_operation = ElementwiseOperation::SCALE}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -30,10 +34,13 @@ TEST(FwdConvInstances,
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"256,256,256,32",
|
||||
"NGCW,GKXC,EmptyTuple,NGKW",
|
||||
"PassThrough,PassThrough,Scale",
|
||||
"Filter1x1Stride1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v2"});
|
||||
"MNKPadding",
|
||||
"Intrawave",
|
||||
"v2"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -10,14 +10,15 @@ using namespace ck_tile::builder::test_utils;
|
||||
|
||||
// 1D FP16 (channels-last) with DEFAULT specialization
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_1D_FP16_ChannelsFirst_scale)
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_1D_FP16_ChannelsFirst)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::NWGC_GKXC_NWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
@@ -28,8 +29,12 @@ TEST(FwdConvInstances,
|
||||
.with_prefetch_config(1, 2, PipelineScheduler::DEFAULT);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>(
|
||||
{"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", "64, 64, 32, 32", "Default"});
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
|
||||
"NWGC,GKXC,EmptyTuple,NWGK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"MNKPadding",
|
||||
"64,64,32,32",
|
||||
"Default"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -14,12 +14,13 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::GNWC_GKXC_GNWK,
|
||||
.data_type = DataType::I8,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::I8,
|
||||
.accumulation_data_type = DataType::INT32,
|
||||
.input = {.config = {.layout = TensorLayout::GNWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{}
|
||||
@@ -30,8 +31,11 @@ TEST(FwdConvInstances,
|
||||
.with_prefetch_config(1, 0, PipelineScheduler::DEFAULT);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>(
|
||||
{"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle", "128, 64, 64, 64", "Default"});
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle",
|
||||
"128,64,64,64",
|
||||
"GNWC,GKXC,EmptyTuple,GNWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"Default"});
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -12,12 +12,13 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::BF16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -29,22 +30,26 @@ TEST(FwdConvInstances,
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"256,256,256,32",
|
||||
"Default",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v1"});
|
||||
"NHWGC,GKYXC,EmptyTuple,NHWGK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"MNKPadding",
|
||||
"Intrawave",
|
||||
"v1"});
|
||||
}
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::BF16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -57,7 +62,10 @@ TEST(FwdConvInstances,
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"Filter3x3",
|
||||
"BlkGemmPipelineVersion: v5"});
|
||||
"NHWGC,GKYXC,EmptyTuple,NHWGK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"MNKPadding",
|
||||
"v5"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder;
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_BF16_scale_add_relu)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::BF16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC, .data_type = DataType::BF16}},
|
||||
.output = ConvolutionTensor{
|
||||
.config = {.layout = TensorLayout::NHWGK},
|
||||
.operation = TensorOperation<>{.elementwise_operation =
|
||||
ElementwiseOperation::SCALEADD_SCALEADD_RELU}
|
||||
.with_auxiliary_operand_configs<TensorLayout::NHWGK,
|
||||
TensorLayout::G_K_strided>()}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_64_64x32x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x2_per_wave)
|
||||
.with_transfer(FwdTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
|
||||
"NHWGC,GKYXC,Tuple(NHWGK,G_K),NHWGK",
|
||||
"PassThrough,PassThrough,ScaleAddScaleAddRelu",
|
||||
"64,64,32,32",
|
||||
"MNKPadding",
|
||||
"Default"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -10,12 +10,13 @@ using namespace ck_tile::builder::test_utils;
|
||||
|
||||
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
|
||||
@@ -26,19 +27,24 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins
|
||||
.with_dl_transfer(DlFwdTransfer);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>(
|
||||
{"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", "256, 128, 128, 16", "Default"});
|
||||
run_test<Builder>({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK",
|
||||
"256,128,128,16",
|
||||
"Default",
|
||||
"MNKPadding",
|
||||
"GNHWC,GKYXC,EmptyTuple,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough"});
|
||||
}
|
||||
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
|
||||
@@ -50,8 +56,12 @@ TEST(FwdConvInstances,
|
||||
.with_dl_transfer(DlFwdTransfer);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>(
|
||||
{"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", "256, 128, 128, 16", "Filter1x1Pad0"});
|
||||
run_test<Builder>({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK",
|
||||
"256,128,128,16",
|
||||
"Filter1x1Pad0",
|
||||
"MNKPadding",
|
||||
"GNHWC,GKYXC,EmptyTuple,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -11,12 +11,13 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -29,10 +30,13 @@ TEST(FwdConvInstances,
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"256,256,256,32",
|
||||
"Filter1x1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v3"});
|
||||
"Intrawave",
|
||||
"v3",
|
||||
"GNHWC,GKYXC,EmptyTuple,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"MNKPadding"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -11,12 +11,13 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP32,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCHW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKCYX}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKHW}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -29,10 +30,13 @@ TEST(FwdConvInstances,
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 128, 128, 32",
|
||||
"256,128,128,32",
|
||||
"Filter1x1Stride1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v4"});
|
||||
"Intrawave",
|
||||
"v4",
|
||||
"NGCHW,GKCYX,EmptyTuple,NGKHW",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"MNKPadding"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -12,12 +12,13 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_FP8_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::FP8,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP8,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
@@ -28,8 +29,12 @@ TEST(FwdConvInstances,
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>(
|
||||
{"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", "256, 256, 128, 32", "Default"});
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
|
||||
"256,256,128,32",
|
||||
"Default",
|
||||
"NHWGC,GKYXC,EmptyTuple,NHWGK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"MNKPadding"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -11,12 +11,13 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
|
||||
@@ -30,20 +31,24 @@ TEST(FwdConvInstances,
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor",
|
||||
"256, 256, 128, 32",
|
||||
"Default"});
|
||||
"256,256,128,32",
|
||||
"Default",
|
||||
"GNHWC,GKYXC,EmptyTuple,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"MNKPadding"});
|
||||
}
|
||||
|
||||
TEST(
|
||||
FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
|
||||
@@ -57,8 +62,11 @@ TEST(
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor",
|
||||
"128, 128, 128, 32",
|
||||
"Filter1x1Pad0"});
|
||||
"128,128,128,32",
|
||||
"Filter1x1Pad0",
|
||||
"GNHWC,GKYXC,EmptyTuple,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"MNKPadding"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -12,12 +12,14 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::BF16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNDHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKZYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNDHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -29,10 +31,13 @@ TEST(FwdConvInstances,
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"256,256,256,32",
|
||||
"Default",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v3"});
|
||||
"Intrawave",
|
||||
"v3",
|
||||
"GNDHWC,GKZYXC,EmptyTuple,GNDHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"MNKPadding"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -12,12 +12,14 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NDHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKZYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NDHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -30,10 +32,13 @@ TEST(FwdConvInstances,
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 128, 128, 32",
|
||||
"256,128,128,32",
|
||||
"Filter1x1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v4"});
|
||||
"Intrawave",
|
||||
"v4",
|
||||
"NDHWGC,GKZYXC,EmptyTuple,NDHWGK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"MNKPadding"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -12,12 +12,14 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP32,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCDHW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKCZYX}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKDHW}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -30,10 +32,13 @@ TEST(FwdConvInstances,
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"256,256,256,32",
|
||||
"Filter1x1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v1"});
|
||||
"Intrawave",
|
||||
"v1",
|
||||
"NGCDHW,GKCZYX,EmptyTuple,NGKDHW",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"MNKPadding"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -85,7 +85,10 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
|
||||
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
|
||||
EXPECT_THAT(Traits::layout,
|
||||
::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC,
|
||||
ck_tile::builder::TensorLayout::GKYXC,
|
||||
ck_tile::builder::TensorLayout::GNHWK));
|
||||
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
@@ -212,7 +215,10 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
|
||||
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
|
||||
EXPECT_THAT(Traits::layout,
|
||||
::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC,
|
||||
ck_tile::builder::TensorLayout::GKYXC,
|
||||
ck_tile::builder::TensorLayout::GNHWK));
|
||||
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
@@ -295,7 +301,10 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
|
||||
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
|
||||
EXPECT_THAT(Traits::layout,
|
||||
::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC,
|
||||
ck_tile::builder::TensorLayout::GKYXC,
|
||||
ck_tile::builder::TensorLayout::GNHWK));
|
||||
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
@@ -10,14 +10,48 @@ namespace ck_tile::builder::test {
|
||||
|
||||
using namespace ck_tile::builder;
|
||||
|
||||
struct TensorConfig
|
||||
{
|
||||
TensorLayout layout;
|
||||
// Optional data types, override the type defined in the signature if provided.
|
||||
DataType data_type{DataType::UNDEFINDED};
|
||||
DataType compute_type{DataType::UNDEFINDED};
|
||||
};
|
||||
|
||||
template <TensorConfig... Configs>
|
||||
struct TensorOperation
|
||||
{
|
||||
ElementwiseOperation elementwise_operation{ElementwiseOperation::PASS_THROUGH};
|
||||
std::array<TensorConfig, sizeof...(Configs)> auxiliary_operand_configs{Configs...};
|
||||
|
||||
// Add builder to add auxiliary tensor configs
|
||||
template <auto... AuxiliaryConfigs>
|
||||
constexpr auto with_auxiliary_operand_configs() const
|
||||
{
|
||||
return TensorOperation<Configs..., TensorConfig{AuxiliaryConfigs}...>{
|
||||
.elementwise_operation = this->elementwise_operation};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Op = TensorOperation<>>
|
||||
struct ConvolutionTensor
|
||||
{
|
||||
TensorConfig config;
|
||||
Op operation{};
|
||||
};
|
||||
|
||||
template <typename InputTensor = ConvolutionTensor<>,
|
||||
typename WeightTensor = ConvolutionTensor<>,
|
||||
typename OutputTensor = ConvolutionTensor<>>
|
||||
struct ConvSignature
|
||||
{
|
||||
int spatial_dim;
|
||||
ConvDirection direction;
|
||||
GroupConvLayout layout;
|
||||
DataType data_type;
|
||||
ElementwiseOperation elementwise_operation;
|
||||
DataType accumulation_data_type;
|
||||
InputTensor input;
|
||||
WeightTensor weight;
|
||||
OutputTensor output;
|
||||
};
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -16,40 +16,79 @@ namespace ckb = ck_tile::builder;
|
||||
namespace ckr = ck_tile::reflect;
|
||||
namespace ckt = ck_tile::test;
|
||||
|
||||
struct TensorOp
|
||||
{
|
||||
ckb::ElementwiseOperation elementwise_operation{ckb::ElementwiseOperation::PASS_THROUGH};
|
||||
};
|
||||
|
||||
struct InvalidTensorOp
|
||||
{
|
||||
int elementwise_operation = 7; // invalid value
|
||||
};
|
||||
static_assert(!ckb::TensorOperatorDescriptor<InvalidTensorOp>);
|
||||
|
||||
struct TensorConfig
|
||||
{
|
||||
ckb::TensorLayout layout;
|
||||
ckb::DataType data_type{ckb::DataType::UNDEFINDED};
|
||||
ckb::DataType compute_type{ckb::DataType::UNDEFINDED};
|
||||
};
|
||||
|
||||
struct ConvTensorSimple
|
||||
{
|
||||
TensorConfig config;
|
||||
};
|
||||
|
||||
struct ConvTensorWithOp
|
||||
{
|
||||
TensorConfig config;
|
||||
TensorOp operation{};
|
||||
};
|
||||
|
||||
struct ConvTensorWithInvalidOp
|
||||
{
|
||||
TensorConfig config;
|
||||
InvalidTensorOp operation{};
|
||||
};
|
||||
|
||||
// Defines the signature of the convolution operation to be tested.
|
||||
// This includes dimensionality, direction, data layout, and data type.
|
||||
struct ConvSignature
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
// ckb::GroupConvDeviceOp device_operation =
|
||||
// ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
int spatial_dim = 2;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::DataType accumulation_data_type = ckb::DataType::FP32;
|
||||
ConvTensorSimple input = {.config = {ckb::TensorLayout::GNHWC}};
|
||||
ConvTensorSimple weight = {.config = {ckb::TensorLayout::GKYXC}};
|
||||
ConvTensorSimple output = {.config = {ckb::TensorLayout::GNHWK}};
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
// Compile time tests for concepts
|
||||
struct ConvSignatureWithOptionalParams
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH;
|
||||
int spatial_dim = 2;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::DataType accumulation_data_type = ckb::DataType::FP32;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
|
||||
ConvTensorWithOp input = {
|
||||
.config = {ckb::TensorLayout::GNHWC, ckb::DataType::FP16},
|
||||
};
|
||||
ConvTensorWithOp weight = {.config = {ckb::TensorLayout::GKYXC, ckb::DataType::FP16}};
|
||||
ConvTensorWithOp output = {.config = {ckb::TensorLayout::GNHWK, ckb::DataType::FP16},
|
||||
.operation = {ckb::ElementwiseOperation::SCALE}};
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignatureWithOptionalParams>);
|
||||
|
||||
struct ConvSignatureWithInvalidOptionalParams
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
int elementwise_operation = 7; // this should fail
|
||||
// ckb::GroupConvDeviceOp device_operation =
|
||||
// ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
int spatial_dim = 2;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::DataType accumulation_data_type = ckb::DataType::FP32;
|
||||
ConvTensorWithInvalidOp input = {.config = {ckb::TensorLayout::GNHWC}};
|
||||
ConvTensorWithInvalidOp weight = {.config = {ckb::TensorLayout::GKYXC}};
|
||||
ConvTensorWithInvalidOp output = {.config = {ckb::TensorLayout::GNHWK}};
|
||||
};
|
||||
|
||||
static_assert(!ckb::ConvSignatureDescriptor<ConvSignatureWithInvalidOptionalParams>);
|
||||
|
||||
struct DefaultAlgorithm
|
||||
@@ -123,7 +162,9 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
|
||||
"2D Forward Convolution Kernel\n"
|
||||
"├─ Signature\n"
|
||||
"│ ├─ Tensor Type: FP16\n"
|
||||
"│ ├─ Memory Layout: GNHWC_GKYXC_GNHWK\n"
|
||||
"│ ├─ Input Layout: GNHWC\n"
|
||||
"│ ├─ Weight Layout: GKYXC\n"
|
||||
"│ ├─ Output Layout: GNHWK\n"
|
||||
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
|
||||
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
|
||||
"│ └─ Output elementwise operation: PASS_THROUGH\n"
|
||||
|
||||
@@ -8,30 +8,38 @@
|
||||
|
||||
namespace {
|
||||
|
||||
using ::ck_tile::builder::factory::internal::ElementwiseOps;
|
||||
using enum ::ck_tile::builder::ElementwiseOperation;
|
||||
using ::ck_tile::builder::ElementwiseOperation;
|
||||
using ::ck_tile::builder::factory::internal::ElementwiseOpToCK;
|
||||
|
||||
TEST(ConvElementwiseOp, AssignsOpsForPassThrough)
|
||||
{
|
||||
using Ops = ElementwiseOps<PASS_THROUGH>;
|
||||
|
||||
EXPECT_TRUE(
|
||||
(std::is_same_v<Ops::AElementwiseOp, ck::tensor_operation::element_wise::PassThrough>));
|
||||
EXPECT_TRUE(
|
||||
(std::is_same_v<Ops::BElementwiseOp, ck::tensor_operation::element_wise::PassThrough>));
|
||||
EXPECT_TRUE(
|
||||
(std::is_same_v<Ops::CDEElementwiseOp, ck::tensor_operation::element_wise::PassThrough>));
|
||||
using Op = ElementwiseOpToCK<ElementwiseOperation::PASS_THROUGH>::Op;
|
||||
EXPECT_TRUE((std::is_same_v<Op, ck::tensor_operation::element_wise::PassThrough>));
|
||||
}
|
||||
|
||||
TEST(ConvElementwiseOp, AssignsOpsForScale)
|
||||
{
|
||||
using Ops = ElementwiseOps<SCALE>;
|
||||
using Op = ElementwiseOpToCK<ElementwiseOperation::SCALE>::Op;
|
||||
EXPECT_TRUE((std::is_same_v<Op, ck::tensor_operation::element_wise::Scale>));
|
||||
}
|
||||
|
||||
TEST(ConvElementwiseOp, AssignsOpsForClamp)
|
||||
{
|
||||
using Op = ElementwiseOpToCK<ElementwiseOperation::CLAMP>::Op;
|
||||
EXPECT_TRUE((std::is_same_v<Op, ck::tensor_operation::element_wise::Clamp>));
|
||||
}
|
||||
|
||||
TEST(ConvElementwiseOp, AssignsOpsForScaleAddScaleAddRelu)
|
||||
{
|
||||
using Op = ElementwiseOpToCK<ElementwiseOperation::SCALEADD_SCALEADD_RELU>::Op;
|
||||
EXPECT_TRUE((std::is_same_v<Op, ck::tensor_operation::element_wise::ScaleAddScaleAddRelu>));
|
||||
}
|
||||
|
||||
TEST(ConvElementwiseOp, AssignsOpsForBiasNormClamp)
|
||||
{
|
||||
using Op = ElementwiseOpToCK<ElementwiseOperation::BIAS_BNORM_CLAMP>::Op;
|
||||
EXPECT_TRUE(
|
||||
(std::is_same_v<Ops::AElementwiseOp, ck::tensor_operation::element_wise::PassThrough>));
|
||||
EXPECT_TRUE(
|
||||
(std::is_same_v<Ops::BElementwiseOp, ck::tensor_operation::element_wise::PassThrough>));
|
||||
EXPECT_TRUE((std::is_same_v<Ops::CDEElementwiseOp, ck::tensor_operation::element_wise::Scale>));
|
||||
(std::is_same_v<Op, ck::tensor_operation::element_wise::BiasNormalizeInInferClamp>));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -4,116 +4,481 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <type_traits>
|
||||
|
||||
// Include the helper file we're testing
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
namespace ckb = ::ck_tile::builder;
|
||||
using ::ck_tile::builder::DataType;
|
||||
using ::ck_tile::builder::ElementwiseOperation;
|
||||
using ::ck_tile::builder::TensorLayout;
|
||||
using ::ck_tile::builder::factory::internal::AuxiliaryTensorLayouts;
|
||||
using ::ck_tile::builder::factory::internal::ConvTensorLayouts;
|
||||
using ::ck_tile::builder::factory::internal::GetTensorLayout;
|
||||
using ::ck_tile::builder::factory::internal::LayoutToCK;
|
||||
|
||||
using namespace ::ck_tile::builder::test;
|
||||
using enum ::ck_tile::builder::ConvDirection;
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK)
|
||||
{
|
||||
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout1D::NWGC_GKXC_NWGK, 1, FORWARD>;
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 1,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NWGK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW)
|
||||
{
|
||||
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout1D::NGCW_GKXC_NGKW, 1, FORWARD>;
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 1,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK)
|
||||
{
|
||||
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout1D::GNWC_GKXC_GNWK, 1, FORWARD>;
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 1,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNWK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW)
|
||||
{
|
||||
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout1D::NGCW_GKCX_NGKW, 1, FORWARD>;
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 1,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKCX}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout2D::NGCHW_GKYXC_NGKHW, 2, FORWARD>;
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 2,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCHW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKHW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout2D::NHWGC_GKYXC_NHWGK, 2, FORWARD>;
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 2,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK, 2, FORWARD>;
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 2,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNHWK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW)
|
||||
{
|
||||
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout2D::NGCHW_GKCYX_NGKHW, 2, FORWARD>;
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 2,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCHW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKCYX}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKHW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCYX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW)
|
||||
{
|
||||
using TensorLayouts =
|
||||
ConvTensorLayouts<ckb::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, 3, FORWARD>;
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 3,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCDHW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKCZYX}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKDHW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCDHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCZYX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKDHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK)
|
||||
{
|
||||
using TensorLayouts =
|
||||
ConvTensorLayouts<ckb::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, 3, FORWARD>;
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 3,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NDHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKZYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NDHWGK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK)
|
||||
{
|
||||
using TensorLayouts =
|
||||
ConvTensorLayouts<ckb::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, 3, FORWARD>;
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 3,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNDHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKZYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNDHWK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNDHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNDHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
TEST(AuxiliaryTensorLayout, AssignsLayoutForG_K_strided)
|
||||
{
|
||||
using CKLayout = LayoutToCK<TensorLayout::G_K_strided>::type;
|
||||
EXPECT_TRUE((std::is_same_v<CKLayout, ck::tensor_layout::convolution::G_K>));
|
||||
}
|
||||
|
||||
TEST(AuxiliaryTensorLayout, AssignsLayoutForGC)
|
||||
{
|
||||
using CKLayout = LayoutToCK<TensorLayout::GC>::type;
|
||||
EXPECT_TRUE((std::is_same_v<CKLayout, ck::tensor_layout::convolution::GC>));
|
||||
}
|
||||
|
||||
TEST(AuxiliaryTensorLayout, AssignsLayoutForG_C_strided)
|
||||
{
|
||||
using CKLayout = LayoutToCK<TensorLayout::G_C_strided>::type;
|
||||
EXPECT_TRUE((std::is_same_v<CKLayout, ck::tensor_layout::convolution::G_C>));
|
||||
}
|
||||
|
||||
TEST(AuxiliaryTensorLayout, EmptyAuxiliaryTensorLayoutIsEmptyTuple)
|
||||
{
|
||||
using ::ck_tile::builder::factory::internal::EmptyAuxiliaryTensorLayout;
|
||||
using EmptyLayout = EmptyAuxiliaryTensorLayout::type;
|
||||
EXPECT_TRUE((std::is_same_v<EmptyLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
struct MockAuxiliaryTensorConfig
|
||||
{
|
||||
TensorLayout layout;
|
||||
};
|
||||
|
||||
TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout)
|
||||
{
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 1);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K>;
|
||||
EXPECT_TRUE((std::is_same_v<AuxLayouts::type, ExpectedType>));
|
||||
}
|
||||
|
||||
TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout)
|
||||
{
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 1);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::GC>;
|
||||
EXPECT_TRUE((std::is_same_v<AuxLayouts::type, ExpectedType>));
|
||||
}
|
||||
|
||||
TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout)
|
||||
{
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_C_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 1);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_C>;
|
||||
EXPECT_TRUE((std::is_same_v<AuxLayouts::type, ExpectedType>));
|
||||
}
|
||||
|
||||
TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors)
|
||||
{
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 2> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided},
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 2);
|
||||
using ExpectedType =
|
||||
ck::Tuple<ck::tensor_layout::convolution::G_K, ck::tensor_layout::convolution::GC>;
|
||||
EXPECT_TRUE((std::is_same_v<AuxLayouts::type, ExpectedType>));
|
||||
}
|
||||
|
||||
TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors)
|
||||
{
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 3> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided},
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC},
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_C_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 3);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K,
|
||||
ck::tensor_layout::convolution::GC,
|
||||
ck::tensor_layout::convolution::G_C>;
|
||||
EXPECT_TRUE((std::is_same_v<AuxLayouts::type, ExpectedType>));
|
||||
}
|
||||
|
||||
TEST(AuxiliaryTensorLayoutIntegration, WorksWith1DConvolution)
|
||||
{
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 1, FORWARD>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 1);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K>;
|
||||
EXPECT_TRUE((std::is_same_v<AuxLayouts::type, ExpectedType>));
|
||||
}
|
||||
|
||||
TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution)
|
||||
{
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 3, FORWARD>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 1);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::GC>;
|
||||
EXPECT_TRUE((std::is_same_v<AuxLayouts::type, ExpectedType>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K)
|
||||
{
|
||||
using OutputOp = TensorOperation<TensorConfig{.layout = TensorLayout::G_K_strided}>;
|
||||
|
||||
static constexpr auto sig =
|
||||
ConvSignature<ConvolutionTensor<>, ConvolutionTensor<>, ConvolutionTensor<OutputOp>>{
|
||||
.spatial_dim = 2,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCHW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKHW},
|
||||
.operation =
|
||||
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC)
|
||||
{
|
||||
using OutputOp = TensorOperation<TensorConfig{.layout = TensorLayout::GC}>;
|
||||
|
||||
static constexpr auto sig =
|
||||
ConvSignature<ConvolutionTensor<>, ConvolutionTensor<>, ConvolutionTensor<OutputOp>>{
|
||||
.spatial_dim = 2,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::BF16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK},
|
||||
.operation =
|
||||
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::GC>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors)
|
||||
{
|
||||
using OutputOp = TensorOperation<TensorConfig{.layout = TensorLayout::G_K_strided},
|
||||
TensorConfig{.layout = TensorLayout::GC}>;
|
||||
|
||||
static constexpr auto sig =
|
||||
ConvSignature<ConvolutionTensor<>, ConvolutionTensor<>, ConvolutionTensor<OutputOp>>{
|
||||
.spatial_dim = 2,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNHWK},
|
||||
.operation = OutputOp{.elementwise_operation =
|
||||
ElementwiseOperation::SCALEADD_SCALEADD_RELU}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
|
||||
|
||||
using ExpectedDsLayout =
|
||||
ck::Tuple<ck::tensor_layout::convolution::G_K, ck::tensor_layout::convolution::GC>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias)
|
||||
{
|
||||
using OutputOp = TensorOperation<TensorConfig{.layout = TensorLayout::G_K_strided}>;
|
||||
|
||||
static constexpr auto sig =
|
||||
ConvSignature<ConvolutionTensor<>, ConvolutionTensor<>, ConvolutionTensor<OutputOp>>{
|
||||
.spatial_dim = 1,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP32,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NWGK},
|
||||
.operation =
|
||||
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias)
|
||||
{
|
||||
using OutputOp = TensorOperation<TensorConfig{.layout = TensorLayout::G_C_strided}>;
|
||||
|
||||
static constexpr auto sig =
|
||||
ConvSignature<ConvolutionTensor<>, ConvolutionTensor<>, ConvolutionTensor<OutputOp>>{
|
||||
.spatial_dim = 3,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NDHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKZYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NDHWGK},
|
||||
.operation = OutputOp{.elementwise_operation =
|
||||
ElementwiseOperation::BIAS_BNORM_CLAMP}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_C>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -9,71 +9,42 @@
|
||||
namespace {
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
using ck_tile::builder::factory::internal::ConvTensorTypes;
|
||||
using ck_tile::builder::factory::internal::DataTypeToCK;
|
||||
|
||||
TEST(ConvTensorType, AssignsTypesForFP16)
|
||||
{
|
||||
using Types = ConvTensorTypes<ckb::DataType::FP16>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<Types::ADataType, ck::half_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::BDataType, ck::half_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::EDataType, ck::half_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::AccDataType, float>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::AComputeType, ck::half_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::BComputeType, ck::half_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::CShuffleDataType, ck::half_t>));
|
||||
using CKType = DataTypeToCK<ckb::DataType::FP16>::type;
|
||||
EXPECT_TRUE((std::is_same_v<CKType, ck::half_t>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorType, AssignsTypesForBF16)
|
||||
{
|
||||
using Types = ConvTensorTypes<ckb::DataType::BF16>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<Types::ADataType, ck::bhalf_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::BDataType, ck::bhalf_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::EDataType, ck::bhalf_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::AccDataType, float>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::AComputeType, ck::bhalf_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::BComputeType, ck::bhalf_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::CShuffleDataType, ck::bhalf_t>));
|
||||
using CKType = DataTypeToCK<ckb::DataType::BF16>::type;
|
||||
EXPECT_TRUE((std::is_same_v<CKType, ck::bhalf_t>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorType, AssignsTypesForFP32)
|
||||
{
|
||||
using Types = ConvTensorTypes<ckb::DataType::FP32>;
|
||||
using CKType = DataTypeToCK<ckb::DataType::FP32>::type;
|
||||
EXPECT_TRUE((std::is_same_v<CKType, float>));
|
||||
}
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<Types::ADataType, float>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::BDataType, float>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::EDataType, float>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::AccDataType, float>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::AComputeType, float>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::BComputeType, float>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::CShuffleDataType, float>));
|
||||
TEST(ConvTensorType, AssignsTypesForINT32)
|
||||
{
|
||||
using CKType = DataTypeToCK<ckb::DataType::INT32>::type;
|
||||
EXPECT_TRUE((std::is_same_v<CKType, int32_t>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorType, AssignsTypesForI8)
|
||||
{
|
||||
using Types = ConvTensorTypes<ckb::DataType::I8>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<Types::ADataType, int8_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::BDataType, int8_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::EDataType, int8_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::AccDataType, int32_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::AComputeType, int8_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::BComputeType, int8_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::CShuffleDataType, int8_t>));
|
||||
using CKType = DataTypeToCK<ckb::DataType::I8>::type;
|
||||
EXPECT_TRUE((std::is_same_v<CKType, int8_t>));
|
||||
}
|
||||
|
||||
TEST(ConvTensorType, AssignsTypesForFP8)
|
||||
{
|
||||
using Types = ConvTensorTypes<ckb::DataType::FP8>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<Types::ADataType, ck::f8_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::BDataType, ck::f8_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::EDataType, ck::f8_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::AccDataType, float>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::AComputeType, ck::f8_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::BComputeType, ck::f8_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Types::CShuffleDataType, ck::f8_t>));
|
||||
using CKType = DataTypeToCK<ckb::DataType::FP8>::type;
|
||||
EXPECT_TRUE((std::is_same_v<CKType, ck::f8_t>));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -178,6 +178,9 @@ constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x2_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1};
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ constexpr void run_test(const std::vector<std::string>& kernel_instance_componen
|
||||
{
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
const auto kernel_string = instance.GetInstanceString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user