mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
Re-use already-existing scales<> functor template
This commit is contained in:
@@ -13,12 +13,37 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T, T s>
|
||||
template <typename Scale, Scale lhs>
|
||||
struct scales_c
|
||||
{
|
||||
template <typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const -> decltype(lhs * rhs)
|
||||
{
|
||||
return lhs * rhs;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scale>
|
||||
struct scales
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a) const { return s * a; }
|
||||
static_assert(std::is_copy_constructible_v<Scale>);
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {}
|
||||
|
||||
template <typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const
|
||||
-> decltype(std::declval<const Scale&>() * rhs)
|
||||
{
|
||||
return lhs_ * rhs;
|
||||
}
|
||||
|
||||
private:
|
||||
Scale lhs_;
|
||||
};
|
||||
|
||||
template <typename Scale>
|
||||
__host__ __device__ scales(Scale)->scales<Scale>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct plus
|
||||
{
|
||||
|
||||
@@ -42,24 +42,6 @@ CK_TILE_HOST auto compose(F... f)
|
||||
return composer<F...>(f...);
|
||||
}
|
||||
|
||||
// start of unary element function
|
||||
|
||||
struct scale
|
||||
{
|
||||
CK_TILE_HOST_DEVICE scale(float factor) : factor_(factor) {}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& x) const;
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE constexpr float operator()<float>(const float& x) const
|
||||
{
|
||||
return factor_ * x;
|
||||
};
|
||||
|
||||
float factor_;
|
||||
};
|
||||
|
||||
// TODO: Overload numeric::min() and numeric::max()
|
||||
struct saturate_f8
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user