From d6cb104d0f3cacf22b628123d389d5a65eda25dd Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 4 Apr 2024 03:18:39 +0000 Subject: [PATCH] Add some elementwise op, prepare to quantization --- .../core/utility/unary_element_function.hpp | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 include/ck_tile/core/utility/unary_element_function.hpp diff --git a/include/ck_tile/core/utility/unary_element_function.hpp b/include/ck_tile/core/utility/unary_element_function.hpp new file mode 100644 index 0000000000..7eb77c01da --- /dev/null +++ b/include/ck_tile/core/utility/unary_element_function.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct composer +{ + composer(F f, Fs... fs) : f_(f), tail_(fs...) {} + + template + CK_TILE_HOST_DEVICE constexpr T operator()(const T& arg) const + { + return f_(tail_(arg)); + } + + F f_; + composer tail_; +}; + +template +struct composer +{ + composer(F f) : f_(f) {} + + template + CK_TILE_HOST_DEVICE constexpr T operator()(const T& arg) const + { + return f_(arg); + } + + F f_; +}; + +template +CK_TILE_HOST auto compose(F... f) +{ + return composer(f...); +} + +// start of unary element function + +struct scale +{ + CK_TILE_HOST_DEVICE scale(float factor) : factor_(factor) {} + + template + CK_TILE_HOST_DEVICE constexpr T operator()(const T& x) const; + + template <> + CK_TILE_HOST_DEVICE constexpr float operator()(const float& x) const + { + return factor_ * x; + }; + + float factor_; +}; + +// TODO: Overload numeric::min() and numeric::max() +struct saturate_f8 +{ + template + CK_TILE_HOST_DEVICE constexpr T operator()(const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + T y = clamp(x, static_cast(-448), static_cast(448)); + return y; + } +}; + +} // namespace ck_tile