mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 05:37:34 +00:00
Use amcro to ensure automatic macthing between concepts are their string representations.
This commit is contained in:
@@ -9,162 +9,223 @@ namespace ck_tile::builder::factory {
|
||||
|
||||
#define CHECK_MARK(cond) (cond ? "[✓]" : "[✗]")
|
||||
|
||||
// Macro to check a concept and generate both the boolean and the string representation
|
||||
#define CHECK_CONCEPT(Type, Concept) \
|
||||
static constexpr bool c_##Concept = Concept<Type>; \
|
||||
static constexpr const char* s_##Concept = #Concept;
|
||||
|
||||
// Helper to create diagnostic message line
|
||||
#define DIAGNOSTIC_LINE(Concept) \
|
||||
" " + std::string(s_##Concept) + ": " + std::string(CHECK_MARK(c_##Concept)) + "\n"
|
||||
|
||||
template <typename T>
|
||||
struct FwdXdlV3Algorithm {
|
||||
static constexpr bool c1 = ConvAlgorithmDescriptor<T>;
|
||||
static constexpr bool c2 = SpecifiesThreadBlock<T>;
|
||||
static constexpr bool c3 = SpecifiesBlockTransfer<T>;
|
||||
static constexpr bool c4 = SpecifiesLdsTransfer<T>;
|
||||
static constexpr bool c5 = SpecifiesThreadClusterAccessOrder<T>;
|
||||
static constexpr bool c6 = SpecifiesSourceAccessOrder<T>;
|
||||
static constexpr bool c7 = SpecifiesGridwiseFwdXdlGemm<T>;
|
||||
static constexpr bool c8 = SpecifiesFwdConvSpecialization<T>;
|
||||
static constexpr bool c9 = SpecifiesGemmSpecialization<T>;
|
||||
static constexpr bool c10 = SpecifiesBlockGemm<T>;
|
||||
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadBlock)
|
||||
CHECK_CONCEPT(T, SpecifiesBlockTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesLdsTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder)
|
||||
CHECK_CONCEPT(T, SpecifiesSourceAccessOrder)
|
||||
CHECK_CONCEPT(T, SpecifiesGridwiseFwdXdlGemm)
|
||||
CHECK_CONCEPT(T, SpecifiesFwdConvSpecialization)
|
||||
CHECK_CONCEPT(T, SpecifiesGemmSpecialization)
|
||||
CHECK_CONCEPT(T, SpecifiesBlockGemm)
|
||||
|
||||
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
|
||||
static constexpr bool c2 = c_SpecifiesThreadBlock;
|
||||
static constexpr bool c3 = c_SpecifiesBlockTransfer;
|
||||
static constexpr bool c4 = c_SpecifiesLdsTransfer;
|
||||
static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder;
|
||||
static constexpr bool c6 = c_SpecifiesSourceAccessOrder;
|
||||
static constexpr bool c7 = c_SpecifiesGridwiseFwdXdlGemm;
|
||||
static constexpr bool c8 = c_SpecifiesFwdConvSpecialization;
|
||||
static constexpr bool c9 = c_SpecifiesGemmSpecialization;
|
||||
static constexpr bool c10 = c_SpecifiesBlockGemm;
|
||||
|
||||
static consteval bool is_valid() {
|
||||
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10;
|
||||
}
|
||||
|
||||
static consteval const std::string message() {
|
||||
return "\n=== Forward XDL V3 Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for FwdXdlV3 Algorithm:\n"
|
||||
" ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n"
|
||||
" SpecifiesThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n"
|
||||
" SpecifiesBlockTransfer: " + std::string(CHECK_MARK(c3)) + "\n"
|
||||
" SpecifiesLdsTransfer: " + std::string(CHECK_MARK(c4)) + "\n"
|
||||
" SpecifiesThreadClusterAccessOrder: " + std::string(CHECK_MARK(c5)) + "\n"
|
||||
" SpecifiesSourceAccessOrder: " + std::string(CHECK_MARK(c6)) + "\n"
|
||||
" SpecifiesGridwiseFwdXdlGemm: " + std::string(CHECK_MARK(c7)) + "\n"
|
||||
" SpecifiesFwdConvSpecialization: " + std::string(CHECK_MARK(c8)) + "\n"
|
||||
" SpecifiesGemmSpecialization: " + std::string(CHECK_MARK(c9)) + "\n"
|
||||
" SpecifiesBlockGemm: " + std::string(CHECK_MARK(c10)) + "\n";
|
||||
return std::string("\n=== Forward XDL V3 Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for FwdXdlV3 Algorithm:\n") +
|
||||
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadBlock) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesLdsTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGridwiseFwdXdlGemm) +
|
||||
DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockGemm);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct FwdXdlAlgorithm {
|
||||
static constexpr bool c1 = ConvAlgorithmDescriptor<T>;
|
||||
static constexpr bool c2 = SpecifiesThreadBlock<T>;
|
||||
static constexpr bool c3 = SpecifiesBlockTransfer<T>;
|
||||
static constexpr bool c4 = SpecifiesLdsTransfer<T>;
|
||||
static constexpr bool c5 = SpecifiesThreadClusterAccessOrder<T>;
|
||||
static constexpr bool c6 = SpecifiesSourceAccessOrder<T>;
|
||||
static constexpr bool c7 = SpecifiesGridwiseFwdXdlGemm<T>;
|
||||
static constexpr bool c8 = SpecifiesFwdConvSpecialization<T>;
|
||||
static constexpr bool c9 = SpecifiesGemmSpecialization<T>;
|
||||
static constexpr bool c10 = SpecifiesNumPrefetchStages<T>;
|
||||
static constexpr bool c11 = SpecifiesNumGroupsToMerge<T>;
|
||||
static constexpr bool c12 = SpecifiesLoopScheduler<T>;
|
||||
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadBlock)
|
||||
CHECK_CONCEPT(T, SpecifiesBlockTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesLdsTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder)
|
||||
CHECK_CONCEPT(T, SpecifiesSourceAccessOrder)
|
||||
CHECK_CONCEPT(T, SpecifiesGridwiseFwdXdlGemm)
|
||||
CHECK_CONCEPT(T, SpecifiesFwdConvSpecialization)
|
||||
CHECK_CONCEPT(T, SpecifiesGemmSpecialization)
|
||||
CHECK_CONCEPT(T, SpecifiesNumPrefetchStages)
|
||||
CHECK_CONCEPT(T, SpecifiesNumGroupsToMerge)
|
||||
CHECK_CONCEPT(T, SpecifiesLoopScheduler)
|
||||
|
||||
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
|
||||
static constexpr bool c2 = c_SpecifiesThreadBlock;
|
||||
static constexpr bool c3 = c_SpecifiesBlockTransfer;
|
||||
static constexpr bool c4 = c_SpecifiesLdsTransfer;
|
||||
static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder;
|
||||
static constexpr bool c6 = c_SpecifiesSourceAccessOrder;
|
||||
static constexpr bool c7 = c_SpecifiesGridwiseFwdXdlGemm;
|
||||
static constexpr bool c8 = c_SpecifiesFwdConvSpecialization;
|
||||
static constexpr bool c9 = c_SpecifiesGemmSpecialization;
|
||||
static constexpr bool c10 = c_SpecifiesNumPrefetchStages;
|
||||
static constexpr bool c11 = c_SpecifiesNumGroupsToMerge;
|
||||
static constexpr bool c12 = c_SpecifiesLoopScheduler;
|
||||
|
||||
static consteval bool is_valid() {
|
||||
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11 && c12;
|
||||
}
|
||||
|
||||
static consteval const std::string message() {
|
||||
return "\n=== Forward XDL Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for FwdXdl Algorithm:\n"
|
||||
" ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n"
|
||||
" SpecifiesThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n"
|
||||
" SpecifiesBlockTransfer: " + std::string(CHECK_MARK(c3)) + "\n"
|
||||
" SpecifiesLdsTransfer: " + std::string(CHECK_MARK(c4)) + "\n"
|
||||
" SpecifiesThreadClusterAccessOrder: " + std::string(CHECK_MARK(c5)) + "\n"
|
||||
" SpecifiesSourceAccessOrder: " + std::string(CHECK_MARK(c6)) + "\n"
|
||||
" SpecifiesGridwiseFwdXdlGemm: " + std::string(CHECK_MARK(c7)) + "\n"
|
||||
" SpecifiesFwdConvSpecialization: " + std::string(CHECK_MARK(c8)) + "\n"
|
||||
" SpecifiesGemmSpecialization: " + std::string(CHECK_MARK(c9)) + "\n"
|
||||
" SpecifiesNumPrefetchStages: " + std::string(CHECK_MARK(c10)) + "\n"
|
||||
" SpecifiesNumGroupsToMerge: " + std::string(CHECK_MARK(c11)) + "\n"
|
||||
" SpecifiesLoopScheduler: " + std::string(CHECK_MARK(c12)) + "\n";
|
||||
return std::string("\n=== Forward XDL Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for FwdXdl Algorithm:\n") +
|
||||
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadBlock) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesLdsTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGridwiseFwdXdlGemm) +
|
||||
DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) +
|
||||
DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) +
|
||||
DIAGNOSTIC_LINE(SpecifiesNumGroupsToMerge) +
|
||||
DIAGNOSTIC_LINE(SpecifiesLoopScheduler);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct FwdWmmaAlgorithm {
|
||||
static constexpr bool c1 = ConvAlgorithmDescriptor<T>;
|
||||
static constexpr bool c2 = SpecifiesThreadBlock<T>;
|
||||
static constexpr bool c3 = SpecifiesBlockTransfer<T>;
|
||||
static constexpr bool c4 = SpecifiesLdsTransfer<T>;
|
||||
static constexpr bool c5 = SpecifiesThreadClusterAccessOrder<T>;
|
||||
static constexpr bool c6 = SpecifiesSourceAccessOrder<T>;
|
||||
static constexpr bool c7 = SpecifiesGridwiseWmmaGemm<T>;
|
||||
static constexpr bool c8 = SpecifiesFwdConvSpecialization<T>;
|
||||
static constexpr bool c9 = SpecifiesGemmSpecialization<T>;
|
||||
static constexpr bool c10 = SpecifiesNumPrefetchStages<T>;
|
||||
static constexpr bool c11 = SpecifiesLoopScheduler<T>;
|
||||
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadBlock)
|
||||
CHECK_CONCEPT(T, SpecifiesBlockTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesLdsTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder)
|
||||
CHECK_CONCEPT(T, SpecifiesSourceAccessOrder)
|
||||
CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm)
|
||||
CHECK_CONCEPT(T, SpecifiesFwdConvSpecialization)
|
||||
CHECK_CONCEPT(T, SpecifiesGemmSpecialization)
|
||||
CHECK_CONCEPT(T, SpecifiesNumPrefetchStages)
|
||||
CHECK_CONCEPT(T, SpecifiesLoopScheduler)
|
||||
|
||||
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
|
||||
static constexpr bool c2 = c_SpecifiesThreadBlock;
|
||||
static constexpr bool c3 = c_SpecifiesBlockTransfer;
|
||||
static constexpr bool c4 = c_SpecifiesLdsTransfer;
|
||||
static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder;
|
||||
static constexpr bool c6 = c_SpecifiesSourceAccessOrder;
|
||||
static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm;
|
||||
static constexpr bool c8 = c_SpecifiesFwdConvSpecialization;
|
||||
static constexpr bool c9 = c_SpecifiesGemmSpecialization;
|
||||
static constexpr bool c10 = c_SpecifiesNumPrefetchStages;
|
||||
static constexpr bool c11 = c_SpecifiesLoopScheduler;
|
||||
|
||||
static consteval bool is_valid() {
|
||||
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11;
|
||||
}
|
||||
|
||||
static consteval const std::string message() {
|
||||
return "\n=== Forward WMMA Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for FwdWmma Algorithm:\n"
|
||||
" ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n"
|
||||
" SpecifiesThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n"
|
||||
" SpecifiesBlockTransfer: " + std::string(CHECK_MARK(c3)) + "\n"
|
||||
" SpecifiesLdsTransfer: " + std::string(CHECK_MARK(c4)) + "\n"
|
||||
" SpecifiesThreadClusterAccessOrder: " + std::string(CHECK_MARK(c5)) + "\n"
|
||||
" SpecifiesSourceAccessOrder: " + std::string(CHECK_MARK(c6)) + "\n"
|
||||
" SpecifiesGridwiseWmmaGemm: " + std::string(CHECK_MARK(c7)) + "\n"
|
||||
" SpecifiesFwdConvSpecialization: " + std::string(CHECK_MARK(c8)) + "\n"
|
||||
" SpecifiesGemmSpecialization: " + std::string(CHECK_MARK(c9)) + "\n"
|
||||
" SpecifiesNumPrefetchStages: " + std::string(CHECK_MARK(c10)) + "\n"
|
||||
" SpecifiesLoopScheduler: " + std::string(CHECK_MARK(c11)) + "\n";
|
||||
return std::string("\n=== Forward WMMA Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for FwdWmma Algorithm:\n") +
|
||||
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadBlock) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesLdsTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) +
|
||||
DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) +
|
||||
DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) +
|
||||
DIAGNOSTIC_LINE(SpecifiesLoopScheduler);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct FwdDlAlgorithm {
|
||||
static constexpr bool c1 = ConvAlgorithmDescriptor<T>;
|
||||
static constexpr bool c2 = SpecifiesThreadBlock<T>;
|
||||
static constexpr bool c3 = SpecifiesFwdConvSpecialization<T>;
|
||||
static constexpr bool c4 = SpecifiesGemmSpecialization<T>;
|
||||
static constexpr bool c5 = SpecifiesDlThreadConfig<T>;
|
||||
static constexpr bool c6 = SpecifiesDlThreadCluster<T>;
|
||||
static constexpr bool c7 = SpecifiesDlBlockTransfer<T>;
|
||||
static constexpr bool c8 = SpecifiesDlEpilogue<T>;
|
||||
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadBlock)
|
||||
CHECK_CONCEPT(T, SpecifiesFwdConvSpecialization)
|
||||
CHECK_CONCEPT(T, SpecifiesGemmSpecialization)
|
||||
CHECK_CONCEPT(T, SpecifiesDlThreadConfig)
|
||||
CHECK_CONCEPT(T, SpecifiesDlThreadCluster)
|
||||
CHECK_CONCEPT(T, SpecifiesDlBlockTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesDlEpilogue)
|
||||
|
||||
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
|
||||
static constexpr bool c2 = c_SpecifiesThreadBlock;
|
||||
static constexpr bool c3 = c_SpecifiesFwdConvSpecialization;
|
||||
static constexpr bool c4 = c_SpecifiesGemmSpecialization;
|
||||
static constexpr bool c5 = c_SpecifiesDlThreadConfig;
|
||||
static constexpr bool c6 = c_SpecifiesDlThreadCluster;
|
||||
static constexpr bool c7 = c_SpecifiesDlBlockTransfer;
|
||||
static constexpr bool c8 = c_SpecifiesDlEpilogue;
|
||||
|
||||
static consteval bool is_valid() {
|
||||
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8;
|
||||
}
|
||||
|
||||
static consteval const std::string message() {
|
||||
return "\n=== Forward DL Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for FwdDl Algorithm:\n"
|
||||
" ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n"
|
||||
" SpecifiesThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n"
|
||||
" SpecifiesFwdConvSpecialization: " + std::string(CHECK_MARK(c3)) + "\n"
|
||||
" SpecifiesGemmSpecialization: " + std::string(CHECK_MARK(c4)) + "\n"
|
||||
" SpecifiesDlThreadConfig: " + std::string(CHECK_MARK(c5)) + "\n"
|
||||
" SpecifiesDlThreadCluster: " + std::string(CHECK_MARK(c6)) + "\n"
|
||||
" SpecifiesDlBlockTransfer: " + std::string(CHECK_MARK(c7)) + "\n"
|
||||
" SpecifiesDlEpilogue: " + std::string(CHECK_MARK(c8)) + "\n";
|
||||
return std::string("\n=== Forward DL Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for FwdDl Algorithm:\n") +
|
||||
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadBlock) +
|
||||
DIAGNOSTIC_LINE(SpecifiesFwdConvSpecialization) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGemmSpecialization) +
|
||||
DIAGNOSTIC_LINE(SpecifiesDlThreadConfig) +
|
||||
DIAGNOSTIC_LINE(SpecifiesDlThreadCluster) +
|
||||
DIAGNOSTIC_LINE(SpecifiesDlBlockTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesDlEpilogue);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TileAlgorithm {
|
||||
static constexpr bool c1 = ConvAlgorithmDescriptor<T>;
|
||||
static constexpr bool c2 = SpecifiesTileThreadBlock<T>;
|
||||
static constexpr bool c3 = SpecifiesTileTransfer<T>;
|
||||
static constexpr bool c4 = SpecifiesTileConvSpecialization<T>;
|
||||
static constexpr bool c5 = SpecifiesTileBlockGemm<T>;
|
||||
static constexpr bool c6 = SpecifiesTileOptimizations<T>;
|
||||
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
|
||||
CHECK_CONCEPT(T, SpecifiesTileThreadBlock)
|
||||
CHECK_CONCEPT(T, SpecifiesTileTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesTileConvSpecialization)
|
||||
CHECK_CONCEPT(T, SpecifiesTileBlockGemm)
|
||||
CHECK_CONCEPT(T, SpecifiesTileOptimizations)
|
||||
|
||||
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
|
||||
static constexpr bool c2 = c_SpecifiesTileThreadBlock;
|
||||
static constexpr bool c3 = c_SpecifiesTileTransfer;
|
||||
static constexpr bool c4 = c_SpecifiesTileConvSpecialization;
|
||||
static constexpr bool c5 = c_SpecifiesTileBlockGemm;
|
||||
static constexpr bool c6 = c_SpecifiesTileOptimizations;
|
||||
|
||||
static consteval bool is_valid() {
|
||||
return c1 && c2 && c3 && c4 && c5 && c6;
|
||||
}
|
||||
|
||||
static consteval const std::string message() {
|
||||
return "\n=== CK Tile Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for CK Tile Conv Algorithm:\n"
|
||||
" ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n"
|
||||
" SpecifiesTileThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n"
|
||||
" SpecifiesTileTransfer: " + std::string(CHECK_MARK(c3)) + "\n"
|
||||
" SpecifiesTileConvSpecialization: " + std::string(CHECK_MARK(c4)) + "\n"
|
||||
" SpecifiesTileBlockGemm: " + std::string(CHECK_MARK(c5)) + "\n"
|
||||
" SpecifiesTileOptimizations: " + std::string(CHECK_MARK(c6)) + "\n";
|
||||
return std::string("\n=== CK Tile Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for CK Tile Conv Algorithm:\n") +
|
||||
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
|
||||
DIAGNOSTIC_LINE(SpecifiesTileThreadBlock) +
|
||||
DIAGNOSTIC_LINE(SpecifiesTileTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesTileConvSpecialization) +
|
||||
DIAGNOSTIC_LINE(SpecifiesTileBlockGemm) +
|
||||
DIAGNOSTIC_LINE(SpecifiesTileOptimizations);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -172,7 +233,9 @@ template <typename T>
|
||||
struct LargeTensorAlgorithm : public FwdXdlAlgorithm<decltype(T::base_algorithm)>
|
||||
{
|
||||
using BaseAlgorithmType = decltype(T::base_algorithm);
|
||||
static constexpr bool c13 = SpecifiesLargeTensorSupport<T>;
|
||||
CHECK_CONCEPT(T, SpecifiesLargeTensorSupport)
|
||||
|
||||
static constexpr bool c13 = c_SpecifiesLargeTensorSupport;
|
||||
|
||||
static consteval bool is_valid() {
|
||||
return FwdXdlAlgorithm<BaseAlgorithmType>::is_valid() && c13;
|
||||
@@ -180,38 +243,48 @@ struct LargeTensorAlgorithm : public FwdXdlAlgorithm<decltype(T::base_algorithm)
|
||||
|
||||
static consteval const std::string message() {
|
||||
return FwdXdlAlgorithm<BaseAlgorithmType>::message() +
|
||||
" SpecifiesLargeTensorSupport: " + std::string(CHECK_MARK(c13)) + "\n";
|
||||
DIAGNOSTIC_LINE(SpecifiesLargeTensorSupport);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BwdXdlAlgorithm {
|
||||
static constexpr bool c1 = ConvAlgorithmDescriptor<T>;
|
||||
static constexpr bool c2 = SpecifiesThreadBlock<T>;
|
||||
static constexpr bool c3 = SpecifiesBlockTransfer<T>;
|
||||
static constexpr bool c4 = SpecifiesLdsTransfer<T>;
|
||||
static constexpr bool c5 = SpecifiesThreadClusterAccessOrder<T>;
|
||||
static constexpr bool c6 = SpecifiesSourceAccessOrder<T>;
|
||||
static constexpr bool c7 = SpecifiesGridwiseBwdXdlGemm<T>;
|
||||
static constexpr bool c8 = SpecifiesBwdWeightConvSpecialization<T>;
|
||||
static constexpr bool c9 = SpecifiesTransposeTransfer<T>;
|
||||
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadBlock)
|
||||
CHECK_CONCEPT(T, SpecifiesBlockTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesLdsTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder)
|
||||
CHECK_CONCEPT(T, SpecifiesSourceAccessOrder)
|
||||
CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm)
|
||||
CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization)
|
||||
CHECK_CONCEPT(T, SpecifiesTransposeTransfer)
|
||||
|
||||
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
|
||||
static constexpr bool c2 = c_SpecifiesThreadBlock;
|
||||
static constexpr bool c3 = c_SpecifiesBlockTransfer;
|
||||
static constexpr bool c4 = c_SpecifiesLdsTransfer;
|
||||
static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder;
|
||||
static constexpr bool c6 = c_SpecifiesSourceAccessOrder;
|
||||
static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm;
|
||||
static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization;
|
||||
static constexpr bool c9 = c_SpecifiesTransposeTransfer;
|
||||
|
||||
static consteval bool is_valid() {
|
||||
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9;
|
||||
}
|
||||
|
||||
static consteval const std::string message() {
|
||||
return "\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for BwdXdl Algorithm:\n"
|
||||
" ConvAlgorithmDescriptor: " + std::string(CHECK_MARK(c1)) + "\n"
|
||||
" SpecifiesThreadBlock: " + std::string(CHECK_MARK(c2)) + "\n"
|
||||
" SpecifiesBlockTransfer: " + std::string(CHECK_MARK(c3)) + "\n"
|
||||
" SpecifiesLdsTransfer: " + std::string(CHECK_MARK(c4)) + "\n"
|
||||
" SpecifiesThreadClusterAccessOrder: " + std::string(CHECK_MARK(c5)) + "\n"
|
||||
" SpecifiesSourceAccessOrder: " + std::string(CHECK_MARK(c6)) + "\n"
|
||||
" SpecifiesGridwiseBwdXdlGemm: " + std::string(CHECK_MARK(c7)) + "\n"
|
||||
" SpecifiesBwdWeightConvSpecialization: " + std::string(CHECK_MARK(c8)) + "\n"
|
||||
" SpecifiesTransposeTransfer: " + std::string(CHECK_MARK(c9)) + "\n";
|
||||
return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for BwdXdl Algorithm:\n") +
|
||||
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadBlock) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesLdsTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) +
|
||||
DIAGNOSTIC_LINE(SpecifiesTransposeTransfer);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -315,4 +388,4 @@ consteval void diagnose_bwd_weight_algorithm_signature()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user