mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
unary fixes
This commit is contained in:
@@ -730,6 +730,15 @@ struct UnaryAbs
|
||||
{
|
||||
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
y = ck::type_convert<bhalf_t>(ck::math::abs(x));
|
||||
};
|
||||
};
|
||||
|
||||
struct UnarySqrt
|
||||
@@ -829,6 +838,9 @@ struct Relu
|
||||
y = x > 0 ? x : 0;
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
|
||||
{
|
||||
@@ -836,6 +848,13 @@ struct Relu
|
||||
float y_f32 = x_f32 > 0 ? x_f32 : 0;
|
||||
y = type_convert<bhalf_t>(y_f32);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
float y_f32 = x > 0 ? x : 0;
|
||||
y = type_convert<bhalf_t>(y_f32);
|
||||
};
|
||||
};
|
||||
|
||||
// Fast GeLU
|
||||
@@ -988,6 +1007,16 @@ struct Sigmoid
|
||||
constexpr T one = type_convert<T>(1);
|
||||
y = one / (one + math::exp(-x));
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<bhalf_t>(one / (one + math::exp(-x)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Silu
|
||||
@@ -1015,6 +1044,15 @@ struct TanH
|
||||
|
||||
y = math::tanh(x);
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<bhalf_t>(math::tanh(x));
|
||||
};
|
||||
};
|
||||
|
||||
struct ACos
|
||||
@@ -1274,6 +1312,13 @@ struct Swish
|
||||
y = type_convert<Y>(x / (1.f + math::exp(bx)));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
float bx = -beta_ * x;
|
||||
y = type_convert<bhalf_t>(x / (1.f + math::exp(bx)));
|
||||
};
|
||||
|
||||
const float beta_;
|
||||
};
|
||||
|
||||
@@ -1292,6 +1337,16 @@ struct SoftRelu
|
||||
constexpr T one = type_convert<T>(1);
|
||||
y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha;
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<bhalf_t>(math::log(one + math::exp(x * alpha_)) / alpha_);
|
||||
};
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
@@ -1313,6 +1368,17 @@ struct Power
|
||||
T shifted_scaled_x = casted_alpha + casted_beta * x;
|
||||
y = math::pow(shifted_scaled_x, casted_gamma);
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
const float shifted_scaled_x = alpha_ + beta_ * x;
|
||||
y = type_convert<bhalf_t>(math::pow(shifted_scaled_x, gamma_));
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
const float beta_;
|
||||
const float gamma_;
|
||||
@@ -1333,6 +1399,16 @@ struct ClippedRelu
|
||||
T casted_beta = type_convert<T>(beta_);
|
||||
y = math::min(casted_beta, math::max(casted_alpha, x));
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<bhalf_t>(math::min(beta_, math::max(alpha_, x)));
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
const float beta_;
|
||||
};
|
||||
@@ -1351,6 +1427,16 @@ struct LeakyRelu
|
||||
T casted_alpha = type_convert<T>(alpha_);
|
||||
y = x >= 0 ? x : x * casted_alpha;
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<bhalf_t>(x >= 0 ? x : x * alpha_);
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
@@ -1368,6 +1454,16 @@ struct Elu
|
||||
T casted_alpha = type_convert<T>(alpha_);
|
||||
y = x > 0 ? x : casted_alpha * math::expm1(x);
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<bhalf_t>(x > 0 ? x : alpha_ * math::expm1(x));
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
@@ -1386,6 +1482,16 @@ struct Logistic
|
||||
constexpr T one = type_convert<T>(1);
|
||||
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<bhalf_t>(alpha_ / (one + ck::math::exp(-x) * alpha_));
|
||||
};
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user