From 19c573fb70380b390db03a8f2ca1b26647da381a Mon Sep 17 00:00:00 2001 From: John Shumway Date: Mon, 8 Sep 2025 19:56:33 +0000 Subject: [PATCH] Fix concepts for convolution signature. We split the concepts to a check on the signature type (ConvSignatureDescriptor) as well as a check on the value (ValidConvSignature). --- .../include/ck_tile/builder/conv_builder.hpp | 6 +++-- .../include/ck_tile/builder/conv_factory.hpp | 2 +- .../ck_tile/builder/conv_signature.hpp | 22 +++++++++---------- .../builder/test/test_conv_builder.cpp | 2 +- .../builder/test/test_conv_instances.cpp | 2 +- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_builder.hpp b/experimental/builder/include/ck_tile/builder/conv_builder.hpp index c1e1b3c6d2..224b2b34c2 100644 --- a/experimental/builder/include/ck_tile/builder/conv_builder.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_builder.hpp @@ -22,8 +22,10 @@ namespace ck_tile::builder { * @tparam ALGORITHM The specific convolution algorithm to be used for the implementation. * @tparam VERSION The version of the builder implementation. */ -template - requires SupportedVersion +template + requires SupportedVersion && ValidConvSignature struct ConvBuilder { static constexpr auto kVersion = VERSION; diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 3fc6a14fd9..5978f43376 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -236,7 +236,7 @@ constexpr ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() } // Factory builds an instance of a grouped convolution kernel. -template +template requires SupportedVersion struct GroupedConvForwardXldCShuffleFactoryV3 { diff --git a/experimental/builder/include/ck_tile/builder/conv_signature.hpp b/experimental/builder/include/ck_tile/builder/conv_signature.hpp index 7a419a2559..940757cf97 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature.hpp @@ -45,18 +45,18 @@ enum class ElementwiseOperation // Operational signature of a convolution. template -concept ConvSignature = requires { - // Dimensionality of the convolution (e.g., 1, 2, or 3). - requires ConvSpatialDim; +concept ConvSignatureDescriptor = requires(T t) { + { t.spatial_dim } -> std::convertible_to; + { t.direction } -> std::convertible_to; + { t.layout } -> std::convertible_to; + { t.data_type } -> std::convertible_to; +}; - // Direction of the convolition (fwd, bwd, or weights). - { T::direction } -> std::same_as; - - // Memory layout of the tensors. - { T::layout } -> std::same_as; - - // Tensor datatype for input and output. - requires ConvDataType; +// Valid values for a convolution signature. +template +concept ValidConvSignature = requires { + requires ConvSpatialDim; + requires ConvDataType; }; } // namespace ck_tile::builder diff --git a/experimental/builder/test/test_conv_builder.cpp b/experimental/builder/test/test_conv_builder.cpp index 2012aa1170..2e1463f281 100644 --- a/experimental/builder/test/test_conv_builder.cpp +++ b/experimental/builder/test/test_conv_builder.cpp @@ -13,7 +13,7 @@ struct FwdConvSignature static constexpr auto layout = ckb::GroupConvLayout::NHWGC_GKYXC_NHWGK; static constexpr auto data_type = ckb::DataType::FP16; }; -static_assert(ckb::ConvSignature); +static_assert(ckb::ConvSignatureDescriptor); struct DefaultFwdConvAlgorithm { diff --git a/experimental/builder/test/test_conv_instances.cpp b/experimental/builder/test/test_conv_instances.cpp index 0cc194b213..5bbf6e3300 100644 --- a/experimental/builder/test/test_conv_instances.cpp +++ b/experimental/builder/test/test_conv_instances.cpp @@ -22,7 +22,7 @@ struct FwdConvSignature static constexpr auto layout = ckb::GroupConvLayout::NHWGC_GKYXC_NHWGK; static constexpr auto data_type = ckb::DataType::FP16; }; -static_assert(ckb::ConvSignature); +static_assert(ckb::ConvSignatureDescriptor); constexpr char API_VERSION[] = "0.1.0"; static_assert(ckb::SupportedVersion);