From 724312aea3387cd6d2ef03d6dfae089b1c96e72e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 30 Oct 2024 17:42:50 +0100 Subject: [PATCH] Remove virtual destructors from unary ops (#1610) * Remove virtual destructors from unary ops * Fixes * Fixes * clang format fixes [ROCm/composable_kernel commit: 9a8a52130d780ca449ae261bb03ae4783f18f296] --- .../element/unary_element_wise_operation.hpp | 112 +++++++++++++++--- include/ck_tile/core/numeric/math.hpp | 2 +- .../host/reference/reference_elementwise.hpp | 2 +- .../host/reference/reference_permute.hpp | 2 +- .../reference/reference_rmsnorm2d_fwd.hpp | 2 +- .../add_rmsnorm2d_rdquant_fwd_shape.hpp | 2 +- ...rmsnorm2d_rdquant_fwd_pipeline_problem.hpp | 2 +- .../ops/fmha/pipeline/tile_fmha_shape.hpp | 2 +- .../pipeline/generic_petmute_problem.hpp | 2 +- .../rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp | 2 +- .../rmsnorm2d_fwd_pipeline_problem.hpp | 2 +- .../ops/welford/block/block_welford.hpp | 2 +- 12 files changed, 105 insertions(+), 29 deletions(-) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 712b886183..39b81ca573 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -13,15 +13,17 @@ namespace ck { namespace tensor_operation { namespace element_wise { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wnon-virtual-dtor" struct UnaryOpBase { public: - __host__ __device__ virtual ~UnaryOpBase() = default; + __host__ __device__ ~UnaryOpBase() = default; - __host__ __device__ UnaryOpBase() = default; - __host__ __device__ UnaryOpBase(const UnaryOpBase&) = default; + __host__ __device__ constexpr UnaryOpBase() = default; + __host__ __device__ constexpr UnaryOpBase(const UnaryOpBase&) = default; + __host__ __device__ constexpr UnaryOpBase(UnaryOpBase&&) = default; __host__ __device__ UnaryOpBase& operator=(const UnaryOpBase&) = default; - __host__ __device__ UnaryOpBase(UnaryOpBase&&) = default; __host__ __device__ UnaryOpBase& operator=(UnaryOpBase&&) = default; __host__ __device__ virtual inline void operator()(float& y, const float& x) const = 0; @@ -50,8 +52,14 @@ struct PassThroughPack2 constexpr const static bool is_pack2_invocable = true; }; -struct PassThrough : public UnaryOpBase +struct PassThrough final : public UnaryOpBase { + __host__ __device__ constexpr PassThrough() = default; + __host__ __device__ constexpr PassThrough(const PassThrough&) = default; + __host__ __device__ constexpr PassThrough(PassThrough&&) = default; + __host__ __device__ PassThrough& operator=(const PassThrough&) = default; + __host__ __device__ PassThrough& operator=(PassThrough&&) = default; + __host__ __device__ ~PassThrough() = default; __host__ __device__ inline void operator()(float& y, const float& x) const final { y = x; } @@ -409,8 +417,15 @@ struct UnarySquare }; }; -struct UnaryAbs : public UnaryOpBase +struct UnaryAbs final : public UnaryOpBase { + __host__ __device__ constexpr UnaryAbs() = default; + __host__ __device__ constexpr UnaryAbs(const UnaryAbs&) = default; + __host__ __device__ constexpr UnaryAbs(UnaryAbs&&) = default; + __host__ __device__ UnaryAbs& operator=(const UnaryAbs&) = default; + __host__ __device__ UnaryAbs& operator=(UnaryAbs&&) = default; + __host__ __device__ ~UnaryAbs() = default; + __host__ __device__ inline void operator()(float& y, const float& x) const final { y = ck::math::abs(x); @@ -459,8 +474,15 @@ struct UnarySqrt }; }; -struct Relu : public UnaryOpBase +struct Relu final : public UnaryOpBase { + __host__ __device__ constexpr Relu() = default; + __host__ __device__ constexpr Relu(const Relu&) = default; + __host__ __device__ constexpr Relu(Relu&&) = default; + __host__ __device__ Relu& operator=(const Relu&) = default; + __host__ __device__ Relu& operator=(Relu&&) = default; + __host__ __device__ ~Relu() = default; + __host__ __device__ inline void operator()(float& y, const float& x) const final { y = x > 0 ? x : 0; @@ -633,8 +655,14 @@ struct Gelu } }; -struct Sigmoid : public UnaryOpBase +struct Sigmoid final : public UnaryOpBase { + __host__ __device__ constexpr Sigmoid() = default; + __host__ __device__ constexpr Sigmoid(const Sigmoid&) = default; + __host__ __device__ constexpr Sigmoid(Sigmoid&&) = default; + __host__ __device__ Sigmoid& operator=(const Sigmoid&) = default; + __host__ __device__ Sigmoid& operator=(Sigmoid&&) = default; + __host__ __device__ ~Sigmoid() = default; __host__ __device__ inline void operator()(float& y, const float& x) const final { @@ -688,8 +716,15 @@ struct Silu }; }; -struct TanH : public UnaryOpBase +struct TanH final : public UnaryOpBase { + __host__ __device__ constexpr TanH() = default; + __host__ __device__ constexpr TanH(const TanH&) = default; + __host__ __device__ constexpr TanH(TanH&&) = default; + __host__ __device__ TanH& operator=(const TanH&) = default; + __host__ __device__ TanH& operator=(TanH&&) = default; + __host__ __device__ ~TanH() = default; + __host__ __device__ inline void operator()(float& y, const float& x) const final { y = ck::math::tanh(x); @@ -959,8 +994,12 @@ struct Rcp }; }; -struct Swish : public UnaryOpBase +struct Swish final : public UnaryOpBase { + __host__ __device__ constexpr Swish(const Swish&) = default; + __host__ __device__ constexpr Swish(Swish&&) = default; + __host__ __device__ ~Swish() = default; + __host__ __device__ Swish(float beta = 1.0f) : beta_(beta) {} __host__ __device__ float get_beta() const { return beta_; } @@ -1019,8 +1058,12 @@ struct Swish : public UnaryOpBase } }; -struct SoftRelu : public UnaryOpBase +struct SoftRelu final : public UnaryOpBase { + __host__ __device__ constexpr SoftRelu(const SoftRelu&) = default; + __host__ __device__ constexpr SoftRelu(SoftRelu&&) = default; + __host__ __device__ ~SoftRelu() = default; + __host__ __device__ SoftRelu(float alpha = 1.0f) : alpha_(alpha) {} __host__ __device__ float get_alpha() const { return alpha_; } @@ -1070,8 +1113,12 @@ struct SoftRelu : public UnaryOpBase } }; -struct Power : public UnaryOpBase +struct Power final : public UnaryOpBase { + __host__ __device__ constexpr Power(const Power&) = default; + __host__ __device__ constexpr Power(Power&&) = default; + __host__ __device__ ~Power() = default; + __host__ __device__ Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) : alpha_(alpha), beta_(beta), gamma_(gamma) { @@ -1148,8 +1195,12 @@ struct Power : public UnaryOpBase } }; -struct ClippedRelu : public UnaryOpBase +struct ClippedRelu final : public UnaryOpBase { + __host__ __device__ constexpr ClippedRelu(const ClippedRelu&) = default; + __host__ __device__ constexpr ClippedRelu(ClippedRelu&&) = default; + __host__ __device__ ~ClippedRelu() = default; + __host__ __device__ ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta) { @@ -1205,8 +1256,11 @@ struct ClippedRelu : public UnaryOpBase } }; -struct LeakyRelu : public UnaryOpBase +struct LeakyRelu final : public UnaryOpBase { + __host__ __device__ constexpr LeakyRelu(const LeakyRelu&) = default; + __host__ __device__ constexpr LeakyRelu(LeakyRelu&&) = default; + __host__ __device__ ~LeakyRelu() = default; __host__ __device__ LeakyRelu(float alpha = 0.f) : alpha_(alpha) {} @@ -1250,8 +1304,11 @@ struct LeakyRelu : public UnaryOpBase } }; -struct Elu : public UnaryOpBase +struct Elu final : public UnaryOpBase { + __host__ __device__ constexpr Elu(const Elu&) = default; + __host__ __device__ constexpr Elu(Elu&&) = default; + __host__ __device__ ~Elu() = default; __host__ __device__ Elu(float alpha = 1.f) : alpha_(alpha) {} @@ -1296,8 +1353,11 @@ struct Elu : public UnaryOpBase } }; -struct Logistic : public UnaryOpBase +struct Logistic final : public UnaryOpBase { + __host__ __device__ constexpr Logistic(const Logistic&) = default; + __host__ __device__ constexpr Logistic(Logistic&&) = default; + __host__ __device__ ~Logistic() = default; __host__ __device__ Logistic(float alpha = 1.0f) : alpha_(alpha) {} @@ -1631,8 +1691,23 @@ struct DynamicUnaryOp __host__ __device__ ~DynamicUnaryOp() { - if(unary_op_ptr_) - delete unary_op_ptr_; + switch(unary_op_type_) + { + case(UnaryOpType::Swish): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::Sigmoid): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::PassThrough): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::Logistic): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::TanH): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::Relu): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::SoftRelu): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::UnaryAbs): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::Power): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::ClippedRelu): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::LeakyRelu): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::Elu): delete static_cast(unary_op_ptr_); break; + + default: break; + } } __device__ void InitUnaryOpPtrOnDevice() @@ -1721,6 +1796,7 @@ struct DynamicUnaryOp float beta; float gamma; }; +#pragma clang diagnostic pop } // namespace element_wise } // namespace tensor_operation diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 0faf1aa043..6bdcb509b0 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_elementwise.hpp b/include/ck_tile/host/reference/reference_elementwise.hpp index 809049fa64..65303279b8 100644 --- a/include/ck_tile/host/reference/reference_elementwise.hpp +++ b/include/ck_tile/host/reference/reference_elementwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_permute.hpp b/include/ck_tile/host/reference/reference_permute.hpp index 1c82483407..14ed4f815e 100644 --- a/include/ck_tile/host/reference/reference_permute.hpp +++ b/include/ck_tile/host/reference/reference_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp index db6e92f4c0..b14e25a85b 100644 --- a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp +++ b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp index a17c53c73f..4bc7db434e 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp index 106e5086be..2e64060038 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index 570754b22e..bb33b5f021 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp b/include/ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp index e504ed7472..17f18acb5e 100644 --- a/include/ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp +++ b/include/ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp index fb484a1069..fc4b9f470c 100644 --- a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp index 87cab34631..2820e18133 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/welford/block/block_welford.hpp b/include/ck_tile/ops/welford/block/block_welford.hpp index 623e1e16d8..ce73c183e1 100644 --- a/include/ck_tile/ops/welford/block/block_welford.hpp +++ b/include/ck_tile/ops/welford/block/block_welford.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once