// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include #include "ck_tile/builder/testing/testing.hpp" #include #include #include #include #include #include #include #include /// @brief ostream-overload for hipError /// /// Google Test likes to print errors to ostream, and this provides integration /// with that. Since we only expect to use this with CK-Builder's own tests, /// providing this implementation seems not problematic, but if it starts to /// clash with another implementation then we will need to provide this /// implementation another way. Unfortunately Google Test does not have a /// dedicated function to override to provide printing support. std::ostream& operator<<(std::ostream& os, hipError_t status); namespace ck_tile::builder::test { template std::ostream& operator<<(std::ostream& os, [[maybe_unused]] Outputs outputs) { return os << ""; } } // namespace ck_tile::builder::test namespace ck_tile::test { static bool isTerminalOutput() { return isatty(fileno(stdout)) || isatty(fileno(stderr)); } // Returns a string highlighting differences between actual and expected. // Differences are enclosed in brackets with actual and expected parts separated by '|'. std::string inlineDiff(const std::string& actual, const std::string& expected, bool use_color = isTerminalOutput()); // A convenience alias for inlineDiff to improve readability in test assertions. // Note that the function has O(n^2) complexity both in compute and in memory - do not use for very // long strings std::string formatInlineDiff(const std::string& actual, const std::string& expected); // Gmock matcher for string equality with inline diff output on failure class StringEqWithDiffMatcher : public ::testing::MatcherInterface { public: explicit StringEqWithDiffMatcher(const std::string& expected); bool MatchAndExplain(std::string actual, ::testing::MatchResultListener* listener) const override; void DescribeTo(std::ostream* os) const override; void DescribeNegationTo(std::ostream* os) const override; private: std::string expected_; }; // Factory function for the StringEqWithDiff matcher ::testing::Matcher StringEqWithDiff(const std::string& expected); using ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; // This utility concept checks whether a type is a valid "Device Operation" - // that is, there is a valid specialization of `DeviceOperationInstanceFactory` // for it available. template concept HasCkFactory = requires { { DeviceOperationInstanceFactory::GetInstances() } -> std::convertible_to>>; }; // This structure represents a (unique) set of instances, either a statically // defined one (for testing) or one obtained from DeviceOperationInstanceFactory. // The idea is that we use this structure as a utility to compare a set of // instances. Instances are stored in a set so that they can be lexicographically // compared, this helps generating readable error messages which just contain // the differenses between sets. struct InstanceSet { explicit InstanceSet() {} explicit InstanceSet(std::initializer_list items) : instances(items.begin(), items.end()) { } template static InstanceSet from_factory() { auto set = InstanceSet(); const auto ops = DeviceOperationInstanceFactory::GetInstances(); for(const auto& op : ops) { set.instances.insert(op->GetInstanceString()); } return set; } std::set instances; }; std::ostream& operator<<(std::ostream& os, const InstanceSet& set); // This is a custom Google Test matcher which can be used to compare two sets // of instance names, with utility functions that print a helpful error // message about the difference between the checked sets. Use `InstancesMatch` // to obtain an instance of this type. struct InstanceMatcher : public ::testing::MatcherInterface { explicit InstanceMatcher(const InstanceSet& expected); bool MatchAndExplain(InstanceSet actual, ::testing::MatchResultListener* listener) const override; void DescribeTo(std::ostream* os) const override; void DescribeNegationTo(std::ostream* os) const override; InstanceSet expected_; }; ::testing::Matcher InstancesMatch(const InstanceSet& expected); /// @brief Google Test hipError_t matcher. /// /// This is a custom Google Test matcher implementation which can be used to /// compare HIP status codes. Use `HipSuccess()` or `HipError()` to obtain /// an instance. /// /// @see HipSuccess /// @see HipError /// @see ::testing::MatcherInterface struct HipStatusMatcher : public ::testing::MatcherInterface { HipStatusMatcher(hipError_t expected) : expected_(expected) {} bool MatchAndExplain(hipError_t actual, ::testing::MatchResultListener* listener) const override; void DescribeTo(std::ostream* os) const override; void DescribeNegationTo(std::ostream* os) const override; hipError_t expected_; }; /// @brief Construct a Google Test matcher that checks that a HIP operation /// was successful. ::testing::Matcher HipSuccess(); /// @brief Construct a Google Test matcher that checks that a HIP operation /// returned a particular error code. /// /// @param error The error to expect. ::testing::Matcher HipError(hipError_t error); /// @brief RunResult matcher /// /// `ckt::run` returns a RunResult which indicates whether there was any /// problem while running the algorithm. This matcher is used to match those /// values. struct RunResultMatcher : public ::testing::MatcherInterface { bool MatchAndExplain(builder::test::RunResult actual, ::testing::MatchResultListener* listener) const override; void DescribeTo(std::ostream* os) const override; void DescribeNegationTo(std::ostream* os) const override; }; /// @brief Construct a Google Test matcher that checks that a ckt::run result /// was successful. ::testing::Matcher SuccessfulRun(); template struct ReferenceOutputMatcher : public ::testing::MatcherInterface> { ReferenceOutputMatcher(const builder::test::Args& args, builder::test::Outputs expected) : args_(&args), expected_(expected) { } bool MatchAndExplain(builder::test::Outputs actual, [[maybe_unused]] ::testing::MatchResultListener* listener) const override { const auto report = ck_tile::builder::test::validate(*args_, actual, expected_); const auto errors = report.get_errors(); if(listener->IsInterested() && !errors.empty()) { *listener << errors.size() << " tensors failed to validate"; for(const auto& e : errors) { *listener << "\n - " << e.tensor_name << ": "; if(e.is_all_zero()) *listener << "all elements in actual and expected tensors are zero"; else { // Round to 2 digits const float percentage = e.wrong_elements * 10000 / e.total_elements / 100.f; *listener << e.wrong_elements << "/" << e.total_elements << " incorrect elements (~" << percentage << "%)," << " max error " << e.max_error; } } } return errors.empty(); } void DescribeTo(std::ostream* os) const override { *os << ""; } void DescribeNegationTo(std::ostream* os) const override { *os << "isn't equal to "; } const builder::test::Args* args_; builder::test::Outputs expected_; }; template ::testing::Matcher> MatchesReference(const builder::test::Args& args, builder::test::Outputs expected) { return ::testing::MakeMatcher(new ReferenceOutputMatcher(args, expected)); } } // namespace ck_tile::test