From c6eac9746f2097a3c4776d230aa92e6e7a697b88 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Tue, 9 Apr 2024 13:18:17 +0000 Subject: [PATCH] Fix type errors in composes<> --- example/ck_tile/01_fmha/fmha_fwd.cpp | 2 +- .../ck_tile/core/utility/unary_element_function.hpp | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index e9ff29caac..508d38a43d 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -508,7 +508,7 @@ bool run(const ck_tile::ArgParser& arg_parser) s_host_ref, ck_tile::identity{}, ck_tile::identity{}, - [&](SaccDataType x) { return pcompute_element_func(scale * x); }); + ck_tile::composes(pcompute_element_func, ck_tile::scales(scale))); if(use_bias) { diff --git a/include/ck_tile/core/utility/unary_element_function.hpp b/include/ck_tile/core/utility/unary_element_function.hpp index 1596519efc..10cbcb9864 100644 --- a/include/ck_tile/core/utility/unary_element_function.hpp +++ b/include/ck_tile/core/utility/unary_element_function.hpp @@ -8,21 +8,22 @@ namespace ck_tile { template -struct composes : private composes, private composes +struct composes : private composes { template CK_TILE_HOST_DEVICE constexpr explicit composes(FirstArg&& firstArg, RestArgs&&... restArgs) - : composes(std::forward(firstArg)), - composes(std::forward(restArgs)...) + : composes(std::forward(firstArg)), inner_(std::forward(restArgs)...) { } template CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(Arg&& arg) const { - return static_cast&>(*this)( - static_cast&>(*this)(std::forward(arg))); + return static_cast&>(*this)(inner_(std::forward(arg))); } + + private: + composes inner_; }; template @@ -30,7 +31,7 @@ struct composes { static_assert(!std::is_reference_v); - template >> + template >> CK_TILE_HOST_DEVICE constexpr explicit composes(Arg&& arg) : f_(std::forward(arg)) { }