mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Generalize the composes<> template
This commit is contained in:
@@ -330,7 +330,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
|
||||
return ck_tile::compose(ck_tile::saturates<ck_tile::fp8_t>{}, ck_tile::scales{0.1f});
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{}, ck_tile::scales{0.1f});
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
@@ -70,7 +70,7 @@ struct FmhaF8StaticQuantizationElementFunctions
|
||||
{
|
||||
using PComputeElementFunction = ck_tile::scales<float>;
|
||||
using OAccElementFunction =
|
||||
ck_tile::composer<ck_tile::saturates<ck_tile::fp8_t>, ck_tile::scales<float>>;
|
||||
ck_tile::composes<ck_tile::saturates<ck_tile::fp8_t>, ck_tile::scales<float>>;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
@@ -8,39 +8,46 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename F, typename... Fs>
|
||||
struct composer
|
||||
struct composes : private composes<F>, private composes<Fs...>
|
||||
{
|
||||
composer(F f, Fs... fs) : f_(f), tail_(fs...) {}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& arg) const
|
||||
template <typename FirstArg, typename... RestArgs>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit composes(FirstArg&& firstArg, RestArgs&&... restArgs)
|
||||
: composes<F>(std::forward<FirstArg>(firstArg)),
|
||||
composes<Fs...>(std::forward<RestArgs>(restArgs)...)
|
||||
{
|
||||
return f_(tail_(arg));
|
||||
}
|
||||
|
||||
F f_;
|
||||
composer<Fs...> tail_;
|
||||
template <typename Arg>
|
||||
CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(Arg&& arg) const
|
||||
{
|
||||
return static_cast<const composes<F>&>(*this)(
|
||||
static_cast<const composes<Fs...>&>(*this)(std::forward<Arg>(arg)));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename F>
|
||||
struct composer<F>
|
||||
struct composes<F>
|
||||
{
|
||||
composer(F f) : f_(f) {}
|
||||
static_assert(!std::is_reference_v<F>);
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& arg) const
|
||||
template <typename Arg, typename = std::enable_if_t<std::is_constructible_v<Arg, F>>>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit composes(Arg&& arg) : f_(std::forward<Arg>(arg))
|
||||
{
|
||||
return f_(arg);
|
||||
}
|
||||
|
||||
template <typename Arg,
|
||||
typename = std::enable_if_t<std::is_invocable_v<std::add_const_t<F>&, Arg>>>
|
||||
CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(Arg&& arg) const
|
||||
{
|
||||
return f_(std::forward<Arg>(arg));
|
||||
}
|
||||
|
||||
private:
|
||||
F f_;
|
||||
};
|
||||
|
||||
template <typename... F>
|
||||
CK_TILE_HOST auto compose(F... f)
|
||||
{
|
||||
return composer<F...>(f...);
|
||||
}
|
||||
template <typename... Ts>
|
||||
__host__ __device__ composes(Ts&&...)->composes<remove_cvref_t<Ts>...>;
|
||||
|
||||
template <typename To>
|
||||
struct saturates
|
||||
|
||||
Reference in New Issue
Block a user