Generalize the composes<> template

This commit is contained in:
Po Yen Chen
2024-04-09 10:14:56 +00:00
parent 6ed739f913
commit ecc64bce12
3 changed files with 27 additions and 20 deletions

View File

@@ -330,7 +330,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
return ck_tile::compose(ck_tile::saturates<ck_tile::fp8_t>{}, ck_tile::scales{0.1f});
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{}, ck_tile::scales{0.1f});
else
return ck_tile::identity{};
}();

View File

@@ -70,7 +70,7 @@ struct FmhaF8StaticQuantizationElementFunctions
{
using PComputeElementFunction = ck_tile::scales<float>;
using OAccElementFunction =
ck_tile::composer<ck_tile::saturates<ck_tile::fp8_t>, ck_tile::scales<float>>;
ck_tile::composes<ck_tile::saturates<ck_tile::fp8_t>, ck_tile::scales<float>>;
};
template <>

View File

@@ -8,39 +8,46 @@
namespace ck_tile {
template <typename F, typename... Fs>
struct composer
struct composes : private composes<F>, private composes<Fs...>
{
composer(F f, Fs... fs) : f_(f), tail_(fs...) {}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& arg) const
template <typename FirstArg, typename... RestArgs>
CK_TILE_HOST_DEVICE constexpr explicit composes(FirstArg&& firstArg, RestArgs&&... restArgs)
: composes<F>(std::forward<FirstArg>(firstArg)),
composes<Fs...>(std::forward<RestArgs>(restArgs)...)
{
return f_(tail_(arg));
}
F f_;
composer<Fs...> tail_;
template <typename Arg>
CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(Arg&& arg) const
{
return static_cast<const composes<F>&>(*this)(
static_cast<const composes<Fs...>&>(*this)(std::forward<Arg>(arg)));
}
};
template <typename F>
struct composer<F>
struct composes<F>
{
composer(F f) : f_(f) {}
static_assert(!std::is_reference_v<F>);
template <typename T>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& arg) const
template <typename Arg, typename = std::enable_if_t<std::is_constructible_v<Arg, F>>>
CK_TILE_HOST_DEVICE constexpr explicit composes(Arg&& arg) : f_(std::forward<Arg>(arg))
{
return f_(arg);
}
template <typename Arg,
typename = std::enable_if_t<std::is_invocable_v<std::add_const_t<F>&, Arg>>>
CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(Arg&& arg) const
{
return f_(std::forward<Arg>(arg));
}
private:
F f_;
};
template <typename... F>
CK_TILE_HOST auto compose(F... f)
{
return composer<F...>(f...);
}
template <typename... Ts>
__host__ __device__ composes(Ts&&...)->composes<remove_cvref_t<Ts>...>;
template <typename To>
struct saturates