From aeb0cb2760f3582093d10e361f03183ff15c8ec9 Mon Sep 17 00:00:00 2001 From: Bartlomiej Kocot Date: Fri, 13 Jun 2025 13:00:40 +0000 Subject: [PATCH] unary fixes --- .../element/unary_element_wise_operation.hpp | 106 ++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 13ea9aac34..8f829496da 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -730,6 +730,15 @@ struct UnaryAbs { y = ck::type_convert(ck::math::abs(ck::type_convert(x))); }; + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = ck::type_convert(ck::math::abs(x)); + }; }; struct UnarySqrt @@ -829,6 +838,9 @@ struct Relu y = x > 0 ? x : 0; } + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + template <> __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { @@ -836,6 +848,13 @@ struct Relu float y_f32 = x_f32 > 0 ? x_f32 : 0; y = type_convert(y_f32); } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + float y_f32 = x > 0 ? x : 0; + y = type_convert(y_f32); + }; }; // Fast GeLU @@ -988,6 +1007,16 @@ struct Sigmoid constexpr T one = type_convert(1); y = one / (one + math::exp(-x)); }; + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + constexpr float one = 1.f; + y = type_convert(one / (one + math::exp(-x))); + }; }; struct Silu @@ -1015,6 +1044,15 @@ struct TanH y = math::tanh(x); }; + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = type_convert(math::tanh(x)); + }; }; struct ACos @@ -1274,6 +1312,13 @@ struct Swish y = type_convert(x / (1.f + math::exp(bx))); }; + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + float bx = -beta_ * x; + y = type_convert(x / (1.f + math::exp(bx))); + }; + const float beta_; }; @@ -1292,6 +1337,16 @@ struct SoftRelu constexpr T one = type_convert(1); y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha; } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + constexpr float one = 1.f; + y = type_convert(math::log(one + math::exp(x * alpha_)) / alpha_); + }; const float alpha_; }; @@ -1313,6 +1368,17 @@ struct Power T shifted_scaled_x = casted_alpha + casted_beta * x; y = math::pow(shifted_scaled_x, casted_gamma); } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + const float shifted_scaled_x = alpha_ + beta_ * x; + y = type_convert(math::pow(shifted_scaled_x, gamma_)); + }; + const float alpha_; const float beta_; const float gamma_; @@ -1333,6 +1399,16 @@ struct ClippedRelu T casted_beta = type_convert(beta_); y = math::min(casted_beta, math::max(casted_alpha, x)); } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = type_convert(math::min(beta_, math::max(alpha_, x))); + }; + const float alpha_; const float beta_; }; @@ -1351,6 +1427,16 @@ struct LeakyRelu T casted_alpha = type_convert(alpha_); y = x >= 0 ? x : x * casted_alpha; } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = type_convert(x >= 0 ? x : x * alpha_); + }; + const float alpha_; }; @@ -1368,6 +1454,16 @@ struct Elu T casted_alpha = type_convert(alpha_); y = x > 0 ? x : casted_alpha * math::expm1(x); } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = type_convert(x > 0 ? x : alpha_ * math::expm1(x)); + }; + const float alpha_; }; @@ -1386,6 +1482,16 @@ struct Logistic constexpr T one = type_convert(1); y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + constexpr float one = 1.f; + y = type_convert(alpha_ / (one + ck::math::exp(-x) * alpha_)); + }; const float alpha_; };