[CK_BUILDER] Integrate reference conv with testing (#3511)

* ck-builder: explicitly delete forward declarations

Before, these functions were seen as a forward declaration for an existing function.
If no actual implementation overload could be found, these would be selected and
a linker error or warning would be generated. By marking these functions as explicitly
deleted, they incorrect invocations are generated as compile error instead.

* ck-builder: ckt::run plumbing for reference conv

This implements the ckt::run plumbing for the reference convolution
implementation and sets up the first complete end-to-end test.

* ck-builder: make validation system check for all-zeros

When both the actual and reference output are both all zero bits,
there is probably something wrong in the test framework.

* ck-builder: proper implementation+tests for TensorDescriptor::is_packed

* ck-builder: fix typos
This commit is contained in:
Robin Voetter
2026-01-06 09:29:06 +01:00
committed by GitHub
parent b78563b3d3
commit 1c433c64ec
9 changed files with 349 additions and 60 deletions

View File

@@ -13,6 +13,7 @@
#include <vector>
#include <algorithm>
#include <functional>
#include <bit>
/// This file implements functionality related to "validation", ie, functionality
/// to compare tensors. The functionality in this file should be testing-framework
@@ -48,12 +49,22 @@ struct ValidationReport
/// The total number of elements in each tensor.
uint64_t total_elements;
/// The number of elements which were bitwise 0.
uint64_t zero_elements;
/// @brief Check whether both the output and reference tensor were both all zeros.
///
/// If both tensors are all zero, it indicates either an incorrect testing setup
/// or an issue with the testing framework. For that reason we also consider that
/// a failure.
bool is_all_zero() const { return zero_elements == total_elements; }
/// @brief Return whether the check associated to this case was successful.
///
/// This function returns whether the check associated to this case was successful,
/// which is directly derived from checking whether the number of incorrect elements
/// was 0.
bool is_ok() const { return wrong_elements == 0; }
/// was 0 AND whether the tensor was not all zero.
bool is_ok() const { return wrong_elements == 0 && !is_all_zero(); }
};
/// @brief Get comparison cases which were incorrect.
@@ -123,10 +134,13 @@ bool ValidationReport::check(std::string_view tensor_name,
// Initial pass: count errors
// Allocate and reset counter
auto d_error_count = alloc_buffer(sizeof(uint64_t));
check_hip(hipMemset(d_error_count.get(), 0, sizeof(uint64_t)));
auto d_counters = alloc_buffer(sizeof(uint64_t) * 2);
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2));
tensor_foreach(descriptor.get_lengths(), [=, error_count = d_error_count.get()](auto index) {
auto d_error_count = &reinterpret_cast<uint64_t*>(d_counters.get())[0];
auto d_zero_count = &reinterpret_cast<uint64_t*>(d_counters.get())[1];
tensor_foreach(descriptor.get_lengths(), [=](auto index) {
using CKType = typename factory::internal::DataTypeToCK<DT>::type;
const auto* actual = static_cast<const CKType*>(actual_data);
@@ -137,21 +151,44 @@ bool ValidationReport::check(std::string_view tensor_name,
const auto offset = calculate_offset(index, strides);
const auto o = static_cast<double>(type_convert<float>(actual[offset]));
const auto r = static_cast<double>(type_convert<float>(expected[offset]));
const auto a = actual[offset];
const auto b = expected[offset];
const auto o = static_cast<double>(type_convert<float>(a));
const auto r = static_cast<double>(type_convert<float>(b));
const auto err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
// We expect the number of errors to be very low, so just use an atomic
// for now.
atomicAdd(reinterpret_cast<uint64_t*>(error_count), 1);
atomicAdd(d_error_count, 1);
}
// Now compare the numbers as bitwise too.
// Update the counter if they're both zero.
using Bytes = std::array<std::byte, sizeof(CKType)>;
bool all_zero = true;
for(auto x : std::bit_cast<Bytes>(a))
{
if(x != std::byte{0})
all_zero = false;
}
for(auto x : std::bit_cast<Bytes>(b))
{
if(x != std::byte{0})
all_zero = false;
}
if(all_zero)
{
atomicAdd(d_zero_count, 1);
}
});
uint64_t error_count = 0;
check_hip(
hipMemcpy(&error_count, d_error_count.get(), sizeof(uint64_t), hipMemcpyDeviceToHost));
check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
uint64_t zero_count = 0;
check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
// TODO: Gather detailed coordinates.
@@ -159,9 +196,10 @@ bool ValidationReport::check(std::string_view tensor_name,
.tensor_name = std::string(tensor_name),
.wrong_elements = error_count,
.total_elements = descriptor.get_element_size(),
.zero_elements = zero_count,
});
return error_count == 0;
return reports_.back().is_ok();
}
} // namespace ck_tile::builder::test