unary fixes

This commit is contained in:
Bartlomiej Kocot
2025-06-13 13:00:40 +00:00
parent 2d4c5129ce
commit aeb0cb2760

View File

@@ -730,6 +730,15 @@ struct UnaryAbs
{
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
};
template <typename Y, typename X>
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
y = ck::type_convert<bhalf_t>(ck::math::abs(x));
};
};
struct UnarySqrt
@@ -829,6 +838,9 @@ struct Relu
y = x > 0 ? x : 0;
}
template <typename Y, typename X>
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
{
@@ -836,6 +848,13 @@ struct Relu
float y_f32 = x_f32 > 0 ? x_f32 : 0;
y = type_convert<bhalf_t>(y_f32);
}
template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
float y_f32 = x > 0 ? x : 0;
y = type_convert<bhalf_t>(y_f32);
};
};
// Fast GeLU
@@ -988,6 +1007,16 @@ struct Sigmoid
constexpr T one = type_convert<T>(1);
y = one / (one + math::exp(-x));
};
template <typename Y, typename X>
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
constexpr float one = 1.f;
y = type_convert<bhalf_t>(one / (one + math::exp(-x)));
};
};
struct Silu
@@ -1015,6 +1044,15 @@ struct TanH
y = math::tanh(x);
};
template <typename Y, typename X>
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
y = type_convert<bhalf_t>(math::tanh(x));
};
};
struct ACos
@@ -1274,6 +1312,13 @@ struct Swish
y = type_convert<Y>(x / (1.f + math::exp(bx)));
};
template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
float bx = -beta_ * x;
y = type_convert<bhalf_t>(x / (1.f + math::exp(bx)));
};
const float beta_;
};
@@ -1292,6 +1337,16 @@ struct SoftRelu
constexpr T one = type_convert<T>(1);
y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha;
}
template <typename Y, typename X>
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
constexpr float one = 1.f;
y = type_convert<bhalf_t>(math::log(one + math::exp(x * alpha_)) / alpha_);
};
const float alpha_;
};
@@ -1313,6 +1368,17 @@ struct Power
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = math::pow(shifted_scaled_x, casted_gamma);
}
template <typename Y, typename X>
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
const float shifted_scaled_x = alpha_ + beta_ * x;
y = type_convert<bhalf_t>(math::pow(shifted_scaled_x, gamma_));
};
const float alpha_;
const float beta_;
const float gamma_;
@@ -1333,6 +1399,16 @@ struct ClippedRelu
T casted_beta = type_convert<T>(beta_);
y = math::min(casted_beta, math::max(casted_alpha, x));
}
template <typename Y, typename X>
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
y = type_convert<bhalf_t>(math::min(beta_, math::max(alpha_, x)));
};
const float alpha_;
const float beta_;
};
@@ -1351,6 +1427,16 @@ struct LeakyRelu
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
template <typename Y, typename X>
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
y = type_convert<bhalf_t>(x >= 0 ? x : x * alpha_);
};
const float alpha_;
};
@@ -1368,6 +1454,16 @@ struct Elu
T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * math::expm1(x);
}
template <typename Y, typename X>
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
y = type_convert<bhalf_t>(x > 0 ? x : alpha_ * math::expm1(x));
};
const float alpha_;
};
@@ -1386,6 +1482,16 @@ struct Logistic
constexpr T one = type_convert<T>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
template <typename Y, typename X>
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
constexpr float one = 1.f;
y = type_convert<bhalf_t>(alpha_ / (one + ck::math::exp(-x) * alpha_));
};
const float alpha_;
};