GTestSuite: computediff improvements (#264)

* GTestSuite: computediff improvements

- Using lazy evaluation to only create error strings when the comparison failed
- Check if increments are zero, then only check first element
This commit is contained in:
Vlachopoulou, Eleni
2025-11-17 08:30:20 +00:00
committed by GitHub
parent 50f3520c33
commit a8daea04ea

View File

@@ -105,25 +105,25 @@ struct ComparisonHelper{
};
// Generic comparison of f.p. numbers that doesn't check for NaN's and Infs:
template<typename T>
template<typename T, typename ErrorMessageFunc>
testing::AssertionResult NumericalComparisonFPOnly(const char* blis_sol_char,
const char* ref_sol_char,
const char* comp_helper_char,
const T blis_sol,
const T ref_sol,
const ComparisonHelper comp_helper,
const std::string error_message)
ErrorMessageFunc error_message_func)
{
if (comp_helper.binary_comparison)
{
if (blis_sol == ref_sol) return testing::AssertionSuccess();
return testing::AssertionFailure() << error_message;
return testing::AssertionFailure() << error_message_func();
}
else {
double error = testinghelpers::getError(blis_sol,ref_sol);
if (error <= comp_helper.threshold) return testing::AssertionSuccess();
using RT = typename testinghelpers::type_info<T>::real_type;
return testing::AssertionFailure() << error_message
return testing::AssertionFailure() << error_message_func()
<< ", thresh = " << comp_helper.threshold
<< ", error = " << error
<< " (" << error/std::numeric_limits<RT>::epsilon() << " * eps)";
@@ -131,39 +131,39 @@ testing::AssertionResult NumericalComparisonFPOnly(const char* blis_sol_char,
}
// NaN/Inf comparison for real numbers
template<typename T>
template<typename T, typename ErrorMessageFunc>
testing::AssertionResult NumericalComparisonRealNaNInf(const char* blis_sol_char,
const char* ref_sol_char,
const char* comp_helper_char,
const T blis_sol,
const T ref_sol,
const ComparisonHelper comp_helper,
const std::string error_message)
ErrorMessageFunc error_message_func)
{
// if both are NaN return SUCCESS
if ((std::isnan(ref_sol)) && (std::isnan(blis_sol)))
return testing::AssertionSuccess();
// if only one of them is NaN, return FAILURE
else if ((std::isnan(ref_sol)) || (std::isnan(blis_sol)))
return testing::AssertionFailure() << error_message;
return testing::AssertionFailure() << error_message_func();
// if both are inf check the sign
else if ((std::isinf(ref_sol)) && (std::isinf(blis_sol)))
{
// check the sign of infs
if( ref_sol == blis_sol ) return testing::AssertionSuccess();
// both are infs but have different signs, return FAILURE.
else return testing::AssertionFailure() << error_message;
else return testing::AssertionFailure() << error_message_func();
}
// if only one of them is Inf
else if ((std::isinf(ref_sol)) || (std::isinf(blis_sol)))
return testing::AssertionFailure() << error_message;
return testing::AssertionFailure() << error_message_func();
// If neither reference nor BLIS sol is NaN/Inf do simple comparison, based on relative or absolute error.
else return NumericalComparisonFPOnly<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, error_message);
else return NumericalComparisonFPOnly<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, error_message_func);
}
// Comparison for complex numbers in the case of NaNs.
// Will be re-used for comparison of real and imaginary components.
template<typename T, typename RT = typename testinghelpers::type_info<T>::real_type>
template<typename T, typename RT = typename testinghelpers::type_info<T>::real_type, typename ErrorMessageFunc>
testing::AssertionResult NumericalComparisonNaN(const char* blis_sol_char,
const char* ref_sol_char,
const char* comp_helper_char,
@@ -171,7 +171,7 @@ testing::AssertionResult NumericalComparisonNaN(const char* blis_sol_char,
const T ref_sol,
const ComparisonHelper comp_helper,
const ComplexPart complex_part,
const std::string error_message)
ErrorMessageFunc error_message_func)
{
// Assign values to intermediate variables as if we are comparing the real part.
RT ref_sol_1 = ref_sol.real, ref_sol_2 = ref_sol.imag, blis_sol_1 = blis_sol.real, blis_sol_2 = blis_sol.imag;
@@ -186,14 +186,14 @@ testing::AssertionResult NumericalComparisonNaN(const char* blis_sol_char,
// Check if the both parts are NaNs.
if ((std::isnan(ref_sol_1)) && (std::isnan(blis_sol_1)))
// Check second part for equality based on real NaN/Inf comparison.
return NumericalComparisonRealNaNInf<RT>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol_2, ref_sol_2, comp_helper, error_message);
return NumericalComparisonRealNaNInf<RT>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol_2, ref_sol_2, comp_helper, error_message_func);
// if only one of the first parts is NaN
return testing::AssertionFailure() << error_message;
return testing::AssertionFailure() << error_message_func();
}
// Comparison for complex numbers in the case of Infs.
// Will be re-used for comparison of real and imaginary components.
template<typename T, typename RT = typename testinghelpers::type_info<T>::real_type>
template<typename T, typename RT = typename testinghelpers::type_info<T>::real_type, typename ErrorMessageFunc>
testing::AssertionResult NumericalComparisonInf(const char* blis_sol_char,
const char* ref_sol_char,
const char* comp_helper_char,
@@ -201,7 +201,7 @@ testing::AssertionResult NumericalComparisonInf(const char* blis_sol_char,
const T ref_sol,
const ComparisonHelper comp_helper,
const ComplexPart complex_part,
const std::string error_message)
ErrorMessageFunc error_message_func)
{
// Assign values to intermediate variables as if we are comparing the real part.
RT ref_sol_1 = ref_sol.real, ref_sol_2 = ref_sol.imag, blis_sol_1 = blis_sol.real, blis_sol_2 = blis_sol.imag;
@@ -219,12 +219,12 @@ testing::AssertionResult NumericalComparisonInf(const char* blis_sol_char,
// check the sign of infs
if( ref_sol_1 == blis_sol_1 )
// Check second part for equality based on real NaN/Inf comparison.
return NumericalComparisonRealNaNInf<RT>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol_2, ref_sol_2, comp_helper, error_message);
return NumericalComparisonRealNaNInf<RT>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol_2, ref_sol_2, comp_helper, error_message_func);
// if both are infs but have different signs, return FAILURE.
else return testing::AssertionFailure() << error_message;
else return testing::AssertionFailure() << error_message_func();
}
// if only one of them is Inf
return testing::AssertionFailure() << error_message;
return testing::AssertionFailure() << error_message_func();
}
// Comparisons that take into account the presence of NaNs and Infs, printing variable name:
@@ -238,26 +238,30 @@ testing::AssertionResult NumericalComparison(const char* var_name_char,
const T ref_sol,
const ComparisonHelper comp_helper)
{
// Base error message used for scalar values
std::string error_message = var_name_char;
error_message += " = " + var_name + ", ";
error_message += blis_sol_char;
error_message += " = " + testinghelpers::to_string(blis_sol) + ", ";
error_message += ref_sol_char;
error_message += " = " + testinghelpers::to_string(ref_sol);
// If we are comparing a vector, update error message to include the current index
if(comp_helper.object_type == VECTOR)
error_message += ", i = " + std::to_string(comp_helper.i);
// If we are comparing a matrix, update error message to include the current indices
else if(comp_helper.object_type == MATRIX)
error_message += ", i = " + std::to_string(comp_helper.i) + ", j = " + std::to_string(comp_helper.j);
// Lazy string construction - only create error message when actually needed
auto create_error_message = [&]() -> std::string {
// Base error message used for scalar values
std::string error_message = var_name_char;
error_message += " = " + var_name + ", ";
error_message += blis_sol_char;
error_message += " = " + testinghelpers::to_string(blis_sol) + ", ";
error_message += ref_sol_char;
error_message += " = " + testinghelpers::to_string(ref_sol);
// If we are comparing a vector, update error message to include the current index
if(comp_helper.object_type == VECTOR)
error_message += ", i = " + std::to_string(comp_helper.i);
// If we are comparing a matrix, update error message to include the current indices
else if(comp_helper.object_type == MATRIX)
error_message += ", i = " + std::to_string(comp_helper.i) + ", j = " + std::to_string(comp_helper.j);
return error_message;
};
// Check if NaN/Inf comparison is necessary and if so, proceed.
// Otherwise, call numerical comparison only, without considering NaNs and Infs.
if (comp_helper.nan_inf_check)
{
if constexpr (testinghelpers::type_info<T>::is_real)
return NumericalComparisonRealNaNInf<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, error_message);
return NumericalComparisonRealNaNInf<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, create_error_message);
// If it's complex we need to check real and imaginary parts.
else
{
@@ -265,35 +269,35 @@ testing::AssertionResult NumericalComparison(const char* var_name_char,
if ((std::isnan(ref_sol.real)) || (std::isnan(blis_sol.real)))
{
ComplexPart complex_part = REAL;
return NumericalComparisonNaN<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, error_message);
return NumericalComparisonNaN<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, create_error_message);
}
// Check if any of the imag parts is NaN, and if so, call into NaN comparator.
else if ((std::isnan(ref_sol.imag)) || (std::isnan(blis_sol.imag)))
{
ComplexPart complex_part = IMAGINARY;
return NumericalComparisonNaN<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, error_message);
return NumericalComparisonNaN<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, create_error_message);
}
// Check if any of the real parts is Inf, and if so, call into Inf comparator.
else if ((std::isinf(ref_sol.real)) || (std::isinf(blis_sol.real)))
{
ComplexPart complex_part = REAL;
return NumericalComparisonInf<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, error_message);
return NumericalComparisonInf<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, create_error_message);
}
// Check if any of the imag parts is NaN or Inf, and if so, call into Inf comparator.
else if ((std::isinf(ref_sol.imag)) || (std::isinf(blis_sol.imag)))
{
ComplexPart complex_part = IMAGINARY;
return NumericalComparisonInf<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, error_message);
return NumericalComparisonInf<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, create_error_message);
}
// If neither reference nor BLIS sol is NaN or Inf, or if NaN/Inf checks are not necessary,
// do simple comparison, based on relative or absolute error.
else
return NumericalComparisonFPOnly<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, error_message);
return NumericalComparisonFPOnly<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, create_error_message);
}
}
// If NaN/Inf checks are not necessary, do simple comparison, based on relative or absolute error.
else
return NumericalComparisonFPOnly<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, error_message);
return NumericalComparisonFPOnly<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, create_error_message);
}
/**
@@ -331,6 +335,10 @@ void computediff( std::string var_name, gtint_t n, T *blis_sol, T *ref_sol, gtin
comp_helper.nan_inf_check = nan_inf_check;
comp_helper.binary_comparison = true;
// If increment is zero, we just have one element to compare.
if (abs_inc == 0)
n = 1;
// In case inc is negative in a call to BLIS APIs, we just access it from the end to the beginning,
// so practically nothing changes. Access from beginning to end to optimize memory operations.
for (gtint_t i = 0; i < n; i++)
@@ -359,6 +367,9 @@ void computediff( std::string var_name, gtint_t n, T *blis_sol, T *ref_sol, gtin
ComparisonHelper comp_helper(VECTOR, thresh);
comp_helper.nan_inf_check = nan_inf_check;
// If increment is zero, we just have one element to compare.
if (abs_inc == 0)
n = 1;
// In case inc is negative in a call to BLIS APIs, we just access it from the end to the beginning,
// so practically nothing changes. Access from beginning to end to optimize memory operations.
for (gtint_t i = 0; i < n; i++)
@@ -512,16 +523,21 @@ testing::AssertionResult EqualityComparison(const char* var_name_char,
const T ref_sol,
const ComparisonHelper comp_helper)
{
// Base error message used for scalar values
std::string error_message = var_name_char;
error_message += " = " + var_name + ", ";
error_message += blis_sol_char;
error_message += " = " + testinghelpers::to_string(blis_sol) + ", ";
error_message += ref_sol_char;
error_message += " = " + testinghelpers::to_string(ref_sol);
if (blis_sol == ref_sol) return testing::AssertionSuccess();
return testing::AssertionFailure() << error_message;
// Lazy string construction - only create when actually needed
auto create_error_message = [&]() -> std::string {
std::string error_message = var_name_char;
error_message += " = " + var_name + ", ";
error_message += blis_sol_char;
error_message += " = " + testinghelpers::to_string(blis_sol) + ", ";
error_message += ref_sol_char;
error_message += " = " + testinghelpers::to_string(ref_sol);
return error_message;
};
return testing::AssertionFailure() << create_error_message();
}
/**
@@ -542,5 +558,4 @@ inline void computediff<char>( std::string var_name, char blis_sol, char ref_sol
{
ComparisonHelper comp_helper(SCALAR);
ASSERT_PRED_FORMAT4(EqualityComparison<char>, var_name, blis_sol, ref_sol, comp_helper);
}
}