Files
composable_kernel/experimental/builder/test/testing_utils.hpp
Robin Voetter aa35585fc7 [CK_BUILDER] conv bwd weight testing (#3618)
* ck-builder: restructure testing conv

In order to prepare for bwd of conv testing, this commit moves some
files and types around so that we can reuse ckt::Args for both forward
and backwards convolution.

* ck-builder: decouple fwd_ck.hpp and fwd_reference.hpp from fwd.hpp

This will allow us to more easily include fwd.hpp from backwards
definitions, which is required for initializing bwd values.

* ck-builder: fix layout of test_ckb_conv_bwd_weight_xdl_cshuffle_v3

Turns out that the supplied layout isn't actually supported...

* ck-builder: ck and reference conv integration for bwd weight

* ck-builder: ck bwd weight execution test

* ck-builder: ckt::run support for ck-tile bwd weight

* ck-builder: ck tile bwd weight execution test

* ck-builder: extra debug printing in MatchesReference

* ck-builder: make ckt::run return RunResult

This type is more convenient than std::tuple, as it will allow us to
use google test matchers with this in the future.

* ck-builder: RunResult matcher

Using EXPECT_THAT(..., SuccessfulRun()) will generate a check and a nice error
message about how and why running an algorithm failed.

* ck-builder: doc fixes

* ck-builder: add missing headers

[ROCm/composable_kernel commit: cc75948d1c]
2026-01-26 23:50:15 +01:00

240 lines
8.5 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <ck/library/tensor_operation_instance/device_operation_instance_factory.hpp>
#include "ck_tile/builder/testing/testing.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <string>
#include <sstream>
#include <iosfwd>
#include <set>
#include <vector>
#include <array>
/// @brief ostream-overload for hipError
///
/// Google Test likes to print errors to ostream, and this provides integration
/// with that. Since we only expect to use this with CK-Builder's own tests,
/// providing this implementation seems not problematic, but if it starts to
/// clash with another implementation then we will need to provide this
/// implementation another way. Unfortunately Google Test does not have a
/// dedicated function to override to provide printing support.
std::ostream& operator<<(std::ostream& os, hipError_t status);
namespace ck_tile::builder::test {
template <auto SIGNATURE>
std::ostream& operator<<(std::ostream& os, [[maybe_unused]] Outputs<SIGNATURE> outputs)
{
return os << "<tensor outputs>";
}
} // namespace ck_tile::builder::test
namespace ck_tile::test {
static bool isTerminalOutput() { return isatty(fileno(stdout)) || isatty(fileno(stderr)); }
// Returns a string highlighting differences between actual and expected.
// Differences are enclosed in brackets with actual and expected parts separated by '|'.
std::string inlineDiff(const std::string& actual,
const std::string& expected,
bool use_color = isTerminalOutput());
// A convenience alias for inlineDiff to improve readability in test assertions.
// Note that the function has O(n^2) complexity both in compute and in memory - do not use for very
// long strings
std::string formatInlineDiff(const std::string& actual, const std::string& expected);
// Gmock matcher for string equality with inline diff output on failure
class StringEqWithDiffMatcher : public ::testing::MatcherInterface<std::string>
{
public:
explicit StringEqWithDiffMatcher(const std::string& expected);
bool MatchAndExplain(std::string actual,
::testing::MatchResultListener* listener) const override;
void DescribeTo(std::ostream* os) const override;
void DescribeNegationTo(std::ostream* os) const override;
private:
std::string expected_;
};
// Factory function for the StringEqWithDiff matcher
::testing::Matcher<std::string> StringEqWithDiff(const std::string& expected);
using ck::tensor_operation::device::instance::DeviceOperationInstanceFactory;
// This utility concept checks whether a type is a valid "Device Operation" -
// that is, there is a valid specialization of `DeviceOperationInstanceFactory`
// for it available.
template <typename DeviceOp>
concept HasCkFactory = requires {
{
DeviceOperationInstanceFactory<DeviceOp>::GetInstances()
} -> std::convertible_to<std::vector<std::unique_ptr<DeviceOp>>>;
};
// This structure represents a (unique) set of instances, either a statically
// defined one (for testing) or one obtained from DeviceOperationInstanceFactory.
// The idea is that we use this structure as a utility to compare a set of
// instances. Instances are stored in a set so that they can be lexicographically
// compared, this helps generating readable error messages which just contain
// the differenses between sets.
struct InstanceSet
{
explicit InstanceSet() {}
explicit InstanceSet(std::initializer_list<const char*> items)
: instances(items.begin(), items.end())
{
}
template <HasCkFactory DeviceOp>
static InstanceSet from_factory()
{
auto set = InstanceSet();
const auto ops = DeviceOperationInstanceFactory<DeviceOp>::GetInstances();
for(const auto& op : ops)
{
set.instances.insert(op->GetInstanceString());
}
return set;
}
std::set<std::string> instances;
};
std::ostream& operator<<(std::ostream& os, const InstanceSet& set);
// This is a custom Google Test matcher which can be used to compare two sets
// of instance names, with utility functions that print a helpful error
// message about the difference between the checked sets. Use `InstancesMatch`
// to obtain an instance of this type.
struct InstanceMatcher : public ::testing::MatcherInterface<InstanceSet>
{
explicit InstanceMatcher(const InstanceSet& expected);
bool MatchAndExplain(InstanceSet actual,
::testing::MatchResultListener* listener) const override;
void DescribeTo(std::ostream* os) const override;
void DescribeNegationTo(std::ostream* os) const override;
InstanceSet expected_;
};
::testing::Matcher<InstanceSet> InstancesMatch(const InstanceSet& expected);
/// @brief Google Test hipError_t matcher.
///
/// This is a custom Google Test matcher implementation which can be used to
/// compare HIP status codes. Use `HipSuccess()` or `HipError()` to obtain
/// an instance.
///
/// @see HipSuccess
/// @see HipError
/// @see ::testing::MatcherInterface
struct HipStatusMatcher : public ::testing::MatcherInterface<hipError_t>
{
HipStatusMatcher(hipError_t expected) : expected_(expected) {}
bool MatchAndExplain(hipError_t actual,
::testing::MatchResultListener* listener) const override;
void DescribeTo(std::ostream* os) const override;
void DescribeNegationTo(std::ostream* os) const override;
hipError_t expected_;
};
/// @brief Construct a Google Test matcher that checks that a HIP operation
/// was successful.
::testing::Matcher<hipError_t> HipSuccess();
/// @brief Construct a Google Test matcher that checks that a HIP operation
/// returned a particular error code.
///
/// @param error The error to expect.
::testing::Matcher<hipError_t> HipError(hipError_t error);
/// @brief RunResult matcher
///
/// `ckt::run` returns a RunResult which indicates whether there was any
/// problem while running the algorithm. This matcher is used to match those
/// values.
struct RunResultMatcher : public ::testing::MatcherInterface<builder::test::RunResult>
{
bool MatchAndExplain(builder::test::RunResult actual,
::testing::MatchResultListener* listener) const override;
void DescribeTo(std::ostream* os) const override;
void DescribeNegationTo(std::ostream* os) const override;
};
/// @brief Construct a Google Test matcher that checks that a ckt::run result
/// was successful.
::testing::Matcher<builder::test::RunResult> SuccessfulRun();
template <auto SIGNATURE>
struct ReferenceOutputMatcher
: public ::testing::MatcherInterface<builder::test::Outputs<SIGNATURE>>
{
ReferenceOutputMatcher(const builder::test::Args<SIGNATURE>& args,
builder::test::Outputs<SIGNATURE> expected)
: args_(&args), expected_(expected)
{
}
bool MatchAndExplain(builder::test::Outputs<SIGNATURE> actual,
[[maybe_unused]] ::testing::MatchResultListener* listener) const override
{
const auto report = ck_tile::builder::test::validate(*args_, actual, expected_);
const auto errors = report.get_errors();
if(listener->IsInterested() && !errors.empty())
{
*listener << errors.size() << " tensors failed to validate";
for(const auto& e : errors)
{
*listener << "\n - " << e.tensor_name << ": ";
if(e.is_all_zero())
*listener << "all elements in actual and expected tensors are zero";
else
{
// Round to 2 digits
const float percentage = e.wrong_elements * 10000 / e.total_elements / 100.f;
*listener << e.wrong_elements << "/" << e.total_elements
<< " incorrect elements (~" << percentage << "%)";
}
}
}
return errors.empty();
}
void DescribeTo(std::ostream* os) const override { *os << "<tensor outputs>"; }
void DescribeNegationTo(std::ostream* os) const override
{
*os << "isn't equal to <tensor outputs>";
}
const builder::test::Args<SIGNATURE>* args_;
builder::test::Outputs<SIGNATURE> expected_;
};
template <auto SIGNATURE>
::testing::Matcher<builder::test::Outputs<SIGNATURE>>
MatchesReference(const builder::test::Args<SIGNATURE>& args,
builder::test::Outputs<SIGNATURE> expected)
{
return ::testing::MakeMatcher(new ReferenceOutputMatcher<SIGNATURE>(args, expected));
}
} // namespace ck_tile::test