Code Refactor for check_err.hpp (#2284)

* refactor & add documentation

* removed return datatype from doxygen comments

* Update include/ck_tile/host/check_err.hpp

Co-authored-by: John Afaganis <john.afaganis@amd.com>

* Update include/ck_tile/host/check_err.hpp

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Update include/ck_tile/host/check_err.hpp

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Update include/ck_tile/host/check_err.hpp

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Update include/ck_tile/host/check_err.hpp

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

---------

Co-authored-by: John Afaganis <john.afaganis@amd.com>
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
This commit is contained in:
Aviral Goel
2025-06-08 16:41:27 -04:00
committed by GitHub
parent aece3c6700
commit 5a0bd157db

View File

@@ -18,16 +18,36 @@
namespace ck_tile {
/** @brief 8-bit floating point type */
using F8 = ck_tile::fp8_t;
/** @brief 8-bit brain floating point type */
using BF8 = ck_tile::bf8_t;
/** @brief 16-bit floating point (half precision) type */
using F16 = ck_tile::half_t;
/** @brief 16-bit brain floating point type */
using BF16 = ck_tile::bf16_t;
/** @brief 32-bit floating point (single precision) type */
using F32 = float;
/** @brief 8-bit signed integer type */
using I8 = int8_t;
/** @brief 32-bit signed integer type */
using I32 = int32_t;
/**
* @brief Calculate relative error threshold for numerical comparisons
*
* Calculates the relative error threshold based on the mantissa bits and characteristics
* of the data types involved in the computation.
*
* @tparam ComputeDataType Type used for computation
* @tparam OutDataType Type used for output
* @tparam AccDataType Type used for accumulation (defaults to ComputeDataType)
* @param number_of_accumulations Number of accumulation operations performed
* @return Relative error threshold based on data type characteristics
*/
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_relative_threshold(const int number_of_accumulations = 1)
{
using F8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
@@ -72,16 +92,22 @@ double get_relative_threshold(const int number_of_accumulations = 1)
return std::max(acc_error, midway_error);
}
/**
* @brief Calculate absolute error threshold for numerical comparisons
*
* Calculates the absolute error threshold based on the maximum possible value and
* the characteristics of the data types involved in the computation.
*
* @tparam ComputeDataType Type used for computation
* @tparam OutDataType Type used for output
* @tparam AccDataType Type used for accumulation (defaults to ComputeDataType)
* @param max_possible_num Maximum possible value in the computation
* @param number_of_accumulations Number of accumulation operations performed
* @return Absolute error threshold based on data type characteristics and maximum value
*/
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
{
using F8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
@@ -128,6 +154,16 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
return std::max(acc_error, midway_error);
}
/**
* @brief Stream operator overload for vector output
*
* Provides a formatted string representation of a vector, useful for debugging and logging.
*
* @tparam T Type of vector elements
* @param os Output stream
* @param v Vector to output
* @return Reference to the output stream
*/
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
@@ -145,6 +181,66 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
return os << "]";
}
/**
* @brief Check for size mismatch between output and reference ranges
*
* Verifies that the output and reference ranges are the same size.
*
* @tparam Range Type of output range
* @tparam RefRange Type of reference range
* @param out Output range to check
* @param ref Reference range to check against
* @param msg Error message to display if sizes mismatch
* @return True if sizes mismatch, false otherwise
*/
template <typename Range, typename RefRange>
bool check_size_mismatch(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!")
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return true;
}
return false;
}
/**
* @brief Report error statistics for numerical comparisons
*
* Outputs statistics about numerical comparison errors including count and maximum error.
*
* @param err_count Number of errors found
* @param max_err Maximum error value encountered
* @param total_size Total number of elements compared
*/
void report_error_stats(int err_count, double max_err, std::size_t total_size)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(total_size) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
}
/**
* @brief Check errors between floating point ranges using the specified tolerances.
*
* Compares two ranges of floating point values within specified relative and absolute tolerances.
* This overload handles standard floating point types except half precision floating point.
*
* @tparam Range Type of output range
* @tparam RefRange Type of reference range
* @param out Output range to check
* @param ref Reference range to check against
* @param msg Error message to display if check fails
* @param rtol Relative tolerance
* @param atol Absolute tolerance
* @param allow_infinity_ref Whether to allow infinity in reference values
* @return True if check passes, false otherwise
*/
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
@@ -158,12 +254,9 @@ check_err(const Range& out,
double atol = 3e-6,
bool allow_infinity_ref = false)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
if(check_size_mismatch(out, ref, msg))
return false;
}
const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
@@ -196,15 +289,27 @@ check_err(const Range& out,
}
if(!res)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
report_error_stats(err_count, max_err, ref.size());
}
return res;
}
/**
* @brief Check errors between floating point ranges using the specified tolerances
*
* Compares two ranges of brain floating point values within specified relative and absolute
* tolerances.
*
* @tparam Range Type of output range
* @tparam RefRange Type of reference range
* @param out Output range to check
* @param ref Reference range to check against
* @param msg Error message to display if check fails
* @param rtol Relative tolerance
* @param atol Absolute tolerance
* @param allow_infinity_ref Whether to allow infinity in reference values
* @return True if check passes, false otherwise
*/
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
@@ -217,12 +322,8 @@ check_err(const Range& out,
double atol = 1e-3,
bool allow_infinity_ref = false)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
if(check_size_mismatch(out, ref, msg))
return false;
}
const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
@@ -256,15 +357,28 @@ check_err(const Range& out,
}
if(!res)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
report_error_stats(err_count, max_err, ref.size());
}
return res;
}
/**
* @brief Check errors between half precision floating point ranges
*
* Compares two ranges of half precision floating point values within specified tolerances.
* This specialization handles the specific requirements and characteristics of half precision
* floating point comparisons.
*
* @tparam Range Type of output range
* @tparam RefRange Type of reference range
* @param out Output range to check
* @param ref Reference range to check against
* @param msg Error message to display if check fails
* @param rtol Relative tolerance
* @param atol Absolute tolerance
* @param allow_infinity_ref Whether to allow infinity in reference values
* @return True if check passes, false otherwise
*/
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
@@ -277,12 +391,8 @@ check_err(const Range& out,
double atol = 1e-3,
bool allow_infinity_ref = false)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
if(check_size_mismatch(out, ref, msg))
return false;
}
const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
@@ -315,15 +425,26 @@ check_err(const Range& out,
}
if(!res)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
report_error_stats(err_count, max_err, ref.size());
}
return res;
}
/**
* @brief Check errors between integer ranges
*
* Compares two ranges of integer values with an absolute tolerance.
* This specialization handles integer types and optionally int4_t when the
* experimental bit int extension is enabled.
*
* @tparam Range Type of output range
* @tparam RefRange Type of reference range
* @param out Output range to check
* @param ref Reference range to check against
* @param msg Error message to display if check fails
* @param atol Absolute tolerance
* @return True if check passes, false otherwise
*/
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_integral_v<ranges::range_value_t<Range>> &&
@@ -339,12 +460,8 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
double = 0,
double atol = 0)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
if(check_size_mismatch(out, ref, msg))
return false;
}
bool res{true};
int err_count = 0;
@@ -370,15 +487,28 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
if(!res)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
report_error_stats(err_count, static_cast<double>(max_err), ref.size());
}
return res;
}
/**
* @brief Check errors between FP8 ranges
*
* Specialized comparison for 8-bit floating point values that takes into account
* the unique characteristics and limitations of FP8 arithmetic, including
* rounding point distances and special handling of infinity values.
*
* @tparam Range Type of output range
* @tparam RefRange Type of reference range
* @param out Output range to check
* @param ref Reference range to check against
* @param msg Error message to display if check fails
* @param max_rounding_point_distance Maximum allowed distance between rounding points
* @param atol Absolute tolerance
* @param allow_infinity_ref Whether to allow infinity in reference values
* @return True if check passes, false otherwise
*/
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, fp8_t>),
@@ -390,12 +520,8 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
double atol = 1e-1,
bool allow_infinity_ref = false)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
if(check_size_mismatch(out, ref, msg))
return false;
}
const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
@@ -447,15 +573,27 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
if(!res)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
report_error_stats(err_count, max_err, ref.size());
}
return res;
}
/**
* @brief Check errors between BF8 ranges
*
* Specialized comparison for 8-bit brain floating point values that considers
* the specific numerical properties and error characteristics of the BF8 format.
*
* @tparam Range Type of output range
* @tparam RefRange Type of reference range
* @param out Output range to check
* @param ref Reference range to check against
* @param msg Error message to display if check fails
* @param rtol Relative tolerance
* @param atol Absolute tolerance
* @param allow_infinity_ref Whether to allow infinity in reference values
* @return True if check passes, false otherwise
*/
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
@@ -467,12 +605,8 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
double atol = 1e-3,
bool allow_infinity_ref = false)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
if(check_size_mismatch(out, ref, msg))
return false;
}
const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
@@ -505,11 +639,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
if(!res)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
report_error_stats(err_count, max_err, ref.size());
}
return res;
}