From 641ae96215792bf45db1e4a70ddc2b8df477ef70 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Mon, 8 Apr 2024 12:31:33 +0000 Subject: [PATCH] Check fp8 rounding error in check_err() --- include/ck_tile/host/check_err.hpp | 36 ++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index f2d5be1628..baeb0b3537 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -273,8 +273,8 @@ std::enable_if_t<(std::is_same_v, ranges::range_val CK_TILE_HOST check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-3, - double atol = 1e-3, + unsigned rounding_error = 1, + double atol = 1e-1, bool allow_infinity_ref = false) { if(out.size() != ref.size()) @@ -291,23 +291,45 @@ std::enable_if_t<(std::is_same_v, ranges::range_val return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; + static const auto less_equal = [](double lhs, double rhs) { + return lhs < rhs || bit_cast(lhs) == bit_cast(rhs); + }; + + static const auto get_rounding_error = [](fp8_t o, fp8_t r) -> unsigned { + static const auto get_sign_bit = [](fp8_t v) -> bool { + return 0x80 & bit_cast(v); + }; + + if(get_sign_bit(o) ^ get_sign_bit(r)) + { + return std::numeric_limits::max(); + } + else + { + return std::abs(bit_cast(o) - bit_cast(r)); + } + }; + bool res{true}; int err_count = 0; double err = 0; double max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { - const double o = type_convert(*std::next(std::begin(out), i)); - const double r = type_convert(*std::next(std::begin(ref), i)); - err = std::abs(o - r); - if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) + const fp8_t o_fp8 = *std::next(std::begin(out), i); + const fp8_t r_fp8 = *std::next(std::begin(ref), i); + const double o_fp64 = type_convert(o_fp8); + const double r_fp64 = type_convert(r_fp8); + err = std::abs(o_fp64 - r_fp64); + if(!(less_equal(err, atol) || get_rounding_error(o_fp8, r_fp8) <= rounding_error) || + is_infinity_error(o_fp64, r_fp64)) { max_err = err > max_err ? err : max_err; err_count++; if(err_count < 5) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i - << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + << "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl; } res = false; }