diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index bfc0c9c010..6bb6fc49f6 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -320,16 +320,91 @@ CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys) return lcm(x, lcm(ys...)); } -template +template struct equal { - CK_TILE_HOST_DEVICE constexpr bool operator()(T x, T y) const { return x == y; } + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs == rhs) + { + return lhs == rhs; + } }; -template +template <> +struct equal +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs == rhs) + { + return lhs == rhs; + } +}; + +__host__ __device__ equal()->equal; + +template struct less { - CK_TILE_HOST_DEVICE constexpr bool operator()(T x, T y) const { return x < y; } + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs < rhs) + { + return lhs < rhs; + } +}; + +template <> +struct less +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs < rhs) + { + return lhs < rhs; + } +}; + +__host__ __device__ less()->less; + +template +struct less_equal +{ + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs <= rhs) + { + return lhs <= rhs; + } +}; + +template <> +struct less_equal +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs <= rhs) + { + return lhs <= rhs; + } +}; + +__host__ __device__ less_equal()->less_equal; + +template <> +struct less_equal +{ + CK_TILE_HOST_DEVICE constexpr bool operator()(float lhs, float rhs) const + { + return lhs < rhs || bit_cast(lhs) == bit_cast(rhs); + } +}; + +template <> +struct less_equal +{ + CK_TILE_HOST_DEVICE constexpr bool operator()(double lhs, double rhs) const + { + return lhs < rhs || bit_cast(lhs) == bit_cast(rhs); + } }; CK_TILE_HOST_DEVICE constexpr int32_t next_power_of_two(int32_t x)