Extend less_equal<>

This commit is contained in:
Po Yen Chen
2024-04-09 13:38:40 +00:00
parent c6eac9746f
commit 3f57b3068a

View File

@@ -320,16 +320,91 @@ CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys)
return lcm(x, lcm(ys...));
}
template <typename T>
template <typename Left = void, typename Right = Left>
struct equal
{
CK_TILE_HOST_DEVICE constexpr bool operator()(T x, T y) const { return x == y; }
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs == rhs)
{
return lhs == rhs;
}
};
template <typename T>
template <>
struct equal<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs == rhs)
{
return lhs == rhs;
}
};
__host__ __device__ equal()->equal<void, void>;
template <typename Left = void, typename Right = Left>
struct less
{
CK_TILE_HOST_DEVICE constexpr bool operator()(T x, T y) const { return x < y; }
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs < rhs)
{
return lhs < rhs;
}
};
template <>
struct less<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs < rhs)
{
return lhs < rhs;
}
};
__host__ __device__ less()->less<void, void>;
template <typename Left = void, typename Right = Left>
struct less_equal
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs <= rhs)
{
return lhs <= rhs;
}
};
template <>
struct less_equal<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs <= rhs)
{
return lhs <= rhs;
}
};
__host__ __device__ less_equal()->less_equal<void, void>;
template <>
struct less_equal<float, float>
{
CK_TILE_HOST_DEVICE constexpr bool operator()(float lhs, float rhs) const
{
return lhs < rhs || bit_cast<uint32_t>(lhs) == bit_cast<uint32_t>(rhs);
}
};
template <>
struct less_equal<double, double>
{
CK_TILE_HOST_DEVICE constexpr bool operator()(double lhs, double rhs) const
{
return lhs < rhs || bit_cast<uint64_t>(lhs) == bit_cast<uint64_t>(rhs);
}
};
CK_TILE_HOST_DEVICE constexpr int32_t next_power_of_two(int32_t x)