Check fp8 rounding error in check_err()

This commit is contained in:
Po Yen Chen
2024-04-08 12:31:33 +00:00
parent 92d45d1681
commit 641ae96215

View File

@@ -273,8 +273,8 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
CK_TILE_HOST check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3,
unsigned rounding_error = 1,
double atol = 1e-1,
bool allow_infinity_ref = false)
{
if(out.size() != ref.size())
@@ -291,23 +291,45 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
};
static const auto less_equal = [](double lhs, double rhs) {
return lhs < rhs || bit_cast<uint64_t>(lhs) == bit_cast<uint64_t>(rhs);
};
static const auto get_rounding_error = [](fp8_t o, fp8_t r) -> unsigned {
static const auto get_sign_bit = [](fp8_t v) -> bool {
return 0x80 & bit_cast<uint8_t>(v);
};
if(get_sign_bit(o) ^ get_sign_bit(r))
{
return std::numeric_limits<unsigned>::max();
}
else
{
return std::abs(bit_cast<int8_t>(o) - bit_cast<int8_t>(r));
}
};
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
const fp8_t o_fp8 = *std::next(std::begin(out), i);
const fp8_t r_fp8 = *std::next(std::begin(ref), i);
const double o_fp64 = type_convert<float>(o_fp8);
const double r_fp64 = type_convert<float>(r_fp8);
err = std::abs(o_fp64 - r_fp64);
if(!(less_equal(err, atol) || get_rounding_error(o_fp8, r_fp8) <= rounding_error) ||
is_infinity_error(o_fp64, r_fp64))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
<< "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl;
}
res = false;
}