diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 6bb6fc49f6..41ecdd2b75 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -343,6 +343,24 @@ struct equal __host__ __device__ equal()->equal; +template <> +struct equal +{ + CK_TILE_HOST_DEVICE constexpr bool operator()(float lhs, float rhs) const + { + return bit_cast(lhs) == bit_cast(rhs); + } +}; + +template <> +struct equal +{ + CK_TILE_HOST_DEVICE constexpr bool operator()(double lhs, double rhs) const + { + return bit_cast(lhs) == bit_cast(rhs); + } +}; + template struct less {