[CK_Builder] removed direction and elementwise_operation from required parameters … (#3192)

Removed direction and elementwise operation from default values required for convolution signature concept. Added constexpr helpers to set default values. Add compile-time tests.
This commit is contained in:
kabrahamAMD
2025-11-18 00:23:48 +01:00
committed by GitHub
parent 22a934a229
commit 92498464f6
3 changed files with 54 additions and 4 deletions

View File

@@ -563,7 +563,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
using AlgorithmType = decltype(ALGORITHM);
static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load ==

View File

@@ -43,14 +43,41 @@ concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == D
template <typename T>
concept ConvLayout = std::same_as<std::remove_cvref_t<T>, GroupConvLayout>;
template <typename T>
concept HasElementwiseOp = requires(T t) {
{ t.elementwise_operation };
};
template <typename T>
concept HasConvolutionDirection = requires(T t) {
{ t.direction };
};
// Note: it is not required to provide an ElementwiseOp, but if one is provided, check if well
// defined
template <typename T>
concept ElementwiseOpWellDefinedIfProvided = requires(T t) {
requires !HasElementwiseOp<T> || requires {
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
};
};
// Note: it is not required to provide a convolution, but if one is provided, check if well defined
template <typename T>
concept ConvolutionDirectionWellDefinedIfProvided = requires(T t) {
requires !HasConvolutionDirection<T> || requires {
{ t.direction } -> std::convertible_to<ConvDirection>;
};
};
// Concept for a type that defines a convolution's operational signature.
template <typename T>
concept ConvSignatureDescriptor = requires(T t) {
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
{ t.direction } -> std::convertible_to<ConvDirection>;
{ t.layout } -> ConvLayout;
{ t.data_type } -> std::convertible_to<DataType>;
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
requires ElementwiseOpWellDefinedIfProvided<T>;
requires ConvolutionDirectionWellDefinedIfProvided<T>;
};
// Concept to validate a convolution signature's values.