[CK TILE] Enable CK TILE Conv Fwd tests in CI and fix check_err (#3624)

* [CK TILE] Enable CK TILE Conv Fwd tests in CI and fix check_err

* Update test_grouped_convnd_fwd_tile.cpp

* Update test_grouped_convnd_fwd_tile.cpp

* Update conv_tuning_params.hpp

* clang format fix

* Update CMakeLists.txt
This commit is contained in:
Bartłomiej Kocot
2026-01-27 10:04:11 +01:00
committed by GitHub
parent c190d8d61f
commit 3d67e6c492
14 changed files with 114 additions and 46 deletions

View File

@@ -51,6 +51,9 @@ struct ValidationReport
/// The number of elements which were bitwise 0.
uint64_t zero_elements;
// Max error.
double max_error;
/// @brief Check whether both the output and reference tensor were both all zeros.
///
/// If both tensors are all zero, it indicates either an incorrect testing setup
@@ -133,11 +136,12 @@ bool ValidationReport::check(std::string_view tensor_name,
// Initial pass: count errors
// Allocate and reset counter
auto d_counters = alloc_buffer(sizeof(uint64_t) * 2);
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2));
auto d_counters = alloc_buffer(sizeof(uint64_t) * 3);
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 3));
auto d_error_count = &reinterpret_cast<uint64_t*>(d_counters.get())[0];
auto d_zero_count = &reinterpret_cast<uint64_t*>(d_counters.get())[1];
auto d_max_error = &reinterpret_cast<double*>(d_counters.get())[2];
tensor_foreach(descriptor.get_lengths(), [=](auto index) {
using CKType = typename factory::internal::DataTypeToCK<DT>::type;
@@ -157,6 +161,7 @@ bool ValidationReport::check(std::string_view tensor_name,
const auto r = static_cast<double>(type_convert<float>(b));
const auto err = std::abs(o - r);
atomicMax(d_max_error, err);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
// We expect the number of errors to be very low, so just use an atomic
@@ -188,6 +193,8 @@ bool ValidationReport::check(std::string_view tensor_name,
check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
uint64_t zero_count = 0;
check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
double max_error = 0;
check_hip(hipMemcpy(&max_error, d_max_error, sizeof(double), hipMemcpyDeviceToHost));
// TODO: Gather detailed coordinates.
@@ -196,6 +203,7 @@ bool ValidationReport::check(std::string_view tensor_name,
.wrong_elements = error_count,
.total_elements = descriptor.get_element_size(),
.zero_elements = zero_count,
.max_error = max_error,
});
return reports_.back().is_ok();