Support heterogeneous argument for binary function types

This commit is contained in:
Po Yen Chen
2024-04-09 07:41:30 +00:00
parent db0d7c6a99
commit ad45cf8613

View File

@@ -19,27 +19,75 @@ struct scales
CK_TILE_HOST_DEVICE constexpr T operator()(T a) const { return s * a; }
};
template <typename T>
template <typename Left = void, typename Right = Left>
struct plus
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct minus
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a - b; }
};
struct multiplies
{
template <typename A, typename B>
CK_TILE_HOST_DEVICE constexpr auto operator()(const A& a, const B& b) const
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs + rhs)
{
return a * b;
return lhs + rhs;
}
};
template <>
struct plus<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs + rhs)
{
return lhs + rhs;
}
};
__host__ __device__ plus()->plus<void, void>;
template <typename Left = void, typename Right = Left>
struct minus
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs - rhs)
{
return lhs - rhs;
}
};
template <>
struct minus<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs - rhs)
{
return lhs - rhs;
}
};
__host__ __device__ minus()->minus<void, void>;
template <typename Left = void, typename Right = Left>
struct multiplies
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs * rhs)
{
return lhs * rhs;
}
};
template <>
struct multiplies<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs * rhs)
{
return lhs * rhs;
}
};
__host__ __device__ multiplies()->multiplies<void, void>;
template <typename T>
struct maximize
{