From c868964f6a4d52623401e56c687c438e8d20ae72 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Wed, 10 Dec 2025 12:25:23 -0800 Subject: [PATCH] Improve sequence sorting and add unit tests (#3376) Old sequence sort code was showing up on build profiles. Convert it to constexpr functions for much more efficient build-time execution. The sorting is still O(N^2), but our sequences are small enough it executes quickly. This reduced compilation time of a small convolution by more than 10% and time overall time spent in the compiler on a narrow build by %6. [ROCm/composable_kernel commit: 15ed65db35e6702593cd8ed1d603222fb11684e4] --- include/ck/utility/sequence.hpp | 325 ++++++--------- test/CMakeLists.txt | 1 + test/util/CMakeLists.txt | 7 + test/util/unit_sequence.cpp | 684 ++++++++++++++++++++++++++++++++ 4 files changed, 808 insertions(+), 209 deletions(-) create mode 100644 test/util/CMakeLists.txt create mode 100644 test/util/unit_sequence.cpp diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 9f97d44a4a..6e68690048 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -380,236 +380,143 @@ struct sequence_reduce }; #endif -template -struct sequence_sort_impl +// Implement sequence_sort and sequence_unique_sort using constexpr functions (C++17) +namespace sort_impl { + +// Temporary arrays to hold values during operations with capacity N and mutable size. +template +struct IndexedValueArray { - template - struct sorted_sequence_merge_impl - { - static constexpr bool choose_left = LeftValues::Front() < RightValues::Front(); - - static constexpr index_t chosen_value = - choose_left ? LeftValues::Front() : RightValues::Front(); - static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front(); - - using new_merged_values = decltype(MergedValues::PushBack(Number{})); - using new_merged_ids = decltype(MergedIds::PushBack(Number{})); - - using new_left_values = - typename conditional::type; - using new_left_ids = - typename conditional::type; - - using new_right_values = - typename conditional::type; - using new_right_ids = - typename conditional::type; - - using merge = sorted_sequence_merge_impl; - // this is output - using merged_values = typename merge::merged_values; - using merged_ids = typename merge::merged_ids; - }; - - template - struct sorted_sequence_merge_impl, - Sequence<>, - MergedValues, - MergedIds, - Comp> - { - using merged_values = typename sequence_merge::type; - using merged_ids = typename sequence_merge::type; - }; - - template - struct sorted_sequence_merge_impl, - Sequence<>, - RightValues, - RightIds, - MergedValues, - MergedIds, - Comp> - { - using merged_values = typename sequence_merge::type; - using merged_ids = typename sequence_merge::type; - }; - - template - struct sorted_sequence_merge - { - using merge = sorted_sequence_merge_impl, - Sequence<>, - Comp>; - - using merged_values = typename merge::merged_values; - using merged_ids = typename merge::merged_ids; - }; - - static constexpr index_t nsize = Values::Size(); - - using split_unsorted_values = sequence_split; - using split_unsorted_ids = sequence_split; - - using left_unsorted_values = typename split_unsorted_values::left_type; - using left_unsorted_ids = typename split_unsorted_ids::left_type; - using left_sort = sequence_sort_impl; - using left_sorted_values = typename left_sort::sorted_values; - using left_sorted_ids = typename left_sort::sorted_ids; - - using right_unsorted_values = typename split_unsorted_values::right_type; - using right_unsorted_ids = typename split_unsorted_ids::right_type; - using right_sort = sequence_sort_impl; - using right_sorted_values = typename right_sort::sorted_values; - using right_sorted_ids = typename right_sort::sorted_ids; - - using merged_sorted = sorted_sequence_merge; - - using sorted_values = typename merged_sorted::merged_values; - using sorted_ids = typename merged_sorted::merged_ids; + index_t values[N > 0 ? N : 1]; + index_t ids[N > 0 ? N : 1]; + index_t size = 0; }; -template -struct sequence_sort_impl, Sequence, Compare> +template +constexpr auto make_indexed_value_array(Sequence) { - static constexpr bool choose_x = Compare{}(ValueX, ValueY); + constexpr index_t N = sizeof...(Is); + IndexedValueArray result = {{Is...}, {}, N}; + for(index_t i = 0; i < N; ++i) + { + result.ids[i] = i; + } + return result; +} - using sorted_values = - typename conditional, Sequence>::type; - using sorted_ids = typename conditional, Sequence>::type; +enum class SortField +{ + Values, + Ids }; -template -struct sequence_sort_impl, Sequence, Compare> +// Perform an insertion sort on an IndexedValueArray. +template +constexpr auto insertion_sort(IndexedValueArray arr, Compare comp) { - using sorted_values = Sequence; - using sorted_ids = Sequence; + for(index_t i = 1; i < arr.size; ++i) + { + index_t key_val = arr.values[i]; + index_t key_id = arr.ids[i]; + index_t j = i - 1; + while(j >= 0 && comp(key_val, arr.values[j])) + { + arr.values[j + 1] = arr.values[j]; + arr.ids[j + 1] = arr.ids[j]; + --j; + } + arr.values[j + 1] = key_val; + arr.ids[j + 1] = key_id; + } + return arr; +} + +// Remove duplicates from a sorted IndexedValueArray. +template +constexpr auto unique(const IndexedValueArray& sorted, Equal eq) +{ + IndexedValueArray result{}; + if constexpr(N == 0) + { + return result; + } + result.size = 1; + result.values[0] = sorted.values[0]; + result.ids[0] = sorted.ids[0]; + for(index_t i = 1; i < sorted.size; ++i) + { + if(!eq(sorted.values[i], sorted.values[i - 1])) + { + result.values[result.size] = sorted.values[i]; + result.ids[result.size] = sorted.ids[i]; + ++result.size; + } + } + return result; +} + +// Compute sorted (and optionally unique) IndexedValueArray from input Sequence. +template +constexpr auto compute_sorted(Sequence seq, Compare comp, Equal eq) +{ + auto sorted = insertion_sort(make_indexed_value_array(seq), comp); + return Unique ? unique(sorted, eq) : sorted; +} + +// Cache the sorted results to avoid recomputation. +template +struct SortedCache +{ + static constexpr auto data = compute_sorted(Seq{}, Compare{}, Equal{}); }; -template -struct sequence_sort_impl, Sequence<>, Compare> +// Build sorted value and ID sequences from cached sorted data +template +constexpr index_t get_sorted_field() { - using sorted_values = Sequence<>; - using sorted_ids = Sequence<>; + constexpr auto& data = SortedCache::data; + return (Field == SortField::Values) ? data.values[I] : data.ids[I]; +} + +template +struct SortedSequences; + +template +struct SortedSequences> +{ + using values_type = + Sequence()...>; + using ids_type = + Sequence()...>; }; +template +using sorted_sequences_t = SortedSequences< + Unique, + Seq, + Compare, + Equal, + typename arithmetic_sequence_gen<0, SortedCache::data.size, 1>:: + type>; + +using Equal = ck::math::equal; + +} // namespace sort_impl + template struct sequence_sort { - using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type; - using sort = sequence_sort_impl; - - // this is output - using type = typename sort::sorted_values; - using sorted2unsorted_map = typename sort::sorted_ids; + using sorted_seqs = sort_impl::sorted_sequences_t; + using type = typename sorted_seqs::values_type; + using sorted2unsorted_map = typename sorted_seqs::ids_type; }; template struct sequence_unique_sort { - template - struct sorted_sequence_uniquify_impl - { - static constexpr index_t current_value = RemainValues::Front(); - static constexpr index_t current_id = RemainIds::Front(); - - static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back()); - - using new_remain_values = decltype(RemainValues::PopFront()); - using new_remain_ids = decltype(RemainIds::PopFront()); - - using new_uniquified_values = - typename conditional{})), - UniquifiedValues>::type; - - using new_uniquified_ids = - typename conditional{})), - UniquifiedIds>::type; - - using uniquify = sorted_sequence_uniquify_impl; - - // this is output - using uniquified_values = typename uniquify::uniquified_values; - using uniquified_ids = typename uniquify::uniquified_ids; - }; - - template - struct sorted_sequence_uniquify_impl, - Sequence<>, - UniquifiedValues, - UniquifiedIds, - Eq> - { - using uniquified_values = UniquifiedValues; - using uniquified_ids = UniquifiedIds; - }; - - template - struct sorted_sequence_uniquify - { - using uniquify = sorted_sequence_uniquify_impl, - Sequence, - Eq>; - - using uniquified_values = typename uniquify::uniquified_values; - using uniquified_ids = typename uniquify::uniquified_ids; - }; - - using sort = sequence_sort; - using sorted_values = typename sort::type; - using sorted_ids = typename sort::sorted2unsorted_map; - - using uniquify = sorted_sequence_uniquify; - - // this is output - using type = typename uniquify::uniquified_values; - using sorted2unsorted_map = typename uniquify::uniquified_ids; + using sorted_seqs = sort_impl::sorted_sequences_t; + using type = typename sorted_seqs::values_type; + using sorted2unsorted_map = typename sorted_seqs::ids_type; }; template diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c221f11f46..b7db14945d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -310,3 +310,4 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx12") endif() add_subdirectory(position_embedding) add_subdirectory(scatter_gather) +add_subdirectory(util) diff --git a/test/util/CMakeLists.txt b/test/util/CMakeLists.txt new file mode 100644 index 0000000000..bf0a444f18 --- /dev/null +++ b/test/util/CMakeLists.txt @@ -0,0 +1,7 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_gtest_executable(unit_sequence unit_sequence.cpp) +if(result EQUAL 0) + target_link_libraries(unit_sequence PRIVATE utility) +endif() diff --git a/test/util/unit_sequence.cpp b/test/util/unit_sequence.cpp new file mode 100644 index 0000000000..f09fd86e06 --- /dev/null +++ b/test/util/unit_sequence.cpp @@ -0,0 +1,684 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "ck/utility/sequence.hpp" +#include "ck/utility/functional.hpp" + +using namespace ck; + +// Test basic Sequence construction and properties +TEST(Sequence, BasicConstruction) +{ + using Seq = Sequence<1, 2, 3, 4, 5>; + EXPECT_EQ(Seq::Size(), 5); + EXPECT_EQ(Seq::mSize, 5); +} + +TEST(Sequence, EmptySequence) +{ + using Seq = Sequence<>; + EXPECT_EQ(Seq::Size(), 0); + EXPECT_EQ(Seq::mSize, 0); +} + +// Test At() method +TEST(Sequence, AtRuntime) +{ + using Seq = Sequence<10, 20, 30, 40>; + EXPECT_EQ(Seq::At(0), 10); + EXPECT_EQ(Seq::At(1), 20); + EXPECT_EQ(Seq::At(2), 30); + EXPECT_EQ(Seq::At(3), 40); +} + +TEST(Sequence, AtCompileTime) +{ + using Seq = Sequence<10, 20, 30, 40>; + EXPECT_EQ(Seq::At(Number<0>{}), 10); + EXPECT_EQ(Seq::At(Number<1>{}), 20); + EXPECT_EQ(Seq::At(Number<2>{}), 30); + EXPECT_EQ(Seq::At(Number<3>{}), 40); +} + +TEST(Sequence, OperatorBracket) +{ + constexpr auto seq = Sequence<5, 10, 15>{}; + EXPECT_EQ(seq[Number<0>{}], 5); + EXPECT_EQ(seq[Number<1>{}], 10); + EXPECT_EQ(seq[Number<2>{}], 15); +} + +// Test Front() and Back() +TEST(Sequence, FrontBack) +{ + using Seq = Sequence<100, 200, 300>; + EXPECT_EQ(Seq::Front(), 100); + EXPECT_EQ(Seq::Back(), 300); +} + +TEST(Sequence, FrontBackSingleElement) +{ + using Seq = Sequence<42>; + EXPECT_EQ(Seq::Front(), 42); + EXPECT_EQ(Seq::Back(), 42); +} + +// Test PushFront and PushBack +TEST(Sequence, PushFront) +{ + using Seq = Sequence<2, 3, 4>; + using Result = decltype(Seq::PushFront(Sequence<1>{})); + using Expected = Sequence<1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, PushFrontNumbers) +{ + using Seq = Sequence<3, 4>; + using Result = decltype(Seq::PushFront(Number<1>{}, Number<2>{})); + using Expected = Sequence<1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, PushBack) +{ + using Seq = Sequence<1, 2, 3>; + using Result = decltype(Seq::PushBack(Sequence<4, 5>{})); + using Expected = Sequence<1, 2, 3, 4, 5>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, PushBackNumbers) +{ + using Seq = Sequence<1, 2>; + using Result = decltype(Seq::PushBack(Number<3>{}, Number<4>{})); + using Expected = Sequence<1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +// Test PopFront and PopBack +TEST(Sequence, PopFront) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = decltype(Seq::PopFront()); + using Expected = Sequence<2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, PopBack) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = decltype(Seq::PopBack()); + using Expected = Sequence<1, 2, 3>; + EXPECT_TRUE((is_same::value)); +} + +// Test Extract +TEST(Sequence, ExtractByNumbers) +{ + using Seq = Sequence<10, 20, 30, 40, 50>; + using Result = decltype(Seq::Extract(Number<0>{}, Number<2>{}, Number<4>{})); + using Expected = Sequence<10, 30, 50>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, ExtractBySequence) +{ + using Seq = Sequence<10, 20, 30, 40, 50>; + using Result = decltype(Seq::Extract(Sequence<1, 3>{})); + using Expected = Sequence<20, 40>; + EXPECT_TRUE((is_same::value)); +} + +// Test Modify +TEST(Sequence, Modify) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = decltype(Seq::Modify(Number<2>{}, Number<99>{})); + using Expected = Sequence<1, 2, 99, 4>; + EXPECT_TRUE((is_same::value)); +} + +// Test Transform +TEST(Sequence, Transform) +{ + using Seq = Sequence<1, 2, 3, 4>; + auto double_it = [](auto x) { return 2 * x; }; + using Result = decltype(Seq::Transform(double_it)); + using Expected = Sequence<2, 4, 6, 8>; + EXPECT_TRUE((is_same::value)); +} + +// Test Reverse +TEST(Sequence, Reverse) +{ + using Seq = Sequence<1, 2, 3, 4, 5>; + using Result = decltype(Seq::Reverse()); + using Expected = Sequence<5, 4, 3, 2, 1>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, ReverseSingleElement) +{ + using Seq = Sequence<42>; + using Result = decltype(Seq::Reverse()); + using Expected = Sequence<42>; + EXPECT_TRUE((is_same::value)); +} + +// Test ReorderGivenNew2Old +TEST(Sequence, ReorderGivenNew2Old) +{ + using Seq = Sequence<10, 20, 30, 40>; + using Result = decltype(Seq::ReorderGivenNew2Old(Sequence<3, 1, 2, 0>{})); + using Expected = Sequence<40, 20, 30, 10>; + EXPECT_TRUE((is_same::value)); +} + +// Test ReorderGivenOld2New +TEST(Sequence, ReorderGivenOld2New) +{ + using Seq = Sequence<10, 20, 30, 40>; + using Result = decltype(Seq::ReorderGivenOld2New(Sequence<3, 1, 2, 0>{})); + using Expected = Sequence<40, 20, 30, 10>; + EXPECT_TRUE((is_same::value)); +} + +// Test arithmetic_sequence_gen +TEST(SequenceGen, ArithmeticSequence) +{ + using Result = typename arithmetic_sequence_gen<0, 5, 1>::type; + using Expected = Sequence<0, 1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, ArithmeticSequenceWithIncrement) +{ + using Result = typename arithmetic_sequence_gen<0, 10, 2>::type; + using Expected = Sequence<0, 2, 4, 6, 8>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, ArithmeticSequenceNegativeIncrement) +{ + using Result = typename arithmetic_sequence_gen<10, 5, -1>::type; + using Expected = Sequence<10, 9, 8, 7, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, ArithmeticSequenceEmpty) +{ + using Result = typename arithmetic_sequence_gen<5, 5, 1>::type; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + +// Test uniform_sequence_gen +TEST(SequenceGen, UniformSequence) +{ + using Result = typename uniform_sequence_gen<5, 42>::type; + using Expected = Sequence<42, 42, 42, 42, 42>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, UniformSequenceZeroSize) +{ + using Result = typename uniform_sequence_gen<0, 42>::type; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + +// Test make_index_sequence +TEST(SequenceGen, MakeIndexSequence) +{ + using Result = make_index_sequence<5>; + using Expected = Sequence<0, 1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, MakeIndexSequenceZero) +{ + using Result = make_index_sequence<0>; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + +// Test sequence_merge +TEST(SequenceMerge, MergeTwoSequences) +{ + using Seq1 = Sequence<1, 2, 3>; + using Seq2 = Sequence<4, 5, 6>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3, 4, 5, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeMultipleSequences) +{ + using Seq1 = Sequence<1, 2>; + using Seq2 = Sequence<3, 4>; + using Seq3 = Sequence<5, 6>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3, 4, 5, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeSingleSequence) +{ + using Seq = Sequence<1, 2, 3>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3>; + EXPECT_TRUE((is_same::value)); +} + +// Test sequence_split +TEST(SequenceSplit, SplitInMiddle) +{ + using Seq = Sequence<1, 2, 3, 4, 5, 6>; + using Split = sequence_split; + using ExpectedLeft = Sequence<1, 2, 3>; + using ExpectedRight = Sequence<4, 5, 6>; + EXPECT_TRUE((is_same::value)); + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSplit, SplitAtBeginning) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Split = sequence_split; + using ExpectedLeft = Sequence<>; + using ExpectedRight = Sequence<1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSplit, SplitAtEnd) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Split = sequence_split; + using ExpectedLeft = Sequence<1, 2, 3, 4>; + using ExpectedRight = Sequence<>; + EXPECT_TRUE((is_same::value)); + EXPECT_TRUE((is_same::value)); +} + +// Test sequence_sort +TEST(SequenceSort, SortAscending) +{ + using Seq = Sequence<5, 2, 8, 1, 9>; + using Result = typename sequence_sort>::type; + using Expected = Sequence<1, 2, 5, 8, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSort, SortDescending) +{ + // Create a greater-than comparator + struct greater + { + __host__ __device__ constexpr bool operator()(index_t x, index_t y) const { return x > y; } + }; + using Seq = Sequence<5, 2, 8, 1, 9>; + using Result = typename sequence_sort::type; + using Expected = Sequence<9, 8, 5, 2, 1>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSort, SortAlreadySorted) +{ + using Seq = Sequence<1, 2, 3, 4, 5>; + using Result = typename sequence_sort>::type; + using Expected = Sequence<1, 2, 3, 4, 5>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSort, SortWithDuplicates) +{ + using Seq = Sequence<3, 1, 4, 1, 5, 9, 2, 6, 5>; + using Result = typename sequence_sort>::type; + using Expected = Sequence<1, 1, 2, 3, 4, 5, 5, 6, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSort, SortEmptySequence) +{ + using Seq = Sequence<>; + using Result = typename sequence_sort>::type; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSort, SortSingleElement) +{ + using Seq = Sequence<42>; + using Result = typename sequence_sort>::type; + using Expected = Sequence<42>; + EXPECT_TRUE((is_same::value)); +} + +// Test sequence_unique_sort +TEST(SequenceUniqueSort, UniqueSort) +{ + using Seq = Sequence<3, 1, 4, 1, 5, 9, 2, 6, 5>; + using Result = + typename sequence_unique_sort, math::equal>::type; + using Expected = Sequence<1, 2, 3, 4, 5, 6, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceUniqueSort, UniqueSortNoDuplicates) +{ + using Seq = Sequence<5, 2, 8, 1, 9>; + using Result = + typename sequence_unique_sort, math::equal>::type; + using Expected = Sequence<1, 2, 5, 8, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceUniqueSort, UniqueSortAllSame) +{ + using Seq = Sequence<5, 5, 5, 5>; + using Result = + typename sequence_unique_sort, math::equal>::type; + using Expected = Sequence<5>; + EXPECT_TRUE((is_same::value)); +} + +// Test is_valid_sequence_map +TEST(SequenceMap, ValidMap) +{ + using Map = Sequence<0, 1, 2, 3>; + EXPECT_TRUE((is_valid_sequence_map::value)); +} + +TEST(SequenceMap, ValidMapPermuted) +{ + using Map = Sequence<2, 0, 3, 1>; + EXPECT_TRUE((is_valid_sequence_map::value)); +} + +TEST(SequenceMap, InvalidMapDuplicate) +{ + using Map = Sequence<0, 1, 1, 3>; + EXPECT_FALSE((is_valid_sequence_map::value)); +} + +TEST(SequenceMap, InvalidMapMissing) +{ + using Map = Sequence<0, 1, 3, 4>; + EXPECT_FALSE((is_valid_sequence_map::value)); +} + +// Test sequence_map_inverse +// Note: sequence_map_inverse inverts a mapping where Map[i] = j means old position i maps to new +// position j The inverse gives us new position i came from old position inverse[i] +TEST(SequenceMapInverse, InverseMap) +{ + // Map = <2, 0, 3, 1> means: old[0]->new[2], old[1]->new[0], old[2]->new[3], old[3]->new[1] + // Inverse should be: new[0]<-old[1], new[1]<-old[3], new[2]<-old[0], new[3]<-old[2] + using Map = Sequence<2, 0, 3, 1>; + using Result = typename sequence_map_inverse::type; + // Verify by checking that Map[Result[i]] == i for all i + EXPECT_EQ((Map::At(Number{})>{}) == 0), true); + EXPECT_EQ((Map::At(Number{})>{}) == 1), true); + EXPECT_EQ((Map::At(Number{})>{}) == 2), true); + EXPECT_EQ((Map::At(Number{})>{}) == 3), true); +} + +TEST(SequenceMapInverse, InverseIdentityMap) +{ + using Map = Sequence<0, 1, 2, 3>; + using Result = typename sequence_map_inverse::type; + // Verify by checking that Map[Result[i]] == i for all i (same as the other test) + EXPECT_EQ((Map::At(Number{})>{}) == 0), true); + EXPECT_EQ((Map::At(Number{})>{}) == 1), true); + EXPECT_EQ((Map::At(Number{})>{}) == 2), true); + EXPECT_EQ((Map::At(Number{})>{}) == 3), true); +} + +// Test sequence operators +TEST(SequenceOperators, Equality) +{ + constexpr auto seq1 = Sequence<1, 2, 3>{}; + constexpr auto seq2 = Sequence<1, 2, 3>{}; + constexpr auto seq3 = Sequence<1, 2, 4>{}; + EXPECT_TRUE(seq1 == seq2); + EXPECT_FALSE(seq1 == seq3); +} + +TEST(SequenceOperators, Addition) +{ + using Seq1 = Sequence<1, 2, 3>; + using Seq2 = Sequence<4, 5, 6>; + using Result = decltype(Seq1{} + Seq2{}); + using Expected = Sequence<5, 7, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, Subtraction) +{ + using Seq1 = Sequence<10, 20, 30>; + using Seq2 = Sequence<1, 2, 3>; + using Result = decltype(Seq1{} - Seq2{}); + using Expected = Sequence<9, 18, 27>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, Multiplication) +{ + using Seq1 = Sequence<2, 3, 4>; + using Seq2 = Sequence<5, 6, 7>; + using Result = decltype(Seq1{} * Seq2{}); + using Expected = Sequence<10, 18, 28>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, Division) +{ + using Seq1 = Sequence<10, 20, 30>; + using Seq2 = Sequence<2, 4, 5>; + using Result = decltype(Seq1{} / Seq2{}); + using Expected = Sequence<5, 5, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, Modulo) +{ + using Seq1 = Sequence<10, 20, 30>; + using Seq2 = Sequence<3, 7, 8>; + using Result = decltype(Seq1{} % Seq2{}); + using Expected = Sequence<1, 6, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, AdditionWithNumber) +{ + using Seq = Sequence<1, 2, 3>; + using Result = decltype(Seq{} + Number<10>{}); + using Expected = Sequence<11, 12, 13>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, SubtractionWithNumber) +{ + using Seq = Sequence<10, 20, 30>; + using Result = decltype(Seq{} - Number<5>{}); + using Expected = Sequence<5, 15, 25>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, MultiplicationWithNumber) +{ + using Seq = Sequence<2, 3, 4>; + using Result = decltype(Seq{} * Number<3>{}); + using Expected = Sequence<6, 9, 12>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, DivisionWithNumber) +{ + using Seq = Sequence<10, 20, 30>; + using Result = decltype(Seq{} / Number<5>{}); + using Expected = Sequence<2, 4, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, NumberAddition) +{ + using Seq = Sequence<1, 2, 3>; + using Result = decltype(Number<10>{} + Seq{}); + using Expected = Sequence<11, 12, 13>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, NumberMultiplication) +{ + using Seq = Sequence<2, 3, 4>; + using Result = decltype(Number<3>{} * Seq{}); + using Expected = Sequence<6, 9, 12>; + EXPECT_TRUE((is_same::value)); +} + +// Test helper functions +TEST(SequenceHelpers, MergeSequences) +{ + using Seq1 = Sequence<1, 2>; + using Seq2 = Sequence<3, 4>; + using Seq3 = Sequence<5, 6>; + using Result = decltype(merge_sequences(Seq1{}, Seq2{}, Seq3{})); + using Expected = Sequence<1, 2, 3, 4, 5, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceHelpers, TransformSequencesSingle) +{ + auto double_it = [](auto x) { return 2 * x; }; + using Seq = Sequence<1, 2, 3>; + using Result = decltype(transform_sequences(double_it, Seq{})); + using Expected = Sequence<2, 4, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceHelpers, TransformSequencesTwo) +{ + auto add = [](auto x, auto y) { return x + y; }; + using Seq1 = Sequence<1, 2, 3>; + using Seq2 = Sequence<4, 5, 6>; + using Result = decltype(transform_sequences(add, Seq1{}, Seq2{})); + using Expected = Sequence<5, 7, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceHelpers, TransformSequencesThree) +{ + auto add3 = [](auto x, auto y, auto z) { return x + y + z; }; + using Seq1 = Sequence<1, 2, 3>; + using Seq2 = Sequence<4, 5, 6>; + using Seq3 = Sequence<7, 8, 9>; + using Result = decltype(transform_sequences(add3, Seq1{}, Seq2{}, Seq3{})); + using Expected = Sequence<12, 15, 18>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceHelpers, ReduceOnSequence) +{ + auto add = [](auto x, auto y) { return x + y; }; + constexpr auto seq = Sequence<1, 2, 3, 4, 5>{}; + constexpr auto result = reduce_on_sequence(seq, add, Number<0>{}); + EXPECT_EQ(result, 15); +} + +TEST(SequenceHelpers, SequenceAnyOf) +{ + auto is_even = [](auto x) { return x % 2 == 0; }; + constexpr auto seq1 = Sequence<1, 3, 5, 7>{}; + constexpr auto seq2 = Sequence<1, 3, 4, 7>{}; + EXPECT_FALSE(sequence_any_of(seq1, is_even)); + EXPECT_TRUE(sequence_any_of(seq2, is_even)); +} + +TEST(SequenceHelpers, SequenceAllOf) +{ + auto is_positive = [](auto x) { return x > 0; }; + constexpr auto seq1 = Sequence<1, 2, 3, 4>{}; + constexpr auto seq2 = Sequence<1, -2, 3, 4>{}; + EXPECT_TRUE(sequence_all_of(seq1, is_positive)); + EXPECT_FALSE(sequence_all_of(seq2, is_positive)); +} + +// Test scan operations +TEST(SequenceScan, ReverseInclusiveScan) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = + decltype(reverse_inclusive_scan_sequence(Seq{}, math::plus{}, Number<0>{})); + using Expected = Sequence<10, 9, 7, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceScan, ReverseExclusiveScan) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = + decltype(reverse_exclusive_scan_sequence(Seq{}, math::plus{}, Number<0>{})); + using Expected = Sequence<9, 7, 4, 0>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceScan, InclusiveScan) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = decltype(inclusive_scan_sequence(Seq{}, math::plus{}, Number<0>{})); + using Expected = Sequence<1, 3, 6, 10>; + EXPECT_TRUE((is_same::value)); +} + +// Test pick and modify operations +TEST(SequencePickModify, PickElementsByIds) +{ + using Seq = Sequence<10, 20, 30, 40, 50>; + using Ids = Sequence<0, 2, 4>; + using Result = decltype(pick_sequence_elements_by_ids(Seq{}, Ids{})); + using Expected = Sequence<10, 30, 50>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequencePickModify, PickElementsByMask) +{ + using Seq = Sequence<10, 20, 30, 40, 50>; + using Mask = Sequence<1, 0, 1, 0, 1>; + using Result = decltype(pick_sequence_elements_by_mask(Seq{}, Mask{})); + using Expected = Sequence<10, 30, 50>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequencePickModify, ModifyElementsByIds) +{ + using Seq = Sequence<10, 20, 30, 40, 50>; + using Values = Sequence<99, 88>; + using Ids = Sequence<1, 3>; + using Result = decltype(modify_sequence_elements_by_ids(Seq{}, Values{}, Ids{})); + using Expected = Sequence<10, 99, 30, 88, 50>; + EXPECT_TRUE((is_same::value)); +} + +// Test sequence_reduce +TEST(SequenceReduce, ReduceTwoSequences) +{ + using Seq1 = Sequence<1, 2, 3>; + using Seq2 = Sequence<4, 5, 6>; + using Result = typename sequence_reduce, Seq1, Seq2>::type; + using Expected = Sequence<5, 7, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceReduce, ReduceMultipleSequences) +{ + using Seq1 = Sequence<1, 2>; + using Seq2 = Sequence<3, 4>; + using Seq3 = Sequence<5, 6>; + using Result = typename sequence_reduce, Seq1, Seq2, Seq3>::type; + using Expected = Sequence<9, 12>; + EXPECT_TRUE((is_same::value)); +}