mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
GTestSuite: computediff improvements (#264)
* GTestSuite: computediff improvements - Using lazy evaluation to only create error strings when the comparison failed - Check if increments are zero, then only check first element
This commit is contained in:
committed by
GitHub
parent
50f3520c33
commit
a8daea04ea
@@ -105,25 +105,25 @@ struct ComparisonHelper{
|
||||
};
|
||||
|
||||
// Generic comparison of f.p. numbers that doesn't check for NaN's and Infs:
|
||||
template<typename T>
|
||||
template<typename T, typename ErrorMessageFunc>
|
||||
testing::AssertionResult NumericalComparisonFPOnly(const char* blis_sol_char,
|
||||
const char* ref_sol_char,
|
||||
const char* comp_helper_char,
|
||||
const T blis_sol,
|
||||
const T ref_sol,
|
||||
const ComparisonHelper comp_helper,
|
||||
const std::string error_message)
|
||||
ErrorMessageFunc error_message_func)
|
||||
{
|
||||
if (comp_helper.binary_comparison)
|
||||
{
|
||||
if (blis_sol == ref_sol) return testing::AssertionSuccess();
|
||||
return testing::AssertionFailure() << error_message;
|
||||
return testing::AssertionFailure() << error_message_func();
|
||||
}
|
||||
else {
|
||||
double error = testinghelpers::getError(blis_sol,ref_sol);
|
||||
if (error <= comp_helper.threshold) return testing::AssertionSuccess();
|
||||
using RT = typename testinghelpers::type_info<T>::real_type;
|
||||
return testing::AssertionFailure() << error_message
|
||||
return testing::AssertionFailure() << error_message_func()
|
||||
<< ", thresh = " << comp_helper.threshold
|
||||
<< ", error = " << error
|
||||
<< " (" << error/std::numeric_limits<RT>::epsilon() << " * eps)";
|
||||
@@ -131,39 +131,39 @@ testing::AssertionResult NumericalComparisonFPOnly(const char* blis_sol_char,
|
||||
}
|
||||
|
||||
// NaN/Inf comparison for real numbers
|
||||
template<typename T>
|
||||
template<typename T, typename ErrorMessageFunc>
|
||||
testing::AssertionResult NumericalComparisonRealNaNInf(const char* blis_sol_char,
|
||||
const char* ref_sol_char,
|
||||
const char* comp_helper_char,
|
||||
const T blis_sol,
|
||||
const T ref_sol,
|
||||
const ComparisonHelper comp_helper,
|
||||
const std::string error_message)
|
||||
ErrorMessageFunc error_message_func)
|
||||
{
|
||||
// if both are NaN return SUCCESS
|
||||
if ((std::isnan(ref_sol)) && (std::isnan(blis_sol)))
|
||||
return testing::AssertionSuccess();
|
||||
// if only one of them is NaN, return FAILURE
|
||||
else if ((std::isnan(ref_sol)) || (std::isnan(blis_sol)))
|
||||
return testing::AssertionFailure() << error_message;
|
||||
return testing::AssertionFailure() << error_message_func();
|
||||
// if both are inf check the sign
|
||||
else if ((std::isinf(ref_sol)) && (std::isinf(blis_sol)))
|
||||
{
|
||||
// check the sign of infs
|
||||
if( ref_sol == blis_sol ) return testing::AssertionSuccess();
|
||||
// both are infs but have different signs, return FAILURE.
|
||||
else return testing::AssertionFailure() << error_message;
|
||||
else return testing::AssertionFailure() << error_message_func();
|
||||
}
|
||||
// if only one of them is Inf
|
||||
else if ((std::isinf(ref_sol)) || (std::isinf(blis_sol)))
|
||||
return testing::AssertionFailure() << error_message;
|
||||
return testing::AssertionFailure() << error_message_func();
|
||||
// If neither reference nor BLIS sol is NaN/Inf do simple comparison, based on relative or absolute error.
|
||||
else return NumericalComparisonFPOnly<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, error_message);
|
||||
else return NumericalComparisonFPOnly<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, error_message_func);
|
||||
}
|
||||
|
||||
// Comparison for complex numbers in the case of NaNs.
|
||||
// Will be re-used for comparison of real and imaginary components.
|
||||
template<typename T, typename RT = typename testinghelpers::type_info<T>::real_type>
|
||||
template<typename T, typename RT = typename testinghelpers::type_info<T>::real_type, typename ErrorMessageFunc>
|
||||
testing::AssertionResult NumericalComparisonNaN(const char* blis_sol_char,
|
||||
const char* ref_sol_char,
|
||||
const char* comp_helper_char,
|
||||
@@ -171,7 +171,7 @@ testing::AssertionResult NumericalComparisonNaN(const char* blis_sol_char,
|
||||
const T ref_sol,
|
||||
const ComparisonHelper comp_helper,
|
||||
const ComplexPart complex_part,
|
||||
const std::string error_message)
|
||||
ErrorMessageFunc error_message_func)
|
||||
{
|
||||
// Assign values to intermediate variables as if we are comparing the real part.
|
||||
RT ref_sol_1 = ref_sol.real, ref_sol_2 = ref_sol.imag, blis_sol_1 = blis_sol.real, blis_sol_2 = blis_sol.imag;
|
||||
@@ -186,14 +186,14 @@ testing::AssertionResult NumericalComparisonNaN(const char* blis_sol_char,
|
||||
// Check if the both parts are NaNs.
|
||||
if ((std::isnan(ref_sol_1)) && (std::isnan(blis_sol_1)))
|
||||
// Check second part for equality based on real NaN/Inf comparison.
|
||||
return NumericalComparisonRealNaNInf<RT>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol_2, ref_sol_2, comp_helper, error_message);
|
||||
return NumericalComparisonRealNaNInf<RT>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol_2, ref_sol_2, comp_helper, error_message_func);
|
||||
// if only one of the first parts is NaN
|
||||
return testing::AssertionFailure() << error_message;
|
||||
return testing::AssertionFailure() << error_message_func();
|
||||
}
|
||||
|
||||
// Comparison for complex numbers in the case of Infs.
|
||||
// Will be re-used for comparison of real and imaginary components.
|
||||
template<typename T, typename RT = typename testinghelpers::type_info<T>::real_type>
|
||||
template<typename T, typename RT = typename testinghelpers::type_info<T>::real_type, typename ErrorMessageFunc>
|
||||
testing::AssertionResult NumericalComparisonInf(const char* blis_sol_char,
|
||||
const char* ref_sol_char,
|
||||
const char* comp_helper_char,
|
||||
@@ -201,7 +201,7 @@ testing::AssertionResult NumericalComparisonInf(const char* blis_sol_char,
|
||||
const T ref_sol,
|
||||
const ComparisonHelper comp_helper,
|
||||
const ComplexPart complex_part,
|
||||
const std::string error_message)
|
||||
ErrorMessageFunc error_message_func)
|
||||
{
|
||||
// Assign values to intermediate variables as if we are comparing the real part.
|
||||
RT ref_sol_1 = ref_sol.real, ref_sol_2 = ref_sol.imag, blis_sol_1 = blis_sol.real, blis_sol_2 = blis_sol.imag;
|
||||
@@ -219,12 +219,12 @@ testing::AssertionResult NumericalComparisonInf(const char* blis_sol_char,
|
||||
// check the sign of infs
|
||||
if( ref_sol_1 == blis_sol_1 )
|
||||
// Check second part for equality based on real NaN/Inf comparison.
|
||||
return NumericalComparisonRealNaNInf<RT>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol_2, ref_sol_2, comp_helper, error_message);
|
||||
return NumericalComparisonRealNaNInf<RT>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol_2, ref_sol_2, comp_helper, error_message_func);
|
||||
// if both are infs but have different signs, return FAILURE.
|
||||
else return testing::AssertionFailure() << error_message;
|
||||
else return testing::AssertionFailure() << error_message_func();
|
||||
}
|
||||
// if only one of them is Inf
|
||||
return testing::AssertionFailure() << error_message;
|
||||
return testing::AssertionFailure() << error_message_func();
|
||||
}
|
||||
|
||||
// Comparisons that take into account the presence of NaNs and Infs, printing variable name:
|
||||
@@ -238,26 +238,30 @@ testing::AssertionResult NumericalComparison(const char* var_name_char,
|
||||
const T ref_sol,
|
||||
const ComparisonHelper comp_helper)
|
||||
{
|
||||
// Base error message used for scalar values
|
||||
std::string error_message = var_name_char;
|
||||
error_message += " = " + var_name + ", ";
|
||||
error_message += blis_sol_char;
|
||||
error_message += " = " + testinghelpers::to_string(blis_sol) + ", ";
|
||||
error_message += ref_sol_char;
|
||||
error_message += " = " + testinghelpers::to_string(ref_sol);
|
||||
// If we are comparing a vector, update error message to include the current index
|
||||
if(comp_helper.object_type == VECTOR)
|
||||
error_message += ", i = " + std::to_string(comp_helper.i);
|
||||
// If we are comparing a matrix, update error message to include the current indices
|
||||
else if(comp_helper.object_type == MATRIX)
|
||||
error_message += ", i = " + std::to_string(comp_helper.i) + ", j = " + std::to_string(comp_helper.j);
|
||||
// Lazy string construction - only create error message when actually needed
|
||||
auto create_error_message = [&]() -> std::string {
|
||||
// Base error message used for scalar values
|
||||
std::string error_message = var_name_char;
|
||||
error_message += " = " + var_name + ", ";
|
||||
error_message += blis_sol_char;
|
||||
error_message += " = " + testinghelpers::to_string(blis_sol) + ", ";
|
||||
error_message += ref_sol_char;
|
||||
error_message += " = " + testinghelpers::to_string(ref_sol);
|
||||
// If we are comparing a vector, update error message to include the current index
|
||||
if(comp_helper.object_type == VECTOR)
|
||||
error_message += ", i = " + std::to_string(comp_helper.i);
|
||||
// If we are comparing a matrix, update error message to include the current indices
|
||||
else if(comp_helper.object_type == MATRIX)
|
||||
error_message += ", i = " + std::to_string(comp_helper.i) + ", j = " + std::to_string(comp_helper.j);
|
||||
return error_message;
|
||||
};
|
||||
|
||||
// Check if NaN/Inf comparison is necessary and if so, proceed.
|
||||
// Otherwise, call numerical comparison only, without considering NaNs and Infs.
|
||||
if (comp_helper.nan_inf_check)
|
||||
{
|
||||
if constexpr (testinghelpers::type_info<T>::is_real)
|
||||
return NumericalComparisonRealNaNInf<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, error_message);
|
||||
return NumericalComparisonRealNaNInf<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, create_error_message);
|
||||
// If it's complex we need to check real and imaginary parts.
|
||||
else
|
||||
{
|
||||
@@ -265,35 +269,35 @@ testing::AssertionResult NumericalComparison(const char* var_name_char,
|
||||
if ((std::isnan(ref_sol.real)) || (std::isnan(blis_sol.real)))
|
||||
{
|
||||
ComplexPart complex_part = REAL;
|
||||
return NumericalComparisonNaN<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, error_message);
|
||||
return NumericalComparisonNaN<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, create_error_message);
|
||||
}
|
||||
// Check if any of the imag parts is NaN, and if so, call into NaN comparator.
|
||||
else if ((std::isnan(ref_sol.imag)) || (std::isnan(blis_sol.imag)))
|
||||
{
|
||||
ComplexPart complex_part = IMAGINARY;
|
||||
return NumericalComparisonNaN<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, error_message);
|
||||
return NumericalComparisonNaN<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, create_error_message);
|
||||
}
|
||||
// Check if any of the real parts is Inf, and if so, call into Inf comparator.
|
||||
else if ((std::isinf(ref_sol.real)) || (std::isinf(blis_sol.real)))
|
||||
{
|
||||
ComplexPart complex_part = REAL;
|
||||
return NumericalComparisonInf<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, error_message);
|
||||
return NumericalComparisonInf<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, create_error_message);
|
||||
}
|
||||
// Check if any of the imag parts is NaN or Inf, and if so, call into Inf comparator.
|
||||
else if ((std::isinf(ref_sol.imag)) || (std::isinf(blis_sol.imag)))
|
||||
{
|
||||
ComplexPart complex_part = IMAGINARY;
|
||||
return NumericalComparisonInf<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, error_message);
|
||||
return NumericalComparisonInf<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, complex_part, create_error_message);
|
||||
}
|
||||
// If neither reference nor BLIS sol is NaN or Inf, or if NaN/Inf checks are not necessary,
|
||||
// do simple comparison, based on relative or absolute error.
|
||||
else
|
||||
return NumericalComparisonFPOnly<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, error_message);
|
||||
return NumericalComparisonFPOnly<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, create_error_message);
|
||||
}
|
||||
}
|
||||
// If NaN/Inf checks are not necessary, do simple comparison, based on relative or absolute error.
|
||||
else
|
||||
return NumericalComparisonFPOnly<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, error_message);
|
||||
return NumericalComparisonFPOnly<T>(blis_sol_char, ref_sol_char, comp_helper_char, blis_sol, ref_sol, comp_helper, create_error_message);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -331,6 +335,10 @@ void computediff( std::string var_name, gtint_t n, T *blis_sol, T *ref_sol, gtin
|
||||
comp_helper.nan_inf_check = nan_inf_check;
|
||||
comp_helper.binary_comparison = true;
|
||||
|
||||
// If increment is zero, we just have one element to compare.
|
||||
if (abs_inc == 0)
|
||||
n = 1;
|
||||
|
||||
// In case inc is negative in a call to BLIS APIs, we just access it from the end to the beginning,
|
||||
// so practically nothing changes. Access from beginning to end to optimize memory operations.
|
||||
for (gtint_t i = 0; i < n; i++)
|
||||
@@ -359,6 +367,9 @@ void computediff( std::string var_name, gtint_t n, T *blis_sol, T *ref_sol, gtin
|
||||
ComparisonHelper comp_helper(VECTOR, thresh);
|
||||
comp_helper.nan_inf_check = nan_inf_check;
|
||||
|
||||
// If increment is zero, we just have one element to compare.
|
||||
if (abs_inc == 0)
|
||||
n = 1;
|
||||
// In case inc is negative in a call to BLIS APIs, we just access it from the end to the beginning,
|
||||
// so practically nothing changes. Access from beginning to end to optimize memory operations.
|
||||
for (gtint_t i = 0; i < n; i++)
|
||||
@@ -512,16 +523,21 @@ testing::AssertionResult EqualityComparison(const char* var_name_char,
|
||||
const T ref_sol,
|
||||
const ComparisonHelper comp_helper)
|
||||
{
|
||||
// Base error message used for scalar values
|
||||
std::string error_message = var_name_char;
|
||||
error_message += " = " + var_name + ", ";
|
||||
error_message += blis_sol_char;
|
||||
error_message += " = " + testinghelpers::to_string(blis_sol) + ", ";
|
||||
error_message += ref_sol_char;
|
||||
error_message += " = " + testinghelpers::to_string(ref_sol);
|
||||
|
||||
|
||||
if (blis_sol == ref_sol) return testing::AssertionSuccess();
|
||||
return testing::AssertionFailure() << error_message;
|
||||
|
||||
// Lazy string construction - only create when actually needed
|
||||
auto create_error_message = [&]() -> std::string {
|
||||
std::string error_message = var_name_char;
|
||||
error_message += " = " + var_name + ", ";
|
||||
error_message += blis_sol_char;
|
||||
error_message += " = " + testinghelpers::to_string(blis_sol) + ", ";
|
||||
error_message += ref_sol_char;
|
||||
error_message += " = " + testinghelpers::to_string(ref_sol);
|
||||
return error_message;
|
||||
};
|
||||
|
||||
return testing::AssertionFailure() << create_error_message();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -542,5 +558,4 @@ inline void computediff<char>( std::string var_name, char blis_sol, char ref_sol
|
||||
{
|
||||
ComparisonHelper comp_helper(SCALAR);
|
||||
ASSERT_PRED_FORMAT4(EqualityComparison<char>, var_name, blis_sol, ref_sol, comp_helper);
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user