diff --git a/CMakeLists.txt b/CMakeLists.txt index acae1f5ece..eaed7d3509 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -766,6 +766,9 @@ if(CK_EXPERIMENTAL_BUILDER) ${PROJECT_SOURCE_DIR}/experimental/builder/include/ck_tile/builder DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile ) + + set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/) + rocm_install(DIRECTORY ${CK_TILE_SRC_FOLDER} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile) endif() set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 9a9c2235e0..99e7479e36 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -84,63 +84,46 @@ namespace ck_tile::builder::factory { // CK Tile kernel template -consteval bool IsTileAlgorithm() -{ - return ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && SpecifiesTileTransfer && - SpecifiesTileConvSpecialization && SpecifiesTileBlockGemm && - SpecifiesTileOptimizations; -} +concept IsTileAlgorithm = ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && + SpecifiesTileTransfer && SpecifiesTileConvSpecialization && + SpecifiesTileBlockGemm && SpecifiesTileOptimizations; // XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) template -consteval bool IsXdlV3Algorithm() -{ - return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && - SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesBlockGemm; -} +concept IsXdlV3Algorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesBlockGemm; // Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply) template -consteval bool IsXdlAlgorithm() -{ - return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && - SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && - SpecifiesLoopScheduler; -} +concept IsXdlAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && + SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; // WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions) template -consteval bool IsWmmaAlgorithm() -{ - return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && - SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; -} +concept IsWmmaAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; // Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts template -consteval bool IsDlAlgorithm() -{ - return ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && - SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; -} +concept IsDlAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && + SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; // XDL-based kernel with large tensor support template -consteval bool IsLargeTensorAlgorithm() -{ - return IsXdlAlgorithm() && SpecifiesLargeTensorSupport; -} +concept IsLargeTensorAlgorithm = + IsXdlAlgorithm && SpecifiesLargeTensorSupport; template ; // CK Tile supports common factory for each direction - if constexpr(IsTileAlgorithm()) + if constexpr(IsTileAlgorithm) { return typename ConvTileFactory::Instance{}; } else if constexpr(ConvDirectionIsForward) { - if constexpr(IsXdlV3Algorithm()) + if constexpr(IsXdlV3Algorithm) { return typename ConvFwdXdlV3Factory::Instance{}; } - else if constexpr(IsXdlAlgorithm()) + else if constexpr(IsXdlAlgorithm) { return typename ConvFwdXdlFactory::Instance{}; } - else if constexpr(IsWmmaAlgorithm()) + else if constexpr(IsWmmaAlgorithm) { return typename ConvFwdWmmaFactory::Instance{}; } - else if constexpr(IsDlAlgorithm()) + else if constexpr(IsDlAlgorithm) { return typename ConvFwdDlFactory::Instance{}; } - else if constexpr(IsLargeTensorAlgorithm()) + else if constexpr(IsLargeTensorAlgorithm) { return typename ConvFwdLargeTensorFactory::Instance{}; }