diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 5b3fb0e6a8..d839518285 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -563,7 +563,7 @@ struct ConvFactory SPATIAL_DIM, ConvDirection::FORWARD>()); using Types = factory_internal::ConvTensorTypes; - using Ops = factory_internal::ElementwiseOps; + using Ops = factory_internal::ElementwiseOps()>; using AlgorithmType = decltype(ALGORITHM); static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load == diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 983273b439..404bef9082 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -43,14 +43,41 @@ concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == D template concept ConvLayout = std::same_as, GroupConvLayout>; +template +concept HasElementwiseOp = requires(T t) { + { t.elementwise_operation }; +}; + +template +concept HasConvolutionDirection = requires(T t) { + { t.direction }; +}; + +// Note: it is not required to provide an ElementwiseOp, but if one is provided, check if well +// defined +template +concept ElementwiseOpWellDefinedIfProvided = requires(T t) { + requires !HasElementwiseOp || requires { + { t.elementwise_operation } -> std::convertible_to; + }; +}; + +// Note: it is not required to provide a convolution, but if one is provided, check if well defined +template +concept ConvolutionDirectionWellDefinedIfProvided = requires(T t) { + requires !HasConvolutionDirection || requires { + { t.direction } -> std::convertible_to; + }; +}; + // Concept for a type that defines a convolution's operational signature. template concept ConvSignatureDescriptor = requires(T t) { { t.spatial_dim } -> std::convertible_to; - { t.direction } -> std::convertible_to; { t.layout } -> ConvLayout; { t.data_type } -> std::convertible_to; - { t.elementwise_operation } -> std::convertible_to; + requires ElementwiseOpWellDefinedIfProvided; + requires ConvolutionDirectionWellDefinedIfProvided; }; // Concept to validate a convolution signature's values. diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 733359d491..b53cdc39c7 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -19,6 +19,17 @@ namespace ckt = ck_tile::test; // Defines the signature of the convolution operation to be tested. // This includes dimensionality, direction, data layout, and data type. struct ConvSignature +{ + int spatial_dim = 2; + ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; + ckb::DataType data_type = ckb::DataType::FP16; + ckb::GroupConvDeviceOp device_operation = + ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; +}; +static_assert(ckb::ConvSignatureDescriptor); + +// Compile time tests for concepts +struct ConvSignatureWithOptionalParams { int spatial_dim = 2; ckb::ConvDirection direction = ckb::ConvDirection::FORWARD; @@ -26,7 +37,19 @@ struct ConvSignature ckb::DataType data_type = ckb::DataType::FP16; ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH; }; -static_assert(ckb::ConvSignatureDescriptor); +static_assert(ckb::ConvSignatureDescriptor); + +struct ConvSignatureWithInvalidOptionalParams +{ + int spatial_dim = 2; + ckb::ConvDirection direction = ckb::ConvDirection::FORWARD; + ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; + ckb::DataType data_type = ckb::DataType::FP16; + int elementwise_operation = 7; // this should fail + ckb::GroupConvDeviceOp device_operation = + ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; +}; +static_assert(!ckb::ConvSignatureDescriptor); struct DefaultAlgorithm {