From f2a0430ce18319ea0b5af372004f0d3936cf296d Mon Sep 17 00:00:00 2001 From: John Shumway Date: Mon, 6 Oct 2025 12:00:26 +0000 Subject: [PATCH] Add initial reflection capabilities to the builder. This PR introduces a Description class as well as ck_tile ConvTraits to add reflection. This is helpful for users, but more critically, it will help us write better tests for the builder. Too many details of the convolutions are hidden or obscured. --- experimental/builder/test/CMakeLists.txt | 3 +- .../builder/test/test_conv_builder.cpp | 12 ++-- experimental/builder/test/testing_utils.cpp | 59 ++++++++++--------- 3 files changed, 40 insertions(+), 34 deletions(-) diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 5a0a99cc02..75c19ce228 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -13,11 +13,12 @@ function(add_ck_builder_test test_name) -Wno-global-constructors -Wno-c++20-compat ) - target_link_libraries(${test_name} PRIVATE GTest::gtest GTest::gtest_main) + target_link_libraries(${test_name} PRIVATE GTest::gtest GTest::gtest_main GTest::gmock) endfunction() add_ck_builder_test(test_conv_builder test_conv_builder.cpp + test_conv_description.cpp test_conv_grp_fwd_2d.cpp test_conv_grp_fwd_3d.cpp test_conv_grp_bwd_2d.cpp diff --git a/experimental/builder/test/test_conv_builder.cpp b/experimental/builder/test/test_conv_builder.cpp index ec4903c9d0..7ee4c262f5 100644 --- a/experimental/builder/test/test_conv_builder.cpp +++ b/experimental/builder/test/test_conv_builder.cpp @@ -1,4 +1,5 @@ #include +#include #include #include "testing_utils.hpp" @@ -6,6 +7,7 @@ namespace { namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::test; using P = ckb::BlockGemmPipelineVersion; // Defines the signature of the convolution operation to be tested. @@ -29,11 +31,11 @@ TEST(ConvBuilderTest, TestDefaultInstance) static constexpr const ConvSignature SIGNATURE; static constexpr const DefaultAlgorithm ALGORITHM; using Builder = ckb::ConvBuilder; - std::string expected = - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, " - "8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>"; - EXPECT_EQ(Builder::Instance::TypeString(), expected) - << ck_tile::test::formatInlineDiff(Builder::Instance::TypeString(), expected); + EXPECT_THAT( + Builder::Instance::TypeString(), + ckt::StringEqWithDiff("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, " + "Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: " + "Intrawave, BlkGemmPipelineVersion: v4>")); } } // namespace diff --git a/experimental/builder/test/testing_utils.cpp b/experimental/builder/test/testing_utils.cpp index e8412bbdf9..178afca166 100644 --- a/experimental/builder/test/testing_utils.cpp +++ b/experimental/builder/test/testing_utils.cpp @@ -48,8 +48,8 @@ const char* RESET = isTerminalOutput() ? "\033[0m" : ""; // - Similar to sequence alignment algorithms used in bioinformatics (Needleman-Wunsch) std::string inlineDiff(const std::string& actual, const std::string& expected) { - const size_t n = expected.length(); // Length of expected string - const size_t m = actual.length(); // Length of actual string + const size_t n = expected.length(); // Length of expected string + const size_t m = actual.length(); // Length of actual string // PHASE 1: Build the Dynamic Programming Table // dp[i][j] = minimum edit distance between expected[0..i-1] and actual[0..j-1] @@ -58,11 +58,11 @@ std::string inlineDiff(const std::string& actual, const std::string& expected) // Base cases: transforming empty string to/from prefixes for(size_t i = 0; i <= n; ++i) { - dp[i][0] = i; // Delete i characters from expected to get empty string + dp[i][0] = i; // Delete i characters from expected to get empty string } for(size_t j = 0; j <= m; ++j) { - dp[0][j] = j; // Insert j characters to empty string to get actual[0..j-1] + dp[0][j] = j; // Insert j characters to empty string to get actual[0..j-1] } // Fill the DP table using the Wagner-Fischer recurrence relation @@ -72,22 +72,22 @@ std::string inlineDiff(const std::string& actual, const std::string& expected) { // Cost is 0 if characters match, 1 if they need substitution int cost = (expected[i - 1] == actual[j - 1]) ? 0 : 1; - + // Choose the minimum cost operation: dp[i][j] = std::min({ - dp[i - 1][j] + 1, // Deletion: remove expected[i-1] - dp[i][j - 1] + 1, // Insertion: add actual[j-1] - dp[i - 1][j - 1] + cost // Substitution/Match + dp[i - 1][j] + 1, // Deletion: remove expected[i-1] + dp[i][j - 1] + 1, // Insertion: add actual[j-1] + dp[i - 1][j - 1] + cost // Substitution/Match }); } } // PHASE 2: Backtrack to Reconstruct the Optimal Edit Sequence // We trace back from dp[n][m] to dp[0][0] to find which operations were used - std::vector operations; // 'M'atch, 'S'ubstitution, 'I'nsertion, 'D'eletion - std::vector> diff_chars; // Character pairs for each operation - - size_t i = n, j = m; // Start from bottom-right corner of DP table + std::vector operations; // 'M'atch, 'S'ubstitution, 'I'nsertion, 'D'eletion + std::vector> diff_chars; // Character pairs for each operation + + size_t i = n, j = m; // Start from bottom-right corner of DP table while(i > 0 || j > 0) { // Determine which operation led to the current cell's value @@ -106,21 +106,22 @@ std::string inlineDiff(const std::string& actual, const std::string& expected) operations.push_back('S'); // Substitution needed diff_chars.push_back({expected[i - 1], actual[j - 1]}); } - --i; --j; // Move diagonally up-left + --i; + --j; // Move diagonally up-left } // Check if we came from left (insertion) else if(j > 0 && dp[i][j] == dp[i][j - 1] + 1) { operations.push_back('I'); // Insertion: actual has extra character diff_chars.push_back({'\0', actual[j - 1]}); - --j; // Move left + --j; // Move left } // Must have come from above (deletion) else if(i > 0 && dp[i][j] == dp[i - 1][j] + 1) { operations.push_back('D'); // Deletion: expected has extra character diff_chars.push_back({expected[i - 1], '\0'}); - --i; // Move up + --i; // Move up } } @@ -131,42 +132,44 @@ std::string inlineDiff(const std::string& actual, const std::string& expected) // Build the final diff string with color highlighting std::ostringstream diff; - std::string expected_diff, actual_diff; // Accumulate consecutive differences - bool in_diff = false; // Track whether we're inside a diff section + std::string expected_diff, actual_diff; // Accumulate consecutive differences + bool in_diff = false; // Track whether we're inside a diff section for(size_t k = 0; k < operations.size(); ++k) { - char op = operations[k]; - char exp_char = diff_chars[k].first; // Expected character ('\0' for insertions) - char act_char = diff_chars[k].second; // Actual character ('\0' for deletions) + char op = operations[k]; + char exp_char = diff_chars[k].first; // Expected character ('\0' for insertions) + char act_char = diff_chars[k].second; // Actual character ('\0' for deletions) if(op == 'M') // Match - characters are identical { if(in_diff) { // Close the current diff section and output it - diff << "[" << EXPECTED_COLOR << expected_diff << RESET << "|" - << ACTUAL_COLOR << actual_diff << RESET << "]"; + diff << "[" << EXPECTED_COLOR << expected_diff << RESET << "|" << ACTUAL_COLOR + << actual_diff << RESET << "]"; expected_diff.clear(); actual_diff.clear(); in_diff = false; } - diff << exp_char; // Output the matching character as-is + diff << exp_char; // Output the matching character as-is } else // Difference (substitution, insertion, or deletion) { in_diff = true; // Accumulate characters for the diff section - if(exp_char != '\0') expected_diff += exp_char; // Add to expected side - if(act_char != '\0') actual_diff += act_char; // Add to actual side + if(exp_char != '\0') + expected_diff += exp_char; // Add to expected side + if(act_char != '\0') + actual_diff += act_char; // Add to actual side } } // Close any remaining diff section at the end if(in_diff) { - diff << "[" << EXPECTED_COLOR << expected_diff << RESET << "|" - << ACTUAL_COLOR << actual_diff << RESET << "]"; + diff << "[" << EXPECTED_COLOR << expected_diff << RESET << "|" << ACTUAL_COLOR + << actual_diff << RESET << "]"; } return diff.str(); @@ -200,7 +203,7 @@ bool StringEqWithDiffMatcher::MatchAndExplain(std::string actual, void StringEqWithDiffMatcher::DescribeTo(std::ostream* os) const { - *os << "is equal to \"" << expected_ << "\""; + *os << "\"" << expected_ << "\""; } void StringEqWithDiffMatcher::DescribeNegationTo(std::ostream* os) const