[CK_BUILDER] CK Tile header installation for builder, algorithm concept improvements (#3419)

* Added install of CK_Tile headers when using CK_EXPERIMENTAL_BUILDER. MIOpen needs this since the builder uses features from CK Tile and the CK Tile install is excluded when doing a narrow build for MIOpen
* Changed algorithm concept type checks to be concepts instead of constexpr bool functions. This improves compiler error messages when using these concepts in static_asserts

---------

Co-authored-by: Daryl Hawkins <DarylHawkins@amd.com>
This commit is contained in:
DarylHawkinsAMD
2025-12-15 16:24:36 -07:00
committed by GitHub
parent 2544e394cf
commit 1e6bbed1fb
2 changed files with 34 additions and 48 deletions

View File

@@ -84,63 +84,46 @@ namespace ck_tile::builder::factory {
// CK Tile kernel
template <typename T>
consteval bool IsTileAlgorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> && SpecifiesTileTransfer<T> &&
SpecifiesTileConvSpecialization<T> && SpecifiesTileBlockGemm<T> &&
SpecifiesTileOptimizations<T>;
}
concept IsTileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> &&
SpecifiesTileTransfer<T> && SpecifiesTileConvSpecialization<T> &&
SpecifiesTileBlockGemm<T> && SpecifiesTileOptimizations<T>;
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
template <typename T>
consteval bool IsXdlV3Algorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesBlockGemm<T>;
}
concept IsXdlV3Algorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply)
template <typename T>
consteval bool IsXdlAlgorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesNumPrefetchStages<T> && SpecifiesNumGroupsToMerge<T> &&
SpecifiesLoopScheduler<T>;
}
concept IsXdlAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions)
template <typename T>
consteval bool IsWmmaAlgorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
}
concept IsWmmaAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts
template <typename T>
consteval bool IsDlAlgorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
}
concept IsDlAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
// XDL-based kernel with large tensor support
template <typename T>
consteval bool IsLargeTensorAlgorithm()
{
return IsXdlAlgorithm<decltype(T::base_algorithm)>() && SpecifiesLargeTensorSupport<T>;
}
concept IsLargeTensorAlgorithm =
IsXdlAlgorithm<decltype(T::base_algorithm)> && SpecifiesLargeTensorSupport<T>;
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
@@ -150,29 +133,29 @@ constexpr auto make_conv_instance()
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
// CK Tile supports common factory for each direction
if constexpr(IsTileAlgorithm<AlgoType>())
if constexpr(IsTileAlgorithm<AlgoType>)
{
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
if constexpr(IsXdlV3Algorithm<AlgoType>())
if constexpr(IsXdlV3Algorithm<AlgoType>)
{
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsXdlAlgorithm<AlgoType>())
else if constexpr(IsXdlAlgorithm<AlgoType>)
{
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsWmmaAlgorithm<AlgoType>())
else if constexpr(IsWmmaAlgorithm<AlgoType>)
{
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsDlAlgorithm<AlgoType>())
else if constexpr(IsDlAlgorithm<AlgoType>)
{
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsLargeTensorAlgorithm<AlgoType>())
else if constexpr(IsLargeTensorAlgorithm<AlgoType>)
{
return typename ConvFwdLargeTensorFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}