mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Optimize sequence metaprogramming utilities to reduce template instantiation depth (#3585)
This change significantly improves compile-time performance by reducing template instantiation depth for sequence generation and merging operations: Optimizations: - sequence_gen: Reduce instantiation depth from O(log N) to O(1) by using __make_integer_seq to generate indices in a single step, then applying the functor via pack expansion - uniform_sequence_gen: Similarly optimized to O(1) depth using __make_integer_seq with a helper that applies a constant value via pack expansion - sequence_merge: Reduce depth from O(N) to O(log N) using binary tree reduction strategy. Added direct concatenation specializations for 1-4 sequences to avoid recursion in common cases, falling back to binary tree merging for 5+ sequences Documentation: - Added extensive inline comments explaining why sequence_merge cannot achieve O(1) depth like sequence_gen (requires computing cumulative sequence lengths from heterogeneous inputs, inherently requiring recursion) - Documented the binary tree reduction approach and why it's superior to fold expressions for this use case Testing: - Added comprehensive unit tests for uniform_sequence_gen with different values, sizes, and edge cases - Added tests for sequence_gen with custom functors (double, square, identity, constant) to verify the new implementation works with arbitrary functors - Added tests for sequence_merge with 4, 5, and many sequences to verify both the direct concatenation path and binary tree reduction path - Added tests for empty sequence edge cases
This commit is contained in:
@@ -229,6 +229,32 @@ TEST(SequenceGen, UniformSequenceZeroSize)
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, UniformSequenceSingleElement)
|
||||
{
|
||||
using Result = typename uniform_sequence_gen<1, 99>::type;
|
||||
using Expected = Sequence<99>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, UniformSequenceDifferentValues)
|
||||
{
|
||||
using Result1 = typename uniform_sequence_gen<3, 0>::type;
|
||||
using Expected1 = Sequence<0, 0, 0>;
|
||||
EXPECT_TRUE((is_same<Result1, Expected1>::value));
|
||||
|
||||
using Result2 = typename uniform_sequence_gen<4, -5>::type;
|
||||
using Expected2 = Sequence<-5, -5, -5, -5>;
|
||||
EXPECT_TRUE((is_same<Result2, Expected2>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, UniformSequenceLargeSize)
|
||||
{
|
||||
// Test with larger size to verify __make_integer_seq implementation
|
||||
using Result = typename uniform_sequence_gen<16, 7>::type;
|
||||
using Expected = Sequence<7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test make_index_sequence
|
||||
TEST(SequenceGen, MakeIndexSequence)
|
||||
{
|
||||
@@ -244,6 +270,54 @@ TEST(SequenceGen, MakeIndexSequenceZero)
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test sequence_gen with custom functors
|
||||
TEST(SequenceGen, SequenceGenWithDoubleFunctor)
|
||||
{
|
||||
struct DoubleFunctor
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t i) const { return i * 2; }
|
||||
};
|
||||
using Result = typename sequence_gen<5, DoubleFunctor>::type;
|
||||
using Expected = Sequence<0, 2, 4, 6, 8>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, SequenceGenWithSquareFunctor)
|
||||
{
|
||||
struct SquareFunctor
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t i) const { return i * i; }
|
||||
};
|
||||
using Result = typename sequence_gen<5, SquareFunctor>::type;
|
||||
using Expected = Sequence<0, 1, 4, 9, 16>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, SequenceGenZeroSize)
|
||||
{
|
||||
struct IdentityFunctor
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t i) const { return i; }
|
||||
};
|
||||
using Result = typename sequence_gen<0, IdentityFunctor>::type;
|
||||
using Expected = Sequence<>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
// Also verify non-zero size works with identity
|
||||
using Result5 = typename sequence_gen<5, IdentityFunctor>::type;
|
||||
EXPECT_TRUE((is_same<Result5, Sequence<0, 1, 2, 3, 4>>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, SequenceGenSingleElement)
|
||||
{
|
||||
struct ConstantFunctor
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t) const { return 42; }
|
||||
};
|
||||
using Result = typename sequence_gen<1, ConstantFunctor>::type;
|
||||
using Expected = Sequence<42>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test sequence_merge
|
||||
TEST(SequenceMerge, MergeTwoSequences)
|
||||
{
|
||||
@@ -272,6 +346,66 @@ TEST(SequenceMerge, MergeSingleSequence)
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMerge, MergeFourSequences)
|
||||
{
|
||||
// Test the 4-sequence specialization
|
||||
using Seq1 = Sequence<1>;
|
||||
using Seq2 = Sequence<2, 3>;
|
||||
using Seq3 = Sequence<4, 5, 6>;
|
||||
using Seq4 = Sequence<7, 8>;
|
||||
using Result = typename sequence_merge<Seq1, Seq2, Seq3, Seq4>::type;
|
||||
using Expected = Sequence<1, 2, 3, 4, 5, 6, 7, 8>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMerge, MergeFiveSequences)
|
||||
{
|
||||
// Test the binary tree reduction path (5+ sequences)
|
||||
using Seq1 = Sequence<1>;
|
||||
using Seq2 = Sequence<2>;
|
||||
using Seq3 = Sequence<3>;
|
||||
using Seq4 = Sequence<4>;
|
||||
using Seq5 = Sequence<5>;
|
||||
using Result = typename sequence_merge<Seq1, Seq2, Seq3, Seq4, Seq5>::type;
|
||||
using Expected = Sequence<1, 2, 3, 4, 5>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMerge, MergeManySequences)
|
||||
{
|
||||
// Test with many sequences to stress the binary tree reduction
|
||||
using Seq1 = Sequence<1>;
|
||||
using Seq2 = Sequence<2>;
|
||||
using Seq3 = Sequence<3, 4>;
|
||||
using Seq4 = Sequence<5>;
|
||||
using Seq5 = Sequence<6, 7>;
|
||||
using Seq6 = Sequence<8>;
|
||||
using Seq7 = Sequence<9, 10>;
|
||||
using Seq8 = Sequence<11, 12>;
|
||||
using Result = typename sequence_merge<Seq1, Seq2, Seq3, Seq4, Seq5, Seq6, Seq7, Seq8>::type;
|
||||
using Expected = Sequence<1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMerge, MergeEmptySequences)
|
||||
{
|
||||
// Test merging empty sequences
|
||||
using Seq1 = Sequence<>;
|
||||
using Seq2 = Sequence<1, 2>;
|
||||
using Seq3 = Sequence<>;
|
||||
using Result = typename sequence_merge<Seq1, Seq2, Seq3>::type;
|
||||
using Expected = Sequence<1, 2>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMerge, MergeZeroSequences)
|
||||
{
|
||||
// Test the empty specialization
|
||||
using Result = typename sequence_merge<>::type;
|
||||
using Expected = Sequence<>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test sequence_split
|
||||
TEST(SequenceSplit, SplitInMiddle)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user