diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index cfda7ee0c5..1cec9b4a77 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -4,8 +4,6 @@ #ifndef CK_ARRAY_HPP #define CK_ARRAY_HPP -#include "functional2.hpp" -#include "sequence.hpp" #include #include #include "type.hpp" @@ -44,8 +42,7 @@ struct Array { static_assert(T::Size() == Size(), "wrong! size not the same"); -#pragma unroll - for(int i = 0; i < NSize; i++) + for(index_t i = 0; i < NSize; i++) { mData[i] = a[i]; } diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 49a2493d89..ecd8c6ddee 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -375,6 +375,19 @@ struct sequence_reverse_inclusive_scan_impl; template struct sequence_reverse_inclusive_scan_impl, Reduce, Init> { + template + static constexpr Array compute_array() + { + Array values = {Is...}; + Array result = {0}; + result.At(Size - 1) = Reduce{}(values[Size - 1], Init); + for(index_t i = Size - 1; i > 0; --i) + { + result.At(i - 1) = Reduce{}(values[i - 1], result[i]); + } + return result; + } + template static constexpr auto compute(Sequence) { @@ -384,18 +397,21 @@ struct sequence_reverse_inclusive_scan_impl, Reduce, Init> { return Sequence<>{}; } + else if constexpr(size == 1) + { + constexpr index_t values[1] = {Is...}; + return Sequence{}; + } + else if constexpr(size == 2) + { + constexpr index_t values[2] = {Is...}; + constexpr index_t r1 = Reduce{}(values[1], Init); + constexpr index_t r0 = Reduce{}(values[0], r1); + return Sequence{}; + } else { - constexpr Array arr = []() { - Array values = {Is...}; - Array result = {0}; - result.At(size - 1) = Reduce{}(values[size - 1], Init); - for(index_t i = size - 1; i > 0; --i) - { - result.At(i - 1) = Reduce{}(values[i - 1], result[i]); - } - return result; - }(); + constexpr Array arr = compute_array(); return Sequence{}; } } diff --git a/test/util/unit_array.cpp b/test/util/unit_array.cpp index 436ac9bf0c..b30ff09872 100644 --- a/test/util/unit_array.cpp +++ b/test/util/unit_array.cpp @@ -6,7 +6,7 @@ using namespace ck; -// Test basic Sequence construction and properties +// Test basic Array construction and properties TEST(Array, BasicConstruction) { using Arr = Array;