mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[CK_BUILDER] CK Tile header installation for builder, algorithm concept improvements (#3419)
* Added install of CK_Tile headers when using CK_EXPERIMENTAL_BUILDER. MIOpen needs this since the builder uses features from CK Tile and the CK Tile install is excluded when doing a narrow build for MIOpen * Changed algorithm concept type checks to be concepts instead of constexpr bool functions. This improves compiler error messages when using these concepts in static_asserts --------- Co-authored-by: Daryl Hawkins <DarylHawkins@amd.com>
This commit is contained in:
@@ -84,63 +84,46 @@ namespace ck_tile::builder::factory {
|
||||
|
||||
// CK Tile kernel
|
||||
template <typename T>
|
||||
consteval bool IsTileAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> && SpecifiesTileTransfer<T> &&
|
||||
SpecifiesTileConvSpecialization<T> && SpecifiesTileBlockGemm<T> &&
|
||||
SpecifiesTileOptimizations<T>;
|
||||
}
|
||||
concept IsTileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> &&
|
||||
SpecifiesTileTransfer<T> && SpecifiesTileConvSpecialization<T> &&
|
||||
SpecifiesTileBlockGemm<T> && SpecifiesTileOptimizations<T>;
|
||||
|
||||
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
|
||||
template <typename T>
|
||||
consteval bool IsXdlV3Algorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesBlockGemm<T>;
|
||||
}
|
||||
concept IsXdlV3Algorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
|
||||
|
||||
// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply)
|
||||
template <typename T>
|
||||
consteval bool IsXdlAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesNumPrefetchStages<T> && SpecifiesNumGroupsToMerge<T> &&
|
||||
SpecifiesLoopScheduler<T>;
|
||||
}
|
||||
concept IsXdlAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
|
||||
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions)
|
||||
template <typename T>
|
||||
consteval bool IsWmmaAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
}
|
||||
concept IsWmmaAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts
|
||||
template <typename T>
|
||||
consteval bool IsDlAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
}
|
||||
concept IsDlAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
|
||||
// XDL-based kernel with large tensor support
|
||||
template <typename T>
|
||||
consteval bool IsLargeTensorAlgorithm()
|
||||
{
|
||||
return IsXdlAlgorithm<decltype(T::base_algorithm)>() && SpecifiesLargeTensorSupport<T>;
|
||||
}
|
||||
concept IsLargeTensorAlgorithm =
|
||||
IsXdlAlgorithm<decltype(T::base_algorithm)> && SpecifiesLargeTensorSupport<T>;
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
@@ -150,29 +133,29 @@ constexpr auto make_conv_instance()
|
||||
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
|
||||
|
||||
// CK Tile supports common factory for each direction
|
||||
if constexpr(IsTileAlgorithm<AlgoType>())
|
||||
if constexpr(IsTileAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
if constexpr(IsXdlV3Algorithm<AlgoType>())
|
||||
if constexpr(IsXdlV3Algorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsXdlAlgorithm<AlgoType>())
|
||||
else if constexpr(IsXdlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsWmmaAlgorithm<AlgoType>())
|
||||
else if constexpr(IsWmmaAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsDlAlgorithm<AlgoType>())
|
||||
else if constexpr(IsDlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsLargeTensorAlgorithm<AlgoType>())
|
||||
else if constexpr(IsLargeTensorAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdLargeTensorFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user