Simplify static_cast if-lands (#1828)

This commit is contained in:
Mateusz Ozga
2025-01-21 23:23:19 +01:00
committed by GitHub
parent 3c93d3c444
commit 3db77bc4f2
3 changed files with 63 additions and 93 deletions

View File

@@ -28,14 +28,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using I8 = int8_t;
using I32 = int32_t;
static_assert(std::is_same_v<ComputeDataType, F8> || std::is_same_v<ComputeDataType, F16> ||
std::is_same_v<ComputeDataType, BF16> ||
std::is_same_v<ComputeDataType, F32> || std::is_same_v<ComputeDataType, I8> ||
std::is_same_v<ComputeDataType, I32> || std::is_same_v<ComputeDataType, int>,
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0;
if constexpr(std::is_same_v<ComputeDataType, I8> || std::is_same_v<ComputeDataType, I32> ||
std::is_same_v<ComputeDataType, int>)
if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value)
{
return 0;
}
@@ -44,14 +41,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(std::is_same_v<OutDataType, F8> || std::is_same_v<OutDataType, F16> ||
std::is_same_v<OutDataType, BF16> || std::is_same_v<OutDataType, F32> ||
std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
std::is_same_v<OutDataType, int>,
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0;
if constexpr(std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
std::is_same_v<OutDataType, int>)
if constexpr(is_any_of<OutDataType, I8, I32, int>::value)
{
return 0;
}
@@ -61,14 +55,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
double midway_error = std::max(compute_error, output_error);
static_assert(std::is_same_v<AccDataType, F8> || std::is_same_v<AccDataType, F16> ||
std::is_same_v<AccDataType, BF16> || std::is_same_v<AccDataType, F32> ||
std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
std::is_same_v<AccDataType, int>,
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0;
if constexpr(std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
std::is_same_v<AccDataType, int>)
if constexpr(is_any_of<AccDataType, I8, I32, int>::value)
{
return 0;
}
@@ -89,15 +80,12 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using I8 = int8_t;
using I32 = int32_t;
static_assert(std::is_same_v<ComputeDataType, F8> || std::is_same_v<ComputeDataType, F16> ||
std::is_same_v<ComputeDataType, BF16> ||
std::is_same_v<ComputeDataType, F32> || std::is_same_v<ComputeDataType, I8> ||
std::is_same_v<ComputeDataType, I32> || std::is_same_v<ComputeDataType, int>,
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num));
double compute_error = 0;
if constexpr(std::is_same_v<ComputeDataType, I8> || std::is_same_v<ComputeDataType, I32> ||
std::is_same_v<ComputeDataType, int>)
if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value)
{
return 0;
}
@@ -106,14 +94,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(std::is_same_v<OutDataType, F8> || std::is_same_v<OutDataType, F16> ||
std::is_same_v<OutDataType, BF16> || std::is_same_v<OutDataType, F32> ||
std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
std::is_same_v<OutDataType, int>,
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0;
if constexpr(std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
std::is_same_v<OutDataType, int>)
if constexpr(is_any_of<OutDataType, I8, I32, int>::value)
{
return 0;
}
@@ -123,14 +108,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
double midway_error = std::max(compute_error, output_error);
static_assert(std::is_same_v<AccDataType, F8> || std::is_same_v<AccDataType, F16> ||
std::is_same_v<AccDataType, BF16> || std::is_same_v<AccDataType, F32> ||
std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
std::is_same_v<AccDataType, int>,
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0;
if constexpr(std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
std::is_same_v<AccDataType, int>)
if constexpr(is_any_of<AccDataType, I8, I32, int>::value)
{
return 0;
}