Update some more concept names.

This commit is contained in:
John Shumway
2025-09-07 19:35:54 +00:00
parent 32b4b27031
commit da140c434b
5 changed files with 14 additions and 14 deletions

View File

@@ -148,12 +148,12 @@ enum class BlockGemmPipelineVersion
// Concept to check if struct specifies block_gemm_pipeline_version.
template <typename T>
concept ProvidesBlockGemmPipelineVersion = requires {
concept SpecifiesGemmPipelineVersion = requires {
{ T::pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
};
// No requirements yet for a ConvAlogorithm concept.
template <typename T>
concept ConvAlgorithm = std::is_class_v<T>;
concept ConvAlgorithmDescriptor = std::is_class_v<T>;
} // namespace ck_tile::builder

View File

@@ -22,7 +22,7 @@ namespace ck_tile::builder {
* @tparam ALGORITHM The specific convolution algorithm to be used for the implementation.
* @tparam VERSION The version of the builder implementation.
*/
template <ConvSignature auto SIGNATURE, ConvAlgorithm auto ALGORITHM, StringLiteral VERSION>
template <ConvSignature auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM, StringLiteral VERSION>
requires SupportedVersion<VERSION>
struct ConvBuilder
{

View File

@@ -73,7 +73,7 @@ struct ConvBlock
MNK<int> per_block;
};
template <ConvAlgorithm auto ALGORITHM>
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr ConvBlock SetThreadBlockInfo()
{
using AlgorithmType = decltype(ALGORITHM);
@@ -102,7 +102,7 @@ struct ConvTuning
int n_xdl_per_wave = 0;
};
template <ConvAlgorithm auto ALGORITHM>
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr ConvTuning SetConvTuningInfo()
{
using AlgorithmType = decltype(ALGORITHM);
@@ -150,7 +150,7 @@ struct CBlockTransfer
int scaler_per_vector = 8;
};
template <ConvAlgorithm auto ALGORITHM>
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr BlockTransfer SetABlockTransfer()
{
BlockTransfer block_transfer{
@@ -172,7 +172,7 @@ constexpr BlockTransfer SetABlockTransfer()
return block_transfer;
}
template <ConvAlgorithm auto ALGORITHM>
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr BlockTransfer SetBBlockTransfer()
{
BlockTransfer block_transfer{
@@ -194,7 +194,7 @@ constexpr BlockTransfer SetBBlockTransfer()
return block_transfer;
}
template <ConvAlgorithm auto ALGORITHM>
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr CBlockTransfer SetCBlockTransfer()
{
CBlockTransfer block_transfer{
@@ -217,11 +217,11 @@ constexpr CBlockTransfer SetCBlockTransfer()
return block_transfer;
}
template <ConvAlgorithm auto ALGORITHM>
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
{
using AlgorithmType = decltype(ALGORITHM);
if constexpr(ProvidesBlockGemmPipelineVersion<AlgorithmType>)
if constexpr(SpecifiesGemmPipelineVersion<AlgorithmType>)
{
switch(ALGORITHM.pipeline_version)
{
@@ -236,7 +236,7 @@ constexpr ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
}
// Factory builds an instance of a grouped convolution kernel.
template <ConvSignature auto SIGNATURE, ConvAlgorithm auto ALGORITHM, auto Version>
template <ConvSignature auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM, auto Version>
requires SupportedVersion<Version>
struct GroupedConvForwardXldCShuffleFactoryV3
{

View File

@@ -18,7 +18,7 @@ static_assert(ckb::ConvSignature<FwdConvSignature>);
struct DefaultFwdConvAlgorithm
{
};
static_assert(ckb::ConvAlgorithm<DefaultFwdConvAlgorithm>);
static_assert(ckb::ConvAlgorithmDescriptor<DefaultFwdConvAlgorithm>);
constexpr char API_VERSION[] = "0.1.0";
static_assert(ckb::SupportedVersion<API_VERSION>);

View File

@@ -42,13 +42,13 @@ struct FwdConvAlgorithm
} block_transfer;
ckb::BlockGemmPipelineVersion pipeline_version;
};
static_assert(ckb::ConvAlgorithm<FwdConvAlgorithm>);
static_assert(ckb::ConvAlgorithmDescriptor<FwdConvAlgorithm>);
static_assert(ckb::SpecifiesThreadBlock<FwdConvAlgorithm>);
static_assert(ckb::SpecifiesConvTuning<FwdConvAlgorithm>);
static_assert(ckb::SpecifiesBlockATransfer<FwdConvAlgorithm>);
static_assert(ckb::SpecifiesBlockBTransfer<FwdConvAlgorithm>);
static_assert(ckb::SpecifiesBlockCTransfer<FwdConvAlgorithm>);
static_assert(ckb::ProvidesBlockGemmPipelineVersion<FwdConvAlgorithm>);
static_assert(ckb::SpecifiesGemmPipelineVersion<FwdConvAlgorithm>);
// A container for a single test case, bundling a descriptive name, the
// algorithm configuration, and the expected generated kernel type string.