mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Remove virtual destructors from unary ops (#1610)
* Remove virtual destructors from unary ops * Fixes * Fixes * clang format fixes
This commit is contained in:
@@ -13,15 +13,17 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace element_wise {
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
|
||||
struct UnaryOpBase
|
||||
{
|
||||
public:
|
||||
__host__ __device__ virtual ~UnaryOpBase() = default;
|
||||
__host__ __device__ ~UnaryOpBase() = default;
|
||||
|
||||
__host__ __device__ UnaryOpBase() = default;
|
||||
__host__ __device__ UnaryOpBase(const UnaryOpBase&) = default;
|
||||
__host__ __device__ constexpr UnaryOpBase() = default;
|
||||
__host__ __device__ constexpr UnaryOpBase(const UnaryOpBase&) = default;
|
||||
__host__ __device__ constexpr UnaryOpBase(UnaryOpBase&&) = default;
|
||||
__host__ __device__ UnaryOpBase& operator=(const UnaryOpBase&) = default;
|
||||
__host__ __device__ UnaryOpBase(UnaryOpBase&&) = default;
|
||||
__host__ __device__ UnaryOpBase& operator=(UnaryOpBase&&) = default;
|
||||
|
||||
__host__ __device__ virtual inline void operator()(float& y, const float& x) const = 0;
|
||||
@@ -50,8 +52,14 @@ struct PassThroughPack2
|
||||
constexpr const static bool is_pack2_invocable = true;
|
||||
};
|
||||
|
||||
struct PassThrough : public UnaryOpBase
|
||||
struct PassThrough final : public UnaryOpBase
|
||||
{
|
||||
__host__ __device__ constexpr PassThrough() = default;
|
||||
__host__ __device__ constexpr PassThrough(const PassThrough&) = default;
|
||||
__host__ __device__ constexpr PassThrough(PassThrough&&) = default;
|
||||
__host__ __device__ PassThrough& operator=(const PassThrough&) = default;
|
||||
__host__ __device__ PassThrough& operator=(PassThrough&&) = default;
|
||||
__host__ __device__ ~PassThrough() = default;
|
||||
|
||||
__host__ __device__ inline void operator()(float& y, const float& x) const final { y = x; }
|
||||
|
||||
@@ -409,8 +417,15 @@ struct UnarySquare
|
||||
};
|
||||
};
|
||||
|
||||
struct UnaryAbs : public UnaryOpBase
|
||||
struct UnaryAbs final : public UnaryOpBase
|
||||
{
|
||||
__host__ __device__ constexpr UnaryAbs() = default;
|
||||
__host__ __device__ constexpr UnaryAbs(const UnaryAbs&) = default;
|
||||
__host__ __device__ constexpr UnaryAbs(UnaryAbs&&) = default;
|
||||
__host__ __device__ UnaryAbs& operator=(const UnaryAbs&) = default;
|
||||
__host__ __device__ UnaryAbs& operator=(UnaryAbs&&) = default;
|
||||
__host__ __device__ ~UnaryAbs() = default;
|
||||
|
||||
__host__ __device__ inline void operator()(float& y, const float& x) const final
|
||||
{
|
||||
y = ck::math::abs(x);
|
||||
@@ -459,8 +474,15 @@ struct UnarySqrt
|
||||
};
|
||||
};
|
||||
|
||||
struct Relu : public UnaryOpBase
|
||||
struct Relu final : public UnaryOpBase
|
||||
{
|
||||
__host__ __device__ constexpr Relu() = default;
|
||||
__host__ __device__ constexpr Relu(const Relu&) = default;
|
||||
__host__ __device__ constexpr Relu(Relu&&) = default;
|
||||
__host__ __device__ Relu& operator=(const Relu&) = default;
|
||||
__host__ __device__ Relu& operator=(Relu&&) = default;
|
||||
__host__ __device__ ~Relu() = default;
|
||||
|
||||
__host__ __device__ inline void operator()(float& y, const float& x) const final
|
||||
{
|
||||
y = x > 0 ? x : 0;
|
||||
@@ -633,8 +655,14 @@ struct Gelu
|
||||
}
|
||||
};
|
||||
|
||||
struct Sigmoid : public UnaryOpBase
|
||||
struct Sigmoid final : public UnaryOpBase
|
||||
{
|
||||
__host__ __device__ constexpr Sigmoid() = default;
|
||||
__host__ __device__ constexpr Sigmoid(const Sigmoid&) = default;
|
||||
__host__ __device__ constexpr Sigmoid(Sigmoid&&) = default;
|
||||
__host__ __device__ Sigmoid& operator=(const Sigmoid&) = default;
|
||||
__host__ __device__ Sigmoid& operator=(Sigmoid&&) = default;
|
||||
__host__ __device__ ~Sigmoid() = default;
|
||||
|
||||
__host__ __device__ inline void operator()(float& y, const float& x) const final
|
||||
{
|
||||
@@ -688,8 +716,15 @@ struct Silu
|
||||
};
|
||||
};
|
||||
|
||||
struct TanH : public UnaryOpBase
|
||||
struct TanH final : public UnaryOpBase
|
||||
{
|
||||
__host__ __device__ constexpr TanH() = default;
|
||||
__host__ __device__ constexpr TanH(const TanH&) = default;
|
||||
__host__ __device__ constexpr TanH(TanH&&) = default;
|
||||
__host__ __device__ TanH& operator=(const TanH&) = default;
|
||||
__host__ __device__ TanH& operator=(TanH&&) = default;
|
||||
__host__ __device__ ~TanH() = default;
|
||||
|
||||
__host__ __device__ inline void operator()(float& y, const float& x) const final
|
||||
{
|
||||
y = ck::math::tanh(x);
|
||||
@@ -959,8 +994,12 @@ struct Rcp
|
||||
};
|
||||
};
|
||||
|
||||
struct Swish : public UnaryOpBase
|
||||
struct Swish final : public UnaryOpBase
|
||||
{
|
||||
__host__ __device__ constexpr Swish(const Swish&) = default;
|
||||
__host__ __device__ constexpr Swish(Swish&&) = default;
|
||||
__host__ __device__ ~Swish() = default;
|
||||
|
||||
__host__ __device__ Swish(float beta = 1.0f) : beta_(beta) {}
|
||||
|
||||
__host__ __device__ float get_beta() const { return beta_; }
|
||||
@@ -1019,8 +1058,12 @@ struct Swish : public UnaryOpBase
|
||||
}
|
||||
};
|
||||
|
||||
struct SoftRelu : public UnaryOpBase
|
||||
struct SoftRelu final : public UnaryOpBase
|
||||
{
|
||||
__host__ __device__ constexpr SoftRelu(const SoftRelu&) = default;
|
||||
__host__ __device__ constexpr SoftRelu(SoftRelu&&) = default;
|
||||
__host__ __device__ ~SoftRelu() = default;
|
||||
|
||||
__host__ __device__ SoftRelu(float alpha = 1.0f) : alpha_(alpha) {}
|
||||
|
||||
__host__ __device__ float get_alpha() const { return alpha_; }
|
||||
@@ -1070,8 +1113,12 @@ struct SoftRelu : public UnaryOpBase
|
||||
}
|
||||
};
|
||||
|
||||
struct Power : public UnaryOpBase
|
||||
struct Power final : public UnaryOpBase
|
||||
{
|
||||
__host__ __device__ constexpr Power(const Power&) = default;
|
||||
__host__ __device__ constexpr Power(Power&&) = default;
|
||||
__host__ __device__ ~Power() = default;
|
||||
|
||||
__host__ __device__ Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
|
||||
: alpha_(alpha), beta_(beta), gamma_(gamma)
|
||||
{
|
||||
@@ -1148,8 +1195,12 @@ struct Power : public UnaryOpBase
|
||||
}
|
||||
};
|
||||
|
||||
struct ClippedRelu : public UnaryOpBase
|
||||
struct ClippedRelu final : public UnaryOpBase
|
||||
{
|
||||
__host__ __device__ constexpr ClippedRelu(const ClippedRelu&) = default;
|
||||
__host__ __device__ constexpr ClippedRelu(ClippedRelu&&) = default;
|
||||
__host__ __device__ ~ClippedRelu() = default;
|
||||
|
||||
__host__ __device__ ClippedRelu(float alpha = 0.f, float beta = 1.f)
|
||||
: alpha_(alpha), beta_(beta)
|
||||
{
|
||||
@@ -1205,8 +1256,11 @@ struct ClippedRelu : public UnaryOpBase
|
||||
}
|
||||
};
|
||||
|
||||
struct LeakyRelu : public UnaryOpBase
|
||||
struct LeakyRelu final : public UnaryOpBase
|
||||
{
|
||||
__host__ __device__ constexpr LeakyRelu(const LeakyRelu&) = default;
|
||||
__host__ __device__ constexpr LeakyRelu(LeakyRelu&&) = default;
|
||||
__host__ __device__ ~LeakyRelu() = default;
|
||||
|
||||
__host__ __device__ LeakyRelu(float alpha = 0.f) : alpha_(alpha) {}
|
||||
|
||||
@@ -1250,8 +1304,11 @@ struct LeakyRelu : public UnaryOpBase
|
||||
}
|
||||
};
|
||||
|
||||
struct Elu : public UnaryOpBase
|
||||
struct Elu final : public UnaryOpBase
|
||||
{
|
||||
__host__ __device__ constexpr Elu(const Elu&) = default;
|
||||
__host__ __device__ constexpr Elu(Elu&&) = default;
|
||||
__host__ __device__ ~Elu() = default;
|
||||
|
||||
__host__ __device__ Elu(float alpha = 1.f) : alpha_(alpha) {}
|
||||
|
||||
@@ -1296,8 +1353,11 @@ struct Elu : public UnaryOpBase
|
||||
}
|
||||
};
|
||||
|
||||
struct Logistic : public UnaryOpBase
|
||||
struct Logistic final : public UnaryOpBase
|
||||
{
|
||||
__host__ __device__ constexpr Logistic(const Logistic&) = default;
|
||||
__host__ __device__ constexpr Logistic(Logistic&&) = default;
|
||||
__host__ __device__ ~Logistic() = default;
|
||||
|
||||
__host__ __device__ Logistic(float alpha = 1.0f) : alpha_(alpha) {}
|
||||
|
||||
@@ -1631,8 +1691,23 @@ struct DynamicUnaryOp
|
||||
|
||||
__host__ __device__ ~DynamicUnaryOp()
|
||||
{
|
||||
if(unary_op_ptr_)
|
||||
delete unary_op_ptr_;
|
||||
switch(unary_op_type_)
|
||||
{
|
||||
case(UnaryOpType::Swish): delete static_cast<Swish*>(unary_op_ptr_); break;
|
||||
case(UnaryOpType::Sigmoid): delete static_cast<Sigmoid*>(unary_op_ptr_); break;
|
||||
case(UnaryOpType::PassThrough): delete static_cast<PassThrough*>(unary_op_ptr_); break;
|
||||
case(UnaryOpType::Logistic): delete static_cast<Logistic*>(unary_op_ptr_); break;
|
||||
case(UnaryOpType::TanH): delete static_cast<TanH*>(unary_op_ptr_); break;
|
||||
case(UnaryOpType::Relu): delete static_cast<Relu*>(unary_op_ptr_); break;
|
||||
case(UnaryOpType::SoftRelu): delete static_cast<SoftRelu*>(unary_op_ptr_); break;
|
||||
case(UnaryOpType::UnaryAbs): delete static_cast<UnaryAbs*>(unary_op_ptr_); break;
|
||||
case(UnaryOpType::Power): delete static_cast<Power*>(unary_op_ptr_); break;
|
||||
case(UnaryOpType::ClippedRelu): delete static_cast<ClippedRelu*>(unary_op_ptr_); break;
|
||||
case(UnaryOpType::LeakyRelu): delete static_cast<LeakyRelu*>(unary_op_ptr_); break;
|
||||
case(UnaryOpType::Elu): delete static_cast<Elu*>(unary_op_ptr_); break;
|
||||
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void InitUnaryOpPtrOnDevice()
|
||||
@@ -1721,6 +1796,7 @@ struct DynamicUnaryOp
|
||||
float beta;
|
||||
float gamma;
|
||||
};
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
|
||||
Reference in New Issue
Block a user