// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include #include #include #include #include #include #include #include #include #include "ck_tile/core.hpp" #include "ck_tile/host/ranges.hpp" namespace ck_tile { /** @brief Maximum number of error values to display when checking errors */ constexpr int ERROR_DETAIL_LIMIT = 16; /** @brief 8-bit floating point type */ using F8 = ck_tile::fp8_t; /** @brief 8-bit brain floating point type */ using BF8 = ck_tile::bf8_t; /** @brief 16-bit floating point (half precision) type */ using F16 = ck_tile::half_t; /** @brief 16-bit brain floating point type */ using BF16 = ck_tile::bf16_t; /** @brief 32-bit floating point (single precision) type */ using F32 = float; /** @brief 8-bit signed integer type */ using I8 = int8_t; /** @brief 32-bit signed integer type */ using I32 = int32_t; /** * @brief Calculate relative error threshold for numerical comparisons * * Calculates the relative error threshold based on the mantissa bits and characteristics * of the data types involved in the computation. * * @tparam ComputeDataType Type used for computation * @tparam OutDataType Type used for output * @tparam AccDataType Type used for accumulation (defaults to ComputeDataType) * @param number_of_accumulations Number of accumulation operations performed * @return Relative error threshold based on data type characteristics */ template CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1) { static_assert(is_any_of::value, "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); double compute_error = 0; if constexpr(is_any_of::value) { return 0; } else { compute_error = std::pow(2, -numeric_traits::mant) * 0.5; } static_assert( is_any_of::value, "Warning: Unhandled OutDataType for setting up the relative threshold!"); double output_error = 0; if constexpr(is_any_of::value) { return 0; } else { output_error = std::pow(2, -numeric_traits::mant) * 0.5; } double midway_error = std::max(compute_error, output_error); static_assert( is_any_of::value, "Warning: Unhandled AccDataType for setting up the relative threshold!"); double acc_error = 0; if constexpr(is_any_of::value) { return 0; } else { acc_error = std::pow(2, -numeric_traits::mant) * 0.5 * number_of_accumulations; } return std::max(acc_error, midway_error); } /** * @brief Calculate absolute error threshold for numerical comparisons * * Calculates the absolute error threshold based on the maximum possible value and * the characteristics of the data types involved in the computation. * * @tparam ComputeDataType Type used for computation * @tparam OutDataType Type used for output * @tparam AccDataType Type used for accumulation (defaults to ComputeDataType) * @param max_possible_num Maximum possible value in the computation * @param number_of_accumulations Number of accumulation operations performed * @return Absolute error threshold based on data type characteristics and maximum value */ template CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) { static_assert(is_any_of::value, "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); // Use discrete exponent (floor of log2) to match actual floating-point exponent levels // This ensures ULP calculation matches the discrete precision levels of FP representation int discrete_expo = std::floor(static_cast(std::floor(std::log2(std::abs(max_possible_num))))); double compute_error = 0; if constexpr(is_any_of::value) { return 0; } else { compute_error = std::pow(2, discrete_expo - numeric_traits::mant) * 0.5; } static_assert( is_any_of::value, "Warning: Unhandled OutDataType for setting up the absolute threshold!"); double output_error = 0; if constexpr(is_any_of::value) { return 0; } else { // Use full ULP (1.0) instead of half ULP (0.5) for output_error to account for // hardware vs software conversion differences (e.g., hardware __bf16 vs software // float_to_bf16 can differ by up to 1 ULP at tie cases) output_error = std::pow(2, discrete_expo - numeric_traits::mant) * 1.0; } double midway_error = std::max(compute_error, output_error); static_assert( is_any_of::value, "Warning: Unhandled AccDataType for setting up the absolute threshold!"); double acc_error = 0; if constexpr(is_any_of::value) { return 0; } else { acc_error = std::pow(2, discrete_expo - numeric_traits::mant) * 0.5 * number_of_accumulations; } return std::max(acc_error, midway_error); } /** * @brief Stream operator overload for vector output * * Provides a formatted string representation of a vector, useful for debugging and logging. * * @tparam T Type of vector elements * @param os Output stream * @param v Vector to output * @return Reference to the output stream */ template std::ostream& operator<<(std::ostream& os, const std::vector& v) { using size_type = typename std::vector::size_type; os << "["; for(size_type idx = 0; idx < v.size(); ++idx) { if(0 < idx) { os << ", "; } os << v[idx]; } return os << "]"; } /** * @brief Check for size mismatch between output and reference ranges * * Verifies that the output and reference ranges are the same size. * * @tparam Range Type of output range * @tparam RefRange Type of reference range * @param out Output range to check * @param ref Reference range to check against * @param msg Error message to display if sizes mismatch * @return True if sizes mismatch, false otherwise */ template CK_TILE_HOST bool check_size_mismatch(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!") { if(out.size() != ref.size()) { std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() << std::endl; return true; } return false; } /** * @brief Report error statistics for numerical comparisons * * Outputs statistics about numerical comparison errors including count and maximum error. * * @param err_count Number of errors found * @param max_err Maximum error value encountered * @param total_size Total number of elements compared */ CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size) { const float error_percent = static_cast(err_count) / static_cast(total_size) * 100.f; std::cerr << "max err: " << max_err; std::cerr << ", number of errors: " << err_count; std::cerr << ", " << error_percent << "% wrong values" << std::endl; } /** * @brief Check errors between floating point ranges using the specified tolerances. * * Compares two ranges of floating point values within specified relative and absolute tolerances. * This overload handles standard floating point types except half precision floating point. * * @tparam Range Type of output range * @tparam RefRange Type of reference range * @param out Output range to check * @param ref Reference range to check against * @param msg Error message to display if check fails * @param rtol Relative tolerance * @param atol Absolute tolerance * @param allow_infinity_ref Whether to allow infinity in reference values * @return True if check passes, false otherwise */ template typename std::enable_if< std::is_same_v, ranges::range_value_t> && std::is_floating_point_v> && !std::is_same_v, half_t>, bool>::type CK_TILE_HOST check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", double rtol = 1e-5, double atol = 3e-6, bool allow_infinity_ref = false) { if(check_size_mismatch(out, ref, msg)) return false; const auto is_infinity_error = [=](auto o, auto r) { const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; bool res{true}; int err_count = 0; double err = 0; double max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { const double o = *std::next(std::begin(out), i); const double r = *std::next(std::begin(ref), i); err = std::abs(o - r); if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) { max_err = err > max_err ? err : max_err; err_count++; if(err_count < ERROR_DETAIL_LIMIT) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; } res = false; } } if(!res) { report_error_stats(err_count, max_err, ref.size()); } return res; } /** * @brief Check errors between floating point ranges using the specified tolerances * * Compares two ranges of brain floating point values within specified relative and absolute * tolerances. * * @tparam Range Type of output range * @tparam RefRange Type of reference range * @param out Output range to check * @param ref Reference range to check against * @param msg Error message to display if check fails * @param rtol Relative tolerance * @param atol Absolute tolerance * @param allow_infinity_ref Whether to allow infinity in reference values * @return True if check passes, false otherwise */ template typename std::enable_if< std::is_same_v, ranges::range_value_t> && std::is_same_v, bf16_t>, bool>::type CK_TILE_HOST check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", double rtol = 1e-3, double atol = 1e-3, bool allow_infinity_ref = false) { if(check_size_mismatch(out, ref, msg)) return false; const auto is_infinity_error = [=](auto o, auto r) { const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; bool res{true}; int err_count = 0; double err = 0; // TODO: This is a hack. We should have proper specialization for bf16_t data type. double max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) { max_err = err > max_err ? err : max_err; err_count++; if(err_count < ERROR_DETAIL_LIMIT) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; } res = false; } } if(!res) { report_error_stats(err_count, max_err, ref.size()); } return res; } /** * @brief Check errors between half precision floating point ranges * * Compares two ranges of half precision floating point values within specified tolerances. * This specialization handles the specific requirements and characteristics of half precision * floating point comparisons. * * @tparam Range Type of output range * @tparam RefRange Type of reference range * @param out Output range to check * @param ref Reference range to check against * @param msg Error message to display if check fails * @param rtol Relative tolerance * @param atol Absolute tolerance * @param allow_infinity_ref Whether to allow infinity in reference values * @return True if check passes, false otherwise */ template typename std::enable_if< std::is_same_v, ranges::range_value_t> && std::is_same_v, half_t>, bool>::type CK_TILE_HOST check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", double rtol = 1e-3, double atol = 1e-3, bool allow_infinity_ref = false) { if(check_size_mismatch(out, ref, msg)) return false; const auto is_infinity_error = [=](auto o, auto r) { const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; bool res{true}; int err_count = 0; double err = 0; double max_err = static_cast(std::numeric_limits>::min()); for(std::size_t i = 0; i < ref.size(); ++i) { const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) { max_err = err > max_err ? err : max_err; err_count++; if(err_count < ERROR_DETAIL_LIMIT) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; } res = false; } } if(!res) { report_error_stats(err_count, max_err, ref.size()); } return res; } /** * @brief Check errors between integer ranges * * Compares two ranges of integer values with an absolute tolerance. * This specialization handles integer types and optionally int4_t when the * experimental bit int extension is enabled. * * @tparam Range Type of output range * @tparam RefRange Type of reference range * @param out Output range to check * @param ref Reference range to check against * @param msg Error message to display if check fails * @param atol Absolute tolerance * @return True if check passes, false otherwise */ template std::enable_if_t<(std::is_same_v, ranges::range_value_t> && std::is_integral_v> && !std::is_same_v, bf16_t>) #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 || std::is_same_v, int4_t> #endif , bool> CK_TILE_HOST check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", double = 0, double atol = 0) { if(check_size_mismatch(out, ref, msg)) return false; bool res{true}; int err_count = 0; int64_t err = 0; int64_t max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { const int64_t o = *std::next(std::begin(out), i); const int64_t r = *std::next(std::begin(ref), i); err = std::abs(o - r); if(err > atol) { max_err = err > max_err ? err : max_err; err_count++; if(err_count < ERROR_DETAIL_LIMIT) { std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; } res = false; } } if(!res) { report_error_stats(err_count, static_cast(max_err), ref.size()); } return res; } /** * @brief Check errors between FP8 ranges * * Specialized comparison for 8-bit floating point values that takes into account * the unique characteristics and limitations of FP8 arithmetic, including * rounding point distances and special handling of infinity values. * * @tparam Range Type of output range * @tparam RefRange Type of reference range * @param out Output range to check * @param ref Reference range to check against * @param msg Error message to display if check fails * @param max_rounding_point_distance Maximum allowed distance between rounding points * @param atol Absolute tolerance * @param allow_infinity_ref Whether to allow infinity in reference values * @return True if check passes, false otherwise */ template std::enable_if_t<(std::is_same_v, ranges::range_value_t> && std::is_same_v, fp8_t>), bool> CK_TILE_HOST check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", unsigned max_rounding_point_distance = 1, double atol = 1e-1, bool allow_infinity_ref = false) { if(check_size_mismatch(out, ref, msg)) return false; const auto is_infinity_error = [=](auto o, auto r) { const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; static const auto get_rounding_point_distance = [](fp8_t o, fp8_t r) -> unsigned { static const auto get_sign_bit = [](fp8_t v) -> bool { return 0x80 & bit_cast(v); }; if(get_sign_bit(o) ^ get_sign_bit(r)) { return std::numeric_limits::max(); } else { return std::abs(bit_cast(o) - bit_cast(r)); } }; bool res{true}; int err_count = 0; double err = 0; double max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { const fp8_t o_fp8 = *std::next(std::begin(out), i); const fp8_t r_fp8 = *std::next(std::begin(ref), i); const double o_fp64 = type_convert(o_fp8); const double r_fp64 = type_convert(r_fp8); err = std::abs(o_fp64 - r_fp64); if(!(less_equal{}(err, atol) || get_rounding_point_distance(o_fp8, r_fp8) <= max_rounding_point_distance) || is_infinity_error(o_fp64, r_fp64)) { max_err = err > max_err ? err : max_err; err_count++; if(err_count < ERROR_DETAIL_LIMIT) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl; } res = false; } } if(!res) { report_error_stats(err_count, max_err, ref.size()); } return res; } /** * @brief Check errors between BF8 ranges * * Specialized comparison for 8-bit brain floating point values that considers * the specific numerical properties and error characteristics of the BF8 format. * * @tparam Range Type of output range * @tparam RefRange Type of reference range * @param out Output range to check * @param ref Reference range to check against * @param msg Error message to display if check fails * @param rtol Relative tolerance * @param atol Absolute tolerance * @param allow_infinity_ref Whether to allow infinity in reference values * @return True if check passes, false otherwise */ template std::enable_if_t<(std::is_same_v, ranges::range_value_t> && std::is_same_v, bf8_t>), bool> CK_TILE_HOST check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", double rtol = 1e-3, double atol = 1e-3, bool allow_infinity_ref = false) { if(check_size_mismatch(out, ref, msg)) return false; const auto is_infinity_error = [=](auto o, auto r) { const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; bool res{true}; int err_count = 0; double err = 0; double max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) { max_err = err > max_err ? err : max_err; err_count++; if(err_count < ERROR_DETAIL_LIMIT) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; } res = false; } } if(!res) { report_error_stats(err_count, max_err, ref.size()); } return res; } /** * @brief Check errors between pk_fp4_t ranges * * Compares two ranges of pk_fp4_t without tolerance. * This specialization handles ck_tile::pk_fp4_t type. * * @tparam Range Type of output range * @tparam RefRange Type of reference range * @param out Output range to check * @param ref Reference range to check against * @param msg Error message to display if check fails * @return True if check passes, false otherwise */ template std::enable_if_t<(std::is_same_v, ranges::range_value_t> && std::is_same_v, pk_fp4_t>), bool> CK_TILE_HOST check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", double = 0, double = 0) { if(check_size_mismatch(out, ref, msg)) return false; int err_count = 0; auto update_err = [&](pk_fp4_raw_t o, pk_fp4_raw_t r, std::size_t index) { if(o != r) { std::cerr << msg << " out[" << index << "] != ref[" << index << "]: " << type_convert(pk_fp4_t{o}) << " != " << type_convert(pk_fp4_t{r}) << std::endl; ++err_count; } }; for(std::size_t i = 0; i < ref.size(); ++i) { const pk_fp4_t o = *std::next(std::begin(out), i); const pk_fp4_t r = *std::next(std::begin(ref), i); update_err(o._unpack(number<0>{}), r._unpack(number<0>{}), i * 2); update_err(o._unpack(number<1>{}), r._unpack(number<1>{}), i * 2 + 1); } if(err_count > 0) { report_error_stats(err_count, numeric::max(), ref.size()); } return err_count == 0; } /** * @brief Check errors between pk_fp6x16_t ranges * * Compares two ranges of pk_fp6x16_t without tolerance. * This specialization handles ck_tile::pk_fp6x16_t type. * * @tparam Range Type of output range * @tparam RefRange Type of reference range * @param out Output range to check * @param ref Reference range to check against * @param msg Error message to display if check fails * @return True if check passes, false otherwise */ template std::enable_if_t<(std::is_same_v, ranges::range_value_t> && std::is_same_v, pk_fp6x16_t>), bool> CK_TILE_HOST check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", double = 0, double = 0) { if(check_size_mismatch(out, ref, msg)) return false; int err_count = 0; float max_err = 0.0f; auto update_err = [&](float o, float r, std::size_t index) { if(std::fabs(o - r) > 1e-8) { std::cerr << msg << " out[" << index << "] != ref[" << index << "]: " << o << " != " << r << std::endl; ++err_count; max_err = max_err < std::fabs(o - r) ? o : max_err; } }; for(std::size_t i = 0; i < ref.size(); ++i) { const pk_fp6x16_t o = *std::next(std::begin(out), i); const pk_fp6x16_t r = *std::next(std::begin(ref), i); for(std::size_t j = 0; j < numeric_traits::PackedSize; j++) { update_err(o.unpack(j), r.unpack(j), i * numeric_traits::PackedSize + j); } } if(err_count > 0) { report_error_stats(err_count, max_err, ref.size()); } return err_count == 0; } } // namespace ck_tile