mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Improve sequence sorting and add unit tests (#3376)
Old sequence sort code was showing up on build profiles. Convert it to constexpr functions for much more efficient build-time execution. The sorting is still O(N^2), but our sequences are small enough it executes quickly. This reduced compilation time of a small convolution by more than 10% and time overall time spent in the compiler on a narrow build by %6.
This commit is contained in:
@@ -380,236 +380,143 @@ struct sequence_reduce<Reduce, Seq>
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename Values, typename Ids, typename Compare>
|
||||
struct sequence_sort_impl
|
||||
// Implement sequence_sort and sequence_unique_sort using constexpr functions (C++17)
|
||||
namespace sort_impl {
|
||||
|
||||
// Temporary arrays to hold values during operations with capacity N and mutable size.
|
||||
template <index_t N>
|
||||
struct IndexedValueArray
|
||||
{
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename RightValues,
|
||||
typename RightIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl
|
||||
{
|
||||
static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();
|
||||
|
||||
static constexpr index_t chosen_value =
|
||||
choose_left ? LeftValues::Front() : RightValues::Front();
|
||||
static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();
|
||||
|
||||
using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
|
||||
using new_merged_ids = decltype(MergedIds::PushBack(Number<chosen_id>{}));
|
||||
|
||||
using new_left_values =
|
||||
typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
|
||||
using new_left_ids =
|
||||
typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;
|
||||
|
||||
using new_right_values =
|
||||
typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
|
||||
using new_right_ids =
|
||||
typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;
|
||||
|
||||
using merge = sorted_sequence_merge_impl<new_left_values,
|
||||
new_left_ids,
|
||||
new_right_values,
|
||||
new_right_ids,
|
||||
new_merged_values,
|
||||
new_merged_ids,
|
||||
Comp>;
|
||||
// this is output
|
||||
using merged_values = typename merge::merged_values;
|
||||
using merged_ids = typename merge::merged_ids;
|
||||
};
|
||||
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl<LeftValues,
|
||||
LeftIds,
|
||||
Sequence<>,
|
||||
Sequence<>,
|
||||
MergedValues,
|
||||
MergedIds,
|
||||
Comp>
|
||||
{
|
||||
using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
|
||||
using merged_ids = typename sequence_merge<MergedIds, LeftIds>::type;
|
||||
};
|
||||
|
||||
template <typename RightValues,
|
||||
typename RightIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl<Sequence<>,
|
||||
Sequence<>,
|
||||
RightValues,
|
||||
RightIds,
|
||||
MergedValues,
|
||||
MergedIds,
|
||||
Comp>
|
||||
{
|
||||
using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
|
||||
using merged_ids = typename sequence_merge<MergedIds, RightIds>::type;
|
||||
};
|
||||
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename RightValues,
|
||||
typename RightIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge
|
||||
{
|
||||
using merge = sorted_sequence_merge_impl<LeftValues,
|
||||
LeftIds,
|
||||
RightValues,
|
||||
RightIds,
|
||||
Sequence<>,
|
||||
Sequence<>,
|
||||
Comp>;
|
||||
|
||||
using merged_values = typename merge::merged_values;
|
||||
using merged_ids = typename merge::merged_ids;
|
||||
};
|
||||
|
||||
static constexpr index_t nsize = Values::Size();
|
||||
|
||||
using split_unsorted_values = sequence_split<Values, nsize / 2>;
|
||||
using split_unsorted_ids = sequence_split<Ids, nsize / 2>;
|
||||
|
||||
using left_unsorted_values = typename split_unsorted_values::left_type;
|
||||
using left_unsorted_ids = typename split_unsorted_ids::left_type;
|
||||
using left_sort = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
|
||||
using left_sorted_values = typename left_sort::sorted_values;
|
||||
using left_sorted_ids = typename left_sort::sorted_ids;
|
||||
|
||||
using right_unsorted_values = typename split_unsorted_values::right_type;
|
||||
using right_unsorted_ids = typename split_unsorted_ids::right_type;
|
||||
using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
|
||||
using right_sorted_values = typename right_sort::sorted_values;
|
||||
using right_sorted_ids = typename right_sort::sorted_ids;
|
||||
|
||||
using merged_sorted = sorted_sequence_merge<left_sorted_values,
|
||||
left_sorted_ids,
|
||||
right_sorted_values,
|
||||
right_sorted_ids,
|
||||
Compare>;
|
||||
|
||||
using sorted_values = typename merged_sorted::merged_values;
|
||||
using sorted_ids = typename merged_sorted::merged_ids;
|
||||
index_t values[N > 0 ? N : 1];
|
||||
index_t ids[N > 0 ? N : 1];
|
||||
index_t size = 0;
|
||||
};
|
||||
|
||||
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
|
||||
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
|
||||
template <index_t... Is>
|
||||
constexpr auto make_indexed_value_array(Sequence<Is...>)
|
||||
{
|
||||
static constexpr bool choose_x = Compare{}(ValueX, ValueY);
|
||||
constexpr index_t N = sizeof...(Is);
|
||||
IndexedValueArray<N> result = {{Is...}, {}, N};
|
||||
for(index_t i = 0; i < N; ++i)
|
||||
{
|
||||
result.ids[i] = i;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
using sorted_values =
|
||||
typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
|
||||
using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
|
||||
enum class SortField
|
||||
{
|
||||
Values,
|
||||
Ids
|
||||
};
|
||||
|
||||
template <index_t Value, index_t Id, typename Compare>
|
||||
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
|
||||
// Perform an insertion sort on an IndexedValueArray.
|
||||
template <index_t N, typename Compare>
|
||||
constexpr auto insertion_sort(IndexedValueArray<N> arr, Compare comp)
|
||||
{
|
||||
using sorted_values = Sequence<Value>;
|
||||
using sorted_ids = Sequence<Id>;
|
||||
for(index_t i = 1; i < arr.size; ++i)
|
||||
{
|
||||
index_t key_val = arr.values[i];
|
||||
index_t key_id = arr.ids[i];
|
||||
index_t j = i - 1;
|
||||
while(j >= 0 && comp(key_val, arr.values[j]))
|
||||
{
|
||||
arr.values[j + 1] = arr.values[j];
|
||||
arr.ids[j + 1] = arr.ids[j];
|
||||
--j;
|
||||
}
|
||||
arr.values[j + 1] = key_val;
|
||||
arr.ids[j + 1] = key_id;
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
// Remove duplicates from a sorted IndexedValueArray.
|
||||
template <index_t N, typename Equal>
|
||||
constexpr auto unique(const IndexedValueArray<N>& sorted, Equal eq)
|
||||
{
|
||||
IndexedValueArray<N> result{};
|
||||
if constexpr(N == 0)
|
||||
{
|
||||
return result;
|
||||
}
|
||||
result.size = 1;
|
||||
result.values[0] = sorted.values[0];
|
||||
result.ids[0] = sorted.ids[0];
|
||||
for(index_t i = 1; i < sorted.size; ++i)
|
||||
{
|
||||
if(!eq(sorted.values[i], sorted.values[i - 1]))
|
||||
{
|
||||
result.values[result.size] = sorted.values[i];
|
||||
result.ids[result.size] = sorted.ids[i];
|
||||
++result.size;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Compute sorted (and optionally unique) IndexedValueArray from input Sequence.
|
||||
template <bool Unique, typename Compare, typename Equal, index_t... Is>
|
||||
constexpr auto compute_sorted(Sequence<Is...> seq, Compare comp, Equal eq)
|
||||
{
|
||||
auto sorted = insertion_sort(make_indexed_value_array(seq), comp);
|
||||
return Unique ? unique(sorted, eq) : sorted;
|
||||
}
|
||||
|
||||
// Cache the sorted results to avoid recomputation.
|
||||
template <bool Unique, typename Seq, typename Compare, typename Equal>
|
||||
struct SortedCache
|
||||
{
|
||||
static constexpr auto data = compute_sorted<Unique>(Seq{}, Compare{}, Equal{});
|
||||
};
|
||||
|
||||
template <typename Compare>
|
||||
struct sequence_sort_impl<Sequence<>, Sequence<>, Compare>
|
||||
// Build sorted value and ID sequences from cached sorted data
|
||||
template <SortField Field, bool Unique, typename Seq, typename Compare, typename Equal, index_t I>
|
||||
constexpr index_t get_sorted_field()
|
||||
{
|
||||
using sorted_values = Sequence<>;
|
||||
using sorted_ids = Sequence<>;
|
||||
constexpr auto& data = SortedCache<Unique, Seq, Compare, Equal>::data;
|
||||
return (Field == SortField::Values) ? data.values[I] : data.ids[I];
|
||||
}
|
||||
|
||||
template <bool Unique, typename Seq, typename Compare, typename Equal, typename IndexSeq>
|
||||
struct SortedSequences;
|
||||
|
||||
template <bool Unique, typename Seq, typename Compare, typename Equal, index_t... Is>
|
||||
struct SortedSequences<Unique, Seq, Compare, Equal, Sequence<Is...>>
|
||||
{
|
||||
using values_type =
|
||||
Sequence<get_sorted_field<SortField::Values, Unique, Seq, Compare, Equal, Is>()...>;
|
||||
using ids_type =
|
||||
Sequence<get_sorted_field<SortField::Ids, Unique, Seq, Compare, Equal, Is>()...>;
|
||||
};
|
||||
|
||||
template <bool Unique, typename Seq, typename Compare, typename Equal>
|
||||
using sorted_sequences_t = SortedSequences<
|
||||
Unique,
|
||||
Seq,
|
||||
Compare,
|
||||
Equal,
|
||||
typename arithmetic_sequence_gen<0, SortedCache<Unique, Seq, Compare, Equal>::data.size, 1>::
|
||||
type>;
|
||||
|
||||
using Equal = ck::math::equal<index_t>;
|
||||
|
||||
} // namespace sort_impl
|
||||
|
||||
template <typename Values, typename Compare>
|
||||
struct sequence_sort
|
||||
{
|
||||
using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
|
||||
using sort = sequence_sort_impl<Values, unsorted_ids, Compare>;
|
||||
|
||||
// this is output
|
||||
using type = typename sort::sorted_values;
|
||||
using sorted2unsorted_map = typename sort::sorted_ids;
|
||||
using sorted_seqs = sort_impl::sorted_sequences_t<false, Values, Compare, sort_impl::Equal>;
|
||||
using type = typename sorted_seqs::values_type;
|
||||
using sorted2unsorted_map = typename sorted_seqs::ids_type;
|
||||
};
|
||||
|
||||
template <typename Values, typename Less, typename Equal>
|
||||
struct sequence_unique_sort
|
||||
{
|
||||
template <typename RemainValues,
|
||||
typename RemainIds,
|
||||
typename UniquifiedValues,
|
||||
typename UniquifiedIds,
|
||||
typename Eq>
|
||||
struct sorted_sequence_uniquify_impl
|
||||
{
|
||||
static constexpr index_t current_value = RemainValues::Front();
|
||||
static constexpr index_t current_id = RemainIds::Front();
|
||||
|
||||
static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());
|
||||
|
||||
using new_remain_values = decltype(RemainValues::PopFront());
|
||||
using new_remain_ids = decltype(RemainIds::PopFront());
|
||||
|
||||
using new_uniquified_values =
|
||||
typename conditional<is_unique_value,
|
||||
decltype(UniquifiedValues::PushBack(Number<current_value>{})),
|
||||
UniquifiedValues>::type;
|
||||
|
||||
using new_uniquified_ids =
|
||||
typename conditional<is_unique_value,
|
||||
decltype(UniquifiedIds::PushBack(Number<current_id>{})),
|
||||
UniquifiedIds>::type;
|
||||
|
||||
using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
|
||||
new_remain_ids,
|
||||
new_uniquified_values,
|
||||
new_uniquified_ids,
|
||||
Eq>;
|
||||
|
||||
// this is output
|
||||
using uniquified_values = typename uniquify::uniquified_values;
|
||||
using uniquified_ids = typename uniquify::uniquified_ids;
|
||||
};
|
||||
|
||||
template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
|
||||
struct sorted_sequence_uniquify_impl<Sequence<>,
|
||||
Sequence<>,
|
||||
UniquifiedValues,
|
||||
UniquifiedIds,
|
||||
Eq>
|
||||
{
|
||||
using uniquified_values = UniquifiedValues;
|
||||
using uniquified_ids = UniquifiedIds;
|
||||
};
|
||||
|
||||
template <typename SortedValues, typename SortedIds, typename Eq>
|
||||
struct sorted_sequence_uniquify
|
||||
{
|
||||
using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
|
||||
decltype(SortedIds::PopFront()),
|
||||
Sequence<SortedValues::Front()>,
|
||||
Sequence<SortedIds::Front()>,
|
||||
Eq>;
|
||||
|
||||
using uniquified_values = typename uniquify::uniquified_values;
|
||||
using uniquified_ids = typename uniquify::uniquified_ids;
|
||||
};
|
||||
|
||||
using sort = sequence_sort<Values, Less>;
|
||||
using sorted_values = typename sort::type;
|
||||
using sorted_ids = typename sort::sorted2unsorted_map;
|
||||
|
||||
using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;
|
||||
|
||||
// this is output
|
||||
using type = typename uniquify::uniquified_values;
|
||||
using sorted2unsorted_map = typename uniquify::uniquified_ids;
|
||||
using sorted_seqs = sort_impl::sorted_sequences_t<true, Values, Less, Equal>;
|
||||
using type = typename sorted_seqs::values_type;
|
||||
using sorted2unsorted_map = typename sorted_seqs::ids_type;
|
||||
};
|
||||
|
||||
template <typename SeqMap>
|
||||
|
||||
@@ -310,3 +310,4 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
endif()
|
||||
add_subdirectory(position_embedding)
|
||||
add_subdirectory(scatter_gather)
|
||||
add_subdirectory(util)
|
||||
|
||||
7
test/util/CMakeLists.txt
Normal file
7
test/util/CMakeLists.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_gtest_executable(unit_sequence unit_sequence.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(unit_sequence PRIVATE utility)
|
||||
endif()
|
||||
684
test/util/unit_sequence.cpp
Normal file
684
test/util/unit_sequence.cpp
Normal file
@@ -0,0 +1,684 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/functional.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
// Test basic Sequence construction and properties
|
||||
TEST(Sequence, BasicConstruction)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4, 5>;
|
||||
EXPECT_EQ(Seq::Size(), 5);
|
||||
EXPECT_EQ(Seq::mSize, 5);
|
||||
}
|
||||
|
||||
TEST(Sequence, EmptySequence)
|
||||
{
|
||||
using Seq = Sequence<>;
|
||||
EXPECT_EQ(Seq::Size(), 0);
|
||||
EXPECT_EQ(Seq::mSize, 0);
|
||||
}
|
||||
|
||||
// Test At() method
|
||||
TEST(Sequence, AtRuntime)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40>;
|
||||
EXPECT_EQ(Seq::At(0), 10);
|
||||
EXPECT_EQ(Seq::At(1), 20);
|
||||
EXPECT_EQ(Seq::At(2), 30);
|
||||
EXPECT_EQ(Seq::At(3), 40);
|
||||
}
|
||||
|
||||
TEST(Sequence, AtCompileTime)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40>;
|
||||
EXPECT_EQ(Seq::At(Number<0>{}), 10);
|
||||
EXPECT_EQ(Seq::At(Number<1>{}), 20);
|
||||
EXPECT_EQ(Seq::At(Number<2>{}), 30);
|
||||
EXPECT_EQ(Seq::At(Number<3>{}), 40);
|
||||
}
|
||||
|
||||
TEST(Sequence, OperatorBracket)
|
||||
{
|
||||
constexpr auto seq = Sequence<5, 10, 15>{};
|
||||
EXPECT_EQ(seq[Number<0>{}], 5);
|
||||
EXPECT_EQ(seq[Number<1>{}], 10);
|
||||
EXPECT_EQ(seq[Number<2>{}], 15);
|
||||
}
|
||||
|
||||
// Test Front() and Back()
|
||||
TEST(Sequence, FrontBack)
|
||||
{
|
||||
using Seq = Sequence<100, 200, 300>;
|
||||
EXPECT_EQ(Seq::Front(), 100);
|
||||
EXPECT_EQ(Seq::Back(), 300);
|
||||
}
|
||||
|
||||
TEST(Sequence, FrontBackSingleElement)
|
||||
{
|
||||
using Seq = Sequence<42>;
|
||||
EXPECT_EQ(Seq::Front(), 42);
|
||||
EXPECT_EQ(Seq::Back(), 42);
|
||||
}
|
||||
|
||||
// Test PushFront and PushBack
|
||||
TEST(Sequence, PushFront)
|
||||
{
|
||||
using Seq = Sequence<2, 3, 4>;
|
||||
using Result = decltype(Seq::PushFront(Sequence<1>{}));
|
||||
using Expected = Sequence<1, 2, 3, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(Sequence, PushFrontNumbers)
|
||||
{
|
||||
using Seq = Sequence<3, 4>;
|
||||
using Result = decltype(Seq::PushFront(Number<1>{}, Number<2>{}));
|
||||
using Expected = Sequence<1, 2, 3, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(Sequence, PushBack)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3>;
|
||||
using Result = decltype(Seq::PushBack(Sequence<4, 5>{}));
|
||||
using Expected = Sequence<1, 2, 3, 4, 5>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(Sequence, PushBackNumbers)
|
||||
{
|
||||
using Seq = Sequence<1, 2>;
|
||||
using Result = decltype(Seq::PushBack(Number<3>{}, Number<4>{}));
|
||||
using Expected = Sequence<1, 2, 3, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test PopFront and PopBack
|
||||
TEST(Sequence, PopFront)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Result = decltype(Seq::PopFront());
|
||||
using Expected = Sequence<2, 3, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(Sequence, PopBack)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Result = decltype(Seq::PopBack());
|
||||
using Expected = Sequence<1, 2, 3>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test Extract
|
||||
TEST(Sequence, ExtractByNumbers)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40, 50>;
|
||||
using Result = decltype(Seq::Extract(Number<0>{}, Number<2>{}, Number<4>{}));
|
||||
using Expected = Sequence<10, 30, 50>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(Sequence, ExtractBySequence)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40, 50>;
|
||||
using Result = decltype(Seq::Extract(Sequence<1, 3>{}));
|
||||
using Expected = Sequence<20, 40>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test Modify
|
||||
TEST(Sequence, Modify)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Result = decltype(Seq::Modify(Number<2>{}, Number<99>{}));
|
||||
using Expected = Sequence<1, 2, 99, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test Transform
|
||||
TEST(Sequence, Transform)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
auto double_it = [](auto x) { return 2 * x; };
|
||||
using Result = decltype(Seq::Transform(double_it));
|
||||
using Expected = Sequence<2, 4, 6, 8>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test Reverse
|
||||
TEST(Sequence, Reverse)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4, 5>;
|
||||
using Result = decltype(Seq::Reverse());
|
||||
using Expected = Sequence<5, 4, 3, 2, 1>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(Sequence, ReverseSingleElement)
|
||||
{
|
||||
using Seq = Sequence<42>;
|
||||
using Result = decltype(Seq::Reverse());
|
||||
using Expected = Sequence<42>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test ReorderGivenNew2Old
|
||||
TEST(Sequence, ReorderGivenNew2Old)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40>;
|
||||
using Result = decltype(Seq::ReorderGivenNew2Old(Sequence<3, 1, 2, 0>{}));
|
||||
using Expected = Sequence<40, 20, 30, 10>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test ReorderGivenOld2New
|
||||
TEST(Sequence, ReorderGivenOld2New)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40>;
|
||||
using Result = decltype(Seq::ReorderGivenOld2New(Sequence<3, 1, 2, 0>{}));
|
||||
using Expected = Sequence<40, 20, 30, 10>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test arithmetic_sequence_gen
|
||||
TEST(SequenceGen, ArithmeticSequence)
|
||||
{
|
||||
using Result = typename arithmetic_sequence_gen<0, 5, 1>::type;
|
||||
using Expected = Sequence<0, 1, 2, 3, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, ArithmeticSequenceWithIncrement)
|
||||
{
|
||||
using Result = typename arithmetic_sequence_gen<0, 10, 2>::type;
|
||||
using Expected = Sequence<0, 2, 4, 6, 8>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, ArithmeticSequenceNegativeIncrement)
|
||||
{
|
||||
using Result = typename arithmetic_sequence_gen<10, 5, -1>::type;
|
||||
using Expected = Sequence<10, 9, 8, 7, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, ArithmeticSequenceEmpty)
|
||||
{
|
||||
using Result = typename arithmetic_sequence_gen<5, 5, 1>::type;
|
||||
using Expected = Sequence<>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test uniform_sequence_gen
|
||||
TEST(SequenceGen, UniformSequence)
|
||||
{
|
||||
using Result = typename uniform_sequence_gen<5, 42>::type;
|
||||
using Expected = Sequence<42, 42, 42, 42, 42>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, UniformSequenceZeroSize)
|
||||
{
|
||||
using Result = typename uniform_sequence_gen<0, 42>::type;
|
||||
using Expected = Sequence<>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test make_index_sequence
|
||||
TEST(SequenceGen, MakeIndexSequence)
|
||||
{
|
||||
using Result = make_index_sequence<5>;
|
||||
using Expected = Sequence<0, 1, 2, 3, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, MakeIndexSequenceZero)
|
||||
{
|
||||
using Result = make_index_sequence<0>;
|
||||
using Expected = Sequence<>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test sequence_merge
|
||||
TEST(SequenceMerge, MergeTwoSequences)
|
||||
{
|
||||
using Seq1 = Sequence<1, 2, 3>;
|
||||
using Seq2 = Sequence<4, 5, 6>;
|
||||
using Result = typename sequence_merge<Seq1, Seq2>::type;
|
||||
using Expected = Sequence<1, 2, 3, 4, 5, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMerge, MergeMultipleSequences)
|
||||
{
|
||||
using Seq1 = Sequence<1, 2>;
|
||||
using Seq2 = Sequence<3, 4>;
|
||||
using Seq3 = Sequence<5, 6>;
|
||||
using Result = typename sequence_merge<Seq1, Seq2, Seq3>::type;
|
||||
using Expected = Sequence<1, 2, 3, 4, 5, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMerge, MergeSingleSequence)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3>;
|
||||
using Result = typename sequence_merge<Seq>::type;
|
||||
using Expected = Sequence<1, 2, 3>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test sequence_split
|
||||
TEST(SequenceSplit, SplitInMiddle)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4, 5, 6>;
|
||||
using Split = sequence_split<Seq, 3>;
|
||||
using ExpectedLeft = Sequence<1, 2, 3>;
|
||||
using ExpectedRight = Sequence<4, 5, 6>;
|
||||
EXPECT_TRUE((is_same<typename Split::left_type, ExpectedLeft>::value));
|
||||
EXPECT_TRUE((is_same<typename Split::right_type, ExpectedRight>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSplit, SplitAtBeginning)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Split = sequence_split<Seq, 0>;
|
||||
using ExpectedLeft = Sequence<>;
|
||||
using ExpectedRight = Sequence<1, 2, 3, 4>;
|
||||
EXPECT_TRUE((is_same<typename Split::left_type, ExpectedLeft>::value));
|
||||
EXPECT_TRUE((is_same<typename Split::right_type, ExpectedRight>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSplit, SplitAtEnd)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Split = sequence_split<Seq, 4>;
|
||||
using ExpectedLeft = Sequence<1, 2, 3, 4>;
|
||||
using ExpectedRight = Sequence<>;
|
||||
EXPECT_TRUE((is_same<typename Split::left_type, ExpectedLeft>::value));
|
||||
EXPECT_TRUE((is_same<typename Split::right_type, ExpectedRight>::value));
|
||||
}
|
||||
|
||||
// Test sequence_sort
|
||||
TEST(SequenceSort, SortAscending)
|
||||
{
|
||||
using Seq = Sequence<5, 2, 8, 1, 9>;
|
||||
using Result = typename sequence_sort<Seq, math::less<index_t>>::type;
|
||||
using Expected = Sequence<1, 2, 5, 8, 9>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortDescending)
|
||||
{
|
||||
// Create a greater-than comparator
|
||||
struct greater
|
||||
{
|
||||
__host__ __device__ constexpr bool operator()(index_t x, index_t y) const { return x > y; }
|
||||
};
|
||||
using Seq = Sequence<5, 2, 8, 1, 9>;
|
||||
using Result = typename sequence_sort<Seq, greater>::type;
|
||||
using Expected = Sequence<9, 8, 5, 2, 1>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortAlreadySorted)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4, 5>;
|
||||
using Result = typename sequence_sort<Seq, math::less<index_t>>::type;
|
||||
using Expected = Sequence<1, 2, 3, 4, 5>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortWithDuplicates)
|
||||
{
|
||||
using Seq = Sequence<3, 1, 4, 1, 5, 9, 2, 6, 5>;
|
||||
using Result = typename sequence_sort<Seq, math::less<index_t>>::type;
|
||||
using Expected = Sequence<1, 1, 2, 3, 4, 5, 5, 6, 9>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortEmptySequence)
|
||||
{
|
||||
using Seq = Sequence<>;
|
||||
using Result = typename sequence_sort<Seq, math::less<index_t>>::type;
|
||||
using Expected = Sequence<>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortSingleElement)
|
||||
{
|
||||
using Seq = Sequence<42>;
|
||||
using Result = typename sequence_sort<Seq, math::less<index_t>>::type;
|
||||
using Expected = Sequence<42>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test sequence_unique_sort
|
||||
TEST(SequenceUniqueSort, UniqueSort)
|
||||
{
|
||||
using Seq = Sequence<3, 1, 4, 1, 5, 9, 2, 6, 5>;
|
||||
using Result =
|
||||
typename sequence_unique_sort<Seq, math::less<index_t>, math::equal<index_t>>::type;
|
||||
using Expected = Sequence<1, 2, 3, 4, 5, 6, 9>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceUniqueSort, UniqueSortNoDuplicates)
|
||||
{
|
||||
using Seq = Sequence<5, 2, 8, 1, 9>;
|
||||
using Result =
|
||||
typename sequence_unique_sort<Seq, math::less<index_t>, math::equal<index_t>>::type;
|
||||
using Expected = Sequence<1, 2, 5, 8, 9>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceUniqueSort, UniqueSortAllSame)
|
||||
{
|
||||
using Seq = Sequence<5, 5, 5, 5>;
|
||||
using Result =
|
||||
typename sequence_unique_sort<Seq, math::less<index_t>, math::equal<index_t>>::type;
|
||||
using Expected = Sequence<5>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test is_valid_sequence_map
|
||||
TEST(SequenceMap, ValidMap)
|
||||
{
|
||||
using Map = Sequence<0, 1, 2, 3>;
|
||||
EXPECT_TRUE((is_valid_sequence_map<Map>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMap, ValidMapPermuted)
|
||||
{
|
||||
using Map = Sequence<2, 0, 3, 1>;
|
||||
EXPECT_TRUE((is_valid_sequence_map<Map>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMap, InvalidMapDuplicate)
|
||||
{
|
||||
using Map = Sequence<0, 1, 1, 3>;
|
||||
EXPECT_FALSE((is_valid_sequence_map<Map>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMap, InvalidMapMissing)
|
||||
{
|
||||
using Map = Sequence<0, 1, 3, 4>;
|
||||
EXPECT_FALSE((is_valid_sequence_map<Map>::value));
|
||||
}
|
||||
|
||||
// Test sequence_map_inverse
|
||||
// Note: sequence_map_inverse inverts a mapping where Map[i] = j means old position i maps to new
|
||||
// position j The inverse gives us new position i came from old position inverse[i]
|
||||
TEST(SequenceMapInverse, InverseMap)
|
||||
{
|
||||
// Map = <2, 0, 3, 1> means: old[0]->new[2], old[1]->new[0], old[2]->new[3], old[3]->new[1]
|
||||
// Inverse should be: new[0]<-old[1], new[1]<-old[3], new[2]<-old[0], new[3]<-old[2]
|
||||
using Map = Sequence<2, 0, 3, 1>;
|
||||
using Result = typename sequence_map_inverse<Map>::type;
|
||||
// Verify by checking that Map[Result[i]] == i for all i
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<0>{})>{}) == 0), true);
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<1>{})>{}) == 1), true);
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<2>{})>{}) == 2), true);
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<3>{})>{}) == 3), true);
|
||||
}
|
||||
|
||||
TEST(SequenceMapInverse, InverseIdentityMap)
|
||||
{
|
||||
using Map = Sequence<0, 1, 2, 3>;
|
||||
using Result = typename sequence_map_inverse<Map>::type;
|
||||
// Verify by checking that Map[Result[i]] == i for all i (same as the other test)
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<0>{})>{}) == 0), true);
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<1>{})>{}) == 1), true);
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<2>{})>{}) == 2), true);
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<3>{})>{}) == 3), true);
|
||||
}
|
||||
|
||||
// Test sequence operators
|
||||
TEST(SequenceOperators, Equality)
|
||||
{
|
||||
constexpr auto seq1 = Sequence<1, 2, 3>{};
|
||||
constexpr auto seq2 = Sequence<1, 2, 3>{};
|
||||
constexpr auto seq3 = Sequence<1, 2, 4>{};
|
||||
EXPECT_TRUE(seq1 == seq2);
|
||||
EXPECT_FALSE(seq1 == seq3);
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, Addition)
|
||||
{
|
||||
using Seq1 = Sequence<1, 2, 3>;
|
||||
using Seq2 = Sequence<4, 5, 6>;
|
||||
using Result = decltype(Seq1{} + Seq2{});
|
||||
using Expected = Sequence<5, 7, 9>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, Subtraction)
|
||||
{
|
||||
using Seq1 = Sequence<10, 20, 30>;
|
||||
using Seq2 = Sequence<1, 2, 3>;
|
||||
using Result = decltype(Seq1{} - Seq2{});
|
||||
using Expected = Sequence<9, 18, 27>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, Multiplication)
|
||||
{
|
||||
using Seq1 = Sequence<2, 3, 4>;
|
||||
using Seq2 = Sequence<5, 6, 7>;
|
||||
using Result = decltype(Seq1{} * Seq2{});
|
||||
using Expected = Sequence<10, 18, 28>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, Division)
|
||||
{
|
||||
using Seq1 = Sequence<10, 20, 30>;
|
||||
using Seq2 = Sequence<2, 4, 5>;
|
||||
using Result = decltype(Seq1{} / Seq2{});
|
||||
using Expected = Sequence<5, 5, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, Modulo)
|
||||
{
|
||||
using Seq1 = Sequence<10, 20, 30>;
|
||||
using Seq2 = Sequence<3, 7, 8>;
|
||||
using Result = decltype(Seq1{} % Seq2{});
|
||||
using Expected = Sequence<1, 6, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, AdditionWithNumber)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3>;
|
||||
using Result = decltype(Seq{} + Number<10>{});
|
||||
using Expected = Sequence<11, 12, 13>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, SubtractionWithNumber)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30>;
|
||||
using Result = decltype(Seq{} - Number<5>{});
|
||||
using Expected = Sequence<5, 15, 25>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, MultiplicationWithNumber)
|
||||
{
|
||||
using Seq = Sequence<2, 3, 4>;
|
||||
using Result = decltype(Seq{} * Number<3>{});
|
||||
using Expected = Sequence<6, 9, 12>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, DivisionWithNumber)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30>;
|
||||
using Result = decltype(Seq{} / Number<5>{});
|
||||
using Expected = Sequence<2, 4, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, NumberAddition)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3>;
|
||||
using Result = decltype(Number<10>{} + Seq{});
|
||||
using Expected = Sequence<11, 12, 13>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, NumberMultiplication)
|
||||
{
|
||||
using Seq = Sequence<2, 3, 4>;
|
||||
using Result = decltype(Number<3>{} * Seq{});
|
||||
using Expected = Sequence<6, 9, 12>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test helper functions
|
||||
TEST(SequenceHelpers, MergeSequences)
|
||||
{
|
||||
using Seq1 = Sequence<1, 2>;
|
||||
using Seq2 = Sequence<3, 4>;
|
||||
using Seq3 = Sequence<5, 6>;
|
||||
using Result = decltype(merge_sequences(Seq1{}, Seq2{}, Seq3{}));
|
||||
using Expected = Sequence<1, 2, 3, 4, 5, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceHelpers, TransformSequencesSingle)
|
||||
{
|
||||
auto double_it = [](auto x) { return 2 * x; };
|
||||
using Seq = Sequence<1, 2, 3>;
|
||||
using Result = decltype(transform_sequences(double_it, Seq{}));
|
||||
using Expected = Sequence<2, 4, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceHelpers, TransformSequencesTwo)
|
||||
{
|
||||
auto add = [](auto x, auto y) { return x + y; };
|
||||
using Seq1 = Sequence<1, 2, 3>;
|
||||
using Seq2 = Sequence<4, 5, 6>;
|
||||
using Result = decltype(transform_sequences(add, Seq1{}, Seq2{}));
|
||||
using Expected = Sequence<5, 7, 9>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceHelpers, TransformSequencesThree)
|
||||
{
|
||||
auto add3 = [](auto x, auto y, auto z) { return x + y + z; };
|
||||
using Seq1 = Sequence<1, 2, 3>;
|
||||
using Seq2 = Sequence<4, 5, 6>;
|
||||
using Seq3 = Sequence<7, 8, 9>;
|
||||
using Result = decltype(transform_sequences(add3, Seq1{}, Seq2{}, Seq3{}));
|
||||
using Expected = Sequence<12, 15, 18>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceHelpers, ReduceOnSequence)
|
||||
{
|
||||
auto add = [](auto x, auto y) { return x + y; };
|
||||
constexpr auto seq = Sequence<1, 2, 3, 4, 5>{};
|
||||
constexpr auto result = reduce_on_sequence(seq, add, Number<0>{});
|
||||
EXPECT_EQ(result, 15);
|
||||
}
|
||||
|
||||
TEST(SequenceHelpers, SequenceAnyOf)
|
||||
{
|
||||
auto is_even = [](auto x) { return x % 2 == 0; };
|
||||
constexpr auto seq1 = Sequence<1, 3, 5, 7>{};
|
||||
constexpr auto seq2 = Sequence<1, 3, 4, 7>{};
|
||||
EXPECT_FALSE(sequence_any_of(seq1, is_even));
|
||||
EXPECT_TRUE(sequence_any_of(seq2, is_even));
|
||||
}
|
||||
|
||||
TEST(SequenceHelpers, SequenceAllOf)
|
||||
{
|
||||
auto is_positive = [](auto x) { return x > 0; };
|
||||
constexpr auto seq1 = Sequence<1, 2, 3, 4>{};
|
||||
constexpr auto seq2 = Sequence<1, -2, 3, 4>{};
|
||||
EXPECT_TRUE(sequence_all_of(seq1, is_positive));
|
||||
EXPECT_FALSE(sequence_all_of(seq2, is_positive));
|
||||
}
|
||||
|
||||
// Test scan operations
|
||||
TEST(SequenceScan, ReverseInclusiveScan)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Result =
|
||||
decltype(reverse_inclusive_scan_sequence(Seq{}, math::plus<index_t>{}, Number<0>{}));
|
||||
using Expected = Sequence<10, 9, 7, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceScan, ReverseExclusiveScan)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Result =
|
||||
decltype(reverse_exclusive_scan_sequence(Seq{}, math::plus<index_t>{}, Number<0>{}));
|
||||
using Expected = Sequence<9, 7, 4, 0>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceScan, InclusiveScan)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Result = decltype(inclusive_scan_sequence(Seq{}, math::plus<index_t>{}, Number<0>{}));
|
||||
using Expected = Sequence<1, 3, 6, 10>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test pick and modify operations
|
||||
TEST(SequencePickModify, PickElementsByIds)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40, 50>;
|
||||
using Ids = Sequence<0, 2, 4>;
|
||||
using Result = decltype(pick_sequence_elements_by_ids(Seq{}, Ids{}));
|
||||
using Expected = Sequence<10, 30, 50>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequencePickModify, PickElementsByMask)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40, 50>;
|
||||
using Mask = Sequence<1, 0, 1, 0, 1>;
|
||||
using Result = decltype(pick_sequence_elements_by_mask(Seq{}, Mask{}));
|
||||
using Expected = Sequence<10, 30, 50>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequencePickModify, ModifyElementsByIds)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40, 50>;
|
||||
using Values = Sequence<99, 88>;
|
||||
using Ids = Sequence<1, 3>;
|
||||
using Result = decltype(modify_sequence_elements_by_ids(Seq{}, Values{}, Ids{}));
|
||||
using Expected = Sequence<10, 99, 30, 88, 50>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test sequence_reduce
|
||||
TEST(SequenceReduce, ReduceTwoSequences)
|
||||
{
|
||||
using Seq1 = Sequence<1, 2, 3>;
|
||||
using Seq2 = Sequence<4, 5, 6>;
|
||||
using Result = typename sequence_reduce<math::plus<index_t>, Seq1, Seq2>::type;
|
||||
using Expected = Sequence<5, 7, 9>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceReduce, ReduceMultipleSequences)
|
||||
{
|
||||
using Seq1 = Sequence<1, 2>;
|
||||
using Seq2 = Sequence<3, 4>;
|
||||
using Seq3 = Sequence<5, 6>;
|
||||
using Result = typename sequence_reduce<math::plus<index_t>, Seq1, Seq2, Seq3>::type;
|
||||
using Expected = Sequence<9, 12>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
Reference in New Issue
Block a user