diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index f21a45938f..d45ddb4233 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -121,19 +121,6 @@ __global__ void static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); - if constexpr(is_same_v) - { - a_element_op.InitUnaryOpPtrOnDevice(); - } - if constexpr(is_same_v) - { - b_element_op.InitUnaryOpPtrOnDevice(); - } - if constexpr(is_same_v) - { - cde_element_op.InitUnaryOpPtrOnDevice(); - } - if constexpr(isMultiA || isMultiB) { AsPointer p_as_grid_grp; diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index c0b4471748..5e522fb2ea 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -247,32 +247,6 @@ struct DequantPack8 constexpr const static bool is_pack8_invocable = true; }; -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wnon-virtual-dtor" -struct UnaryOpBase -{ - public: - __host__ __device__ ~UnaryOpBase() = default; - - __host__ __device__ constexpr UnaryOpBase() = default; - __host__ __device__ constexpr UnaryOpBase(const UnaryOpBase&) = default; - __host__ __device__ constexpr UnaryOpBase(UnaryOpBase&&) = default; - __host__ __device__ UnaryOpBase& operator=(const UnaryOpBase&) = default; - __host__ __device__ UnaryOpBase& operator=(UnaryOpBase&&) = default; - - __host__ __device__ virtual inline void operator()(float& y, const float& x) const = 0; - - __host__ __device__ virtual inline void operator()(double& y, const double& x) const = 0; - - __host__ __device__ virtual inline void operator()(int32_t& y, const int32_t& x) const = 0; - - __host__ __device__ virtual inline void operator()(int8_t& y, const int8_t& x) const = 0; - - __host__ __device__ virtual inline void operator()(half_t& y, const half_t& x) const = 0; - - __host__ __device__ virtual inline void operator()(bhalf_t& y, const bhalf_t& x) const = 0; -}; - struct PassThroughPack2 { template @@ -304,27 +278,8 @@ struct PassThroughPack2 constexpr const static bool is_pack2_invocable = true; }; -struct PassThrough final : public UnaryOpBase +struct PassThrough { - __host__ __device__ constexpr PassThrough() = default; - __host__ __device__ constexpr PassThrough(const PassThrough&) = default; - __host__ __device__ constexpr PassThrough(PassThrough&&) = default; - __host__ __device__ PassThrough& operator=(const PassThrough&) = default; - __host__ __device__ PassThrough& operator=(PassThrough&&) = default; - __host__ __device__ ~PassThrough() = default; - - __host__ __device__ inline void operator()(float& y, const float& x) const final { y = x; } - - __host__ __device__ inline void operator()(double& y, const double& x) const final { y = x; } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final { y = x; } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final { y = x; } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final { y = x; } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final { y = x; } - template __host__ __device__ void operator()(Y& y, const X& x) const; @@ -334,6 +289,12 @@ struct PassThrough final : public UnaryOpBase y = x; } + template <> + __host__ __device__ void operator()(double& y, const double& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(float& y, const double& x) const { @@ -346,12 +307,36 @@ struct PassThrough final : public UnaryOpBase y = type_convert(x); } + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(half_t& y, const half_t& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(half_t& y, const float& x) const { y = type_convert(x); } + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(int32_t& y, const int32_t& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(bhalf_t& y, const float& x) const { @@ -376,6 +361,12 @@ struct PassThrough final : public UnaryOpBase y = type_convert(x); } + template <> + __host__ __device__ void operator()(int8_t& y, const int8_t& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(half_t& y, const int8_t& x) const { @@ -675,45 +666,20 @@ struct UnarySquare }; }; -struct UnaryAbs final : public UnaryOpBase +struct UnaryAbs { - __host__ __device__ constexpr UnaryAbs() = default; - __host__ __device__ constexpr UnaryAbs(const UnaryAbs&) = default; - __host__ __device__ constexpr UnaryAbs(UnaryAbs&&) = default; - __host__ __device__ UnaryAbs& operator=(const UnaryAbs&) = default; - __host__ __device__ UnaryAbs& operator=(UnaryAbs&&) = default; - __host__ __device__ ~UnaryAbs() = default; - - __host__ __device__ inline void operator()(float& y, const float& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { - y = ck::math::abs(x); - } + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); - __host__ __device__ inline void operator()(double& y, const double& x) const final - { y = ck::math::abs(x); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - y = ck::math::abs(x); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - y = ck::math::abs(x); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - y = ck::math::abs(x); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - y = ck::math::abs(x); - } + }; + template <> __host__ __device__ void operator()(f8_t& y, const f8_t& x) const { y = ck::type_convert(ck::math::abs(ck::type_convert(x))); @@ -732,41 +698,20 @@ struct UnarySqrt }; }; -struct Relu final : public UnaryOpBase +struct Relu { - __host__ __device__ constexpr Relu() = default; - __host__ __device__ constexpr Relu(const Relu&) = default; - __host__ __device__ constexpr Relu(Relu&&) = default; - __host__ __device__ Relu& operator=(const Relu&) = default; - __host__ __device__ Relu& operator=(Relu&&) = default; - __host__ __device__ ~Relu() = default; - - __host__ __device__ inline void operator()(float& y, const float& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); y = x > 0 ? x : 0; } - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - y = x > 0 ? x : 0; - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - y = x > 0 ? x : 0; - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - y = x > 0 ? x : 0; - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - y = x > 0 ? x : 0; - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { float x_f32 = ck::type_convert(x); float y_f32 = x_f32 > 0 ? x_f32 : 0; @@ -913,52 +858,18 @@ struct Gelu } }; -struct Sigmoid final : public UnaryOpBase +struct Sigmoid { - __host__ __device__ constexpr Sigmoid() = default; - __host__ __device__ constexpr Sigmoid(const Sigmoid&) = default; - __host__ __device__ constexpr Sigmoid(Sigmoid&&) = default; - __host__ __device__ Sigmoid& operator=(const Sigmoid&) = default; - __host__ __device__ Sigmoid& operator=(Sigmoid&&) = default; - __host__ __device__ ~Sigmoid() = default; - - __host__ __device__ inline void operator()(float& y, const float& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { - constexpr float one = type_convert(1); - y = one / (one + ck::math::exp(-x)); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - constexpr double one = type_convert(1); - y = one / (one + ck::math::exp(-x)); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - constexpr int32_t one = type_convert(1); - y = one / (one + ck::math::exp(-x)); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - constexpr int8_t one = type_convert(1); - y = one / (one + ck::math::exp(-x)); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - constexpr half_t one = type_convert(1); - y = one / (one + ck::math::exp(-x)); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - constexpr float one = type_convert(1); - float x_f32 = ck::type_convert(x); - float y_f32 = one / (one + ck::math::exp(x_f32)); - y = ck::type_convert(y_f32); - } + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + constexpr T one = type_convert(1); + y = one / (one + ck::math::exp(-x)); + }; }; struct Silu @@ -974,44 +885,18 @@ struct Silu }; }; -struct TanH final : public UnaryOpBase +struct TanH { - __host__ __device__ constexpr TanH() = default; - __host__ __device__ constexpr TanH(const TanH&) = default; - __host__ __device__ constexpr TanH(TanH&&) = default; - __host__ __device__ TanH& operator=(const TanH&) = default; - __host__ __device__ TanH& operator=(TanH&&) = default; - __host__ __device__ ~TanH() = default; - - __host__ __device__ inline void operator()(float& y, const float& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { - y = ck::math::tanh(x); - } + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); - __host__ __device__ inline void operator()(double& y, const double& x) const final - { y = ck::math::tanh(x); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - y = ck::math::tanh(x); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - y = ck::math::tanh(x); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - y = ck::math::tanh(x); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - y = ck::math::tanh(x); - } + }; }; struct ACos @@ -1252,418 +1137,138 @@ struct Rcp }; }; -struct Swish final : public UnaryOpBase +struct Swish { - __host__ __device__ constexpr Swish(const Swish&) = default; - __host__ __device__ constexpr Swish(Swish&&) = default; - __host__ __device__ ~Swish() = default; - - __host__ __device__ Swish(float beta = 1.0f) : beta_(beta) {} - - __host__ __device__ float get_beta() const { return beta_; } - - const float beta_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } + Swish(float beta = 1.0f) : beta_(beta) {} template __host__ __device__ void operator()(Y& y, const X& x) const { static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value, "Data type is not supported by this operation!"); static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value, "Data type is not supported by this operation!"); float bx = -beta_ * type_convert(x); y = type_convert(x / (1.f + ck::math::exp(bx))); - } + }; + + const float beta_; }; -struct SoftRelu final : public UnaryOpBase +struct SoftRelu { - __host__ __device__ constexpr SoftRelu(const SoftRelu&) = default; - __host__ __device__ constexpr SoftRelu(SoftRelu&&) = default; - __host__ __device__ ~SoftRelu() = default; - - __host__ __device__ SoftRelu(float alpha = 1.0f) : alpha_(alpha) {} - - __host__ __device__ float get_alpha() const { return alpha_; } + SoftRelu(float alpha = 1.f) : alpha_(alpha){}; + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + constexpr T one = type_convert(1); + y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; + } const float alpha_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - constexpr float one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - constexpr double one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - constexpr int32_t one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - constexpr int8_t one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - constexpr half_t one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - bhalf_t casted_alpha = type_convert(alpha_); - constexpr bhalf_t one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } }; -struct Power final : public UnaryOpBase +struct Power { - __host__ __device__ constexpr Power(const Power&) = default; - __host__ __device__ constexpr Power(Power&&) = default; - __host__ __device__ ~Power() = default; + Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) + : alpha_(alpha), beta_(beta), gamma_(gamma){}; - __host__ __device__ Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) - : alpha_(alpha), beta_(beta), gamma_(gamma) + template + __host__ __device__ void operator()(T& y, const T& x) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + T casted_gamma = type_convert(gamma_); + T shifted_scaled_x = casted_alpha + casted_beta * x; + y = ck::math::pow(shifted_scaled_x, casted_gamma); } - - __host__ __device__ float get_alpha() const { return alpha_; } - - __host__ __device__ float get_beta() const { return beta_; } - - __host__ __device__ float get_gamma() const { return gamma_; } - const float alpha_; const float beta_; const float gamma_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - float casted_beta = type_convert(beta_); - float casted_gamma = type_convert(gamma_); - - float shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - double casted_beta = type_convert(beta_); - double casted_gamma = type_convert(gamma_); - - double shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - int32_t casted_beta = type_convert(beta_); - int32_t casted_gamma = type_convert(gamma_); - - int32_t shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - int8_t casted_beta = type_convert(beta_); - int8_t casted_gamma = type_convert(gamma_); - - int8_t shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - half_t casted_beta = type_convert(beta_); - half_t casted_gamma = type_convert(gamma_); - - half_t shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - bhalf_t casted_alpha = type_convert(alpha_); - bhalf_t casted_beta = type_convert(beta_); - bhalf_t casted_gamma = type_convert(gamma_); - - bhalf_t shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } }; -struct ClippedRelu final : public UnaryOpBase +struct ClippedRelu { - __host__ __device__ constexpr ClippedRelu(const ClippedRelu&) = default; - __host__ __device__ constexpr ClippedRelu(ClippedRelu&&) = default; - __host__ __device__ ~ClippedRelu() = default; + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; - __host__ __device__ ClippedRelu(float alpha = 0.f, float beta = 1.f) - : alpha_(alpha), beta_(beta) + template + __host__ __device__ void operator()(T& y, const T& x) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); } - - __host__ __device__ float get_alpha() const { return alpha_; } - - __host__ __device__ float get_beta() const { return beta_; } - const float alpha_; const float beta_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - float casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - double casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - int32_t casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - int8_t casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - half_t casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - bhalf_t casted_alpha = type_convert(alpha_); - bhalf_t casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } }; -struct LeakyRelu final : public UnaryOpBase +struct LeakyRelu { - __host__ __device__ constexpr LeakyRelu(const LeakyRelu&) = default; - __host__ __device__ constexpr LeakyRelu(LeakyRelu&&) = default; - __host__ __device__ ~LeakyRelu() = default; - - __host__ __device__ LeakyRelu(float alpha = 0.f) : alpha_(alpha) {} - - __host__ __device__ float get_alpha() const { return alpha_; } + LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x >= 0 ? x : x * casted_alpha; + } const float alpha_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - y = x >= 0 ? x : x * casted_alpha; - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - y = x >= 0 ? x : x * casted_alpha; - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - y = x >= 0 ? x : x * casted_alpha; - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - y = x >= 0 ? x : x * casted_alpha; - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - y = x >= 0 ? x : x * casted_alpha; - } - - __host__ __device__ inline void operator()([[maybe_unused]] bhalf_t& y, - [[maybe_unused]] const bhalf_t& x) const final - { - } }; -struct Elu final : public UnaryOpBase +struct Elu { - __host__ __device__ constexpr Elu(const Elu&) = default; - __host__ __device__ constexpr Elu(Elu&&) = default; - __host__ __device__ ~Elu() = default; - - __host__ __device__ Elu(float alpha = 1.f) : alpha_(alpha) {} - - __host__ __device__ float get_alpha() const { return alpha_; } + Elu(float alpha = 1.f) : alpha_(alpha){}; + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x > 0 ? x : casted_alpha * ck::math::expm1(x); + } const float alpha_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - bhalf_t casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } }; -struct Logistic final : public UnaryOpBase +struct Logistic { - __host__ __device__ constexpr Logistic(const Logistic&) = default; - __host__ __device__ constexpr Logistic(Logistic&&) = default; - __host__ __device__ ~Logistic() = default; - - __host__ __device__ Logistic(float alpha = 1.0f) : alpha_(alpha) {} - - __host__ __device__ float get_alpha() const { return alpha_; } + Logistic(float alpha = 1.f) : alpha_(alpha){}; + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + constexpr T one = type_convert(1); + y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); + } const float alpha_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - constexpr float one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - constexpr double one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - constexpr int32_t one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - constexpr int8_t one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - constexpr half_t one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - bhalf_t casted_alpha = type_convert(alpha_); - constexpr bhalf_t one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } }; struct ConvInvscale @@ -1728,7 +1333,7 @@ struct ConvScaleRelu __host__ __device__ void operator()(f8_t& e, const float& c) const { float x; - Relu{}(x, c * scale_in_ * scale_wei_); + Relu{}.template operator()(x, c * scale_in_ * scale_wei_); e = type_convert(x * scale_out_); }; @@ -1809,225 +1414,138 @@ struct FastNumericArrayConverter struct DynamicUnaryOp { - - DynamicUnaryOp& operator=(const DynamicUnaryOp& other) - { - if(this != &other) - { - unary_op_ptr_ = other.unary_op_ptr_; - unary_op_type_ = other.unary_op_type_; - } - return *this; - } - __host__ __device__ DynamicUnaryOp() = delete; __host__ __device__ DynamicUnaryOp(const Swish& swish) + : unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_} { - unary_op_type_ = UnaryOpType::Swish; - beta = swish.get_beta(); } __host__ __device__ DynamicUnaryOp(const Swish&& swish) + : unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_} { - unary_op_type_ = UnaryOpType::Swish; - beta = swish.get_beta(); } - __host__ __device__ DynamicUnaryOp(const Sigmoid&) { unary_op_type_ = UnaryOpType::Sigmoid; } + __host__ __device__ DynamicUnaryOp(const Sigmoid&) : unary_op_type_(UnaryOpType::Sigmoid) {} - __host__ __device__ DynamicUnaryOp(const Sigmoid&&) { unary_op_type_ = UnaryOpType::Sigmoid; } + __host__ __device__ DynamicUnaryOp(const Sigmoid&&) : unary_op_type_(UnaryOpType::Sigmoid) {} __host__ __device__ DynamicUnaryOp(const PassThrough&) + : unary_op_type_(UnaryOpType::PassThrough) { - unary_op_type_ = UnaryOpType::PassThrough; } __host__ __device__ DynamicUnaryOp(const PassThrough&&) + : unary_op_type_(UnaryOpType::PassThrough) { - unary_op_type_ = UnaryOpType::PassThrough; } __host__ __device__ DynamicUnaryOp(const Logistic& logistic) + : unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_} { - unary_op_type_ = UnaryOpType::Logistic; - alpha = logistic.get_alpha(); } __host__ __device__ DynamicUnaryOp(const Logistic&& logistic) + : unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_} { - unary_op_type_ = UnaryOpType::Logistic; - alpha = logistic.get_alpha(); } - __host__ __device__ DynamicUnaryOp(const TanH&) { unary_op_type_ = UnaryOpType::TanH; } + __host__ __device__ DynamicUnaryOp(const TanH&) : unary_op_type_(UnaryOpType::TanH) {} - __host__ __device__ DynamicUnaryOp(const TanH&&) { unary_op_type_ = UnaryOpType::TanH; } + __host__ __device__ DynamicUnaryOp(const TanH&&) : unary_op_type_(UnaryOpType::TanH) {} - __host__ __device__ DynamicUnaryOp(const Relu&) { unary_op_type_ = UnaryOpType::Relu; } + __host__ __device__ DynamicUnaryOp(const Relu&) : unary_op_type_(UnaryOpType::Relu) {} - __host__ __device__ DynamicUnaryOp(const Relu&&) { unary_op_type_ = UnaryOpType::Relu; } + __host__ __device__ DynamicUnaryOp(const Relu&&) : unary_op_type_(UnaryOpType::Relu) {} __host__ __device__ DynamicUnaryOp(const SoftRelu& softrelu) + : unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_} { - unary_op_type_ = UnaryOpType::SoftRelu; - alpha = softrelu.get_alpha(); } __host__ __device__ DynamicUnaryOp(const SoftRelu&& softrelu) + : unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_} { - unary_op_type_ = UnaryOpType::SoftRelu; - alpha = softrelu.get_alpha(); } - __host__ __device__ DynamicUnaryOp(const UnaryAbs&) { unary_op_type_ = UnaryOpType::UnaryAbs; } + __host__ __device__ DynamicUnaryOp(const UnaryAbs&) : unary_op_type_(UnaryOpType::UnaryAbs) {} - __host__ __device__ DynamicUnaryOp(const UnaryAbs&&) { unary_op_type_ = UnaryOpType::UnaryAbs; } + __host__ __device__ DynamicUnaryOp(const UnaryAbs&&) : unary_op_type_(UnaryOpType::UnaryAbs) {} __host__ __device__ DynamicUnaryOp(const Power& pow) + : unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_) { - unary_op_type_ = UnaryOpType::Power; - alpha = pow.get_alpha(); - beta = pow.get_beta(); - gamma = pow.get_gamma(); } __host__ __device__ DynamicUnaryOp(const Power&& pow) + : unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_) { - unary_op_type_ = UnaryOpType::Power; - alpha = pow.get_alpha(); - beta = pow.get_beta(); - gamma = pow.get_gamma(); } __host__ __device__ DynamicUnaryOp(const ClippedRelu& clippedrelu) + : unary_op_type_(UnaryOpType::ClippedRelu), + clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_} { - unary_op_type_ = UnaryOpType::ClippedRelu; - alpha = clippedrelu.get_alpha(); - beta = clippedrelu.get_beta(); } __host__ __device__ DynamicUnaryOp(const ClippedRelu&& clippedrelu) + : unary_op_type_(UnaryOpType::ClippedRelu), + clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_} { - unary_op_type_ = UnaryOpType::ClippedRelu; - alpha = clippedrelu.get_alpha(); - beta = clippedrelu.get_beta(); } __host__ __device__ DynamicUnaryOp(const LeakyRelu& leakyrelu) + : unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_} { - unary_op_type_ = UnaryOpType::LeakyRelu; - alpha = leakyrelu.get_alpha(); } __host__ __device__ DynamicUnaryOp(const LeakyRelu&& leakyrelu) + : unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_} { - unary_op_type_ = UnaryOpType::LeakyRelu; - alpha = leakyrelu.get_alpha(); } __host__ __device__ DynamicUnaryOp(const Elu& elu) + : unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_} { - unary_op_type_ = UnaryOpType::Elu; - alpha = elu.get_alpha(); } __host__ __device__ DynamicUnaryOp(const Elu&& elu) - { - unary_op_type_ = UnaryOpType::Elu; - alpha = elu.get_alpha(); - } - - __host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op) - : unary_op_type_(dynamic_op.unary_op_type_), - unary_op_ptr_(dynamic_op.unary_op_ptr_), - alpha(dynamic_op.alpha), - beta(dynamic_op.beta), - gamma(dynamic_op.gamma) + : unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_} { } - __host__ __device__ ~DynamicUnaryOp() + __host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op) = default; + + __host__ __device__ ~DynamicUnaryOp() {} + + template + __host__ __device__ void operator()(Y& y, const X& x) const { switch(unary_op_type_) { - case(UnaryOpType::Swish): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::Sigmoid): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::PassThrough): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::Logistic): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::TanH): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::Relu): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::SoftRelu): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::UnaryAbs): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::Power): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::ClippedRelu): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::LeakyRelu): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::Elu): delete static_cast(unary_op_ptr_); break; - + case(UnaryOpType::Swish): swish_(y, x); break; + case(UnaryOpType::Sigmoid): sigmoid_(y, x); break; + case(UnaryOpType::PassThrough): pass_through_(y, x); break; + case(UnaryOpType::Logistic): logistic_(y, x); break; + case(UnaryOpType::TanH): tanh_(y, x); break; + case(UnaryOpType::Relu): relu_(y, x); break; + case(UnaryOpType::SoftRelu): soft_relu_(y, x); break; + case(UnaryOpType::UnaryAbs): unary_abs_(y, x); break; + case(UnaryOpType::Power): power_(y, x); break; + case(UnaryOpType::ClippedRelu): clipped_relu_(y, x); break; + case(UnaryOpType::LeakyRelu): leaky_relu_(y, x); break; + case(UnaryOpType::Elu): elu_(y, x); break; default: break; } } - __device__ void InitUnaryOpPtrOnDevice() + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { - switch(unary_op_type_) - { - case(UnaryOpType::Swish): unary_op_ptr_ = new Swish(beta); break; - case(UnaryOpType::Sigmoid): unary_op_ptr_ = new Sigmoid; break; - case(UnaryOpType::PassThrough): unary_op_ptr_ = new PassThrough; break; - case(UnaryOpType::Logistic): unary_op_ptr_ = new Logistic(alpha); break; - case(UnaryOpType::TanH): unary_op_ptr_ = new TanH; break; - case(UnaryOpType::Relu): unary_op_ptr_ = new Relu; break; - case(UnaryOpType::SoftRelu): unary_op_ptr_ = new SoftRelu(alpha); break; - case(UnaryOpType::UnaryAbs): unary_op_ptr_ = new UnaryAbs; break; - case(UnaryOpType::Power): unary_op_ptr_ = new Power(alpha, beta, gamma); break; - case(UnaryOpType::ClippedRelu): unary_op_ptr_ = new ClippedRelu(alpha, beta); break; - case(UnaryOpType::LeakyRelu): unary_op_ptr_ = new LeakyRelu(alpha); break; - case(UnaryOpType::Elu): unary_op_ptr_ = new Elu(alpha); break; - - default: unary_op_ptr_ = nullptr; break; - } - } - - template - __device__ void operator()(Y& y, const X& x) const - { - isSupported(); - unary_op_ptr_->operator()(y, x); - } - - template - __host__ void operator()(Y& y, const X& x) const - { - isSupported(); - switch(unary_op_type_) - { - case(UnaryOpType::Swish): Swish{}.operator()(y, x); break; - case(UnaryOpType::Sigmoid): Sigmoid{}.operator()(y, x); break; - case(UnaryOpType::PassThrough): PassThrough{}.operator()(y, x); break; - case(UnaryOpType::Logistic): Logistic{}.operator()(y, x); break; - case(UnaryOpType::TanH): TanH{}.operator()(y, x); break; - case(UnaryOpType::Relu): Relu{}.operator()(y, x); break; - case(UnaryOpType::SoftRelu): SoftRelu{}.operator()(y, x); break; - case(UnaryOpType::UnaryAbs): UnaryAbs{}.operator()(y, x); break; - case(UnaryOpType::Power): Power{}.operator()(y, x); break; - case(UnaryOpType::ClippedRelu): ClippedRelu{}.operator()(y, x); break; - case(UnaryOpType::LeakyRelu): LeakyRelu{}.operator()(y, x); break; - case(UnaryOpType::Elu): Elu{}.operator()(y, x); break; - default: break; - } - } - - template - __device__ __host__ constexpr void isSupported() const - { - - static_assert(std::is_same::value, "X and Y must be of the same type"); - - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Data type is not supported by this operation!"); + float y_float; + float x_float = type_convert(x); + this->operator()(y_float, x_float); + y = type_convert(y_float); } private: @@ -2049,12 +1567,20 @@ struct DynamicUnaryOp public: UnaryOpType unary_op_type_; - UnaryOpBase* unary_op_ptr_ = nullptr; - float alpha; - float beta; - float gamma; + + Swish swish_; + Sigmoid sigmoid_; + PassThrough pass_through_; + Logistic logistic_; + TanH tanh_; + Relu relu_; + SoftRelu soft_relu_; + UnaryAbs unary_abs_; + Power power_; + ClippedRelu clipped_relu_; + LeakyRelu leaky_relu_; + Elu elu_; }; -#pragma clang diagnostic pop } // namespace element_wise } // namespace tensor_operation