mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_BUILDER] conv bwd weight testing (#3618)
* ck-builder: restructure testing conv In order to prepare for bwd of conv testing, this commit moves some files and types around so that we can reuse ckt::Args for both forward and backwards convolution. * ck-builder: decouple fwd_ck.hpp and fwd_reference.hpp from fwd.hpp This will allow us to more easily include fwd.hpp from backwards definitions, which is required for initializing bwd values. * ck-builder: fix layout of test_ckb_conv_bwd_weight_xdl_cshuffle_v3 Turns out that the supplied layout isn't actually supported... * ck-builder: ck and reference conv integration for bwd weight * ck-builder: ck bwd weight execution test * ck-builder: ckt::run support for ck-tile bwd weight * ck-builder: ck tile bwd weight execution test * ck-builder: extra debug printing in MatchesReference * ck-builder: make ckt::run return RunResult This type is more convenient than std::tuple, as it will allow us to use google test matchers with this in the future. * ck-builder: RunResult matcher Using EXPECT_THAT(..., SuccessfulRun()) will generate a check and a nice error message about how and why running an algorithm failed. * ck-builder: doc fixes * ck-builder: add missing headers
This commit is contained in:
@@ -161,6 +161,23 @@ struct HipStatusMatcher : public ::testing::MatcherInterface<hipError_t>
|
||||
/// @param error The error to expect.
|
||||
::testing::Matcher<hipError_t> HipError(hipError_t error);
|
||||
|
||||
/// @brief RunResult matcher
|
||||
///
|
||||
/// `ckt::run` returns a RunResult which indicates whether there was any
|
||||
/// problem while running the algorithm. This matcher is used to match those
|
||||
/// values.
|
||||
struct RunResultMatcher : public ::testing::MatcherInterface<builder::test::RunResult>
|
||||
{
|
||||
bool MatchAndExplain(builder::test::RunResult actual,
|
||||
::testing::MatchResultListener* listener) const override;
|
||||
void DescribeTo(std::ostream* os) const override;
|
||||
void DescribeNegationTo(std::ostream* os) const override;
|
||||
};
|
||||
|
||||
/// @brief Construct a Google Test matcher that checks that a ckt::run result
|
||||
/// was successful.
|
||||
::testing::Matcher<builder::test::RunResult> SuccessfulRun();
|
||||
|
||||
template <auto SIGNATURE>
|
||||
struct ReferenceOutputMatcher
|
||||
: public ::testing::MatcherInterface<builder::test::Outputs<SIGNATURE>>
|
||||
@@ -180,6 +197,21 @@ struct ReferenceOutputMatcher
|
||||
if(listener->IsInterested() && !errors.empty())
|
||||
{
|
||||
*listener << errors.size() << " tensors failed to validate";
|
||||
|
||||
for(const auto& e : errors)
|
||||
{
|
||||
*listener << "\n - " << e.tensor_name << ": ";
|
||||
|
||||
if(e.is_all_zero())
|
||||
*listener << "all elements in actual and expected tensors are zero";
|
||||
else
|
||||
{
|
||||
// Round to 2 digits
|
||||
const float percentage = e.wrong_elements * 10000 / e.total_elements / 100.f;
|
||||
*listener << e.wrong_elements << "/" << e.total_elements
|
||||
<< " incorrect elements (~" << percentage << "%)";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errors.empty();
|
||||
|
||||
Reference in New Issue
Block a user