mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[CK TILE] Enable CK TILE Conv Fwd tests in CI and fix check_err (#3624)
* [CK TILE] Enable CK TILE Conv Fwd tests in CI and fix check_err * Update test_grouped_convnd_fwd_tile.cpp * Update test_grouped_convnd_fwd_tile.cpp * Update conv_tuning_params.hpp * clang format fix * Update CMakeLists.txt
This commit is contained in:
@@ -51,6 +51,9 @@ struct ValidationReport
|
||||
/// The number of elements which were bitwise 0.
|
||||
uint64_t zero_elements;
|
||||
|
||||
// Max error.
|
||||
double max_error;
|
||||
|
||||
/// @brief Check whether both the output and reference tensor were both all zeros.
|
||||
///
|
||||
/// If both tensors are all zero, it indicates either an incorrect testing setup
|
||||
@@ -133,11 +136,12 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
// Initial pass: count errors
|
||||
|
||||
// Allocate and reset counter
|
||||
auto d_counters = alloc_buffer(sizeof(uint64_t) * 2);
|
||||
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2));
|
||||
auto d_counters = alloc_buffer(sizeof(uint64_t) * 3);
|
||||
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 3));
|
||||
|
||||
auto d_error_count = &reinterpret_cast<uint64_t*>(d_counters.get())[0];
|
||||
auto d_zero_count = &reinterpret_cast<uint64_t*>(d_counters.get())[1];
|
||||
auto d_max_error = &reinterpret_cast<double*>(d_counters.get())[2];
|
||||
|
||||
tensor_foreach(descriptor.get_lengths(), [=](auto index) {
|
||||
using CKType = typename factory::internal::DataTypeToCK<DT>::type;
|
||||
@@ -157,6 +161,7 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
const auto r = static_cast<double>(type_convert<float>(b));
|
||||
const auto err = std::abs(o - r);
|
||||
|
||||
atomicMax(d_max_error, err);
|
||||
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
|
||||
{
|
||||
// We expect the number of errors to be very low, so just use an atomic
|
||||
@@ -188,6 +193,8 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
|
||||
uint64_t zero_count = 0;
|
||||
check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
|
||||
double max_error = 0;
|
||||
check_hip(hipMemcpy(&max_error, d_max_error, sizeof(double), hipMemcpyDeviceToHost));
|
||||
|
||||
// TODO: Gather detailed coordinates.
|
||||
|
||||
@@ -196,6 +203,7 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
.wrong_elements = error_count,
|
||||
.total_elements = descriptor.get_element_size(),
|
||||
.zero_elements = zero_count,
|
||||
.max_error = max_error,
|
||||
});
|
||||
|
||||
return reports_.back().is_ok();
|
||||
|
||||
Reference in New Issue
Block a user