From ecc64bce12fd50d6a95444c7008f584b58457780 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Tue, 9 Apr 2024 10:14:56 +0000 Subject: [PATCH] Generalize the composes<> template --- example/ck_tile/01_fmha/fmha_fwd.cpp | 2 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 2 +- .../core/utility/unary_element_function.hpp | 43 +++++++++++-------- 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 6126cc9059..34898c4f64 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -330,7 +330,7 @@ bool run(const ck_tile::ArgParser& arg_parser) auto oacc_element_func = [&]() { if constexpr(std::is_same_v) - return ck_tile::compose(ck_tile::saturates{}, ck_tile::scales{0.1f}); + return ck_tile::composes(ck_tile::saturates{}, ck_tile::scales{0.1f}); else return ck_tile::identity{}; }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 6f21b9bffc..4837f2c996 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -70,7 +70,7 @@ struct FmhaF8StaticQuantizationElementFunctions { using PComputeElementFunction = ck_tile::scales; using OAccElementFunction = - ck_tile::composer, ck_tile::scales>; + ck_tile::composes, ck_tile::scales>; }; template <> diff --git a/include/ck_tile/core/utility/unary_element_function.hpp b/include/ck_tile/core/utility/unary_element_function.hpp index 199d225ba4..c882cbe5fe 100644 --- a/include/ck_tile/core/utility/unary_element_function.hpp +++ b/include/ck_tile/core/utility/unary_element_function.hpp @@ -8,39 +8,46 @@ namespace ck_tile { template -struct composer +struct composes : private composes, private composes { - composer(F f, Fs... fs) : f_(f), tail_(fs...) {} - - template - CK_TILE_HOST_DEVICE constexpr T operator()(const T& arg) const + template + CK_TILE_HOST_DEVICE constexpr explicit composes(FirstArg&& firstArg, RestArgs&&... restArgs) + : composes(std::forward(firstArg)), + composes(std::forward(restArgs)...) { - return f_(tail_(arg)); } - F f_; - composer tail_; + template + CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(Arg&& arg) const + { + return static_cast&>(*this)( + static_cast&>(*this)(std::forward(arg))); + } }; template -struct composer +struct composes { - composer(F f) : f_(f) {} + static_assert(!std::is_reference_v); - template - CK_TILE_HOST_DEVICE constexpr T operator()(const T& arg) const + template >> + CK_TILE_HOST_DEVICE constexpr explicit composes(Arg&& arg) : f_(std::forward(arg)) { - return f_(arg); } + template &, Arg>>> + CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(Arg&& arg) const + { + return f_(std::forward(arg)); + } + + private: F f_; }; -template -CK_TILE_HOST auto compose(F... f) -{ - return composer(f...); -} +template +__host__ __device__ composes(Ts&&...)->composes...>; template struct saturates