From ad45cf8613211a11308dcae5637c829e59710668 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Tue, 9 Apr 2024 07:41:30 +0000 Subject: [PATCH] Support heterogeneous argument for binary function types --- include/ck_tile/core/numeric/math.hpp | 78 +++++++++++++++++++++------ 1 file changed, 63 insertions(+), 15 deletions(-) diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 0c67a640af..868fff05b7 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -19,27 +19,75 @@ struct scales CK_TILE_HOST_DEVICE constexpr T operator()(T a) const { return s * a; } }; -template +template struct plus { - CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a + b; } -}; - -template -struct minus -{ - CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a - b; } -}; - -struct multiplies -{ - template - CK_TILE_HOST_DEVICE constexpr auto operator()(const A& a, const B& b) const + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs + rhs) { - return a * b; + return lhs + rhs; } }; +template <> +struct plus +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs + rhs) + { + return lhs + rhs; + } +}; + +__host__ __device__ plus()->plus; + +template +struct minus +{ + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs - rhs) + { + return lhs - rhs; + } +}; + +template <> +struct minus +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs - rhs) + { + return lhs - rhs; + } +}; + +__host__ __device__ minus()->minus; + +template +struct multiplies +{ + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs * rhs) + { + return lhs * rhs; + } +}; + +template <> +struct multiplies +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs * rhs) + { + return lhs * rhs; + } +}; + +__host__ __device__ multiplies()->multiplies; + template struct maximize {