[CK_BUILDER] Improve CK Builder and CK Builder tests (#3382)

* Remove stale documentation.

* Add placeholder for conv algorithm design description. Add link to conv factory description.

* Improve testing transfer parameters.

* Python script to check the block tilings.

* Improve tests and conv types serialization.

* Change representation of boolean values from 1/0 to true/false in instance strings.

* Change representation of boolean values from 1/0 to true/false in conv algorithm types.

* Test code improvements.

* Improve covn descriptions tests.

* Improve conv signature definition in conv fwd builder tests.

* clang-format.

* Remove obsolete script.

* Revert StaticAssertTypeEq changes in conv layout tests.

* Remove obsolete using declaration.

---------

Co-authored-by: Ville Pietilä <>
This commit is contained in:
Ville Pietilä
2025-12-11 09:50:00 +02:00
committed by GitHub
parent 6d25525adc
commit d66e5f667c
33 changed files with 1568 additions and 1042 deletions

View File

@@ -4,14 +4,16 @@ This directory contains the builder framework for Composable Kernel, which provi
## Table of Contents
- [Convolution Signature Design](#convolution-signature-design)
- [Convolution Signature](#convolution-signature)
- [Overview](#overview)
- [Architecture](#architecture)
- [Core Components](#core-components)
- [Concepts and Validation](#concepts-and-validation)
- [Convolution Algorithm](#convolution-algorithm)
- [Convolution Factory](#convolution-factory)
---
## Convolution Signature Design
## Convolution Signature
### Overview
@@ -220,25 +222,9 @@ Several fields in the signature are optional:
This design follows the principle of "make the common case simple, the complex case possible."
#### Union-Based Layout Representation
## Convolution Algorithm
The `ConvLayout` type uses unions to support dimension-agnostic code:
## Convolution Factory
```cpp
struct ConvLayout {
union {
ConvInputLayout _input_layout;
ConvWeightLayout _weight_layout;
ConvOutputLayout _output_layout;
ConvAuxiliaryTensorLayout _aux_tensor_layout;
};
// ... constructors for each type
};
```
This allows:
- Single type to represent all layout variants
- Type-safe construction through overloaded constructors
- Compile-time enforcement of valid combinations through concepts
---
Convolution factory builds the instance based on the convolution signature and convolution algorithm.
The signature and the algorithm descriptions are dispatched to the relevant algorithm specific factory for instance creation. The convolution factory design is described in a separate [Readme](factory/README.md).

View File

@@ -65,17 +65,19 @@ consteval auto GetTensorDataAndComputeTypes()
constexpr auto data_type = Config.data_type;
constexpr auto compute_type = Config.compute_type;
if constexpr(data_type == DataType::UNDEFINDED && compute_type == DataType::UNDEFINDED)
using enum DataType;
if constexpr(data_type == UNDEFINED_DATA_TYPE && compute_type == UNDEFINED_DATA_TYPE)
{
return std::make_pair(ConvertDataTypeToCK<SignatureDataType>(),
ConvertDataTypeToCK<SignatureDataType>());
}
else if constexpr(data_type == DataType::UNDEFINDED)
else if constexpr(data_type == UNDEFINED_DATA_TYPE)
{
return std::make_pair(ConvertDataTypeToCK<SignatureDataType>(),
ConvertDataTypeToCK<compute_type>());
}
else if constexpr(compute_type == DataType::UNDEFINDED)
else if constexpr(compute_type == UNDEFINED_DATA_TYPE)
{
return std::make_pair(ConvertDataTypeToCK<data_type>(),
ConvertDataTypeToCK<SignatureDataType>());
@@ -91,7 +93,7 @@ template <DataType SignatureAccDataType, DataType SignatureDataType>
consteval auto GetTensorAccumulationType()
{
constexpr auto data_type = SignatureAccDataType;
if constexpr(data_type == DataType::UNDEFINDED)
if constexpr(data_type == DataType::UNDEFINED_DATA_TYPE)
{
return ConvertDataTypeToCK<SignatureDataType>();
}
@@ -105,7 +107,7 @@ template <auto Config, DataType SignatureDataType>
consteval auto GetAuxiliaryTensorDataTypeValue()
{
constexpr auto data_type = Config.data_type;
if constexpr(data_type == DataType::UNDEFINDED)
if constexpr(data_type == DataType::UNDEFINED_DATA_TYPE)
{
return ConvertDataTypeToCK<SignatureDataType>();
}

View File

@@ -316,7 +316,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
oss << "," << kABlockTransferSrcScalarPerVector; // 32. ABlockTransferSrcScalarPerVector
oss << ","
<< kABlockTransferDstScalarPerVectorK1; // 33. ABlockTransferDstScalarPerVector_AK1
oss << "," << kABlockLdsExtraM; // 34. ABlockLdsExtraM
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 34. ABlockLdsExtraM
oss << ","
<< detail::array_to_string(
kBThreadClusterLengths); // 35. BBlockTransferThreadClusterLengths
@@ -329,10 +329,10 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
oss << "," << kBBlockTransferSrcVectorDim; // 38. BBlockTransferSrcVectorDim
oss << "," << kBBlockTransferSrcScalarPerVector; // 39. BBlockTransferSrcScalarPerVector
oss << ","
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
oss << "," << kBBlockLdsExtraN; // 41. BBlockLdsExtraN
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 41. BBlockLdsExtraN
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
oss << ","
<< detail::array_to_string(
kCThreadClusterLengths); // 44. CDEBlockTransferClusterLengths

View File

@@ -316,7 +316,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
oss << "," << kABlockTransferSrcScalarPerVector; // 31. ABlockTransferSrcScalarPerVector
oss << ","
<< kABlockTransferDstScalarPerVectorK1; // 32. ABlockTransferDstScalarPerVector_AK1
oss << "," << kABlockLdsExtraM; // 33. ABlockLdsExtraM
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 33. ABlockLdsExtraM
oss << ","
<< detail::array_to_string(
kBThreadClusterLengths); // 34. BBlockTransferThreadClusterLengths
@@ -329,10 +329,10 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
oss << "," << kBBlockTransferSrcVectorDim; // 37. BBlockTransferSrcVectorDim
oss << "," << kBBlockTransferSrcScalarPerVector; // 38. BBlockTransferSrcScalarPerVector
oss << ","
<< kBBlockTransferDstScalarPerVectorK1; // 39. BBlockTransferDstScalarPerVector_BK1
oss << "," << kBBlockLdsExtraN; // 40. BBlockLdsExtraN
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 41. CShuffleMXdlPerWavePerShuffle
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 42. CShuffleNXdlPerWavePerShuffle
<< kBBlockTransferDstScalarPerVectorK1; // 39. BBlockTransferDstScalarPerVector_BK1
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 40. BBlockLdsExtraN
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 41. CShuffleMXdlPerWavePerShuffle
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 42. CShuffleNXdlPerWavePerShuffle
oss << ","
<< detail::array_to_string(
kCThreadClusterLengths); // 43. CDEBlockTransferClusterLengths

View File

@@ -311,7 +311,7 @@ struct InstanceTraits<
oss << "," << kABlockTransferSrcScalarPerVector; // 32. ABlockTransferSrcScalarPerVector
oss << ","
<< kABlockTransferDstScalarPerVectorK1; // 33. ABlockTransferDstScalarPerVector_AK1
oss << "," << kABlockLdsExtraM; // 34. ABlockLdsExtraM
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 34. ABlockLdsExtraM
oss << ","
<< detail::array_to_string(
kBThreadClusterLengths); // 35. BBlockTransferThreadClusterLengths
@@ -324,10 +324,10 @@ struct InstanceTraits<
oss << "," << kBBlockTransferSrcVectorDim; // 38. BBlockTransferSrcVectorDim
oss << "," << kBBlockTransferSrcScalarPerVector; // 39. BBlockTransferSrcScalarPerVector
oss << ","
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
oss << "," << kBBlockLdsExtraN; // 41. BBlockLdsExtraN
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 41. BBlockLdsExtraN
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
oss << ","
<< detail::array_to_string(
kCThreadClusterLengths); // 44. CDEBlockTransferClusterLengths

View File

@@ -13,7 +13,7 @@ namespace ck_tile::builder {
enum class DataType
{
UNDEFINDED = 0,
UNDEFINED_DATA_TYPE = 0,
FP32,
FP16,
BF16,
@@ -25,7 +25,7 @@ enum class DataType
enum class TensorLayout
{
UNDEFINED,
UNDEFINED_TENSOR_LAYOUT = 0,
// Bias tensors
GC,
@@ -212,219 +212,269 @@ enum class ConvAlgorithmSpecialization
LARGE_TENSOR
};
// ostream operator overloads for enum classes
inline std::ostream& operator<<(std::ostream& os, DataType dt)
// toString methods for enum classes
inline std::string_view toString(DataType dt)
{
using enum DataType;
switch(dt)
{
case FP16: return os << "FP16";
case FP32: return os << "FP32";
case BF16: return os << "BF16";
case FP8: return os << "FP8";
case INT32: return os << "INT32";
case I8: return os << "I8";
case U8: return os << "U8";
case UNDEFINDED: return os << "UNDEFINDED";
default: return os << "Unknown";
case FP16: return "FP16";
case FP32: return "FP32";
case BF16: return "BF16";
case FP8: return "FP8";
case INT32: return "INT32";
case I8: return "I8";
case U8: return "U8";
case UNDEFINED_DATA_TYPE: return "UNDEFINED_DATA_TYPE";
default: return "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir)
inline std::string_view toString(ConvDirection dir)
{
using enum ConvDirection;
switch(dir)
{
case FORWARD: return os << "Forward";
case BACKWARD_DATA: return os << "Backward Data";
case BACKWARD_WEIGHT: return os << "Backward Weight";
default: return os << "Unknown";
case FORWARD: return "Forward";
case BACKWARD_DATA: return "Backward Data";
case BACKWARD_WEIGHT: return "Backward Weight";
default: return "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
inline std::string_view toString(ElementwiseOperation op)
{
using enum ElementwiseOperation;
switch(op)
{
case CLAMP: return os << "CLAMP";
case SCALE: return os << "SCALE";
case PASS_THROUGH: return os << "PASS_THROUGH";
case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP";
case SCALEADD_SCALEADD_RELU: return os << "SCALEADD_SCALEADD_RELU";
default: return os << "Unknown";
case CLAMP: return "CLAMP";
case SCALE: return "SCALE";
case PASS_THROUGH: return "PASS_THROUGH";
case BIAS_BNORM_CLAMP: return "BIAS_BNORM_CLAMP";
case SCALEADD_SCALEADD_RELU: return "SCALEADD_SCALEADD_RELU";
default: return "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver)
inline std::string_view toString(PipelineVersion ver)
{
using enum PipelineVersion;
switch(ver)
{
case V1: return os << "V1";
case V2: return os << "V2";
case V3: return os << "V3";
case V4: return os << "V4";
case V5: return os << "V5";
case WEIGHT_ONLY: return os << "WEIGHT_ONLY";
default: return os << "Unknown";
case V1: return "V1";
case V2: return "V2";
case V3: return "V3";
case V4: return "V4";
case V5: return "V5";
case WEIGHT_ONLY: return "WEIGHT_ONLY";
default: return "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
inline std::string_view toString(GemmSpecialization spec)
{
using enum GemmSpecialization;
switch(spec)
{
case Default: return os << "Default";
case MPadding: return os << "MPadding";
case NPadding: return os << "NPadding";
case KPadding: return os << "KPadding";
case MNPadding: return os << "MNPadding";
case MKPadding: return os << "MKPadding";
case NKPadding: return os << "NKPadding";
case MNKPadding: return os << "MNKPadding";
case OPadding: return os << "OPadding";
case MOPadding: return os << "MOPadding";
case NOPadding: return os << "NOPadding";
case KOPadding: return os << "KOPadding";
case MNOPadding: return os << "MNOPadding";
case MKOPadding: return os << "MKOPadding";
case NKOPadding: return os << "NKOPadding";
case MNKOPadding: return os << "MNKOPadding";
default: return os << "Unknown";
case Default: return "Default";
case MPadding: return "MPadding";
case NPadding: return "NPadding";
case KPadding: return "KPadding";
case MNPadding: return "MNPadding";
case MKPadding: return "MKPadding";
case NKPadding: return "NKPadding";
case MNKPadding: return "MNKPadding";
case OPadding: return "OPadding";
case MOPadding: return "MOPadding";
case NOPadding: return "NOPadding";
case KOPadding: return "KOPadding";
case MNOPadding: return "MNOPadding";
case MKOPadding: return "MKOPadding";
case NKOPadding: return "NKOPadding";
case MNKOPadding: return "MNKOPadding";
default: return "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
inline std::string_view toString(ConvFwdSpecialization spec)
{
using enum ConvFwdSpecialization;
switch(spec)
{
case DEFAULT: return os << "DEFAULT";
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
case FILTER_3x3: return os << "FILTER_3x3";
default: return os << "Unknown";
case DEFAULT: return "DEFAULT";
case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0";
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
case FILTER_3x3: return "FILTER_3x3";
default: return "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
inline std::string_view toString(ConvBwdDataSpecialization spec)
{
using enum ConvBwdDataSpecialization;
switch(spec)
{
case DEFAULT: return os << "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
default: return os << "Unknown";
case DEFAULT: return "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
default: return "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
inline std::string_view toString(ConvBwdWeightSpecialization spec)
{
using enum ConvBwdWeightSpecialization;
switch(spec)
{
case DEFAULT: return os << "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
case ODD_C: return os << "ODD_C";
default: return os << "Unknown";
case DEFAULT: return "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0";
case ODD_C: return "ODD_C";
default: return "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GemmPadding padding)
inline std::string_view toString(GemmPadding padding)
{
using enum GemmPadding;
switch(padding)
{
case DEFAULT: return os << "DEFAULT";
case M_PADDING: return os << "M_PADDING";
case N_PADDING: return os << "N_PADDING";
case K_PADDING: return os << "K_PADDING";
case MN_PADDING: return os << "MN_PADDING";
case MK_PADDING: return os << "MK_PADDING";
case NK_PADDING: return os << "NK_PADDING";
case MNK_PADDING: return os << "MNK_PADDING";
case O_PADDING: return os << "O_PADDING";
case MO_PADDING: return os << "MO_PADDING";
case NO_PADDING: return os << "NO_PADDING";
case KO_PADDING: return os << "KO_PADDING";
case MNO_PADDING: return os << "MNO_PADDING";
case MKO_PADDING: return os << "MKO_PADDING";
case NKO_PADDING: return os << "NKO_PADDING";
case MNKO_PADDING: return os << "MNKO_PADDING";
default: return os << "Unknown";
case DEFAULT: return "DEFAULT";
case M_PADDING: return "M_PADDING";
case N_PADDING: return "N_PADDING";
case K_PADDING: return "K_PADDING";
case MN_PADDING: return "MN_PADDING";
case MK_PADDING: return "MK_PADDING";
case NK_PADDING: return "NK_PADDING";
case MNK_PADDING: return "MNK_PADDING";
case O_PADDING: return "O_PADDING";
case MO_PADDING: return "MO_PADDING";
case NO_PADDING: return "NO_PADDING";
case KO_PADDING: return "KO_PADDING";
case MNO_PADDING: return "MNO_PADDING";
case MKO_PADDING: return "MKO_PADDING";
case NKO_PADDING: return "NKO_PADDING";
case MNKO_PADDING: return "MNKO_PADDING";
default: return "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
inline std::string_view toString(PipelineScheduler sched)
{
using enum PipelineScheduler;
switch(sched)
{
case DEFAULT: return os << "DEFAULT";
case INTRAWAVE: return os << "INTRAWAVE";
case INTERWAVE: return os << "INTERWAVE";
default: return os << "Unknown";
case DEFAULT: return "DEFAULT";
case INTRAWAVE: return "INTRAWAVE";
case INTERWAVE: return "INTERWAVE";
default: return "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, TensorLayout layout)
inline std::string_view toString(TensorLayout layout)
{
using enum TensorLayout;
switch(layout)
{
case GNCW: return os << "GNCW";
case GNWC: return os << "GNWC";
case NWGC: return os << "NWGC";
case NGCW: return os << "NGCW";
case G_NW_C_strided: return os << "G_NW_C_strided";
case GNCHW: return os << "GNCHW";
case GNHWC: return os << "GNHWC";
case NHWGC: return os << "NHWGC";
case NGCHW: return os << "NGCHW";
case G_NHW_C_strided: return os << "G_NHW_C_strided";
case GNCDHW: return os << "GNCDHW";
case GNDHWC: return os << "GNDHWC";
case NDHWGC: return os << "NDHWGC";
case NGCDHW: return os << "NGCDHW";
case G_NDHW_C_strided: return os << "G_NDHW_C_strided";
case GKXC: return os << "GKXC";
case GKCX: return os << "GKCX";
case KXGC: return os << "KXGC";
case G_K_X_C_strided: return os << "G_K_X_C_strided";
case GKYXC: return os << "GKYXC";
case GKCYX: return os << "GKCYX";
case KYXGC: return os << "KYXGC";
case G_K_YX_C_strided: return os << "G_K_YX_C_strided";
case GKZYXC: return os << "GKZYXC";
case GKCZYX: return os << "GKCZYX";
case KZYXGC: return os << "KZYXGC";
case G_K_ZYX_C_strided: return os << "G_K_ZYX_C_strided";
case GNKW: return os << "GNKW";
case GNWK: return os << "GNWK";
case NWGK: return os << "NWGK";
case NGKW: return os << "NGKW";
case G_NW_K_strided: return os << "G_NW_K_strided";
case GNKHW: return os << "GNKHW";
case GNHWK: return os << "GNHWK";
case NHWGK: return os << "NHWGK";
case NGKHW: return os << "NGKHW";
case G_NHW_K_strided: return os << "G_NHW_K_strided";
case GNKDHW: return os << "GNKDHW";
case GNDHWK: return os << "GNDHWK";
case NDHWGK: return os << "NDHWGK";
case NGKDHW: return os << "NGKDHW";
case G_NDHW_K_strided: return os << "G_NDHW_K_strided";
case GC: return os << "GC";
case G_C_strided: return os << "G_C_strided";
case G_K_strided: return os << "G_K_strided";
case UNDEFINED: return os << "UNDEFINED";
default: return os << "Unknown";
case GNCW: return "GNCW";
case GNWC: return "GNWC";
case NWGC: return "NWGC";
case NGCW: return "NGCW";
case G_NW_C_strided: return "G_NW_C_strided";
case GNCHW: return "GNCHW";
case GNHWC: return "GNHWC";
case NHWGC: return "NHWGC";
case NGCHW: return "NGCHW";
case G_NHW_C_strided: return "G_NHW_C_strided";
case GNCDHW: return "GNCDHW";
case GNDHWC: return "GNDHWC";
case NDHWGC: return "NDHWGC";
case NGCDHW: return "NGCDHW";
case G_NDHW_C_strided: return "G_NDHW_C_strided";
case GKXC: return "GKXC";
case GKCX: return "GKCX";
case KXGC: return "KXGC";
case G_K_X_C_strided: return "G_K_X_C_strided";
case GKYXC: return "GKYXC";
case GKCYX: return "GKCYX";
case KYXGC: return "KYXGC";
case G_K_YX_C_strided: return "G_K_YX_C_strided";
case GKZYXC: return "GKZYXC";
case GKCZYX: return "GKCZYX";
case KZYXGC: return "KZYXGC";
case G_K_ZYX_C_strided: return "G_K_ZYX_C_strided";
case GNKW: return "GNKW";
case GNWK: return "GNWK";
case NWGK: return "NWGK";
case NGKW: return "NGKW";
case G_NW_K_strided: return "G_NW_K_strided";
case GNKHW: return "GNKHW";
case GNHWK: return "GNHWK";
case NHWGK: return "NHWGK";
case NGKHW: return "NGKHW";
case G_NHW_K_strided: return "G_NHW_K_strided";
case GNKDHW: return "GNKDHW";
case GNDHWK: return "GNDHWK";
case NDHWGK: return "NDHWGK";
case NGKDHW: return "NGKDHW";
case G_NDHW_K_strided: return "G_NDHW_K_strided";
case GC: return "GC";
case G_C_strided: return "G_C_strided";
case G_K_strided: return "G_K_strided";
case UNDEFINED_TENSOR_LAYOUT: return "UNDEFINED_TENSOR_LAYOUT";
default: return "Unknown";
}
}
// ostream operator overloads for enum classes
inline std::ostream& operator<<(std::ostream& os, DataType dt) { return os << toString(dt); }
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir) { return os << toString(dir); }
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
{
return os << toString(op);
}
inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver)
{
return os << toString(ver);
}
inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
{
return os << toString(spec);
}
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
{
return os << toString(spec);
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
{
return os << toString(spec);
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
{
return os << toString(spec);
}
inline std::ostream& operator<<(std::ostream& os, GemmPadding padding)
{
return os << toString(padding);
}
inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
{
return os << toString(sched);
}
inline std::ostream& operator<<(std::ostream& os, TensorLayout layout)
{
return os << toString(layout);
}
// ostream operator overload for std::variant of convolution specializations
inline std::ostream& operator<<(std::ostream& os,
const std::variant<ConvFwdSpecialization,