mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
324 lines
9.1 KiB
C++
324 lines
9.1 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "ck/utility/is_static.hpp"
|
|
#include "ck/utility/print.hpp"
|
|
#include "ck/utility/integral_constant.hpp"
|
|
#include "ck/utility/sequence.hpp"
|
|
#include "ck/utility/type.hpp"
|
|
#include "ck/utility/enable_if.hpp"
|
|
|
|
namespace ck {
|
|
|
|
namespace detail {
|
|
|
|
template <index_t>
|
|
struct TupleElementKey
|
|
{
|
|
__host__ __device__ constexpr TupleElementKey() = default;
|
|
};
|
|
|
|
template <typename Key, typename Data>
|
|
struct TupleElementKeyData
|
|
{
|
|
using DataType = Data;
|
|
|
|
#if 0 // workaround compiler complaint about implicitly-deleted default constructor
|
|
__host__ __device__ constexpr TupleElementKeyData() = default;
|
|
#else
|
|
__host__ __device__ constexpr TupleElementKeyData() : mData{} {}
|
|
#endif
|
|
|
|
template <typename T,
|
|
typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value,
|
|
bool>::type = false>
|
|
__host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward<T>(v))
|
|
{
|
|
}
|
|
|
|
DataType mData;
|
|
};
|
|
|
|
// for read access of tuple element
|
|
template <typename Key, typename Data>
|
|
__host__ __device__ constexpr const Data&
|
|
get_tuple_element_data_reference(const TupleElementKeyData<Key, Data>& x)
|
|
{
|
|
return static_cast<const Data&>(x.mData);
|
|
}
|
|
|
|
// for write access of tuple element
|
|
template <typename Key, typename Data>
|
|
__host__ __device__ constexpr Data&
|
|
get_tuple_element_data_reference(TupleElementKeyData<Key, Data>& x)
|
|
{
|
|
return x.mData;
|
|
}
|
|
|
|
// TODO: not sure the use of reference is correct
|
|
template <typename Key, typename Data>
|
|
__host__ __device__ constexpr Data&&
|
|
get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
|
|
{
|
|
return static_cast<Data&&>(x.mData);
|
|
}
|
|
|
|
// for infering type of tuple element
|
|
template <typename Key, typename Data>
|
|
__host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
|
|
{
|
|
return std::forward(x.mData);
|
|
}
|
|
|
|
template <typename Indices, typename... Xs>
|
|
struct TupleImpl;
|
|
|
|
template <index_t... Is, typename... Xs>
|
|
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<Is>, Xs>...
|
|
{
|
|
__host__ __device__ constexpr TupleImpl() = default;
|
|
|
|
template <typename Y,
|
|
typename enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
|
|
!is_same<remove_cvref_t<Y>, TupleImpl>::value,
|
|
bool>::type = false>
|
|
__host__ __device__ constexpr TupleImpl(Y&& y)
|
|
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
|
|
{
|
|
}
|
|
|
|
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
|
__host__ __device__ constexpr TupleImpl(Ys&&... ys)
|
|
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
|
|
{
|
|
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
|
|
"wrong! inconsistent size");
|
|
}
|
|
|
|
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
|
|
|
|
template <index_t I>
|
|
__host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey<I>) const
|
|
{
|
|
return get_tuple_element_data_reference<TupleElementKey<I>>(*this);
|
|
}
|
|
|
|
template <index_t I>
|
|
__host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey<I>)
|
|
{
|
|
return get_tuple_element_data_reference<TupleElementKey<I>>(*this);
|
|
}
|
|
};
|
|
|
|
} // namespace detail
|
|
|
|
template <typename... Xs>
|
|
struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>
|
|
{
|
|
using base =
|
|
detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>;
|
|
|
|
__host__ __device__ constexpr Tuple() = default;
|
|
|
|
template <typename Y,
|
|
typename enable_if<sizeof...(Xs) == 1 && !is_same<remove_cvref_t<Y>, Tuple>::value,
|
|
bool>::type = false>
|
|
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
|
|
{
|
|
}
|
|
|
|
template <typename... Ys,
|
|
typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
|
|
false>
|
|
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
|
|
{
|
|
}
|
|
|
|
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
|
|
|
|
// read access
|
|
template <index_t I>
|
|
__host__ __device__ constexpr const auto& At() const
|
|
{
|
|
static_assert(I < base::Size(), "wrong! out of range");
|
|
return base::GetElementDataByKey(detail::TupleElementKey<I>{});
|
|
}
|
|
|
|
// write access
|
|
template <index_t I>
|
|
__host__ __device__ constexpr auto& At()
|
|
{
|
|
static_assert(I < base::Size(), "wrong! out of range");
|
|
return base::GetElementDataByKey(detail::TupleElementKey<I>{});
|
|
}
|
|
|
|
// read access
|
|
template <index_t I>
|
|
__host__ __device__ constexpr const auto& At(Number<I>) const
|
|
{
|
|
static_assert(I < base::Size(), "wrong! out of range");
|
|
return base::GetElementDataByKey(detail::TupleElementKey<I>{});
|
|
}
|
|
|
|
// write access
|
|
template <index_t I>
|
|
__host__ __device__ constexpr auto& At(Number<I>)
|
|
{
|
|
static_assert(I < base::Size(), "wrong! out of range");
|
|
return base::GetElementDataByKey(detail::TupleElementKey<I>{});
|
|
}
|
|
|
|
// read access
|
|
template <index_t I>
|
|
__host__ __device__ constexpr const auto& operator[](Number<I> i) const
|
|
{
|
|
return At(i);
|
|
}
|
|
|
|
// write access
|
|
template <index_t I>
|
|
__host__ __device__ constexpr auto& operator()(Number<I> i)
|
|
{
|
|
return At(i);
|
|
}
|
|
|
|
// WARNING: needed by compiler for C++ structured binding support only, don't use this function!
|
|
template <std::size_t I>
|
|
__host__ __device__ constexpr const auto& get() const
|
|
{
|
|
return this->template At<I>();
|
|
}
|
|
|
|
// WARNING: needed bu compiler for C++ structured binding support only, don't use this function!
|
|
template <std::size_t I>
|
|
__host__ __device__ constexpr auto& get()
|
|
{
|
|
return this->template At<I>();
|
|
}
|
|
|
|
template <typename T>
|
|
__host__ __device__ constexpr auto operator=(const T& a)
|
|
{
|
|
static_assert(T::Size() == Size(), "wrong! size not the same");
|
|
|
|
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
|
|
|
|
return *this;
|
|
}
|
|
|
|
__host__ __device__ static constexpr bool IsStatic()
|
|
{
|
|
bool flag = true;
|
|
|
|
static_for<0, sizeof...(Xs), 1>{}([&flag](auto i) {
|
|
flag &= is_static_v<remove_cvref_t<type_pack_element<i.value, Xs...>>>;
|
|
});
|
|
|
|
return flag;
|
|
}
|
|
|
|
// FIXME: remove
|
|
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
|
|
|
__host__ __device__ void Print() const
|
|
{
|
|
printf("Tuple{size: %d, data: [", static_cast<index_t>(Size()));
|
|
|
|
static_for<0, Size(), 1>{}([&](auto i) {
|
|
print(At(i));
|
|
|
|
if(i < Size() - 1)
|
|
{
|
|
printf(", ");
|
|
}
|
|
});
|
|
|
|
printf("]}");
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct Tuple<>
|
|
{
|
|
__host__ __device__ constexpr Tuple() = default;
|
|
|
|
__host__ __device__ static constexpr index_t Size() { return 0; }
|
|
|
|
template <typename T>
|
|
__host__ __device__ constexpr auto operator=(const T&)
|
|
{
|
|
return *this;
|
|
}
|
|
|
|
__host__ __device__ static constexpr bool IsStatic() { return true; }
|
|
|
|
// FIXME: remove
|
|
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
|
|
|
__host__ __device__ void Print() const { printf("Tuple{size: 0, data: []}"); }
|
|
};
|
|
|
|
template <typename... Xs>
|
|
__host__ __device__ constexpr bool operator==(const Tuple<Xs...>& a, const Tuple<Xs...>& b)
|
|
{
|
|
bool same = true;
|
|
|
|
static_for<0, sizeof...(Xs), 1>{}([&](auto i) {
|
|
if(a[i] != b[i])
|
|
{
|
|
same = false;
|
|
}
|
|
});
|
|
|
|
return same;
|
|
}
|
|
|
|
template <typename... Xs>
|
|
__host__ __device__ constexpr bool operator!=(const Tuple<Xs...>& a, const Tuple<Xs...>& b)
|
|
{
|
|
return !(a == b);
|
|
}
|
|
|
|
template <index_t I, typename TTuple>
|
|
struct tuple_element
|
|
{
|
|
// type should keep the cv/ref qualifier of original tuple element
|
|
using type = decltype(detail::get_tuple_element_data<detail::TupleElementKey<I>>(TTuple{}));
|
|
};
|
|
|
|
template <index_t I, typename TTuple>
|
|
using tuple_element_t = typename tuple_element<I, TTuple>::type;
|
|
|
|
template <typename... Xs>
|
|
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
|
|
{
|
|
return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
|
|
}
|
|
|
|
// https://en.cppreference.com/w/cpp/utility/tuple/tie
|
|
template <typename... Args>
|
|
constexpr Tuple<Args&...> tie(Args&... args) noexcept
|
|
{
|
|
return {args...};
|
|
}
|
|
|
|
} // namespace ck
|
|
|
|
namespace std {
|
|
|
|
// WARNING: needed by compiler for C++ structured binding support only, don't use this
|
|
template <typename... Ts>
|
|
struct tuple_size<ck::Tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
|
|
{
|
|
};
|
|
|
|
// WARNING: needed by compiler for C++ structured binding support only, don't use this
|
|
template <std::size_t I, typename... Ts>
|
|
struct tuple_element<I, ck::Tuple<Ts...>> : ck::tuple_element<I, ck::Tuple<Ts...>>
|
|
{
|
|
};
|
|
|
|
} // namespace std
|