mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[CK_BUILDER] Improve CK Builder and CK Builder tests (#3382)
* Remove stale documentation. * Add placeholder for conv algorithm design description. Add link to conv factory description. * Improve testing transfer parameters. * Python script to check the block tilings. * Improve tests and conv types serialization. * Change representation of boolean values from 1/0 to true/false in instance strings. * Change representation of boolean values from 1/0 to true/false in conv algorithm types. * Test code improvements. * Improve covn descriptions tests. * Improve conv signature definition in conv fwd builder tests. * clang-format. * Remove obsolete script. * Revert StaticAssertTypeEq changes in conv layout tests. * Remove obsolete using declaration. --------- Co-authored-by: Ville Pietilä <>
This commit is contained in:
@@ -4,14 +4,16 @@ This directory contains the builder framework for Composable Kernel, which provi
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Convolution Signature Design](#convolution-signature-design)
|
||||
- [Convolution Signature](#convolution-signature)
|
||||
- [Overview](#overview)
|
||||
- [Architecture](#architecture)
|
||||
- [Core Components](#core-components)
|
||||
- [Concepts and Validation](#concepts-and-validation)
|
||||
- [Convolution Algorithm](#convolution-algorithm)
|
||||
- [Convolution Factory](#convolution-factory)
|
||||
---
|
||||
|
||||
## Convolution Signature Design
|
||||
## Convolution Signature
|
||||
|
||||
### Overview
|
||||
|
||||
@@ -220,25 +222,9 @@ Several fields in the signature are optional:
|
||||
|
||||
This design follows the principle of "make the common case simple, the complex case possible."
|
||||
|
||||
#### Union-Based Layout Representation
|
||||
## Convolution Algorithm
|
||||
|
||||
The `ConvLayout` type uses unions to support dimension-agnostic code:
|
||||
## Convolution Factory
|
||||
|
||||
```cpp
|
||||
struct ConvLayout {
|
||||
union {
|
||||
ConvInputLayout _input_layout;
|
||||
ConvWeightLayout _weight_layout;
|
||||
ConvOutputLayout _output_layout;
|
||||
ConvAuxiliaryTensorLayout _aux_tensor_layout;
|
||||
};
|
||||
// ... constructors for each type
|
||||
};
|
||||
```
|
||||
|
||||
This allows:
|
||||
- Single type to represent all layout variants
|
||||
- Type-safe construction through overloaded constructors
|
||||
- Compile-time enforcement of valid combinations through concepts
|
||||
|
||||
---
|
||||
Convolution factory builds the instance based on the convolution signature and convolution algorithm.
|
||||
The signature and the algorithm descriptions are dispatched to the relevant algorithm specific factory for instance creation. The convolution factory design is described in a separate [Readme](factory/README.md).
|
||||
|
||||
@@ -65,17 +65,19 @@ consteval auto GetTensorDataAndComputeTypes()
|
||||
constexpr auto data_type = Config.data_type;
|
||||
constexpr auto compute_type = Config.compute_type;
|
||||
|
||||
if constexpr(data_type == DataType::UNDEFINDED && compute_type == DataType::UNDEFINDED)
|
||||
using enum DataType;
|
||||
|
||||
if constexpr(data_type == UNDEFINED_DATA_TYPE && compute_type == UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<SignatureDataType>(),
|
||||
ConvertDataTypeToCK<SignatureDataType>());
|
||||
}
|
||||
else if constexpr(data_type == DataType::UNDEFINDED)
|
||||
else if constexpr(data_type == UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<SignatureDataType>(),
|
||||
ConvertDataTypeToCK<compute_type>());
|
||||
}
|
||||
else if constexpr(compute_type == DataType::UNDEFINDED)
|
||||
else if constexpr(compute_type == UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<data_type>(),
|
||||
ConvertDataTypeToCK<SignatureDataType>());
|
||||
@@ -91,7 +93,7 @@ template <DataType SignatureAccDataType, DataType SignatureDataType>
|
||||
consteval auto GetTensorAccumulationType()
|
||||
{
|
||||
constexpr auto data_type = SignatureAccDataType;
|
||||
if constexpr(data_type == DataType::UNDEFINDED)
|
||||
if constexpr(data_type == DataType::UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return ConvertDataTypeToCK<SignatureDataType>();
|
||||
}
|
||||
@@ -105,7 +107,7 @@ template <auto Config, DataType SignatureDataType>
|
||||
consteval auto GetAuxiliaryTensorDataTypeValue()
|
||||
{
|
||||
constexpr auto data_type = Config.data_type;
|
||||
if constexpr(data_type == DataType::UNDEFINDED)
|
||||
if constexpr(data_type == DataType::UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return ConvertDataTypeToCK<SignatureDataType>();
|
||||
}
|
||||
|
||||
@@ -316,7 +316,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 32. ABlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kABlockTransferDstScalarPerVectorK1; // 33. ABlockTransferDstScalarPerVector_AK1
|
||||
oss << "," << kABlockLdsExtraM; // 34. ABlockLdsExtraM
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 34. ABlockLdsExtraM
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterLengths); // 35. BBlockTransferThreadClusterLengths
|
||||
@@ -329,10 +329,10 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 38. BBlockTransferSrcVectorDim
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 39. BBlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << kBBlockLdsExtraN; // 41. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 41. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kCThreadClusterLengths); // 44. CDEBlockTransferClusterLengths
|
||||
|
||||
@@ -316,7 +316,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 31. ABlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kABlockTransferDstScalarPerVectorK1; // 32. ABlockTransferDstScalarPerVector_AK1
|
||||
oss << "," << kABlockLdsExtraM; // 33. ABlockLdsExtraM
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 33. ABlockLdsExtraM
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterLengths); // 34. BBlockTransferThreadClusterLengths
|
||||
@@ -329,10 +329,10 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 37. BBlockTransferSrcVectorDim
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 38. BBlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 39. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << kBBlockLdsExtraN; // 40. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 41. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 42. CShuffleNXdlPerWavePerShuffle
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 39. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 40. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 41. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 42. CShuffleNXdlPerWavePerShuffle
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kCThreadClusterLengths); // 43. CDEBlockTransferClusterLengths
|
||||
|
||||
@@ -311,7 +311,7 @@ struct InstanceTraits<
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 32. ABlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kABlockTransferDstScalarPerVectorK1; // 33. ABlockTransferDstScalarPerVector_AK1
|
||||
oss << "," << kABlockLdsExtraM; // 34. ABlockLdsExtraM
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 34. ABlockLdsExtraM
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterLengths); // 35. BBlockTransferThreadClusterLengths
|
||||
@@ -324,10 +324,10 @@ struct InstanceTraits<
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 38. BBlockTransferSrcVectorDim
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 39. BBlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << kBBlockLdsExtraN; // 41. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 41. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kCThreadClusterLengths); // 44. CDEBlockTransferClusterLengths
|
||||
|
||||
@@ -13,7 +13,7 @@ namespace ck_tile::builder {
|
||||
|
||||
enum class DataType
|
||||
{
|
||||
UNDEFINDED = 0,
|
||||
UNDEFINED_DATA_TYPE = 0,
|
||||
FP32,
|
||||
FP16,
|
||||
BF16,
|
||||
@@ -25,7 +25,7 @@ enum class DataType
|
||||
|
||||
enum class TensorLayout
|
||||
{
|
||||
UNDEFINED,
|
||||
UNDEFINED_TENSOR_LAYOUT = 0,
|
||||
|
||||
// Bias tensors
|
||||
GC,
|
||||
@@ -212,219 +212,269 @@ enum class ConvAlgorithmSpecialization
|
||||
LARGE_TENSOR
|
||||
};
|
||||
|
||||
// ostream operator overloads for enum classes
|
||||
inline std::ostream& operator<<(std::ostream& os, DataType dt)
|
||||
// toString methods for enum classes
|
||||
inline std::string_view toString(DataType dt)
|
||||
{
|
||||
using enum DataType;
|
||||
switch(dt)
|
||||
{
|
||||
case FP16: return os << "FP16";
|
||||
case FP32: return os << "FP32";
|
||||
case BF16: return os << "BF16";
|
||||
case FP8: return os << "FP8";
|
||||
case INT32: return os << "INT32";
|
||||
case I8: return os << "I8";
|
||||
case U8: return os << "U8";
|
||||
case UNDEFINDED: return os << "UNDEFINDED";
|
||||
default: return os << "Unknown";
|
||||
case FP16: return "FP16";
|
||||
case FP32: return "FP32";
|
||||
case BF16: return "BF16";
|
||||
case FP8: return "FP8";
|
||||
case INT32: return "INT32";
|
||||
case I8: return "I8";
|
||||
case U8: return "U8";
|
||||
case UNDEFINED_DATA_TYPE: return "UNDEFINED_DATA_TYPE";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir)
|
||||
inline std::string_view toString(ConvDirection dir)
|
||||
{
|
||||
using enum ConvDirection;
|
||||
switch(dir)
|
||||
{
|
||||
case FORWARD: return os << "Forward";
|
||||
case BACKWARD_DATA: return os << "Backward Data";
|
||||
case BACKWARD_WEIGHT: return os << "Backward Weight";
|
||||
default: return os << "Unknown";
|
||||
case FORWARD: return "Forward";
|
||||
case BACKWARD_DATA: return "Backward Data";
|
||||
case BACKWARD_WEIGHT: return "Backward Weight";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
|
||||
inline std::string_view toString(ElementwiseOperation op)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
switch(op)
|
||||
{
|
||||
case CLAMP: return os << "CLAMP";
|
||||
case SCALE: return os << "SCALE";
|
||||
case PASS_THROUGH: return os << "PASS_THROUGH";
|
||||
case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP";
|
||||
case SCALEADD_SCALEADD_RELU: return os << "SCALEADD_SCALEADD_RELU";
|
||||
default: return os << "Unknown";
|
||||
case CLAMP: return "CLAMP";
|
||||
case SCALE: return "SCALE";
|
||||
case PASS_THROUGH: return "PASS_THROUGH";
|
||||
case BIAS_BNORM_CLAMP: return "BIAS_BNORM_CLAMP";
|
||||
case SCALEADD_SCALEADD_RELU: return "SCALEADD_SCALEADD_RELU";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver)
|
||||
inline std::string_view toString(PipelineVersion ver)
|
||||
{
|
||||
using enum PipelineVersion;
|
||||
switch(ver)
|
||||
{
|
||||
case V1: return os << "V1";
|
||||
case V2: return os << "V2";
|
||||
case V3: return os << "V3";
|
||||
case V4: return os << "V4";
|
||||
case V5: return os << "V5";
|
||||
case WEIGHT_ONLY: return os << "WEIGHT_ONLY";
|
||||
default: return os << "Unknown";
|
||||
case V1: return "V1";
|
||||
case V2: return "V2";
|
||||
case V3: return "V3";
|
||||
case V4: return "V4";
|
||||
case V5: return "V5";
|
||||
case WEIGHT_ONLY: return "WEIGHT_ONLY";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
|
||||
inline std::string_view toString(GemmSpecialization spec)
|
||||
{
|
||||
using enum GemmSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case Default: return os << "Default";
|
||||
case MPadding: return os << "MPadding";
|
||||
case NPadding: return os << "NPadding";
|
||||
case KPadding: return os << "KPadding";
|
||||
case MNPadding: return os << "MNPadding";
|
||||
case MKPadding: return os << "MKPadding";
|
||||
case NKPadding: return os << "NKPadding";
|
||||
case MNKPadding: return os << "MNKPadding";
|
||||
case OPadding: return os << "OPadding";
|
||||
case MOPadding: return os << "MOPadding";
|
||||
case NOPadding: return os << "NOPadding";
|
||||
case KOPadding: return os << "KOPadding";
|
||||
case MNOPadding: return os << "MNOPadding";
|
||||
case MKOPadding: return os << "MKOPadding";
|
||||
case NKOPadding: return os << "NKOPadding";
|
||||
case MNKOPadding: return os << "MNKOPadding";
|
||||
default: return os << "Unknown";
|
||||
case Default: return "Default";
|
||||
case MPadding: return "MPadding";
|
||||
case NPadding: return "NPadding";
|
||||
case KPadding: return "KPadding";
|
||||
case MNPadding: return "MNPadding";
|
||||
case MKPadding: return "MKPadding";
|
||||
case NKPadding: return "NKPadding";
|
||||
case MNKPadding: return "MNKPadding";
|
||||
case OPadding: return "OPadding";
|
||||
case MOPadding: return "MOPadding";
|
||||
case NOPadding: return "NOPadding";
|
||||
case KOPadding: return "KOPadding";
|
||||
case MNOPadding: return "MNOPadding";
|
||||
case MKOPadding: return "MKOPadding";
|
||||
case NKOPadding: return "NKOPadding";
|
||||
case MNKOPadding: return "MNKOPadding";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
|
||||
inline std::string_view toString(ConvFwdSpecialization spec)
|
||||
{
|
||||
using enum ConvFwdSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_3x3: return os << "FILTER_3x3";
|
||||
default: return os << "Unknown";
|
||||
case DEFAULT: return "DEFAULT";
|
||||
case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_3x3: return "FILTER_3x3";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
|
||||
inline std::string_view toString(ConvBwdDataSpecialization spec)
|
||||
{
|
||||
using enum ConvBwdDataSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
|
||||
default: return os << "Unknown";
|
||||
case DEFAULT: return "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
|
||||
inline std::string_view toString(ConvBwdWeightSpecialization spec)
|
||||
{
|
||||
using enum ConvBwdWeightSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
|
||||
case ODD_C: return os << "ODD_C";
|
||||
default: return os << "Unknown";
|
||||
case DEFAULT: return "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0";
|
||||
case ODD_C: return "ODD_C";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmPadding padding)
|
||||
inline std::string_view toString(GemmPadding padding)
|
||||
{
|
||||
using enum GemmPadding;
|
||||
switch(padding)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case M_PADDING: return os << "M_PADDING";
|
||||
case N_PADDING: return os << "N_PADDING";
|
||||
case K_PADDING: return os << "K_PADDING";
|
||||
case MN_PADDING: return os << "MN_PADDING";
|
||||
case MK_PADDING: return os << "MK_PADDING";
|
||||
case NK_PADDING: return os << "NK_PADDING";
|
||||
case MNK_PADDING: return os << "MNK_PADDING";
|
||||
case O_PADDING: return os << "O_PADDING";
|
||||
case MO_PADDING: return os << "MO_PADDING";
|
||||
case NO_PADDING: return os << "NO_PADDING";
|
||||
case KO_PADDING: return os << "KO_PADDING";
|
||||
case MNO_PADDING: return os << "MNO_PADDING";
|
||||
case MKO_PADDING: return os << "MKO_PADDING";
|
||||
case NKO_PADDING: return os << "NKO_PADDING";
|
||||
case MNKO_PADDING: return os << "MNKO_PADDING";
|
||||
default: return os << "Unknown";
|
||||
case DEFAULT: return "DEFAULT";
|
||||
case M_PADDING: return "M_PADDING";
|
||||
case N_PADDING: return "N_PADDING";
|
||||
case K_PADDING: return "K_PADDING";
|
||||
case MN_PADDING: return "MN_PADDING";
|
||||
case MK_PADDING: return "MK_PADDING";
|
||||
case NK_PADDING: return "NK_PADDING";
|
||||
case MNK_PADDING: return "MNK_PADDING";
|
||||
case O_PADDING: return "O_PADDING";
|
||||
case MO_PADDING: return "MO_PADDING";
|
||||
case NO_PADDING: return "NO_PADDING";
|
||||
case KO_PADDING: return "KO_PADDING";
|
||||
case MNO_PADDING: return "MNO_PADDING";
|
||||
case MKO_PADDING: return "MKO_PADDING";
|
||||
case NKO_PADDING: return "NKO_PADDING";
|
||||
case MNKO_PADDING: return "MNKO_PADDING";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
|
||||
inline std::string_view toString(PipelineScheduler sched)
|
||||
{
|
||||
using enum PipelineScheduler;
|
||||
switch(sched)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case INTRAWAVE: return os << "INTRAWAVE";
|
||||
case INTERWAVE: return os << "INTERWAVE";
|
||||
default: return os << "Unknown";
|
||||
case DEFAULT: return "DEFAULT";
|
||||
case INTRAWAVE: return "INTRAWAVE";
|
||||
case INTERWAVE: return "INTERWAVE";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, TensorLayout layout)
|
||||
inline std::string_view toString(TensorLayout layout)
|
||||
{
|
||||
using enum TensorLayout;
|
||||
switch(layout)
|
||||
{
|
||||
case GNCW: return os << "GNCW";
|
||||
case GNWC: return os << "GNWC";
|
||||
case NWGC: return os << "NWGC";
|
||||
case NGCW: return os << "NGCW";
|
||||
case G_NW_C_strided: return os << "G_NW_C_strided";
|
||||
case GNCHW: return os << "GNCHW";
|
||||
case GNHWC: return os << "GNHWC";
|
||||
case NHWGC: return os << "NHWGC";
|
||||
case NGCHW: return os << "NGCHW";
|
||||
case G_NHW_C_strided: return os << "G_NHW_C_strided";
|
||||
case GNCDHW: return os << "GNCDHW";
|
||||
case GNDHWC: return os << "GNDHWC";
|
||||
case NDHWGC: return os << "NDHWGC";
|
||||
case NGCDHW: return os << "NGCDHW";
|
||||
case G_NDHW_C_strided: return os << "G_NDHW_C_strided";
|
||||
case GKXC: return os << "GKXC";
|
||||
case GKCX: return os << "GKCX";
|
||||
case KXGC: return os << "KXGC";
|
||||
case G_K_X_C_strided: return os << "G_K_X_C_strided";
|
||||
case GKYXC: return os << "GKYXC";
|
||||
case GKCYX: return os << "GKCYX";
|
||||
case KYXGC: return os << "KYXGC";
|
||||
case G_K_YX_C_strided: return os << "G_K_YX_C_strided";
|
||||
case GKZYXC: return os << "GKZYXC";
|
||||
case GKCZYX: return os << "GKCZYX";
|
||||
case KZYXGC: return os << "KZYXGC";
|
||||
case G_K_ZYX_C_strided: return os << "G_K_ZYX_C_strided";
|
||||
case GNKW: return os << "GNKW";
|
||||
case GNWK: return os << "GNWK";
|
||||
case NWGK: return os << "NWGK";
|
||||
case NGKW: return os << "NGKW";
|
||||
case G_NW_K_strided: return os << "G_NW_K_strided";
|
||||
case GNKHW: return os << "GNKHW";
|
||||
case GNHWK: return os << "GNHWK";
|
||||
case NHWGK: return os << "NHWGK";
|
||||
case NGKHW: return os << "NGKHW";
|
||||
case G_NHW_K_strided: return os << "G_NHW_K_strided";
|
||||
case GNKDHW: return os << "GNKDHW";
|
||||
case GNDHWK: return os << "GNDHWK";
|
||||
case NDHWGK: return os << "NDHWGK";
|
||||
case NGKDHW: return os << "NGKDHW";
|
||||
case G_NDHW_K_strided: return os << "G_NDHW_K_strided";
|
||||
case GC: return os << "GC";
|
||||
case G_C_strided: return os << "G_C_strided";
|
||||
case G_K_strided: return os << "G_K_strided";
|
||||
case UNDEFINED: return os << "UNDEFINED";
|
||||
default: return os << "Unknown";
|
||||
case GNCW: return "GNCW";
|
||||
case GNWC: return "GNWC";
|
||||
case NWGC: return "NWGC";
|
||||
case NGCW: return "NGCW";
|
||||
case G_NW_C_strided: return "G_NW_C_strided";
|
||||
case GNCHW: return "GNCHW";
|
||||
case GNHWC: return "GNHWC";
|
||||
case NHWGC: return "NHWGC";
|
||||
case NGCHW: return "NGCHW";
|
||||
case G_NHW_C_strided: return "G_NHW_C_strided";
|
||||
case GNCDHW: return "GNCDHW";
|
||||
case GNDHWC: return "GNDHWC";
|
||||
case NDHWGC: return "NDHWGC";
|
||||
case NGCDHW: return "NGCDHW";
|
||||
case G_NDHW_C_strided: return "G_NDHW_C_strided";
|
||||
case GKXC: return "GKXC";
|
||||
case GKCX: return "GKCX";
|
||||
case KXGC: return "KXGC";
|
||||
case G_K_X_C_strided: return "G_K_X_C_strided";
|
||||
case GKYXC: return "GKYXC";
|
||||
case GKCYX: return "GKCYX";
|
||||
case KYXGC: return "KYXGC";
|
||||
case G_K_YX_C_strided: return "G_K_YX_C_strided";
|
||||
case GKZYXC: return "GKZYXC";
|
||||
case GKCZYX: return "GKCZYX";
|
||||
case KZYXGC: return "KZYXGC";
|
||||
case G_K_ZYX_C_strided: return "G_K_ZYX_C_strided";
|
||||
case GNKW: return "GNKW";
|
||||
case GNWK: return "GNWK";
|
||||
case NWGK: return "NWGK";
|
||||
case NGKW: return "NGKW";
|
||||
case G_NW_K_strided: return "G_NW_K_strided";
|
||||
case GNKHW: return "GNKHW";
|
||||
case GNHWK: return "GNHWK";
|
||||
case NHWGK: return "NHWGK";
|
||||
case NGKHW: return "NGKHW";
|
||||
case G_NHW_K_strided: return "G_NHW_K_strided";
|
||||
case GNKDHW: return "GNKDHW";
|
||||
case GNDHWK: return "GNDHWK";
|
||||
case NDHWGK: return "NDHWGK";
|
||||
case NGKDHW: return "NGKDHW";
|
||||
case G_NDHW_K_strided: return "G_NDHW_K_strided";
|
||||
case GC: return "GC";
|
||||
case G_C_strided: return "G_C_strided";
|
||||
case G_K_strided: return "G_K_strided";
|
||||
case UNDEFINED_TENSOR_LAYOUT: return "UNDEFINED_TENSOR_LAYOUT";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
// ostream operator overloads for enum classes
|
||||
inline std::ostream& operator<<(std::ostream& os, DataType dt) { return os << toString(dt); }
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir) { return os << toString(dir); }
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
|
||||
{
|
||||
return os << toString(op);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver)
|
||||
{
|
||||
return os << toString(ver);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
|
||||
{
|
||||
return os << toString(spec);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
|
||||
{
|
||||
return os << toString(spec);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
|
||||
{
|
||||
return os << toString(spec);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
|
||||
{
|
||||
return os << toString(spec);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmPadding padding)
|
||||
{
|
||||
return os << toString(padding);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
|
||||
{
|
||||
return os << toString(sched);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, TensorLayout layout)
|
||||
{
|
||||
return os << toString(layout);
|
||||
}
|
||||
|
||||
// ostream operator overload for std::variant of convolution specializations
|
||||
inline std::ostream& operator<<(std::ostream& os,
|
||||
const std::variant<ConvFwdSpecialization,
|
||||
|
||||
Reference in New Issue
Block a user