mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
160 lines
4.4 KiB
C++
160 lines
4.4 KiB
C++
#ifndef CK_TUPLE_HPP
|
|
#define CK_TUPLE_HPP
|
|
|
|
#include "integral_constant.hpp"
|
|
#include "type.hpp"
|
|
#include "sequence.hpp"
|
|
|
|
namespace ck {
|
|
|
|
namespace detail {
|
|
|
|
template <index_t>
|
|
struct TupleElementKey
|
|
{
|
|
};
|
|
|
|
template <typename Key, typename Data>
|
|
struct TupleElement
|
|
{
|
|
__host__ __device__ explicit constexpr TupleElement() : mData() {}
|
|
|
|
template <typename T>
|
|
__host__ __device__ explicit constexpr TupleElement(T&& v) : mData(static_cast<T&&>(v))
|
|
{
|
|
}
|
|
|
|
Data mData;
|
|
};
|
|
|
|
template <typename Key, typename Data>
|
|
__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement<Key, Data>& x)
|
|
{
|
|
return x.mData;
|
|
}
|
|
|
|
template <typename Key, typename Data>
|
|
__host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x)
|
|
{
|
|
return x.mData;
|
|
}
|
|
|
|
template <typename Key, typename Data>
|
|
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x)
|
|
{
|
|
return static_cast<Data&&>(x.mData);
|
|
}
|
|
|
|
template <typename Indices, typename... Xs>
|
|
struct TupleImpl;
|
|
|
|
template <index_t... Is, typename... Xs>
|
|
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>...
|
|
{
|
|
__host__ __device__ explicit constexpr TupleImpl() : TupleElement<TupleElementKey<Is>, Xs>()...
|
|
{
|
|
}
|
|
|
|
template <typename... Ys>
|
|
__host__ __device__ explicit constexpr TupleImpl(Ys&&... ys)
|
|
: TupleElement<TupleElementKey<Is>, Xs>(static_cast<Ys&&>(ys))...
|
|
{
|
|
}
|
|
|
|
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
|
|
|
|
template <index_t I>
|
|
__host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey<I>) const
|
|
{
|
|
return get_tuple_element<TupleElementKey<I>>(*this);
|
|
}
|
|
|
|
template <index_t I>
|
|
__host__ __device__ constexpr auto& GetElementByKey(TupleElementKey<I>)
|
|
{
|
|
return get_tuple_element<TupleElementKey<I>>(*this);
|
|
}
|
|
};
|
|
|
|
} // namespace detail
|
|
|
|
template <typename... Xs>
|
|
struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>
|
|
{
|
|
using base =
|
|
detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>;
|
|
|
|
template <typename... Ys>
|
|
__host__ __device__ explicit constexpr Tuple(Ys&&... ys) : base(static_cast<Ys&&>(ys)...)
|
|
{
|
|
}
|
|
|
|
template <index_t I>
|
|
__host__ __device__ constexpr const auto& At(Number<I>) const
|
|
{
|
|
static_assert(I < base::Size(), "wrong! out of range");
|
|
return base::GetElementByKey(detail::TupleElementKey<I>{});
|
|
}
|
|
|
|
template <index_t I>
|
|
__host__ __device__ constexpr auto& At(Number<I>)
|
|
{
|
|
static_assert(I < base::Size(), "wrong! out of range");
|
|
return base::GetElementByKey(detail::TupleElementKey<I>{});
|
|
}
|
|
};
|
|
|
|
template <typename... Xs>
|
|
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
|
|
{
|
|
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
|
|
}
|
|
|
|
namespace detail {
|
|
|
|
template <typename F, typename X, index_t... Is>
|
|
__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence<Is...>)
|
|
{
|
|
return make_tuple(f(x.At(Number<Is>{}))...);
|
|
}
|
|
|
|
template <typename F, typename X, typename Y, index_t... Is>
|
|
__host__ __device__ constexpr auto
|
|
transform_tuples_impl(F f, const X& x, const Y& y, Sequence<Is...>)
|
|
{
|
|
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}))...);
|
|
}
|
|
|
|
template <typename F, typename X, typename Y, typename Z, index_t... Is>
|
|
__host__ __device__ constexpr auto
|
|
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence<Is...>)
|
|
{
|
|
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
template <typename F, typename X>
|
|
__host__ __device__ constexpr auto transform_tuples(F f, const X& x)
|
|
{
|
|
return detail::transform_tuples_impl(
|
|
f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
|
|
}
|
|
|
|
template <typename F, typename X, typename Y>
|
|
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y)
|
|
{
|
|
return detail::transform_tuples_impl(
|
|
f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
|
|
}
|
|
|
|
template <typename F, typename X, typename Y, typename Z>
|
|
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
|
|
{
|
|
return detail::transform_tuples_impl(
|
|
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
|
|
}
|
|
|
|
} // namespace ck
|
|
#endif
|