Add max error metric

This commit is contained in:
Bartlomiej Kocot
2026-01-20 08:47:47 -05:00
parent 51214187a1
commit 4010341092
2 changed files with 106 additions and 5 deletions

View File

@@ -22,6 +22,99 @@
/// etc, by the actual testing framework that the user has chosen.
namespace ck_tile::builder::test {
template <DataType DT, typename GemmType = DataType>
inline __host__ __device__ constexpr double get_rtol()
{
// if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
// {
// return 5e-3;
// }
// else
if constexpr(DataType::FP32 == DT)
{
return 1e-3;
}
// else if constexpr(std::is_same_v<DataType, double>)
// {
// return 1e-6;
// }
else if constexpr(DataType::FP16== DT)
{
return 1e-3;
}
else if constexpr(DataType::BF16 == DT)
{
return 5e-2;
}
// else if constexpr(std::is_same_v<DataType, int32_t>)
// {
// return 1e-1;
// }
// else if constexpr(std::is_same_v<DataType, int8_t>)
// {
// return 1e-1;
// }
// else if constexpr(std::is_same_v<DataType, ck::f8_t>)
// {
// return 1e-1; // 240 and 224 are acceptable
// }
// else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
// {
// return 1.5e-1; // 57344 and 49152 are acceptable
// }
else
{
return 1e-3;
}
}
template <DataType DT, typename GemmType = DataType>
inline __host__ __device__ constexpr double get_atol()
{
// if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
// {
// return 1e-3;
// }
if constexpr(DataType::FP32 == DT)
{
return 1e-3;
}
// else if constexpr(std::is_same_v<DataType, double>)
// {
// return 1e-6;
// }
else if constexpr(DataType::FP16== DT)
{
return 1e-3;
}
else if constexpr(DataType::BF16 == DT)
{
return 5e-2;
}
// else if constexpr(std::is_same_v<DataType, int32_t>)
// {
// return 1e-1;
// }
// else if constexpr(std::is_same_v<DataType, int8_t>)
// {
// return 1e-1;
// }
// else if constexpr(std::is_same_v<DataType, ck::f8_t>)
// {
// return 16.1; // 240 and 224 are acceptable
// }
// else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
// {
// return 8192.1; // 57344 and 49152 are acceptable
// }
else
{
return 1e-3;
}
}
/// @brief Information about how a set of comparisons failed or succeeded.
///
@@ -51,6 +144,9 @@ struct ValidationReport
/// The number of elements which were bitwise 0.
uint64_t zero_elements;
// Max error.
double max_error;
/// @brief Check whether both the output and reference tensor were both all zeros.
///
/// If both tensors are all zero, it indicates either an incorrect testing setup
@@ -108,8 +204,8 @@ struct ValidationReport
const TensorDescriptor<DT, RANK>& descriptor,
const void* actual,
const void* expected,
double rtol = 1e-3,
double atol = 1e-3);
double rtol = get_rtol<DT>(),
double atol = get_atol<DT>());
private:
std::vector<Case> reports_;
@@ -133,11 +229,12 @@ bool ValidationReport::check(std::string_view tensor_name,
// Initial pass: count errors
// Allocate and reset counter
auto d_counters = alloc_buffer(sizeof(uint64_t) * 2);
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2));
auto d_counters = alloc_buffer(sizeof(uint64_t) * 3);
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 3));
auto d_error_count = &reinterpret_cast<uint64_t*>(d_counters.get())[0];
auto d_zero_count = &reinterpret_cast<uint64_t*>(d_counters.get())[1];
auto d_max_error = &reinterpret_cast<double*>(d_counters.get())[2];
tensor_foreach(descriptor.get_lengths(), [=](auto index) {
using CKType = typename factory::internal::DataTypeToCK<DT>::type;
@@ -157,6 +254,7 @@ bool ValidationReport::check(std::string_view tensor_name,
const auto r = static_cast<double>(type_convert<float>(b));
const auto err = std::abs(o - r);
atomicMax(d_max_error, err);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
// We expect the number of errors to be very low, so just use an atomic
@@ -188,6 +286,8 @@ bool ValidationReport::check(std::string_view tensor_name,
check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
uint64_t zero_count = 0;
check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
double max_error = 0;
check_hip(hipMemcpy(&max_error, d_max_error, sizeof(double), hipMemcpyDeviceToHost));
// TODO: Gather detailed coordinates.
@@ -196,6 +296,7 @@ bool ValidationReport::check(std::string_view tensor_name,
.wrong_elements = error_count,
.total_elements = descriptor.get_element_size(),
.zero_elements = zero_count,
.max_error= max_error,
});
return reports_.back().is_ok();

View File

@@ -125,7 +125,7 @@ run_grouped_conv_forward_tile_algs(const ckt::Args<SIGNATURE>& args,
{
valid = false;
std::cout << "Number of incorrect values: " << error.wrong_elements
<< " Is all zero:" << error.is_all_zero() << std::endl;
<< " Is all zero:" << error.is_all_zero() << " max err: " << error.max_error << std::endl;
}
best_avg_time = std::min(best_avg_time, avg_time);
best_op_name = best_avg_time < avg_time ? best_op_name : op_name;