Rewrite StaticallyIndexedArray to use C-array instead of Tuple

Replace the recursive template metaprogramming implementation of
StaticallyIndexedArray with a simple C-array based struct. This avoids
deep template instantiation while maintaining the same interface.

Key changes:
- StaticallyIndexedArray now stores `T data_[N]` instead of inheriting from Tuple
- Added constexpr conversion constructor to convert from any indexed container (Tuple, etc.)
- Added arithmetic operators (+, -, *, +=, -=) using C++20 concepts
- Added overloads for container_reorder_given_new2old/old2new
- Added overloads for get_container_subset and set_container_subset
- Specialization for empty array (N=0)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Max Podkorytov
2026-01-15 21:54:58 -06:00
parent 3f735c127b
commit aef254ca0d
2 changed files with 224 additions and 27 deletions

View File

@@ -76,6 +76,25 @@ __host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple<T
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
}
template <typename T, index_t N, index_t... IRs>
__host__ __device__ constexpr auto
container_reorder_given_new2old(const StaticallyIndexedArray<T, N>& old_arr,
Sequence<IRs...> /*new2old*/)
{
static_assert(N == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
return make_statically_indexed_array<T>(old_arr[Number<IRs>{}]...);
}
template <typename T, index_t N, index_t... IRs>
__host__ __device__ constexpr auto
container_reorder_given_old2new(const StaticallyIndexedArray<T, N>& old_arr,
Sequence<IRs...> old2new)
{
return container_reorder_given_new2old(
old_arr, typename sequence_map_inverse<decltype(old2new)>::type{});
}
template <index_t... Is, index_t... IRs>
__host__ __device__ constexpr auto container_reorder_given_new2old(Sequence<Is...> /* old_seq */,
Sequence<IRs...> /*new2old*/)
@@ -358,6 +377,15 @@ __host__ __device__ constexpr auto get_container_subset(const Tuple<Ts...>& tup,
return make_tuple(tup[Number<Is>{}]...);
}
template <typename T, index_t N, index_t... Is>
__host__ __device__ constexpr auto get_container_subset(const StaticallyIndexedArray<T, N>& arr,
Sequence<Is...>)
{
static_assert(N >= sizeof...(Is), "wrong! size");
return StaticallyIndexedArray<T, sizeof...(Is)>{arr[Number<Is>{}]...};
}
template <typename T, index_t N, index_t... Is>
__host__ __device__ constexpr void
set_container_subset(Array<T, N>& y, Sequence<Is...> picks, const Array<T, sizeof...(Is)>& x)
@@ -376,6 +404,29 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>&
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}
template <typename T, index_t N, index_t... Is>
__host__ __device__ constexpr void
set_container_subset(StaticallyIndexedArray<T, N>& y,
Sequence<Is...> picks,
const StaticallyIndexedArray<T, sizeof...(Is)>& x)
{
static_assert(N >= sizeof...(Is), "wrong! size");
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}
// Generic set_container_subset for StaticallyIndexedArray destination with any indexed source
template <typename T, index_t N, index_t... Is, typename Src>
requires requires { Src::Size(); }
__host__ __device__ constexpr void
set_container_subset(StaticallyIndexedArray<T, N>& y, Sequence<Is...> picks, const Src& x)
{
static_assert(N >= sizeof...(Is), "wrong! size");
static_assert(Src::Size() == sizeof...(Is), "wrong! size mismatch");
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}
template <index_t... Is>
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
{

View File

@@ -10,51 +10,124 @@
namespace ck {
namespace detail {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<Tuple<Xs...>, Tuple<Ys...>>
{
using type = Tuple<Xs..., Ys...>;
};
// StaticallyIndexedArray using simple C-array instead of template metaprogramming
// This avoids deep template instantiation while maintaining the same interface
template <typename T, index_t N>
struct StaticallyIndexedArrayImpl
struct StaticallyIndexedArray
{
using type =
typename tuple_concat<typename StaticallyIndexedArrayImpl<T, N / 2>::type,
typename StaticallyIndexedArrayImpl<T, N - N / 2>::type>::type;
__host__ __device__ constexpr StaticallyIndexedArray() : data_{} {}
// Single-element constructor - exclude containers with matching size (to prefer conversion
// constructor)
template <typename X>
requires(N == 1 &&
// Allow if X is same type as T or doesn't have Size() method
(is_same<remove_cvref_t<X>, T>::value || !requires { remove_cvref_t<X>::Size(); }))
__host__ __device__ constexpr StaticallyIndexedArray(X&& x)
: data_{static_cast<T>(ck::forward<X>(x))}
{
}
// Multi-element constructor
template <typename... Xs>
requires(sizeof...(Xs) == N && N > 1)
__host__ __device__ constexpr StaticallyIndexedArray(Xs&&... xs)
: data_{static_cast<T>(ck::forward<Xs>(xs))...}
{
}
// Conversion constructor from any indexed container (Tuple, etc.)
template <typename Container>
requires(!is_same<remove_cvref_t<Container>, StaticallyIndexedArray>::value &&
requires { Container::Size(); } && Container::Size() == N)
__host__ __device__ constexpr StaticallyIndexedArray(const Container& src)
: StaticallyIndexedArray(
make_from_container(src, typename arithmetic_sequence_gen<0, N, 1>::type{}))
{
}
private:
template <typename Container, index_t... Is>
__host__ __device__ static constexpr StaticallyIndexedArray
make_from_container(const Container& src, Sequence<Is...>)
{
return StaticallyIndexedArray{static_cast<T>(src[Number<Is>{}])...};
}
public:
__host__ __device__ static constexpr index_t Size() { return N; }
// read access
template <index_t I>
__host__ __device__ constexpr const T& At(Number<I>) const
{
static_assert(I < N, "wrong! out of range");
return data_[I];
}
// write access
template <index_t I>
__host__ __device__ constexpr T& At(Number<I>)
{
static_assert(I < N, "wrong! out of range");
return data_[I];
}
// read access
template <index_t I>
__host__ __device__ constexpr const T& operator[](Number<I> i) const
{
return At(i);
}
// write access
template <index_t I>
__host__ __device__ constexpr T& operator()(Number<I> i)
{
return At(i);
}
template <typename U>
__host__ __device__ constexpr auto operator=(const U& a)
{
static_assert(U::Size() == Size(), "wrong! size not the same");
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
return *this;
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
T data_[N];
};
// Specialization for empty array
template <typename T>
struct StaticallyIndexedArrayImpl<T, 0>
struct StaticallyIndexedArray<T, 0>
{
using type = Tuple<>;
};
__host__ __device__ constexpr StaticallyIndexedArray() = default;
template <typename T>
struct StaticallyIndexedArrayImpl<T, 1>
{
using type = Tuple<T>;
};
} // namespace detail
__host__ __device__ static constexpr index_t Size() { return 0; }
template <typename T, index_t N>
using StaticallyIndexedArray = typename detail::StaticallyIndexedArrayImpl<T, N>::type;
template <typename U>
__host__ __device__ constexpr auto operator=(const U&)
{
return *this;
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
};
template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
{
return StaticallyIndexedArray<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
return StaticallyIndexedArray<X, sizeof...(Xs) + 1>{x, static_cast<X>(xs)...};
}
// make empty StaticallyIndexedArray
template <typename X>
__host__ __device__ constexpr auto make_statically_indexed_array()
{
return StaticallyIndexedArray<X, 0>();
return StaticallyIndexedArray<X, 0>{};
}
template <typename T, index_t N>
@@ -101,5 +174,78 @@ struct StaticallyIndexedArray_v2
T data_[N];
};
// Concepts for StaticallyIndexedArray arithmetic operators
template <typename T>
concept Scalar = ck::is_integral<T>::value || ck::is_floating_point<T>::value;
template <typename T>
concept IndexedContainer = !Scalar<T> && requires { T::Size(); };
// Arithmetic operators for StaticallyIndexedArray (to match Tuple operators)
// StaticallyIndexedArray += X
template <typename T, index_t N, IndexedContainer X>
__host__ __device__ constexpr auto operator+=(StaticallyIndexedArray<T, N>& y, const X& x)
{
static_assert(X::Size() == N, "wrong! size not the same");
static_for<0, N, 1>{}([&](auto i) { y(i) += x[i]; });
return y;
}
// StaticallyIndexedArray -= X
template <typename T, index_t N, IndexedContainer X>
__host__ __device__ constexpr auto operator-=(StaticallyIndexedArray<T, N>& y, const X& x)
{
static_assert(X::Size() == N, "wrong! size not the same");
static_for<0, N, 1>{}([&](auto i) { y(i) -= x[i]; });
return y;
}
// StaticallyIndexedArray + Y
template <typename T, index_t N, IndexedContainer Y>
__host__ __device__ constexpr auto operator+(const StaticallyIndexedArray<T, N>& x, const Y& y)
{
static_assert(Y::Size() == N, "wrong! size not the same");
StaticallyIndexedArray<T, N> r;
static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] + y[i]; });
return r;
}
// StaticallyIndexedArray - Y
template <typename T, index_t N, IndexedContainer Y>
__host__ __device__ constexpr auto operator-(const StaticallyIndexedArray<T, N>& x, const Y& y)
{
static_assert(Y::Size() == N, "wrong! size not the same");
StaticallyIndexedArray<T, N> r;
static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] - y[i]; });
return r;
}
// StaticallyIndexedArray * Y (element-wise)
template <typename T, index_t N, IndexedContainer Y>
__host__ __device__ constexpr auto operator*(const StaticallyIndexedArray<T, N>& x, const Y& y)
{
static_assert(Y::Size() == N, "wrong! size not the same");
StaticallyIndexedArray<T, N> r;
static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] * y[i]; });
return r;
}
// scalar * StaticallyIndexedArray
template <typename T, index_t N, Scalar S>
__host__ __device__ constexpr auto operator*(S a, const StaticallyIndexedArray<T, N>& x)
{
StaticallyIndexedArray<T, N> r;
static_for<0, N, 1>{}([&](auto i) { r(i) = a * x[i]; });
return r;
}
// StaticallyIndexedArray * scalar
template <typename T, index_t N, Scalar S>
__host__ __device__ constexpr auto operator*(const StaticallyIndexedArray<T, N>& x, S a)
{
return a * x;
}
} // namespace ck
#endif