Avoid too much generalizing check_err()

This commit is contained in:
Po-Yen, Chen
2022-08-19 11:58:48 -04:00
parent 4d4a659cd6
commit c1fbabea04

View File

@@ -150,19 +150,17 @@ check_err(const std::vector<T>& out,
return res;
}
template <typename Out, typename Ref>
template <typename T>
std::enable_if_t<
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
(is_signed_integral_v<Out> || std::is_same_v<Out, ck::int4_t>)&&(
is_signed_integral_v<Ref> || std::is_same_v<Ref, ck::int4_t>)&&
(is_signed_integral_v<T> || std::is_same_v<T, ck::int4_t>)&&
#else
is_signed_integral_v<Out> && is_signed_integral_v<Ref> &&
is_signed_integral_v<T> &&
#endif
(sizeof(Out) <= sizeof(Ref) && sizeof(Ref) <= sizeof(int64_t)) &&
!std::is_same_v<Out, bhalf_t>,
sizeof(T) <= sizeof(int64_t),
bool>
check_err(const std::vector<Out>& out,
const std::vector<Ref>& ref,
check_err(const std::vector<T>& out,
const std::vector<T>& ref,
const std::string& msg = "Error: Incorrect results!",
double = 0,
double atol = 0)
@@ -180,12 +178,8 @@ check_err(const std::vector<Out>& out,
int64_t max_err = std::numeric_limits<int64_t>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
constexpr bool should_downcast_ref =
(sizeof(Out) < sizeof(Ref) || !std::is_same_v<Out, Ref>);
int64_t o = out[i];
/// TODO: clamp value if necessary
int64_t r = static_cast<std::conditional_t<should_downcast_ref, Out, Ref>>(ref[i]);
int64_t r = ref[i];
err = std::abs(o - r);
if(err > atol)