Fix concepts for convolution signature.

We split the concepts to a check on the signature type (ConvSignatureDescriptor) as well as a check on the value (ValidConvSignature).
This commit is contained in:
John Shumway
2025-09-08 19:56:33 +00:00
parent d85ba0965b
commit 19c573fb70
5 changed files with 18 additions and 16 deletions

View File

@@ -22,8 +22,10 @@ namespace ck_tile::builder {
* @tparam ALGORITHM The specific convolution algorithm to be used for the implementation.
* @tparam VERSION The version of the builder implementation.
*/
template <ConvSignature auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM, StringLiteral VERSION>
requires SupportedVersion<VERSION>
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires SupportedVersion<VERSION> && ValidConvSignature<SIGNATURE>
struct ConvBuilder
{
static constexpr auto kVersion = VERSION;

View File

@@ -236,7 +236,7 @@ constexpr ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
}
// Factory builds an instance of a grouped convolution kernel.
template <ConvSignature auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM, auto Version>
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM, auto Version>
requires SupportedVersion<Version>
struct GroupedConvForwardXldCShuffleFactoryV3
{

View File

@@ -45,18 +45,18 @@ enum class ElementwiseOperation
// Operational signature of a convolution.
template <typename T>
concept ConvSignature = requires {
// Dimensionality of the convolution (e.g., 1, 2, or 3).
requires ConvSpatialDim<T::spatial_dim>;
concept ConvSignatureDescriptor = requires(T t) {
{ t.spatial_dim } -> std::convertible_to<int>;
{ t.direction } -> std::convertible_to<ConvDirection>;
{ t.layout } -> std::convertible_to<GroupConvLayout>;
{ t.data_type } -> std::convertible_to<DataType>;
};
// Direction of the convolition (fwd, bwd, or weights).
{ T::direction } -> std::same_as<const ConvDirection&>;
// Memory layout of the tensors.
{ T::layout } -> std::same_as<const GroupConvLayout&>;
// Tensor datatype for input and output.
requires ConvDataType<T::data_type>;
// Valid values for a convolution signature.
template <auto Sig>
concept ValidConvSignature = requires {
requires ConvSpatialDim<Sig.spatial_dim>;
requires ConvDataType<Sig.data_type>;
};
} // namespace ck_tile::builder

View File

@@ -13,7 +13,7 @@ struct FwdConvSignature
static constexpr auto layout = ckb::GroupConvLayout::NHWGC_GKYXC_NHWGK;
static constexpr auto data_type = ckb::DataType::FP16;
};
static_assert(ckb::ConvSignature<FwdConvSignature>);
static_assert(ckb::ConvSignatureDescriptor<FwdConvSignature>);
struct DefaultFwdConvAlgorithm
{

View File

@@ -22,7 +22,7 @@ struct FwdConvSignature
static constexpr auto layout = ckb::GroupConvLayout::NHWGC_GKYXC_NHWGK;
static constexpr auto data_type = ckb::DataType::FP16;
};
static_assert(ckb::ConvSignature<FwdConvSignature>);
static_assert(ckb::ConvSignatureDescriptor<FwdConvSignature>);
constexpr char API_VERSION[] = "0.1.0";
static_assert(ckb::SupportedVersion<API_VERSION>);