mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Optimize sequence metaprogramming utilities to reduce template instantiation depth (#3585)
This change significantly improves compile-time performance by reducing template instantiation depth for sequence generation and merging operations: Optimizations: - sequence_gen: Reduce instantiation depth from O(log N) to O(1) by using __make_integer_seq to generate indices in a single step, then applying the functor via pack expansion - uniform_sequence_gen: Similarly optimized to O(1) depth using __make_integer_seq with a helper that applies a constant value via pack expansion - sequence_merge: Reduce depth from O(N) to O(log N) using binary tree reduction strategy. Added direct concatenation specializations for 1-4 sequences to avoid recursion in common cases, falling back to binary tree merging for 5+ sequences Documentation: - Added extensive inline comments explaining why sequence_merge cannot achieve O(1) depth like sequence_gen (requires computing cumulative sequence lengths from heterogeneous inputs, inherently requiring recursion) - Documented the binary tree reduction approach and why it's superior to fold expressions for this use case Testing: - Added comprehensive unit tests for uniform_sequence_gen with different values, sizes, and edge cases - Added tests for sequence_gen with custom functors (double, square, identity, constant) to verify the new implementation works with arbitrary functors - Added tests for sequence_merge with 4, 5, and many sequences to verify both the direct concatenation path and binary tree reduction path - Added tests for empty sequence edge cases
This commit is contained in:
@@ -199,55 +199,113 @@ template <index_t N>
|
||||
using make_index_sequence =
|
||||
typename __make_integer_seq<impl::__integer_sequence, index_t, N>::seq_type;
|
||||
|
||||
// merge sequence
|
||||
template <typename Seq, typename... Seqs>
|
||||
struct sequence_merge
|
||||
// merge sequence - optimized to avoid recursive instantiation
|
||||
//
|
||||
// Note: Unlike sequence_gen and uniform_sequence_gen which use __make_integer_seq for O(1)
|
||||
// instantiation depth, sequence_merge cannot achieve O(1) depth. Here's why:
|
||||
//
|
||||
// - sequence_gen and uniform_sequence_gen generate a SINGLE output sequence where each
|
||||
// element can be computed independently: output[i] = f(i)
|
||||
//
|
||||
// - sequence_merge takes MULTIPLE input sequences with different, unknown lengths.
|
||||
// To compute output[i], we need to know:
|
||||
// 1. Which input sequence contains this index
|
||||
// 2. The offset within that sequence
|
||||
// This requires computing cumulative sequence lengths, which requires recursion/iteration.
|
||||
//
|
||||
// Instead, we use a binary tree reduction approach that achieves O(log N) instantiation depth:
|
||||
// - Base cases handle 1-4 sequences directly (O(1) for common cases)
|
||||
// - Recursive case merges pairs then combines: merge(s1,s2) + merge(s3,s4,...)
|
||||
// - This gives O(log N) depth, which is optimal for merging heterogeneous sequences
|
||||
//
|
||||
// Alternative considered: Fold expressions (... + sequences) would give O(N) depth due to
|
||||
// linear dependency chain, so binary tree is superior.
|
||||
//
|
||||
namespace detail {
|
||||
|
||||
// Helper to concatenate multiple sequences in one step using fold expression
|
||||
template <typename... Seqs>
|
||||
struct sequence_merge_impl;
|
||||
|
||||
// Base case: single sequence
|
||||
template <index_t... Is>
|
||||
struct sequence_merge_impl<Sequence<Is...>>
|
||||
{
|
||||
using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
|
||||
using type = Sequence<Is...>;
|
||||
};
|
||||
|
||||
// Two sequences: direct concatenation
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
|
||||
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>>
|
||||
{
|
||||
using type = Sequence<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
template <typename Seq>
|
||||
struct sequence_merge<Seq>
|
||||
// Three sequences: direct concatenation (avoids one level of recursion)
|
||||
template <index_t... Xs, index_t... Ys, index_t... Zs>
|
||||
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>>
|
||||
{
|
||||
using type = Seq;
|
||||
using type = Sequence<Xs..., Ys..., Zs...>;
|
||||
};
|
||||
|
||||
// generate sequence
|
||||
// Four sequences: direct concatenation
|
||||
template <index_t... As, index_t... Bs, index_t... Cs, index_t... Ds>
|
||||
struct sequence_merge_impl<Sequence<As...>, Sequence<Bs...>, Sequence<Cs...>, Sequence<Ds...>>
|
||||
{
|
||||
using type = Sequence<As..., Bs..., Cs..., Ds...>;
|
||||
};
|
||||
|
||||
// General case: binary tree reduction (O(log N) depth instead of O(N))
|
||||
template <typename S1, typename S2, typename S3, typename S4, typename... Rest>
|
||||
struct sequence_merge_impl<S1, S2, S3, S4, Rest...>
|
||||
{
|
||||
// Merge pairs first, then recurse
|
||||
using left = typename sequence_merge_impl<S1, S2>::type;
|
||||
using right = typename sequence_merge_impl<S3, S4, Rest...>::type;
|
||||
using type = typename sequence_merge_impl<left, right>::type;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename... Seqs>
|
||||
struct sequence_merge
|
||||
{
|
||||
using type = typename detail::sequence_merge_impl<Seqs...>::type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct sequence_merge<>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
// generate sequence - optimized using __make_integer_seq to avoid recursive instantiation
|
||||
namespace detail {
|
||||
|
||||
// Helper that applies functor F to indices and produces a Sequence
|
||||
// __make_integer_seq<sequence_gen_helper, index_t, N> produces sequence_gen_helper<index_t, 0, 1,
|
||||
// ..., N-1>
|
||||
template <typename T, T... Is>
|
||||
struct sequence_gen_helper
|
||||
{
|
||||
// Apply a functor F to all indices at once via pack expansion (O(1) depth)
|
||||
template <typename F>
|
||||
using apply = Sequence<F{}(Number<Is>{})...>;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <index_t NSize, typename F>
|
||||
struct sequence_gen
|
||||
{
|
||||
template <index_t IBegin, index_t NRemain, typename G>
|
||||
struct sequence_gen_impl
|
||||
{
|
||||
static constexpr index_t NRemainLeft = NRemain / 2;
|
||||
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
|
||||
static constexpr index_t IMiddle = IBegin + NRemainLeft;
|
||||
using type =
|
||||
typename __make_integer_seq<detail::sequence_gen_helper, index_t, NSize>::template apply<F>;
|
||||
};
|
||||
|
||||
using type = typename sequence_merge<
|
||||
typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
|
||||
typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
|
||||
};
|
||||
|
||||
template <index_t I, typename G>
|
||||
struct sequence_gen_impl<I, 1, G>
|
||||
{
|
||||
static constexpr index_t Is = G{}(Number<I>{});
|
||||
using type = Sequence<Is>;
|
||||
};
|
||||
|
||||
template <index_t I, typename G>
|
||||
struct sequence_gen_impl<I, 0, G>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
using type = typename sequence_gen_impl<0, NSize, F>::type;
|
||||
template <typename F>
|
||||
struct sequence_gen<0, F>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
// arithmetic sequence
|
||||
@@ -283,16 +341,30 @@ struct arithmetic_sequence_gen<0, IEnd, 1>
|
||||
using type = typename __make_integer_seq<WrapSequence, index_t, IEnd>::type;
|
||||
};
|
||||
|
||||
// uniform sequence
|
||||
// uniform sequence - optimized using __make_integer_seq
|
||||
namespace detail {
|
||||
|
||||
template <typename T, T... Is>
|
||||
struct uniform_sequence_helper
|
||||
{
|
||||
// Apply a constant value to all indices via pack expansion
|
||||
template <index_t Value>
|
||||
using apply = Sequence<((void)Is, Value)...>;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <index_t NSize, index_t I>
|
||||
struct uniform_sequence_gen
|
||||
{
|
||||
struct F
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t) const { return I; }
|
||||
};
|
||||
using type = typename __make_integer_seq<detail::uniform_sequence_helper, index_t, NSize>::
|
||||
template apply<I>;
|
||||
};
|
||||
|
||||
using type = typename sequence_gen<NSize, F>::type;
|
||||
template <index_t I>
|
||||
struct uniform_sequence_gen<0, I>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
// reverse inclusive scan (with init) sequence
|
||||
|
||||
@@ -20,6 +20,7 @@ struct tuple_concat<Tuple<Xs...>, Tuple<Ys...>>
|
||||
using type = Tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
// StaticallyIndexedArrayImpl uses binary split for O(log N) depth
|
||||
template <typename T, index_t N>
|
||||
struct StaticallyIndexedArrayImpl
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user