mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Update some more concept names.
This commit is contained in:
@@ -148,12 +148,12 @@ enum class BlockGemmPipelineVersion
|
||||
|
||||
// Concept to check if struct specifies block_gemm_pipeline_version.
|
||||
template <typename T>
|
||||
concept ProvidesBlockGemmPipelineVersion = requires {
|
||||
concept SpecifiesGemmPipelineVersion = requires {
|
||||
{ T::pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
|
||||
};
|
||||
|
||||
// No requirements yet for a ConvAlogorithm concept.
|
||||
template <typename T>
|
||||
concept ConvAlgorithm = std::is_class_v<T>;
|
||||
concept ConvAlgorithmDescriptor = std::is_class_v<T>;
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -22,7 +22,7 @@ namespace ck_tile::builder {
|
||||
* @tparam ALGORITHM The specific convolution algorithm to be used for the implementation.
|
||||
* @tparam VERSION The version of the builder implementation.
|
||||
*/
|
||||
template <ConvSignature auto SIGNATURE, ConvAlgorithm auto ALGORITHM, StringLiteral VERSION>
|
||||
template <ConvSignature auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM, StringLiteral VERSION>
|
||||
requires SupportedVersion<VERSION>
|
||||
struct ConvBuilder
|
||||
{
|
||||
|
||||
@@ -73,7 +73,7 @@ struct ConvBlock
|
||||
MNK<int> per_block;
|
||||
};
|
||||
|
||||
template <ConvAlgorithm auto ALGORITHM>
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr ConvBlock SetThreadBlockInfo()
|
||||
{
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
@@ -102,7 +102,7 @@ struct ConvTuning
|
||||
int n_xdl_per_wave = 0;
|
||||
};
|
||||
|
||||
template <ConvAlgorithm auto ALGORITHM>
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr ConvTuning SetConvTuningInfo()
|
||||
{
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
@@ -150,7 +150,7 @@ struct CBlockTransfer
|
||||
int scaler_per_vector = 8;
|
||||
};
|
||||
|
||||
template <ConvAlgorithm auto ALGORITHM>
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr BlockTransfer SetABlockTransfer()
|
||||
{
|
||||
BlockTransfer block_transfer{
|
||||
@@ -172,7 +172,7 @@ constexpr BlockTransfer SetABlockTransfer()
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
template <ConvAlgorithm auto ALGORITHM>
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr BlockTransfer SetBBlockTransfer()
|
||||
{
|
||||
BlockTransfer block_transfer{
|
||||
@@ -194,7 +194,7 @@ constexpr BlockTransfer SetBBlockTransfer()
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
template <ConvAlgorithm auto ALGORITHM>
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr CBlockTransfer SetCBlockTransfer()
|
||||
{
|
||||
CBlockTransfer block_transfer{
|
||||
@@ -217,11 +217,11 @@ constexpr CBlockTransfer SetCBlockTransfer()
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
template <ConvAlgorithm auto ALGORITHM>
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
{
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(ProvidesBlockGemmPipelineVersion<AlgorithmType>)
|
||||
if constexpr(SpecifiesGemmPipelineVersion<AlgorithmType>)
|
||||
{
|
||||
switch(ALGORITHM.pipeline_version)
|
||||
{
|
||||
@@ -236,7 +236,7 @@ constexpr ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
}
|
||||
|
||||
// Factory builds an instance of a grouped convolution kernel.
|
||||
template <ConvSignature auto SIGNATURE, ConvAlgorithm auto ALGORITHM, auto Version>
|
||||
template <ConvSignature auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM, auto Version>
|
||||
requires SupportedVersion<Version>
|
||||
struct GroupedConvForwardXldCShuffleFactoryV3
|
||||
{
|
||||
|
||||
@@ -18,7 +18,7 @@ static_assert(ckb::ConvSignature<FwdConvSignature>);
|
||||
struct DefaultFwdConvAlgorithm
|
||||
{
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithm<DefaultFwdConvAlgorithm>);
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<DefaultFwdConvAlgorithm>);
|
||||
|
||||
constexpr char API_VERSION[] = "0.1.0";
|
||||
static_assert(ckb::SupportedVersion<API_VERSION>);
|
||||
|
||||
@@ -42,13 +42,13 @@ struct FwdConvAlgorithm
|
||||
} block_transfer;
|
||||
ckb::BlockGemmPipelineVersion pipeline_version;
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithm<FwdConvAlgorithm>);
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<FwdConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesThreadBlock<FwdConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesConvTuning<FwdConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBlockATransfer<FwdConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBlockBTransfer<FwdConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBlockCTransfer<FwdConvAlgorithm>);
|
||||
static_assert(ckb::ProvidesBlockGemmPipelineVersion<FwdConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesGemmPipelineVersion<FwdConvAlgorithm>);
|
||||
|
||||
// A container for a single test case, bundling a descriptive name, the
|
||||
// algorithm configuration, and the expected generated kernel type string.
|
||||
|
||||
Reference in New Issue
Block a user