mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Rewrite StaticallyIndexedArray to use C-array instead of Tuple
Replace the recursive template metaprogramming implementation of StaticallyIndexedArray with a simple C-array based struct. This avoids deep template instantiation while maintaining the same interface. Key changes: - StaticallyIndexedArray now stores `T data_[N]` instead of inheriting from Tuple - Added constexpr conversion constructor to convert from any indexed container (Tuple, etc.) - Added arithmetic operators (+, -, *, +=, -=) using C++20 concepts - Added overloads for container_reorder_given_new2old/old2new - Added overloads for get_container_subset and set_container_subset - Specialization for empty array (N=0) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -76,6 +76,25 @@ __host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple<T
|
||||
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
|
||||
}
|
||||
|
||||
template <typename T, index_t N, index_t... IRs>
|
||||
__host__ __device__ constexpr auto
|
||||
container_reorder_given_new2old(const StaticallyIndexedArray<T, N>& old_arr,
|
||||
Sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(N == sizeof...(IRs), "wrong! size not consistent");
|
||||
static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
return make_statically_indexed_array<T>(old_arr[Number<IRs>{}]...);
|
||||
}
|
||||
|
||||
template <typename T, index_t N, index_t... IRs>
|
||||
__host__ __device__ constexpr auto
|
||||
container_reorder_given_old2new(const StaticallyIndexedArray<T, N>& old_arr,
|
||||
Sequence<IRs...> old2new)
|
||||
{
|
||||
return container_reorder_given_new2old(
|
||||
old_arr, typename sequence_map_inverse<decltype(old2new)>::type{});
|
||||
}
|
||||
|
||||
template <index_t... Is, index_t... IRs>
|
||||
__host__ __device__ constexpr auto container_reorder_given_new2old(Sequence<Is...> /* old_seq */,
|
||||
Sequence<IRs...> /*new2old*/)
|
||||
@@ -358,6 +377,15 @@ __host__ __device__ constexpr auto get_container_subset(const Tuple<Ts...>& tup,
|
||||
return make_tuple(tup[Number<Is>{}]...);
|
||||
}
|
||||
|
||||
template <typename T, index_t N, index_t... Is>
|
||||
__host__ __device__ constexpr auto get_container_subset(const StaticallyIndexedArray<T, N>& arr,
|
||||
Sequence<Is...>)
|
||||
{
|
||||
static_assert(N >= sizeof...(Is), "wrong! size");
|
||||
|
||||
return StaticallyIndexedArray<T, sizeof...(Is)>{arr[Number<Is>{}]...};
|
||||
}
|
||||
|
||||
template <typename T, index_t N, index_t... Is>
|
||||
__host__ __device__ constexpr void
|
||||
set_container_subset(Array<T, N>& y, Sequence<Is...> picks, const Array<T, sizeof...(Is)>& x)
|
||||
@@ -376,6 +404,29 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>&
|
||||
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
|
||||
}
|
||||
|
||||
template <typename T, index_t N, index_t... Is>
|
||||
__host__ __device__ constexpr void
|
||||
set_container_subset(StaticallyIndexedArray<T, N>& y,
|
||||
Sequence<Is...> picks,
|
||||
const StaticallyIndexedArray<T, sizeof...(Is)>& x)
|
||||
{
|
||||
static_assert(N >= sizeof...(Is), "wrong! size");
|
||||
|
||||
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
|
||||
}
|
||||
|
||||
// Generic set_container_subset for StaticallyIndexedArray destination with any indexed source
|
||||
template <typename T, index_t N, index_t... Is, typename Src>
|
||||
requires requires { Src::Size(); }
|
||||
__host__ __device__ constexpr void
|
||||
set_container_subset(StaticallyIndexedArray<T, N>& y, Sequence<Is...> picks, const Src& x)
|
||||
{
|
||||
static_assert(N >= sizeof...(Is), "wrong! size");
|
||||
static_assert(Src::Size() == sizeof...(Is), "wrong! size mismatch");
|
||||
|
||||
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
|
||||
{
|
||||
|
||||
@@ -10,51 +10,124 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
namespace detail {
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<Tuple<Xs...>, Tuple<Ys...>>
|
||||
{
|
||||
using type = Tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
// StaticallyIndexedArray using simple C-array instead of template metaprogramming
|
||||
// This avoids deep template instantiation while maintaining the same interface
|
||||
template <typename T, index_t N>
|
||||
struct StaticallyIndexedArrayImpl
|
||||
struct StaticallyIndexedArray
|
||||
{
|
||||
using type =
|
||||
typename tuple_concat<typename StaticallyIndexedArrayImpl<T, N / 2>::type,
|
||||
typename StaticallyIndexedArrayImpl<T, N - N / 2>::type>::type;
|
||||
__host__ __device__ constexpr StaticallyIndexedArray() : data_{} {}
|
||||
|
||||
// Single-element constructor - exclude containers with matching size (to prefer conversion
|
||||
// constructor)
|
||||
template <typename X>
|
||||
requires(N == 1 &&
|
||||
// Allow if X is same type as T or doesn't have Size() method
|
||||
(is_same<remove_cvref_t<X>, T>::value || !requires { remove_cvref_t<X>::Size(); }))
|
||||
__host__ __device__ constexpr StaticallyIndexedArray(X&& x)
|
||||
: data_{static_cast<T>(ck::forward<X>(x))}
|
||||
{
|
||||
}
|
||||
|
||||
// Multi-element constructor
|
||||
template <typename... Xs>
|
||||
requires(sizeof...(Xs) == N && N > 1)
|
||||
__host__ __device__ constexpr StaticallyIndexedArray(Xs&&... xs)
|
||||
: data_{static_cast<T>(ck::forward<Xs>(xs))...}
|
||||
{
|
||||
}
|
||||
|
||||
// Conversion constructor from any indexed container (Tuple, etc.)
|
||||
template <typename Container>
|
||||
requires(!is_same<remove_cvref_t<Container>, StaticallyIndexedArray>::value &&
|
||||
requires { Container::Size(); } && Container::Size() == N)
|
||||
__host__ __device__ constexpr StaticallyIndexedArray(const Container& src)
|
||||
: StaticallyIndexedArray(
|
||||
make_from_container(src, typename arithmetic_sequence_gen<0, N, 1>::type{}))
|
||||
{
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename Container, index_t... Is>
|
||||
__host__ __device__ static constexpr StaticallyIndexedArray
|
||||
make_from_container(const Container& src, Sequence<Is...>)
|
||||
{
|
||||
return StaticallyIndexedArray{static_cast<T>(src[Number<Is>{}])...};
|
||||
}
|
||||
|
||||
public:
|
||||
__host__ __device__ static constexpr index_t Size() { return N; }
|
||||
|
||||
// read access
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const T& At(Number<I>) const
|
||||
{
|
||||
static_assert(I < N, "wrong! out of range");
|
||||
return data_[I];
|
||||
}
|
||||
|
||||
// write access
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr T& At(Number<I>)
|
||||
{
|
||||
static_assert(I < N, "wrong! out of range");
|
||||
return data_[I];
|
||||
}
|
||||
|
||||
// read access
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const T& operator[](Number<I> i) const
|
||||
{
|
||||
return At(i);
|
||||
}
|
||||
|
||||
// write access
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr T& operator()(Number<I> i)
|
||||
{
|
||||
return At(i);
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__host__ __device__ constexpr auto operator=(const U& a)
|
||||
{
|
||||
static_assert(U::Size() == Size(), "wrong! size not the same");
|
||||
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||
|
||||
T data_[N];
|
||||
};
|
||||
|
||||
// Specialization for empty array
|
||||
template <typename T>
|
||||
struct StaticallyIndexedArrayImpl<T, 0>
|
||||
struct StaticallyIndexedArray<T, 0>
|
||||
{
|
||||
using type = Tuple<>;
|
||||
};
|
||||
__host__ __device__ constexpr StaticallyIndexedArray() = default;
|
||||
|
||||
template <typename T>
|
||||
struct StaticallyIndexedArrayImpl<T, 1>
|
||||
{
|
||||
using type = Tuple<T>;
|
||||
};
|
||||
} // namespace detail
|
||||
__host__ __device__ static constexpr index_t Size() { return 0; }
|
||||
|
||||
template <typename T, index_t N>
|
||||
using StaticallyIndexedArray = typename detail::StaticallyIndexedArrayImpl<T, N>::type;
|
||||
template <typename U>
|
||||
__host__ __device__ constexpr auto operator=(const U&)
|
||||
{
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||
};
|
||||
|
||||
template <typename X, typename... Xs>
|
||||
__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
|
||||
{
|
||||
return StaticallyIndexedArray<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
|
||||
return StaticallyIndexedArray<X, sizeof...(Xs) + 1>{x, static_cast<X>(xs)...};
|
||||
}
|
||||
|
||||
// make empty StaticallyIndexedArray
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto make_statically_indexed_array()
|
||||
{
|
||||
return StaticallyIndexedArray<X, 0>();
|
||||
return StaticallyIndexedArray<X, 0>{};
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
@@ -101,5 +174,78 @@ struct StaticallyIndexedArray_v2
|
||||
T data_[N];
|
||||
};
|
||||
|
||||
// Concepts for StaticallyIndexedArray arithmetic operators
|
||||
template <typename T>
|
||||
concept Scalar = ck::is_integral<T>::value || ck::is_floating_point<T>::value;
|
||||
|
||||
template <typename T>
|
||||
concept IndexedContainer = !Scalar<T> && requires { T::Size(); };
|
||||
|
||||
// Arithmetic operators for StaticallyIndexedArray (to match Tuple operators)
|
||||
|
||||
// StaticallyIndexedArray += X
|
||||
template <typename T, index_t N, IndexedContainer X>
|
||||
__host__ __device__ constexpr auto operator+=(StaticallyIndexedArray<T, N>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == N, "wrong! size not the same");
|
||||
static_for<0, N, 1>{}([&](auto i) { y(i) += x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
// StaticallyIndexedArray -= X
|
||||
template <typename T, index_t N, IndexedContainer X>
|
||||
__host__ __device__ constexpr auto operator-=(StaticallyIndexedArray<T, N>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == N, "wrong! size not the same");
|
||||
static_for<0, N, 1>{}([&](auto i) { y(i) -= x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
// StaticallyIndexedArray + Y
|
||||
template <typename T, index_t N, IndexedContainer Y>
|
||||
__host__ __device__ constexpr auto operator+(const StaticallyIndexedArray<T, N>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == N, "wrong! size not the same");
|
||||
StaticallyIndexedArray<T, N> r;
|
||||
static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] + y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// StaticallyIndexedArray - Y
|
||||
template <typename T, index_t N, IndexedContainer Y>
|
||||
__host__ __device__ constexpr auto operator-(const StaticallyIndexedArray<T, N>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == N, "wrong! size not the same");
|
||||
StaticallyIndexedArray<T, N> r;
|
||||
static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] - y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// StaticallyIndexedArray * Y (element-wise)
|
||||
template <typename T, index_t N, IndexedContainer Y>
|
||||
__host__ __device__ constexpr auto operator*(const StaticallyIndexedArray<T, N>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == N, "wrong! size not the same");
|
||||
StaticallyIndexedArray<T, N> r;
|
||||
static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] * y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// scalar * StaticallyIndexedArray
|
||||
template <typename T, index_t N, Scalar S>
|
||||
__host__ __device__ constexpr auto operator*(S a, const StaticallyIndexedArray<T, N>& x)
|
||||
{
|
||||
StaticallyIndexedArray<T, N> r;
|
||||
static_for<0, N, 1>{}([&](auto i) { r(i) = a * x[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// StaticallyIndexedArray * scalar
|
||||
template <typename T, index_t N, Scalar S>
|
||||
__host__ __device__ constexpr auto operator*(const StaticallyIndexedArray<T, N>& x, S a)
|
||||
{
|
||||
return a * x;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user