Fix type errors in composes<>

This commit is contained in:
Po Yen Chen
2024-04-09 13:18:17 +00:00
parent 35e2d18e5e
commit c6eac9746f
2 changed files with 8 additions and 7 deletions

View File

@@ -508,7 +508,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
s_host_ref,
ck_tile::identity{},
ck_tile::identity{},
[&](SaccDataType x) { return pcompute_element_func(scale * x); });
ck_tile::composes(pcompute_element_func, ck_tile::scales(scale)));
if(use_bias)
{

View File

@@ -8,21 +8,22 @@
namespace ck_tile {
template <typename F, typename... Fs>
struct composes : private composes<F>, private composes<Fs...>
struct composes : private composes<F>
{
template <typename FirstArg, typename... RestArgs>
CK_TILE_HOST_DEVICE constexpr explicit composes(FirstArg&& firstArg, RestArgs&&... restArgs)
: composes<F>(std::forward<FirstArg>(firstArg)),
composes<Fs...>(std::forward<RestArgs>(restArgs)...)
: composes<F>(std::forward<FirstArg>(firstArg)), inner_(std::forward<RestArgs>(restArgs)...)
{
}
template <typename Arg>
CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(Arg&& arg) const
{
return static_cast<const composes<F>&>(*this)(
static_cast<const composes<Fs...>&>(*this)(std::forward<Arg>(arg)));
return static_cast<const composes<F>&>(*this)(inner_(std::forward<Arg>(arg)));
}
private:
composes<Fs...> inner_;
};
template <typename F>
@@ -30,7 +31,7 @@ struct composes<F>
{
static_assert(!std::is_reference_v<F>);
template <typename Arg, typename = std::enable_if_t<std::is_constructible_v<Arg, F>>>
template <typename Arg, typename = std::enable_if_t<std::is_constructible_v<F, Arg>>>
CK_TILE_HOST_DEVICE constexpr explicit composes(Arg&& arg) : f_(std::forward<Arg>(arg))
{
}