Files
composable_kernel/experimental/builder/test/testing_utils.cpp
Robin Voetter cc75948d1c [CK_BUILDER] conv bwd weight testing (#3618)
* ck-builder: restructure testing conv

In order to prepare for bwd of conv testing, this commit moves some
files and types around so that we can reuse ckt::Args for both forward
and backwards convolution.

* ck-builder: decouple fwd_ck.hpp and fwd_reference.hpp from fwd.hpp

This will allow us to more easily include fwd.hpp from backwards
definitions, which is required for initializing bwd values.

* ck-builder: fix layout of test_ckb_conv_bwd_weight_xdl_cshuffle_v3

Turns out that the supplied layout isn't actually supported...

* ck-builder: ck and reference conv integration for bwd weight

* ck-builder: ck bwd weight execution test

* ck-builder: ckt::run support for ck-tile bwd weight

* ck-builder: ck tile bwd weight execution test

* ck-builder: extra debug printing in MatchesReference

* ck-builder: make ckt::run return RunResult

This type is more convenient than std::tuple, as it will allow us to
use google test matchers with this in the future.

* ck-builder: RunResult matcher

Using EXPECT_THAT(..., SuccessfulRun()) will generate a check and a nice error
message about how and why running an algorithm failed.

* ck-builder: doc fixes

* ck-builder: add missing headers
2026-01-26 23:50:15 +01:00

361 lines
12 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "testing_utils.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <unistd.h>
#include <string>
#include <sstream>
#include <ostream>
#include <vector>
#include <algorithm>
std::ostream& operator<<(std::ostream& os, hipError_t status)
{
return os << hipGetErrorString(status);
}
namespace ck_tile::test {
// Wagner-Fischer Algorithm for Computing Edit Distance and Inline Diff
//
// OUTPUT FORMAT: [expected|actual] for differences, plain text for matches
// Example: "hello world" vs "hello earth" → "hello [world|earth]"
//
// This function implements the Wagner-Fischer algorithm (1974), which is the classic
// dynamic programming solution for computing the minimum edit distance (Levenshtein distance)
// between two strings. The algorithm has O(n*m) time and space complexity.
//
// ALGORITHM OVERVIEW:
// 1. Build a 2D DP table where dp[i][j] represents the minimum edit distance
// between the first i characters of 'expected' and first j characters of 'actual'
// 2. Fill the table using the recurrence relation:
// dp[i][j] = min(
// dp[i-1][j] + 1, // deletion (remove char from expected)
// dp[i][j-1] + 1, // insertion (add char to expected)
// dp[i-1][j-1] + cost // substitution (cost=0 if chars match, 1 if different)
// )
// 3. Backtrack through the table to reconstruct the optimal edit sequence
//
// REFERENCES:
// - Wagner, R. A.; Fischer, M. J. (1974). "The String-to-String Correction Problem"
// - Also known as: Levenshtein distance, edit distance, string alignment
// - Similar to sequence alignment algorithms used in bioinformatics (Needleman-Wunsch)
std::string inlineDiff(const std::string& actual, const std::string& expected, bool use_color)
{
const char* EXPECTED_COLOR = use_color ? "\033[36m" : ""; // Cyan
const char* ACTUAL_COLOR = use_color ? "\033[35m" : ""; // Magenta
const char* RESET = use_color ? "\033[0m" : "";
const size_t n = expected.length(); // Length of expected string
const size_t m = actual.length(); // Length of actual string
// PHASE 1: Build the Dynamic Programming Table
// dp[i][j] = minimum edit distance between expected[0..i-1] and actual[0..j-1]
std::vector<std::vector<int>> dp(n + 1, std::vector<int>(m + 1));
// Base cases: transforming empty string to/from prefixes
for(size_t i = 0; i <= n; ++i)
{
dp[i][0] = i; // Delete i characters from expected to get empty string
}
for(size_t j = 0; j <= m; ++j)
{
dp[0][j] = j; // Insert j characters to empty string to get actual[0..j-1]
}
// Fill the DP table using the Wagner-Fischer recurrence relation
for(size_t i = 1; i <= n; ++i)
{
for(size_t j = 1; j <= m; ++j)
{
// Cost is 0 if characters match, 1 if they need substitution
int cost = (expected[i - 1] == actual[j - 1]) ? 0 : 1;
// Choose the minimum cost operation:
dp[i][j] = std::min({
dp[i - 1][j] + 1, // Deletion: remove expected[i-1]
dp[i][j - 1] + 1, // Insertion: add actual[j-1]
dp[i - 1][j - 1] + cost // Substitution/Match
});
}
}
// PHASE 2: Backtrack to Reconstruct the Optimal Edit Sequence
// We trace back from dp[n][m] to dp[0][0] to find which operations were used
std::vector<char> operations; // 'M'atch, 'S'ubstitution, 'I'nsertion, 'D'eletion
std::vector<std::pair<char, char>> diff_chars; // Character pairs for each operation
size_t i = n, j = m; // Start from bottom-right corner of DP table
while(i > 0 || j > 0)
{
// Determine which operation led to the current cell's value
int cost = (i > 0 && j > 0 && expected[i - 1] == actual[j - 1]) ? 0 : 1;
// Check if we came from diagonal (substitution/match)
if(i > 0 && j > 0 && dp[i][j] == dp[i - 1][j - 1] + cost)
{
if(cost == 0)
{
operations.push_back('M'); // Characters match
diff_chars.push_back({expected[i - 1], actual[j - 1]});
}
else
{
operations.push_back('S'); // Substitution needed
diff_chars.push_back({expected[i - 1], actual[j - 1]});
}
--i;
--j; // Move diagonally up-left
}
// Check if we came from left (insertion)
else if(j > 0 && dp[i][j] == dp[i][j - 1] + 1)
{
operations.push_back('I'); // Insertion: actual has extra character
diff_chars.push_back({'\0', actual[j - 1]});
--j; // Move left
}
// Must have come from above (deletion)
else if(i > 0 && dp[i][j] == dp[i - 1][j] + 1)
{
operations.push_back('D'); // Deletion: expected has extra character
diff_chars.push_back({expected[i - 1], '\0'});
--i; // Move up
}
}
// PHASE 3: Reverse and Build the Human-Readable Diff String
// Backtracking gives us operations in reverse order, so we reverse to get forward order
std::reverse(operations.begin(), operations.end());
std::reverse(diff_chars.begin(), diff_chars.end());
// Build the final diff string with color highlighting
std::ostringstream diff;
std::string expected_diff, actual_diff; // Accumulate consecutive differences
bool in_diff = false; // Track whether we're inside a diff section
for(size_t k = 0; k < operations.size(); ++k)
{
char op = operations[k];
char exp_char = diff_chars[k].first; // Expected character ('\0' for insertions)
char act_char = diff_chars[k].second; // Actual character ('\0' for deletions)
if(op == 'M') // Match - characters are identical
{
if(in_diff)
{
// Close the current diff section and output it
diff << "[" << EXPECTED_COLOR << expected_diff << RESET << "|" << ACTUAL_COLOR
<< actual_diff << RESET << "]";
expected_diff.clear();
actual_diff.clear();
in_diff = false;
}
diff << exp_char; // Output the matching character as-is
}
else // Difference (substitution, insertion, or deletion)
{
in_diff = true;
// Accumulate characters for the diff section
if(exp_char != '\0')
expected_diff += exp_char; // Add to expected side
if(act_char != '\0')
actual_diff += act_char; // Add to actual side
}
}
// Close any remaining diff section at the end
if(in_diff)
{
diff << "[" << EXPECTED_COLOR << expected_diff << RESET << "|" << ACTUAL_COLOR
<< actual_diff << RESET << "]";
}
return diff.str();
}
std::string formatInlineDiff(const std::string& actual, const std::string& expected)
{
return std::string("Inline diff: \"") + inlineDiff(actual, expected) + "\"";
}
// StringEqWithDiffMatcher implementation
StringEqWithDiffMatcher::StringEqWithDiffMatcher(const std::string& expected) : expected_(expected)
{
}
bool StringEqWithDiffMatcher::MatchAndExplain(std::string actual,
::testing::MatchResultListener* listener) const
{
if(actual == expected_)
{
return true;
}
// On failure, provide detailed diff information
if(listener->IsInterested())
{
*listener << "\n Diff: \"" << inlineDiff(actual, expected_) << "\"";
}
return false;
}
void StringEqWithDiffMatcher::DescribeTo(std::ostream* os) const
{
*os << "\"" << expected_ << "\"";
}
void StringEqWithDiffMatcher::DescribeNegationTo(std::ostream* os) const
{
*os << "is not equal to \"" << expected_ << "\"";
}
// Factory function for the StringEqWithDiff matcher
::testing::Matcher<std::string> StringEqWithDiff(const std::string& expected)
{
return ::testing::MakeMatcher(new StringEqWithDiffMatcher(expected));
}
std::ostream& operator<<(std::ostream& os, const InstanceSet& set)
{
// These sets can grow very large, and so its not very nice or useful to print them
// in the event of a mismatch. Just print a brief description here, and use
// InstancesMatcher to print a more useful message.
return (os << "(set of " << set.instances.size() << " instances)");
}
InstanceMatcher::InstanceMatcher(const InstanceSet& expected) : expected_(expected) {}
::testing::Matcher<InstanceSet> InstancesMatch(const InstanceSet& expected)
{
return ::testing::MakeMatcher(new InstanceMatcher(expected));
}
bool InstanceMatcher::MatchAndExplain(InstanceSet actual,
::testing::MatchResultListener* listener) const
{
if(actual.instances == expected_.instances)
{
return true;
}
if(listener->IsInterested())
{
std::vector<std::string> instances;
std::set_difference(expected_.instances.begin(),
expected_.instances.end(),
actual.instances.begin(),
actual.instances.end(),
std::back_inserter(instances));
*listener << "\n";
if(instances.size() > 0)
{
*listener << " Missing: " << instances.size() << "\n";
for(const auto& instance : instances)
{
if(instance == "")
{
*listener << "- (empty string)\n";
}
else
{
*listener << "- " << instance << "\n";
}
}
}
instances.clear();
std::set_difference(actual.instances.begin(),
actual.instances.end(),
expected_.instances.begin(),
expected_.instances.end(),
std::back_inserter(instances));
if(instances.size() > 0)
{
*listener << "Unexpected: " << instances.size() << "\n";
for(const auto& instance : instances)
{
if(instance == "")
{
*listener << "- (empty string)\n";
}
else
{
*listener << "- " << instance << "\n";
}
}
}
}
return false;
}
void InstanceMatcher::DescribeTo(std::ostream* os) const { *os << expected_; }
void InstanceMatcher::DescribeNegationTo(std::ostream* os) const
{
*os << "is not equal to " << expected_;
}
bool HipStatusMatcher::MatchAndExplain(hipError_t actual,
::testing::MatchResultListener* listener) const
{
(void)listener;
if(actual == expected_)
{
return true;
}
return false;
}
void HipStatusMatcher::DescribeTo(std::ostream* os) const { *os << hipGetErrorString(expected_); }
void HipStatusMatcher::DescribeNegationTo(std::ostream* os) const
{
if(expected_ == hipSuccess)
{
*os << "any error";
}
else
{
*os << "isn't equal to " << hipGetErrorString(expected_);
}
}
::testing::Matcher<hipError_t> HipSuccess()
{
return ::testing::MakeMatcher(new HipStatusMatcher(hipSuccess));
}
::testing::Matcher<hipError_t> HipError(hipError_t error)
{
return ::testing::MakeMatcher(new HipStatusMatcher(error));
}
bool RunResultMatcher::MatchAndExplain(builder::test::RunResult actual,
::testing::MatchResultListener* listener) const
{
if(actual.error.has_value() && listener)
*listener << "run failed: " << actual.error.value();
return actual.is_supported();
}
void RunResultMatcher::DescribeTo(std::ostream* os) const { *os << "successful run"; }
void RunResultMatcher::DescribeNegationTo(std::ostream* os) const { *os << "unsuccessful run"; }
::testing::Matcher<builder::test::RunResult> SuccessfulRun()
{
return ::testing::MakeMatcher(new RunResultMatcher());
}
} // namespace ck_tile::test