diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index a7ec568b03..902c3b3579 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -23,19 +23,12 @@ namespace ck_tile::reflect::detail { -// Metaprogramming helper to convert ck::Sequence to constexpr std::array -template -struct SequenceToArray; - -template -struct SequenceToArray> -{ - static constexpr std::array value = {static_cast(Is)...}; -}; - -// Convert data types to string names +// Implementation detail for type name mapping +// This is the single source of truth for supported data types that +// returns an empty string to indicate an unsupported type. +namespace impl { template -consteval std::string_view type_name() +consteval std::string_view type_name_impl() { if constexpr(std::is_same_v) return "fp16"; @@ -56,22 +49,38 @@ consteval std::string_view type_name() else if constexpr(std::is_same_v) return "bf8"; else - static_assert(false, "unknown_type"); + return std::string_view{}; // Return empty for supported types +} +} // namespace impl + +// Convert data types to string names +// Fails at compile time for unsupported types +template +consteval std::string_view type_name() +{ + constexpr auto name = impl::type_name_impl(); + static_assert(!name.empty(), "Unsupported data type"); + return name; } -// Convert layout types to string names +// Concept that checks if a type is a valid data type +// Uses the impl directly to avoid triggering static_assert during concept evaluation template +concept IsDataType = !impl::type_name_impl().empty(); + +// Concept that checks valid layout types +template +concept IsLayoutType = (std::is_base_of_v || + std::is_base_of_v) && + requires { + { T::name } -> std::convertible_to; + }; + +// Convert layout types to string names +template constexpr std::string_view layout_name() { - if constexpr((std::is_base_of_v || - std::is_base_of_v) && - requires { - { T::name } -> std::convertible_to; - }) - return T::name; - else - static_assert(false, - "Layout type must derive from BaseTensorLayout and have name attribute"); + return T::name; } // Convert element-wise operation types to string names @@ -90,64 +99,64 @@ constexpr std::string_view elementwise_op_name() constexpr std::string_view conv_fwd_spec_name(ck::tensor_operation::device::ConvolutionForwardSpecialization spec) { - using ck::tensor_operation::device::ConvolutionForwardSpecialization; + using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; switch(spec) { - case ConvolutionForwardSpecialization::Default: return "Default"; - case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; - case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; - case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3"; - case ConvolutionForwardSpecialization::OddC: return "OddC"; + case Default: return "Default"; + case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; + case Filter1x1Pad0: return "Filter1x1Pad0"; + case Filter3x3: return "Filter3x3"; + case OddC: return "OddC"; } } // Convert GemmSpecialization enum to string constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec) { - using ck::tensor_operation::device::GemmSpecialization; + using enum ck::tensor_operation::device::GemmSpecialization; switch(spec) { - case GemmSpecialization::Default: return "Default"; - case GemmSpecialization::MPadding: return "MPadding"; - case GemmSpecialization::NPadding: return "NPadding"; - case GemmSpecialization::KPadding: return "KPadding"; - case GemmSpecialization::MNPadding: return "MNPadding"; - case GemmSpecialization::MKPadding: return "MKPadding"; - case GemmSpecialization::NKPadding: return "NKPadding"; - case GemmSpecialization::MNKPadding: return "MNKPadding"; - case GemmSpecialization::OPadding: return "OPadding"; - case GemmSpecialization::MOPadding: return "MOPadding"; - case GemmSpecialization::NOPadding: return "NOPadding"; - case GemmSpecialization::KOPadding: return "KOPadding"; - case GemmSpecialization::MNOPadding: return "MNOPadding"; - case GemmSpecialization::MKOPadding: return "MKOPadding"; - case GemmSpecialization::NKOPadding: return "NKOPadding"; - case GemmSpecialization::MNKOPadding: return "MNKOPadding"; + case Default: return "Default"; + case MPadding: return "MPadding"; + case NPadding: return "NPadding"; + case KPadding: return "KPadding"; + case MNPadding: return "MNPadding"; + case MKPadding: return "MKPadding"; + case NKPadding: return "NKPadding"; + case MNKPadding: return "MNKPadding"; + case OPadding: return "OPadding"; + case MOPadding: return "MOPadding"; + case NOPadding: return "NOPadding"; + case KOPadding: return "KOPadding"; + case MNOPadding: return "MNOPadding"; + case MKOPadding: return "MKOPadding"; + case NKOPadding: return "NKOPadding"; + case MNKOPadding: return "MNKOPadding"; } } // Convert BlockGemmPipelineScheduler enum to string constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineScheduler sched) { - using ck::BlockGemmPipelineScheduler; + using enum ck::BlockGemmPipelineScheduler; switch(sched) { - case BlockGemmPipelineScheduler::Intrawave: return "Intrawave"; - case BlockGemmPipelineScheduler::Interwave: return "Interwave"; + case Intrawave: return "Intrawave"; + case Interwave: return "Interwave"; } } // Convert BlockGemmPipelineVersion enum to string constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver) { - using ck::BlockGemmPipelineVersion; + using enum ck::BlockGemmPipelineVersion; switch(ver) { - case BlockGemmPipelineVersion::v1: return "v1"; - case BlockGemmPipelineVersion::v2: return "v2"; - case BlockGemmPipelineVersion::v3: return "v3"; - case BlockGemmPipelineVersion::v4: return "v4"; - case BlockGemmPipelineVersion::v5: return "v5"; + case v1: return "v1"; + case v2: return "v2"; + case v3: return "v3"; + case v4: return "v4"; + case v5: return "v5"; } } @@ -167,12 +176,138 @@ inline std::string array_to_string(const std::array& arr) return oss.str(); } -// Handle ck::Tuple (empty tuple for DsLayout/DsDataType) -template -constexpr std::string_view tuple_name() +// Metaprogramming helper to convert ck::Sequence to constexpr std::array +template +struct SequenceToArray; + +template +struct SequenceToArray> { - // For now, just check if it's an empty tuple - return "EmptyTuple"; + static constexpr std::array value = {static_cast(Is)...}; +}; + +namespace detail { +// Generic helper to build list-like strings (Tuple, Seq, etc.) +// +// Example output: "Seq(1,2,3)" +// +// prefix: The list-like container name (e.g. "Tuple" or "Seq") +// converter_fn: A callable that converts each element to a string representation +// For types: converter_fn should be a template lambda like []() { return +// type_name(); } For values: converter_fn should be a regular lambda like [](auto value) { +// return std::to_string(value); } +template +constexpr std::string build_list_string(std::string_view prefix, const ConverterFn& converter_fn) +{ + if constexpr(sizeof...(Elements) == 0) + { + return std::string(prefix) + "()"; + } + else + { + std::string result = std::string(prefix) + "("; + std::size_t index = 0; + ((result += + (index++ > 0 ? "," : "") + std::string(converter_fn.template operator()())), + ...); + result += ")"; + return result; + } +} + +// Overload for value-based lists (sequences) +template +constexpr std::string build_list_string_values(std::string_view prefix, + const ConverterFn& converter_fn) +{ + if constexpr(sizeof...(Values) == 0) + { + return std::string(prefix) + "()"; + } + else + { + std::string result = std::string(prefix) + "("; + std::size_t index = 0; + ((result += (index++ > 0 ? "," : "") + converter_fn(Values)), ...); + result += ")"; + return result; + } +} +} // namespace detail + +// Convert ck::Sequence to string representation +// Converts a ck::Sequence type to a string in the format "Seq(v1,v2,...,vn)" +// where each value is converted using std::to_string. +// +// Template parameter: +// T: Must be a ck::Sequence<...> type +// +// Constraints: +// - Sequence elements must support std::to_string (typically ck::index_t) +// +// Examples: +// sequence_name>() returns "Seq()" +// sequence_name>() returns "Seq(42)" +// sequence_name>() returns "Seq(1,2,3)" +// sequence_name>() returns "Seq(256,128,64)" +template + requires requires { [](ck::Sequence*) {}(static_cast(nullptr)); } +constexpr std::string sequence_name() +{ + return [](ck::Sequence*) constexpr { + auto to_string_fn = [](auto value) { return std::to_string(value); }; + return detail::build_list_string_values("Seq", to_string_fn); + }(static_cast(nullptr)); +} + +// Convert ck::Tuple to string representation +// Converts a ck::Tuple type to a string in the format "Tuple(e1,e2,...,en)" +// where each element is converted based on its type (layout names or data type names). +// +// Template parameter: +// T: Must be a ck::Tuple<...> type +// +// Constraints: +// - Empty tuples are supported and return "EmptyTuple" +// - All tuple elements must be homogeneous: either all layouts (IsLayoutType) or all data types +// (IsDataType) +// - Mixed layouts and data types in the same tuple will cause a compile-time error +// +// Examples: +// tuple_name>() returns "EmptyTuple" +// tuple_name>() returns "Tuple(RowMajor)" +// tuple_name>() returns "Tuple(NCHW,NHWC)" +// tuple_name>() returns "Tuple(fp16)" +// tuple_name>() returns "Tuple(fp16,fp32,fp64)" +template + requires requires { [](ck::Tuple*) {}(static_cast(nullptr)); } +constexpr std::string tuple_name() +{ + return [](ck::Tuple*) constexpr { + if constexpr(sizeof...(Ts) == 0) + { + return std::string("EmptyTuple"); + } + else if constexpr((IsLayoutType && ...)) + { + // Lambda wrapper for layout_name + auto layout_name_fn = []() { return layout_name(); }; + return detail::build_list_string("Tuple", + layout_name_fn); + } + else if constexpr((IsDataType && ...)) + { + // Lambda wrapper for type_name + auto type_name_fn = []() { return type_name(); }; + return detail::build_list_string("Tuple", type_name_fn); + } + else + { + static_assert((IsLayoutType && ...) || (IsDataType && ...), + "Tuple elements must be all layouts or all data types, not mixed"); + return std::string{}; // unreachable + } + }(static_cast(nullptr)); } } // namespace ck_tile::reflect::detail diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index d44864938f..1dc508a0c3 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -16,20 +16,13 @@ function(add_ck_builder_test test_name) target_link_libraries(${test_name} PRIVATE GTest::gtest_main GTest::gmock) endfunction() +# The test_conv_builder target has all the unit tests (each test should run < 10 ms) add_ck_builder_test(test_conv_builder test_conv_builder.cpp test_instance_traits.cpp + test_instance_traits_util.cpp testing_utils.cpp) +# Testing the virtual GetInstanceString methods requires kernel compilation. add_ck_builder_test(test_get_instance_string test_get_instance_string.cpp) - -add_ck_builder_test(test_inline_diff test_inline_diff.cpp testing_utils.cpp) - -add_ck_builder_test(test_ckb_build_fwd_instances - conv/test_ckb_conv_fwd_2d_bf16.cpp - conv/test_ckb_conv_fwd_2d_fp16.cpp - conv/test_ckb_conv_fwd_2d_fp32.cpp - conv/test_ckb_conv_fwd_3d_bf16.cpp - conv/test_ckb_conv_fwd_3d_fp16.cpp - conv/test_ckb_conv_fwd_3d_fp32.cpp) \ No newline at end of file diff --git a/experimental/builder/test/test_instance_traits_util.cpp b/experimental/builder/test/test_instance_traits_util.cpp new file mode 100644 index 0000000000..4aa5ebf25e --- /dev/null +++ b/experimental/builder/test/test_instance_traits_util.cpp @@ -0,0 +1,263 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile::reflect::detail { +namespace { + +using ::testing::ElementsAre; +using ::testing::IsEmpty; + +TEST(InstanceTraitsUtil, SequenceToArrayReturnsEmptyArrayForEmptySequence) +{ + EXPECT_THAT(SequenceToArray>::value, IsEmpty()); +} + +TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithSingleElement) +{ + EXPECT_THAT(SequenceToArray>::value, ElementsAre(42)); +} + +TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithMultipleElements) +{ + EXPECT_THAT((SequenceToArray>::value), ElementsAre(1, 2, 3, 4, 5)); +} + +TEST(InstanceTraitsUtil, TypeNameReturnsCorrectStrings) +{ + EXPECT_THAT((std::vector{type_name(), + type_name(), + type_name(), + type_name(), + type_name(), + type_name(), + type_name(), + type_name()}), + ElementsAre("fp16", "fp32", "fp64", "s8", "s32", "bf16", "fp8", "bf8")); +} + +TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForGemmLayouts) +{ + namespace gemm = ck::tensor_layout::gemm; + EXPECT_THAT((std::vector{layout_name(), + layout_name(), + layout_name()}), + ElementsAre("RowMajor", "ColumnMajor", "MFMA")); +} + +TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForConvLayouts) +{ + namespace conv = ck::tensor_layout::convolution; + EXPECT_THAT((std::vector{ + // Input tensor layouts + // TODO(deprecated): Remove non-grouped layouts once instances are removed. + layout_name(), + layout_name(), + layout_name(), + layout_name(), + // Grouped input layouts + layout_name(), + layout_name(), + // Weight tensor layouts + layout_name(), + layout_name(), + layout_name(), + layout_name(), + // Output tensor layouts + layout_name(), + layout_name(), + layout_name(), + layout_name(), + // Strided layouts + // TODO(deprecated): Remove strided layouts once instances are removed. + layout_name(), + layout_name(), + layout_name(), + // Bias layouts + layout_name(), + layout_name()}), + ElementsAre("NCHW", + "NHWC", + "NCDHW", + "NDHWC", + "GNCHW", + "GNHWC", + "KCYX", + "KYXC", + "GKCYX", + "GKYXC", + "NKHW", + "NHWK", + "GNKHW", + "GNHWK", + "G_NHW_C", + "G_K_YX_C", + "G_NHW_K", + "G_C", + "G_K")); +} + +TEST(InstanceTraitsUtil, ElementwiseOpNameReturnsCorrectStrings) +{ + namespace element_wise = ck::tensor_operation::element_wise; + EXPECT_THAT((std::vector{ + elementwise_op_name(), + elementwise_op_name(), + elementwise_op_name(), + elementwise_op_name(), + elementwise_op_name(), + elementwise_op_name(), + elementwise_op_name(), + elementwise_op_name(), + elementwise_op_name()}), + ElementsAre("PassThrough", + "Scale", + "Bilinear", + "Add", + "AddRelu", + "Relu", + "BiasNormalizeInInferClamp", + "Clamp", + "AddClamp")); +} + +TEST(InstanceTraitsUtil, ConvFwdSpecNameReturnsCorrectStrings) +{ + using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; + EXPECT_THAT( + (std::vector{conv_fwd_spec_name(Default), + conv_fwd_spec_name(Filter1x1Stride1Pad0), + conv_fwd_spec_name(Filter1x1Pad0), + conv_fwd_spec_name(Filter3x3), + conv_fwd_spec_name(OddC)}), + ElementsAre("Default", "Filter1x1Stride1Pad0", "Filter1x1Pad0", "Filter3x3", "OddC")); +} + +TEST(InstanceTraitsUtil, GemmSpecNameReturnsCorrectStrings) +{ + using enum ck::tensor_operation::device::GemmSpecialization; + EXPECT_THAT((std::vector{gemm_spec_name(Default), + gemm_spec_name(MPadding), + gemm_spec_name(NPadding), + gemm_spec_name(KPadding), + gemm_spec_name(MNPadding), + gemm_spec_name(MKPadding), + gemm_spec_name(NKPadding), + gemm_spec_name(MNKPadding), + gemm_spec_name(OPadding), + gemm_spec_name(MOPadding), + gemm_spec_name(NOPadding), + gemm_spec_name(KOPadding), + gemm_spec_name(MNOPadding), + gemm_spec_name(MKOPadding), + gemm_spec_name(NKOPadding), + gemm_spec_name(MNKOPadding)}), + ElementsAre("Default", + "MPadding", + "NPadding", + "KPadding", + "MNPadding", + "MKPadding", + "NKPadding", + "MNKPadding", + "OPadding", + "MOPadding", + "NOPadding", + "KOPadding", + "MNOPadding", + "MKOPadding", + "NKOPadding", + "MNKOPadding")); +} + +TEST(InstanceTraitsUtil, PipelineSchedulerNameReturnsCorrectStrings) +{ + using enum ck::BlockGemmPipelineScheduler; + EXPECT_THAT((std::vector{pipeline_scheduler_name(Intrawave), + pipeline_scheduler_name(Interwave)}), + ElementsAre("Intrawave", "Interwave")); +} + +TEST(InstanceTraitsUtil, PipelineVersionNameReturnsCorrectStrings) +{ + using enum ck::BlockGemmPipelineVersion; + EXPECT_THAT((std::vector{pipeline_version_name(v1), + pipeline_version_name(v2), + pipeline_version_name(v3), + pipeline_version_name(v4), + pipeline_version_name(v5)}), + ElementsAre("v1", "v2", "v3", "v4", "v5")); +} + +TEST(InstanceTraitsUtil, TupleNameReturnsEmptyTupleForEmptyTuple) +{ + EXPECT_EQ(tuple_name>(), "EmptyTuple"); +} + +TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleLayout) +{ + EXPECT_EQ(tuple_name>(), "Tuple(NCHW)"); +} + +TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoLayouts) +{ + EXPECT_EQ((tuple_name>()), + "Tuple(NCHW,NHWC)"); +} + +TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeLayouts) +{ + EXPECT_EQ((tuple_name>()), + "Tuple(NCHW,NHWC,NKHW)"); +} + +TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleDataType) +{ + EXPECT_EQ(tuple_name>(), "Tuple(fp16)"); +} + +TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoDataTypes) +{ + EXPECT_EQ((tuple_name>()), "Tuple(fp16,fp32)"); +} + +TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeDataTypes) +{ + EXPECT_EQ((tuple_name>()), "Tuple(fp16,fp32,fp64)"); +} + +TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForEmptySequence) +{ + EXPECT_EQ(sequence_name>(), "Seq()"); +} + +TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForSingleValueSequence) +{ + EXPECT_EQ(sequence_name>(), "Seq(42)"); +} + +TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForTwoValueSequence) +{ + EXPECT_EQ((sequence_name>()), "Seq(1,2)"); +} + +TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForMultipleValueSequence) +{ + EXPECT_EQ((sequence_name>()), "Seq(256,128,64,32,16)"); +} + +} // namespace +} // namespace ck_tile::reflect::detail diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 61d249fc93..4954144aca 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -12,6 +12,8 @@ namespace element_wise { struct Add { + static constexpr const char* name = "Add"; + template __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; @@ -279,6 +281,8 @@ struct Subtract struct Bilinear { + static constexpr const char* name = "Bilinear"; + Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; template @@ -353,6 +357,8 @@ struct Bilinear struct AddClamp { + static constexpr const char* name = "AddClamp"; + AddClamp(float floor = 0.f, float ceil = NumericLimits::Max()) : floor_(floor), ceil_(ceil){}; @@ -442,6 +448,8 @@ struct AddClamp struct AddRelu { + static constexpr const char* name = "AddRelu"; + template __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 089d4c2a9d..5edcdd257b 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -565,6 +565,8 @@ struct NormalizeInInfer // used by Conv+Bias+BatchNorm+Clamp inference struct BiasNormalizeInInferClamp { + static constexpr const char* name = "BiasNormalizeInInferClamp"; + BiasNormalizeInInferClamp(float floor = 0.f, float ceil = NumericLimits::Max(), float epsilon = 1e-4) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 4643c0bcb3..59292b30e2 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -332,6 +332,8 @@ struct PassThroughPack2 struct PassThrough { + static constexpr const char* name = "PassThrough"; + template __host__ __device__ void operator()(Y& y, const X& x) const; @@ -552,8 +554,6 @@ struct PassThrough { y = type_convert(x); } - - static constexpr const char* name = "PassThrough"; }; struct UnaryConvert @@ -620,6 +620,8 @@ struct ConvertF8RNE struct Scale { + static constexpr const char* name = "Scale"; + __host__ __device__ Scale(float scale = 1.f) : scale_(scale) {} template @@ -783,6 +785,8 @@ struct UnarySqrt struct Clamp { + static constexpr const char* name = "Clamp"; + Clamp(float floor = 0.f, float ceil = NumericLimits::Max()) : floor_(floor), ceil_(ceil){}; @@ -856,6 +860,8 @@ struct Clamp struct Relu { + static constexpr const char* name = "Relu"; + template __host__ __device__ void operator()(T& y, const T& x) const {