Rewrite sequence_map_inverse using O(1) depth pack expansion

Replace O(N) recursive template sequence_map_inverse_impl with
constexpr function and pack expansion for O(1) template depth.

Results:
- sequence_map_inverse: 45 instances, 187ms → 7 instances, 10ms (95% reduction)
This commit is contained in:
Max Podkorytov
2026-01-16 11:19:35 -06:00
parent 02e42dcaa1
commit a8c9be9378

View File

@@ -576,31 +576,35 @@ struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMa
{
};
// O(1) template depth helper to find source index in permutation inversion
// For a permutation X2Y, finds i such that X2Y[i] == Target
namespace detail {
template <index_t Target, index_t... Is>
__host__ __device__ constexpr index_t find_source_index(Sequence<Is...>)
{
constexpr index_t values[] = {Is...};
for(index_t i = 0; i < static_cast<index_t>(sizeof...(Is)); ++i)
{
if(values[i] == Target)
return i;
}
return 0; // should not reach for valid permutation
}
template <typename SeqMap, index_t... Positions>
__host__ __device__ constexpr auto invert_permutation_impl(Sequence<Positions...>)
{
return Sequence<find_source_index<Positions>(SeqMap{})...>{};
}
} // namespace detail
// Invert a permutation sequence using O(1) template depth pack expansion
// For X2Y = {a, b, c, ...}, computes Y2X where Y2X[X2Y[i]] = i
template <typename SeqMap>
struct sequence_map_inverse
{
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
{
static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
type;
};
template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{
using type = WorkingY2X;
};
using type =
typename sequence_map_inverse_impl<SeqMap,
typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
0,
SeqMap::Size()>::type;
using type = decltype(detail::invert_permutation_impl<SeqMap>(
typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type{}));
};
template <index_t... Xs, index_t... Ys>