Merge commit '4ce7d4c511c7e98a9ac01580ed1e9112e59061a0' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-23 10:13:44 +00:00
parent b8269a8c17
commit 3e31171d74
3 changed files with 300 additions and 4 deletions

View File

@@ -80,12 +80,24 @@ concept ConvOutputLayout3D =
(L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) ||
(L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided);
template <typename T>
concept HasDataType = requires(T t) {
{ t.data_type };
};
// Note: for signature and TensorConfigDescriptor,
// it is not required to provide a default data type, but if one is provided, check if well defined
template <typename T>
concept DataTypeWellDefinedIfProvided = requires(T t) {
requires !HasDataType<T> || requires {
{ t.data_type } -> std::convertible_to<DataType>;
};
};
template <typename T>
concept TensorConfigDescriptor = requires(T t) {
{ t.layout } -> std::convertible_to<TensorLayout>;
// Only require that data type is defined. It might be set to undefined value, in which case the
// signature's data type is used.
{ t.data_type } -> std::convertible_to<DataType>;
requires DataTypeWellDefinedIfProvided<T>;
};
template <typename T>
@@ -164,11 +176,11 @@ concept HasElementwiseOpWithAuxiliaryOperands = requires(T t) {
template <typename T>
concept ConvSignatureDescriptor = requires(T t) {
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
{ t.data_type } -> std::convertible_to<DataType>;
{ t.input } -> ConvTensorDescriptor;
{ t.weight } -> ConvTensorDescriptor;
{ t.output } -> ConvTensorDescriptor;
requires ConvolutionDirectionWellDefinedIfProvided<T>;
requires DataTypeWellDefinedIfProvided<T>;
};
// Concept to validate a convolution signature's values.

View File

@@ -0,0 +1,192 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include <type_traits>
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder {
/**********************************************
* constexpr helper functions for optional parameters
**********************************************/
template <auto Sig>
concept ProvidesElementwiseOperation = requires { Sig.elementwise_operation; };
template <auto Sig>
concept ProvidesDataType = requires { Sig.data_type; };
template <auto ConvTensor>
concept ConvTensorHasOp = requires { ConvTensor.operation; };
template <auto Sig>
concept ProvidesConvolutionDirection = requires { Sig.direction; };
// returns elementwise operation for input tensor
// will defalut to signature's generic type if provided
// otherwise, default to PASS_THROUGH
template <auto Sig>
requires ValidConvSignature<Sig>
constexpr auto getInputElementwiseOperation()
{
if constexpr(ConvTensorHasOp<Sig.input>)
{
return Sig.input.operation.elementwise_operation;
}
else if constexpr(ProvidesElementwiseOperation<Sig>)
{
return Sig.elementwise_operation;
}
else
{
return ElementwiseOperation::PASS_THROUGH;
}
}
// returns elementwise operation for weight tensor
// will defalut to signature's generic type if provided
// otherwise, default to PASS_THROUGH
template <auto Sig>
requires ValidConvSignature<Sig>
constexpr auto getWeightElementwiseOperation()
{
if constexpr(ConvTensorHasOp<Sig.weight>)
{
return Sig.weight.operation.elementwise_operation;
}
else if constexpr(ProvidesElementwiseOperation<Sig>)
{
return Sig.elementwise_operation;
}
else
{
return ElementwiseOperation::PASS_THROUGH;
}
}
// returns elementwise operation for output tensor
// will defalut to signature's generic type if provided
// otherwise, default to PASS_THROUGH
template <auto Sig>
requires ValidConvSignature<Sig>
constexpr auto getOutputElementwiseOperation()
{
if constexpr(ConvTensorHasOp<Sig.output>)
{
return Sig.output.operation.elementwise_operation;
}
else if constexpr(ProvidesElementwiseOperation<Sig>)
{
return Sig.elementwise_operation;
}
else
{
return ElementwiseOperation::PASS_THROUGH;
}
}
// returns convolution direction for signature. Will default to FORWARD if not provided by signature
template <auto Sig>
requires ValidConvSignature<Sig>
constexpr auto getConvDirection()
{
if constexpr(ProvidesConvolutionDirection<Sig>)
{
return Sig.direction;
}
else
{
return ConvDirection::FORWARD;
}
}
// generic helper that returns data_type if provided and UNDEFINED otherwise
// can be used on both signature and TensorConfigDescriptor objects
template <auto TensorConfigOrSig>
constexpr auto getDataType()
{
if constexpr(ProvidesDataType<TensorConfigOrSig>)
{
return TensorConfigOrSig.data_type;
}
else
{
return DataType::UNDEFINED_DATA_TYPE;
}
}
// return data type of input tensor
template <auto Sig>
requires ValidConvSignature<Sig>
consteval auto getInputDataType()
{
constexpr auto tensorDataType = getDataType<Sig.input.config>();
constexpr auto universalDataType = getDataType<Sig>();
if constexpr(tensorDataType != DataType::UNDEFINED_DATA_TYPE)
{
return tensorDataType;
}
else
{
return universalDataType;
}
}
template <auto Sig>
requires ValidConvSignature<Sig>
consteval auto getWeightDataType()
{
constexpr auto tensorDataType = getDataType<Sig.weight.config>();
constexpr auto universalDataType = getDataType<Sig>();
if constexpr(tensorDataType != DataType::UNDEFINED_DATA_TYPE)
{
return tensorDataType;
}
else
{
return universalDataType;
}
}
template <auto Sig>
requires ValidConvSignature<Sig>
consteval auto getOutputDataType()
{
constexpr auto tensorDataType = getDataType<Sig.output.config>();
constexpr auto universalDataType = getDataType<Sig>();
if constexpr(tensorDataType != DataType::UNDEFINED_DATA_TYPE)
{
return tensorDataType;
}
else
{
return universalDataType;
}
}
// returns data type if and only if all tensors have the same type.
// Otherwise, return DataType::UNDEFINED_DATA_TYPE
template <auto Sig>
requires ValidConvSignature<Sig>
consteval auto getDataTypeIfCommon()
{
auto inputDataType = getInputDataType<Sig>();
auto weightDataType = getWeightDataType<Sig>();
auto outputDataType = getOutputDataType<Sig>();
if(inputDataType == weightDataType && inputDataType == outputDataType)
{
return inputDataType;
}
else
{
return DataType::UNDEFINED_DATA_TYPE;
}
}
} // namespace ck_tile::builder