mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Fix concepts for convolution signature.
We split the concepts to a check on the signature type (ConvSignatureDescriptor) as well as a check on the value (ValidConvSignature).
This commit is contained in:
@@ -22,8 +22,10 @@ namespace ck_tile::builder {
|
||||
* @tparam ALGORITHM The specific convolution algorithm to be used for the implementation.
|
||||
* @tparam VERSION The version of the builder implementation.
|
||||
*/
|
||||
template <ConvSignature auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM, StringLiteral VERSION>
|
||||
requires SupportedVersion<VERSION>
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires SupportedVersion<VERSION> && ValidConvSignature<SIGNATURE>
|
||||
struct ConvBuilder
|
||||
{
|
||||
static constexpr auto kVersion = VERSION;
|
||||
|
||||
@@ -236,7 +236,7 @@ constexpr ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
}
|
||||
|
||||
// Factory builds an instance of a grouped convolution kernel.
|
||||
template <ConvSignature auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM, auto Version>
|
||||
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM, auto Version>
|
||||
requires SupportedVersion<Version>
|
||||
struct GroupedConvForwardXldCShuffleFactoryV3
|
||||
{
|
||||
|
||||
@@ -45,18 +45,18 @@ enum class ElementwiseOperation
|
||||
|
||||
// Operational signature of a convolution.
|
||||
template <typename T>
|
||||
concept ConvSignature = requires {
|
||||
// Dimensionality of the convolution (e.g., 1, 2, or 3).
|
||||
requires ConvSpatialDim<T::spatial_dim>;
|
||||
concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.spatial_dim } -> std::convertible_to<int>;
|
||||
{ t.direction } -> std::convertible_to<ConvDirection>;
|
||||
{ t.layout } -> std::convertible_to<GroupConvLayout>;
|
||||
{ t.data_type } -> std::convertible_to<DataType>;
|
||||
};
|
||||
|
||||
// Direction of the convolition (fwd, bwd, or weights).
|
||||
{ T::direction } -> std::same_as<const ConvDirection&>;
|
||||
|
||||
// Memory layout of the tensors.
|
||||
{ T::layout } -> std::same_as<const GroupConvLayout&>;
|
||||
|
||||
// Tensor datatype for input and output.
|
||||
requires ConvDataType<T::data_type>;
|
||||
// Valid values for a convolution signature.
|
||||
template <auto Sig>
|
||||
concept ValidConvSignature = requires {
|
||||
requires ConvSpatialDim<Sig.spatial_dim>;
|
||||
requires ConvDataType<Sig.data_type>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -13,7 +13,7 @@ struct FwdConvSignature
|
||||
static constexpr auto layout = ckb::GroupConvLayout::NHWGC_GKYXC_NHWGK;
|
||||
static constexpr auto data_type = ckb::DataType::FP16;
|
||||
};
|
||||
static_assert(ckb::ConvSignature<FwdConvSignature>);
|
||||
static_assert(ckb::ConvSignatureDescriptor<FwdConvSignature>);
|
||||
|
||||
struct DefaultFwdConvAlgorithm
|
||||
{
|
||||
|
||||
@@ -22,7 +22,7 @@ struct FwdConvSignature
|
||||
static constexpr auto layout = ckb::GroupConvLayout::NHWGC_GKYXC_NHWGK;
|
||||
static constexpr auto data_type = ckb::DataType::FP16;
|
||||
};
|
||||
static_assert(ckb::ConvSignature<FwdConvSignature>);
|
||||
static_assert(ckb::ConvSignatureDescriptor<FwdConvSignature>);
|
||||
|
||||
constexpr char API_VERSION[] = "0.1.0";
|
||||
static_assert(ckb::SupportedVersion<API_VERSION>);
|
||||
|
||||
Reference in New Issue
Block a user