diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index 90e42528e1..01333833dd 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -40,6 +40,8 @@ consteval std::string_view type_name() return "fp16"; else if constexpr(std::is_same_v) return "fp32"; + else if constexpr(std::is_same_v) + return "tf32"; else if constexpr(std::is_same_v) return "fp64"; else if constexpr(std::is_same_v) @@ -60,40 +62,25 @@ consteval std::string_view type_name() template constexpr std::string_view layout_name() { - if constexpr(requires { + if constexpr(std::is_base_of_v && requires { { T::name } -> std::convertible_to; }) return T::name; else - static_assert(false, "layout type is missing name attribute"); + static_assert(false, + "Layout type must derive from BaseTensorLayout and have name attribute"); } // Convert element-wise operation types to string names template constexpr std::string_view elementwise_op_name() { - namespace element_wise = ck::tensor_operation::element_wise; - - if constexpr(std::is_same_v) - return "PassThrough"; - else if constexpr(std::is_same_v) - return "Scale"; - else if constexpr(std::is_same_v) - return "Bilinear"; - else if constexpr(std::is_same_v) - return "Add"; - else if constexpr(std::is_same_v) - return "AddRelu"; - else if constexpr(std::is_same_v) - return "Relu"; - else if constexpr(std::is_same_v) - return "BiasNormalizeInInferClamp"; - else if constexpr(std::is_same_v) - return "Clamp"; - else if constexpr(std::is_same_v) - return "AddClamp"; + if constexpr(requires { + { T::name } -> std::convertible_to; + }) + return T::name; else - static_assert(false, "unknown_op"); + static_assert(false, "Elementwise operation is missing name attribute"); } // Convert ConvolutionForwardSpecialization enum to string diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 04b63b7823..f77219d019 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -17,7 +17,10 @@ endfunction() add_ck_builder_test(test_conv_builder test_conv_builder.cpp - test_instance_traits.cpp) + test_instance_traits.cpp + testing_utils.cpp) add_ck_builder_test(test_get_instance_string test_get_instance_string.cpp) + +add_ck_builder_test(test_inline_diff test_inline_diff.cpp testing_utils.cpp) diff --git a/experimental/builder/test/test_inline_diff.cpp b/experimental/builder/test/test_inline_diff.cpp new file mode 100644 index 0000000000..41692fb40e --- /dev/null +++ b/experimental/builder/test/test_inline_diff.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "testing_utils.hpp" + +namespace ck_tile::builder { +namespace { + +TEST(InlineDiff, simpleColorDiff) +{ + std::string str1{"hello"}; + std::string str2{"hello"}; + std::string str3{"world"}; + + // some easy tests + // you can veryfy the ungodly strings are meaningful by running echo -e "" + EXPECT_THAT(test::inlineDiff(str1, str2, true), "hello"); + EXPECT_THAT(test::inlineDiff(str1, str3, true), + "[\x1B[36mwor\x1B[0m|\x1B[35mhel\x1B[0m]l[\x1B[36md\x1B[0m|\x1B[35mo\x1B[0m]"); +} + +TEST(InlineDiff, noColorDiff) +{ + std::string str1{"hello"}; + std::string str2{"hello"}; + std::string str3{"world"}; + + // some easy tests without color + EXPECT_THAT(test::inlineDiff(str1, str2, false), "hello"); + EXPECT_THAT(test::inlineDiff(str1, str3, false), "[wor|hel]l[d|o]"); +} + +TEST(InlineDiff, complexColorDiff) +{ + + // now something more interesting + std::string str4{"this part has changed, this part has been left out, this part, this part has " + "an extra letter"}; + std::string str5{ + "this part has degeahc, this part has, this part added, this part has ana extra letter"}; + + EXPECT_THAT( + test::inlineDiff(str5, str4, true), + "this part has [\x1B[36mchanged\x1B[0m|\x1B[35mdegeahc\x1B[0m], this part has[\x1B[36m " + "been left out\x1B[0m|\x1B[35m\x1B[0m], this part[\x1B[36m\x1B[0m|\x1B[35m added\x1B[0m], " + "this part has an[\x1B[36m\x1B[0m|\x1B[35ma\x1B[0m] extra letter"); +}; + +} // namespace +} // namespace ck_tile::builder diff --git a/experimental/builder/test/testing_utils.cpp b/experimental/builder/test/testing_utils.cpp new file mode 100644 index 0000000000..c99d56ef56 --- /dev/null +++ b/experimental/builder/test/testing_utils.cpp @@ -0,0 +1,219 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include "testing_utils.hpp" + +namespace ck_tile::test { + +namespace { + +} // namespace + +// Wagner-Fischer Algorithm for Computing Edit Distance and Inline Diff +// +// OUTPUT FORMAT: [expected|actual] for differences, plain text for matches +// Example: "hello world" vs "hello earth" → "hello [world|earth]" +// +// This function implements the Wagner-Fischer algorithm (1974), which is the classic +// dynamic programming solution for computing the minimum edit distance (Levenshtein distance) +// between two strings. The algorithm has O(n*m) time and space complexity. +// +// ALGORITHM OVERVIEW: +// 1. Build a 2D DP table where dp[i][j] represents the minimum edit distance +// between the first i characters of 'expected' and first j characters of 'actual' +// 2. Fill the table using the recurrence relation: +// dp[i][j] = min( +// dp[i-1][j] + 1, // deletion (remove char from expected) +// dp[i][j-1] + 1, // insertion (add char to expected) +// dp[i-1][j-1] + cost // substitution (cost=0 if chars match, 1 if different) +// ) +// 3. Backtrack through the table to reconstruct the optimal edit sequence +// +// REFERENCES: +// - Wagner, R. A.; Fischer, M. J. (1974). "The String-to-String Correction Problem" +// - Also known as: Levenshtein distance, edit distance, string alignment +// - Similar to sequence alignment algorithms used in bioinformatics (Needleman-Wunsch) +std::string inlineDiff(const std::string& actual, const std::string& expected, bool use_color) +{ + + const char* EXPECTED_COLOR = use_color ? "\033[36m" : ""; // Cyan + const char* ACTUAL_COLOR = use_color ? "\033[35m" : ""; // Magenta + const char* RESET = use_color ? "\033[0m" : ""; + + const size_t n = expected.length(); // Length of expected string + const size_t m = actual.length(); // Length of actual string + + // PHASE 1: Build the Dynamic Programming Table + // dp[i][j] = minimum edit distance between expected[0..i-1] and actual[0..j-1] + std::vector> dp(n + 1, std::vector(m + 1)); + + // Base cases: transforming empty string to/from prefixes + for(size_t i = 0; i <= n; ++i) + { + dp[i][0] = i; // Delete i characters from expected to get empty string + } + for(size_t j = 0; j <= m; ++j) + { + dp[0][j] = j; // Insert j characters to empty string to get actual[0..j-1] + } + + // Fill the DP table using the Wagner-Fischer recurrence relation + for(size_t i = 1; i <= n; ++i) + { + for(size_t j = 1; j <= m; ++j) + { + // Cost is 0 if characters match, 1 if they need substitution + int cost = (expected[i - 1] == actual[j - 1]) ? 0 : 1; + + // Choose the minimum cost operation: + dp[i][j] = std::min({ + dp[i - 1][j] + 1, // Deletion: remove expected[i-1] + dp[i][j - 1] + 1, // Insertion: add actual[j-1] + dp[i - 1][j - 1] + cost // Substitution/Match + }); + } + } + + // PHASE 2: Backtrack to Reconstruct the Optimal Edit Sequence + // We trace back from dp[n][m] to dp[0][0] to find which operations were used + std::vector operations; // 'M'atch, 'S'ubstitution, 'I'nsertion, 'D'eletion + std::vector> diff_chars; // Character pairs for each operation + + size_t i = n, j = m; // Start from bottom-right corner of DP table + while(i > 0 || j > 0) + { + // Determine which operation led to the current cell's value + int cost = (i > 0 && j > 0 && expected[i - 1] == actual[j - 1]) ? 0 : 1; + + // Check if we came from diagonal (substitution/match) + if(i > 0 && j > 0 && dp[i][j] == dp[i - 1][j - 1] + cost) + { + if(cost == 0) + { + operations.push_back('M'); // Characters match + diff_chars.push_back({expected[i - 1], actual[j - 1]}); + } + else + { + operations.push_back('S'); // Substitution needed + diff_chars.push_back({expected[i - 1], actual[j - 1]}); + } + --i; + --j; // Move diagonally up-left + } + // Check if we came from left (insertion) + else if(j > 0 && dp[i][j] == dp[i][j - 1] + 1) + { + operations.push_back('I'); // Insertion: actual has extra character + diff_chars.push_back({'\0', actual[j - 1]}); + --j; // Move left + } + // Must have come from above (deletion) + else if(i > 0 && dp[i][j] == dp[i - 1][j] + 1) + { + operations.push_back('D'); // Deletion: expected has extra character + diff_chars.push_back({expected[i - 1], '\0'}); + --i; // Move up + } + } + + // PHASE 3: Reverse and Build the Human-Readable Diff String + // Backtracking gives us operations in reverse order, so we reverse to get forward order + std::reverse(operations.begin(), operations.end()); + std::reverse(diff_chars.begin(), diff_chars.end()); + + // Build the final diff string with color highlighting + std::ostringstream diff; + std::string expected_diff, actual_diff; // Accumulate consecutive differences + bool in_diff = false; // Track whether we're inside a diff section + + for(size_t k = 0; k < operations.size(); ++k) + { + char op = operations[k]; + char exp_char = diff_chars[k].first; // Expected character ('\0' for insertions) + char act_char = diff_chars[k].second; // Actual character ('\0' for deletions) + + if(op == 'M') // Match - characters are identical + { + if(in_diff) + { + // Close the current diff section and output it + diff << "[" << EXPECTED_COLOR << expected_diff << RESET << "|" << ACTUAL_COLOR + << actual_diff << RESET << "]"; + expected_diff.clear(); + actual_diff.clear(); + in_diff = false; + } + diff << exp_char; // Output the matching character as-is + } + else // Difference (substitution, insertion, or deletion) + { + in_diff = true; + // Accumulate characters for the diff section + if(exp_char != '\0') + expected_diff += exp_char; // Add to expected side + if(act_char != '\0') + actual_diff += act_char; // Add to actual side + } + } + + // Close any remaining diff section at the end + if(in_diff) + { + diff << "[" << EXPECTED_COLOR << expected_diff << RESET << "|" << ACTUAL_COLOR + << actual_diff << RESET << "]"; + } + + return diff.str(); +} + +std::string formatInlineDiff(const std::string& actual, const std::string& expected) +{ + return std::string("Inline diff: \"") + inlineDiff(actual, expected) + "\""; +} + +// StringEqWithDiffMatcher implementation +StringEqWithDiffMatcher::StringEqWithDiffMatcher(const std::string& expected) : expected_(expected) +{ +} + +bool StringEqWithDiffMatcher::MatchAndExplain(std::string actual, + ::testing::MatchResultListener* listener) const +{ + if(actual == expected_) + { + return true; + } + + // On failure, provide detailed diff information + if(listener->IsInterested()) + { + *listener << "\n Diff: \"" << inlineDiff(actual, expected_) << "\""; + } + return false; +} + +void StringEqWithDiffMatcher::DescribeTo(std::ostream* os) const +{ + *os << "\"" << expected_ << "\""; +} + +void StringEqWithDiffMatcher::DescribeNegationTo(std::ostream* os) const +{ + *os << "is not equal to \"" << expected_ << "\""; +} + +// Factory function for the StringEqWithDiff matcher +::testing::Matcher StringEqWithDiff(const std::string& expected) +{ + return ::testing::MakeMatcher(new StringEqWithDiffMatcher(expected)); +} + +} // namespace ck_tile::test diff --git a/experimental/builder/test/testing_utils.hpp b/experimental/builder/test/testing_utils.hpp new file mode 100644 index 0000000000..3e8772a080 --- /dev/null +++ b/experimental/builder/test/testing_utils.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +namespace ck_tile::test { + +static bool isTerminalOutput() { return isatty(fileno(stdout)) || isatty(fileno(stderr)); } + +// Returns a string highlighting differences between actual and expected. +// Differences are enclosed in brackets with actual and expected parts separated by '|'. +std::string inlineDiff(const std::string& actual, + const std::string& expected, + bool use_color = isTerminalOutput()); + +// A convenience alias for inlineDiff to improve readability in test assertions. +// Note that the function has O(n^2) complexity both in compute and in memory - do not use for very +// long strings +std::string formatInlineDiff(const std::string& actual, const std::string& expected); + +// Gmock matcher for string equality with inline diff output on failure +class StringEqWithDiffMatcher : public ::testing::MatcherInterface +{ + public: + explicit StringEqWithDiffMatcher(const std::string& expected); + + bool MatchAndExplain(std::string actual, + ::testing::MatchResultListener* listener) const override; + + void DescribeTo(std::ostream* os) const override; + void DescribeNegationTo(std::ostream* os) const override; + + private: + std::string expected_; +}; + +// Factory function for the StringEqWithDiff matcher +::testing::Matcher StringEqWithDiff(const std::string& expected); + +} // namespace ck_tile::test diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index ea8ba4557e..c6f2db639c 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -349,6 +349,8 @@ CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q) struct PassThroughPack8 { + static constexpr const char* name = "PassThroughPack8"; + template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; @@ -388,6 +390,8 @@ struct PassThroughPack8 struct DequantPack8 { + static constexpr const char* name = "DequantPack8"; + template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x, const Z& z) const; @@ -403,6 +407,8 @@ struct DequantPack8 struct PassThroughPack2 { + static constexpr const char* name = "PassThroughPack2"; + template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; @@ -429,6 +435,8 @@ struct PassThroughPack2 struct PassThrough { + static constexpr const char* name = "PassThrough"; + template using raw_t = std::remove_cv_t>; @@ -465,6 +473,8 @@ struct PassThrough struct AddScale { + static constexpr const char* name = "AddScale"; + template CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const As&... as) const { @@ -482,6 +492,8 @@ struct AddScale struct MultiDMultiply { + static constexpr const char* name = "MultiDMultiply"; + template CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void { @@ -497,6 +509,8 @@ struct MultiDMultiply struct MultiDAdd { + static constexpr const char* name = "MultiDAdd"; + template CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void { @@ -512,6 +526,8 @@ struct MultiDAdd struct UnaryConvert { + static constexpr const char* name = "UnaryConvert"; + template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const { @@ -576,6 +592,8 @@ struct ConvertF8RNE struct Scale { + static constexpr const char* name = "Scale"; + CK_TILE_HOST_DEVICE Scale(float scale = 1.f) : scale_(scale) {} template @@ -623,6 +641,8 @@ struct Scale struct ScaleAndResetNaNToMinusInfinity { + static constexpr const char* name = "ScaleAndResetNaNToMinusInfinity"; + CK_TILE_HOST_DEVICE ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {} template @@ -639,6 +659,8 @@ struct ScaleAndResetNaNToMinusInfinity struct UnaryDivide { + static constexpr const char* name = "UnaryDivide"; + CK_TILE_HOST_DEVICE UnaryDivide(const int32_t divider = 1) : divider_(divider) {} template @@ -656,6 +678,8 @@ struct UnaryDivide struct UnarySquare { + static constexpr const char* name = "UnarySquare"; + template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const { @@ -673,6 +697,8 @@ struct UnarySquare struct UnaryAbs { + static constexpr const char* name = "UnaryAbs"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -687,6 +713,8 @@ struct UnaryAbs struct UnarySqrt { + static constexpr const char* name = "UnarySqrt"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -699,6 +727,8 @@ struct UnarySqrt struct Relu { + static constexpr const char* name = "Relu"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -725,6 +755,8 @@ struct Relu // gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function struct FastGelu { + static constexpr const char* name = "FastGelu"; + template CK_TILE_HOST void operator()(Y& y, const X& x) const; @@ -842,6 +874,8 @@ struct FastGelu struct FastGeluAsm { + static constexpr const char* name = "FastGeluAsm"; + template CK_TILE_HOST void operator()(Y& y, const X& x) const; @@ -943,6 +977,8 @@ struct FastGeluAsm // y = 0.5*x*(1+erf(x/sqrt(2))) struct Gelu { + static constexpr const char* name = "Gelu"; + template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; @@ -963,6 +999,8 @@ struct Gelu struct Sigmoid { + static constexpr const char* name = "Sigmoid"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -977,6 +1015,8 @@ struct Sigmoid struct Silu { + static constexpr const char* name = "Silu"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1066,6 +1106,8 @@ struct SiluAsm struct TanH { + static constexpr const char* name = "TanH"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1080,6 +1122,8 @@ struct TanH struct ACos { + static constexpr const char* name = "ACos"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1094,6 +1138,8 @@ struct ACos struct Neg { + static constexpr const char* name = "Neg"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1108,6 +1154,8 @@ struct Neg struct ATan { + static constexpr const char* name = "ATan"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1122,6 +1170,8 @@ struct ATan struct Sin { + static constexpr const char* name = "Sin"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1136,6 +1186,8 @@ struct Sin struct ASinH { + static constexpr const char* name = "ASinH"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1150,6 +1202,8 @@ struct ASinH struct Cos { + static constexpr const char* name = "Cos"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1164,6 +1218,8 @@ struct Cos struct ACosH { + static constexpr const char* name = "ACosH"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1178,6 +1234,8 @@ struct ACosH struct Tan { + static constexpr const char* name = "Tan"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1192,6 +1250,8 @@ struct Tan struct ATanH { + static constexpr const char* name = "ATanH"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1206,6 +1266,8 @@ struct ATanH struct SinH { + static constexpr const char* name = "SinH"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1220,6 +1282,8 @@ struct SinH struct Ceil { + static constexpr const char* name = "Ceil"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1234,6 +1298,8 @@ struct Ceil struct Exp { + static constexpr const char* name = "Exp"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1248,6 +1314,8 @@ struct Exp struct CosH { + static constexpr const char* name = "CosH"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1262,6 +1330,8 @@ struct CosH struct Floor { + static constexpr const char* name = "Floor"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1276,6 +1346,8 @@ struct Floor struct Log { + static constexpr const char* name = "Log"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1290,6 +1362,8 @@ struct Log struct ASin { + static constexpr const char* name = "ASin"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1304,6 +1378,8 @@ struct ASin struct Rcp { + static constexpr const char* name = "Rcp"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1318,6 +1394,8 @@ struct Rcp struct Swish { + static constexpr const char* name = "Swish"; + Swish(float beta = 1.0f) : beta_(beta) {} template @@ -1340,6 +1418,8 @@ struct Swish struct SoftRelu { + static constexpr const char* name = "SoftRelu"; + SoftRelu(float alpha = 1.f) : alpha_(alpha){}; template @@ -1358,6 +1438,8 @@ struct SoftRelu struct Power { + static constexpr const char* name = "Power"; + Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) : alpha_(alpha), beta_(beta), gamma_(gamma){}; @@ -1381,6 +1463,8 @@ struct Power struct ClippedRelu { + static constexpr const char* name = "ClippedRelu"; + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; template @@ -1400,6 +1484,8 @@ struct ClippedRelu struct LeakyRelu { + static constexpr const char* name = "LeakyRelu"; + LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; template @@ -1417,6 +1503,8 @@ struct LeakyRelu struct Elu { + static constexpr const char* name = "Elu"; + Elu(float alpha = 1.f) : alpha_(alpha){}; template @@ -1434,6 +1522,8 @@ struct Elu struct Logistic { + static constexpr const char* name = "Logistic"; + Logistic(float alpha = 1.f) : alpha_(alpha){}; template @@ -1452,6 +1542,8 @@ struct Logistic struct ConvInvscale { + static constexpr const char* name = "ConvInvscale"; + CK_TILE_HOST_DEVICE ConvInvscale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f) : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) @@ -1475,6 +1567,8 @@ struct ConvInvscale struct ConvScale { + static constexpr const char* name = "ConvScale"; + CK_TILE_HOST_DEVICE ConvScale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f) : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) @@ -1498,6 +1592,8 @@ struct ConvScale struct ConvScaleRelu { + static constexpr const char* name = "ConvScaleRelu"; + CK_TILE_HOST_DEVICE ConvScaleRelu(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f) : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) @@ -1524,6 +1620,8 @@ struct ConvScaleRelu template struct Cast { + static constexpr const char* name = "Cast"; + template CK_TILE_HOST_DEVICE void operator()(DstType& y, const SrcType& x) const {