Optimize sequence metaprogramming utilities to reduce template instantiation depth (#3585)

This change significantly improves compile-time performance by reducing template
instantiation depth for sequence generation and merging operations:

Optimizations:
- sequence_gen: Reduce instantiation depth from O(log N) to O(1) by using
  __make_integer_seq to generate indices in a single step, then applying the
  functor via pack expansion
- uniform_sequence_gen: Similarly optimized to O(1) depth using __make_integer_seq
  with a helper that applies a constant value via pack expansion
- sequence_merge: Reduce depth from O(N) to O(log N) using binary tree reduction
  strategy. Added direct concatenation specializations for 1-4 sequences to
  avoid recursion in common cases, falling back to binary tree merging for 5+
  sequences

Documentation:
- Added extensive inline comments explaining why sequence_merge cannot achieve
  O(1) depth like sequence_gen (requires computing cumulative sequence lengths
  from heterogeneous inputs, inherently requiring recursion)
- Documented the binary tree reduction approach and why it's superior to fold
  expressions for this use case

Testing:
- Added comprehensive unit tests for uniform_sequence_gen with different values,
  sizes, and edge cases
- Added tests for sequence_gen with custom functors (double, square, identity,
  constant) to verify the new implementation works with arbitrary functors
- Added tests for sequence_merge with 4, 5, and many sequences to verify both
  the direct concatenation path and binary tree reduction path
- Added tests for empty sequence edge cases

[ROCm/composable_kernel commit: de59c0716c]
This commit is contained in:
Max Podkorytov
2026-01-26 10:08:55 -08:00
committed by GitHub
parent 9c3cc098c4
commit 8ae166963e
3 changed files with 247 additions and 40 deletions

View File

@@ -199,55 +199,113 @@ template <index_t N>
using make_index_sequence =
typename __make_integer_seq<impl::__integer_sequence, index_t, N>::seq_type;
// merge sequence
template <typename Seq, typename... Seqs>
struct sequence_merge
// merge sequence - optimized to avoid recursive instantiation
//
// Note: Unlike sequence_gen and uniform_sequence_gen which use __make_integer_seq for O(1)
// instantiation depth, sequence_merge cannot achieve O(1) depth. Here's why:
//
// - sequence_gen and uniform_sequence_gen generate a SINGLE output sequence where each
// element can be computed independently: output[i] = f(i)
//
// - sequence_merge takes MULTIPLE input sequences with different, unknown lengths.
// To compute output[i], we need to know:
// 1. Which input sequence contains this index
// 2. The offset within that sequence
// This requires computing cumulative sequence lengths, which requires recursion/iteration.
//
// Instead, we use a binary tree reduction approach that achieves O(log N) instantiation depth:
// - Base cases handle 1-4 sequences directly (O(1) for common cases)
// - Recursive case merges pairs then combines: merge(s1,s2) + merge(s3,s4,...)
// - This gives O(log N) depth, which is optimal for merging heterogeneous sequences
//
// Alternative considered: Fold expressions (... + sequences) would give O(N) depth due to
// linear dependency chain, so binary tree is superior.
//
namespace detail {
// Helper to concatenate multiple sequences in one step using fold expression
template <typename... Seqs>
struct sequence_merge_impl;
// Base case: single sequence
template <index_t... Is>
struct sequence_merge_impl<Sequence<Is...>>
{
using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
using type = Sequence<Is...>;
};
// Two sequences: direct concatenation
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>>
{
using type = Sequence<Xs..., Ys...>;
};
template <typename Seq>
struct sequence_merge<Seq>
// Three sequences: direct concatenation (avoids one level of recursion)
template <index_t... Xs, index_t... Ys, index_t... Zs>
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>>
{
using type = Seq;
using type = Sequence<Xs..., Ys..., Zs...>;
};
// generate sequence
// Four sequences: direct concatenation
template <index_t... As, index_t... Bs, index_t... Cs, index_t... Ds>
struct sequence_merge_impl<Sequence<As...>, Sequence<Bs...>, Sequence<Cs...>, Sequence<Ds...>>
{
using type = Sequence<As..., Bs..., Cs..., Ds...>;
};
// General case: binary tree reduction (O(log N) depth instead of O(N))
template <typename S1, typename S2, typename S3, typename S4, typename... Rest>
struct sequence_merge_impl<S1, S2, S3, S4, Rest...>
{
// Merge pairs first, then recurse
using left = typename sequence_merge_impl<S1, S2>::type;
using right = typename sequence_merge_impl<S3, S4, Rest...>::type;
using type = typename sequence_merge_impl<left, right>::type;
};
} // namespace detail
template <typename... Seqs>
struct sequence_merge
{
using type = typename detail::sequence_merge_impl<Seqs...>::type;
};
template <>
struct sequence_merge<>
{
using type = Sequence<>;
};
// generate sequence - optimized using __make_integer_seq to avoid recursive instantiation
namespace detail {
// Helper that applies functor F to indices and produces a Sequence
// __make_integer_seq<sequence_gen_helper, index_t, N> produces sequence_gen_helper<index_t, 0, 1,
// ..., N-1>
template <typename T, T... Is>
struct sequence_gen_helper
{
// Apply a functor F to all indices at once via pack expansion (O(1) depth)
template <typename F>
using apply = Sequence<F{}(Number<Is>{})...>;
};
} // namespace detail
template <index_t NSize, typename F>
struct sequence_gen
{
template <index_t IBegin, index_t NRemain, typename G>
struct sequence_gen_impl
{
static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
using type =
typename __make_integer_seq<detail::sequence_gen_helper, index_t, NSize>::template apply<F>;
};
using type = typename sequence_merge<
typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 1, G>
{
static constexpr index_t Is = G{}(Number<I>{});
using type = Sequence<Is>;
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 0, G>
{
using type = Sequence<>;
};
using type = typename sequence_gen_impl<0, NSize, F>::type;
template <typename F>
struct sequence_gen<0, F>
{
using type = Sequence<>;
};
// arithmetic sequence
@@ -283,16 +341,30 @@ struct arithmetic_sequence_gen<0, IEnd, 1>
using type = typename __make_integer_seq<WrapSequence, index_t, IEnd>::type;
};
// uniform sequence
// uniform sequence - optimized using __make_integer_seq
namespace detail {
template <typename T, T... Is>
struct uniform_sequence_helper
{
// Apply a constant value to all indices via pack expansion
template <index_t Value>
using apply = Sequence<((void)Is, Value)...>;
};
} // namespace detail
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
struct F
{
__host__ __device__ constexpr index_t operator()(index_t) const { return I; }
};
using type = typename __make_integer_seq<detail::uniform_sequence_helper, index_t, NSize>::
template apply<I>;
};
using type = typename sequence_gen<NSize, F>::type;
template <index_t I>
struct uniform_sequence_gen<0, I>
{
using type = Sequence<>;
};
// reverse inclusive scan (with init) sequence

View File

@@ -20,6 +20,7 @@ struct tuple_concat<Tuple<Xs...>, Tuple<Ys...>>
using type = Tuple<Xs..., Ys...>;
};
// StaticallyIndexedArrayImpl uses binary split for O(log N) depth
template <typename T, index_t N>
struct StaticallyIndexedArrayImpl
{

View File

@@ -229,6 +229,32 @@ TEST(SequenceGen, UniformSequenceZeroSize)
EXPECT_TRUE((is_same<Result, Expected>::value));
}
TEST(SequenceGen, UniformSequenceSingleElement)
{
using Result = typename uniform_sequence_gen<1, 99>::type;
using Expected = Sequence<99>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}
TEST(SequenceGen, UniformSequenceDifferentValues)
{
using Result1 = typename uniform_sequence_gen<3, 0>::type;
using Expected1 = Sequence<0, 0, 0>;
EXPECT_TRUE((is_same<Result1, Expected1>::value));
using Result2 = typename uniform_sequence_gen<4, -5>::type;
using Expected2 = Sequence<-5, -5, -5, -5>;
EXPECT_TRUE((is_same<Result2, Expected2>::value));
}
TEST(SequenceGen, UniformSequenceLargeSize)
{
// Test with larger size to verify __make_integer_seq implementation
using Result = typename uniform_sequence_gen<16, 7>::type;
using Expected = Sequence<7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}
// Test make_index_sequence
TEST(SequenceGen, MakeIndexSequence)
{
@@ -244,6 +270,54 @@ TEST(SequenceGen, MakeIndexSequenceZero)
EXPECT_TRUE((is_same<Result, Expected>::value));
}
// Test sequence_gen with custom functors
TEST(SequenceGen, SequenceGenWithDoubleFunctor)
{
struct DoubleFunctor
{
__host__ __device__ constexpr index_t operator()(index_t i) const { return i * 2; }
};
using Result = typename sequence_gen<5, DoubleFunctor>::type;
using Expected = Sequence<0, 2, 4, 6, 8>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}
TEST(SequenceGen, SequenceGenWithSquareFunctor)
{
struct SquareFunctor
{
__host__ __device__ constexpr index_t operator()(index_t i) const { return i * i; }
};
using Result = typename sequence_gen<5, SquareFunctor>::type;
using Expected = Sequence<0, 1, 4, 9, 16>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}
TEST(SequenceGen, SequenceGenZeroSize)
{
struct IdentityFunctor
{
__host__ __device__ constexpr index_t operator()(index_t i) const { return i; }
};
using Result = typename sequence_gen<0, IdentityFunctor>::type;
using Expected = Sequence<>;
EXPECT_TRUE((is_same<Result, Expected>::value));
// Also verify non-zero size works with identity
using Result5 = typename sequence_gen<5, IdentityFunctor>::type;
EXPECT_TRUE((is_same<Result5, Sequence<0, 1, 2, 3, 4>>::value));
}
TEST(SequenceGen, SequenceGenSingleElement)
{
struct ConstantFunctor
{
__host__ __device__ constexpr index_t operator()(index_t) const { return 42; }
};
using Result = typename sequence_gen<1, ConstantFunctor>::type;
using Expected = Sequence<42>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}
// Test sequence_merge
TEST(SequenceMerge, MergeTwoSequences)
{
@@ -272,6 +346,66 @@ TEST(SequenceMerge, MergeSingleSequence)
EXPECT_TRUE((is_same<Result, Expected>::value));
}
TEST(SequenceMerge, MergeFourSequences)
{
// Test the 4-sequence specialization
using Seq1 = Sequence<1>;
using Seq2 = Sequence<2, 3>;
using Seq3 = Sequence<4, 5, 6>;
using Seq4 = Sequence<7, 8>;
using Result = typename sequence_merge<Seq1, Seq2, Seq3, Seq4>::type;
using Expected = Sequence<1, 2, 3, 4, 5, 6, 7, 8>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}
TEST(SequenceMerge, MergeFiveSequences)
{
// Test the binary tree reduction path (5+ sequences)
using Seq1 = Sequence<1>;
using Seq2 = Sequence<2>;
using Seq3 = Sequence<3>;
using Seq4 = Sequence<4>;
using Seq5 = Sequence<5>;
using Result = typename sequence_merge<Seq1, Seq2, Seq3, Seq4, Seq5>::type;
using Expected = Sequence<1, 2, 3, 4, 5>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}
TEST(SequenceMerge, MergeManySequences)
{
// Test with many sequences to stress the binary tree reduction
using Seq1 = Sequence<1>;
using Seq2 = Sequence<2>;
using Seq3 = Sequence<3, 4>;
using Seq4 = Sequence<5>;
using Seq5 = Sequence<6, 7>;
using Seq6 = Sequence<8>;
using Seq7 = Sequence<9, 10>;
using Seq8 = Sequence<11, 12>;
using Result = typename sequence_merge<Seq1, Seq2, Seq3, Seq4, Seq5, Seq6, Seq7, Seq8>::type;
using Expected = Sequence<1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}
TEST(SequenceMerge, MergeEmptySequences)
{
// Test merging empty sequences
using Seq1 = Sequence<>;
using Seq2 = Sequence<1, 2>;
using Seq3 = Sequence<>;
using Result = typename sequence_merge<Seq1, Seq2, Seq3>::type;
using Expected = Sequence<1, 2>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}
TEST(SequenceMerge, MergeZeroSequences)
{
// Test the empty specialization
using Result = typename sequence_merge<>::type;
using Expected = Sequence<>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}
// Test sequence_split
TEST(SequenceSplit, SplitInMiddle)
{