From df73287b820c5eb801480a1e6b957b8c717d35b8 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 17 May 2019 14:56:39 -0500 Subject: [PATCH] rework sequence --- src/include/ConstantTensorDescriptor.hip.hpp | 69 +++++------ src/include/Sequence.hip.hpp | 113 ++++++------------- 2 files changed, 62 insertions(+), 120 deletions(-) diff --git a/src/include/ConstantTensorDescriptor.hip.hpp b/src/include/ConstantTensorDescriptor.hip.hpp index d61632a389..06223f8cc8 100644 --- a/src/include/ConstantTensorDescriptor.hip.hpp +++ b/src/include/ConstantTensorDescriptor.hip.hpp @@ -57,7 +57,7 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence struct ConstantTensorDescriptor { - using Type = ConstantTensorDescriptor; + using Type = ConstantTensorDescriptor; static constexpr index_t nDim = Lengths::GetSize(); __host__ __device__ constexpr ConstantTensorDescriptor() @@ -195,19 +195,14 @@ struct ConstantTensorDescriptor Number{} * reverse_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies{}); - // left and right lengths - constexpr auto lengths_pair = GetLengths().Split(Number{}); - constexpr auto left_lengths = lengths_pair.first; - constexpr auto right_lengths = lengths_pair.second.PopFront(); - - // left and right strides - constexpr auto strides_pair = GetStrides().Split(Number{}); - constexpr auto left_strides = strides_pair.first; - constexpr auto right_strides = strides_pair.second.PopFront(); + // left and right + constexpr auto left = make_increasing_sequence(Number<0>{}, Number{}, Number<1>{}); + constexpr auto right = make_increasing_sequence( + Number{}, Number{}, Number<1>{}); return make_ConstantTensorDescriptor( - left_lengths.Append(fold_lengths).Append(right_lengths), - left_strides.Append(fold_strides).Append(right_strides)); + GetLengths().Extract(left).Append(fold_lengths).Append(GetLengths().Extract(right)), + GetStrides().Extract(left).Append(fold_strides).Append(GetStrides().Extract(right))); } template @@ -228,40 +223,28 @@ struct ConstantTensorDescriptor "wrong! dimensions to be unfolded need to be packed"); }); - // lengths - constexpr auto lens_pair1 = Lengths{}.Split(Number{}); + // left and right + constexpr auto left = + make_increasing_sequence(Number<0>{}, Number{}, Number<1>{}); + constexpr auto middle = make_increasing_sequence( + Number{}, Number{}, Number<1>{}); + constexpr auto right = make_increasing_sequence( + Number{}, Number{}, Number<1>{}); - constexpr auto right_lengths = lens_pair1.second; + // length and stride + constexpr index_t unfold_length = accumulate_on_sequence( + GetLengths().Extract(middle), std::multiplies{}, Number<1>{}); - constexpr auto lens_pair2 = lens_pair1.first.Split(Number{}); + constexpr index_t unfold_stride = GetStride(Number{}); - constexpr auto left_lengths = lens_pair2.first; - - constexpr auto fold_lengths = lens_pair2.second; - - constexpr index_t unfold_length = - accumulate_on_sequence(fold_lengths, std::multiplies{}, Number<1>{}); - - constexpr auto new_lengths = - left_lengths.PopBack(Number{}).Append(right_lengths); - - // strides - constexpr auto strides_pair1 = Strides{}.Split(Number{}); - - constexpr auto right_strides = strides_pair1.second; - - constexpr auto strides_pair2 = strides_pair1.first.Split(Number{}); - - constexpr auto left_strides = strides_pair2.first; - - constexpr auto fold_strides = strides_pair2.second; - - constexpr index_t unfold_stride = fold_strides.Back(); - - constexpr auto new_strides = - left_strides.PushBack(Number{}).Append(right_strides); - - return make_ConstantTensorDescriptor(new_lengths, new_strides); + return make_ConstantTensorDescriptor(GetLengths() + .Extract(left) + .PushBack(Number{}) + .Append(GetLengths().Extract(right)), + GetStrides() + .Extract(left) + .PushBack(Number{}) + .Append(GetStrides().Extract(right))); } template diff --git a/src/include/Sequence.hip.hpp b/src/include/Sequence.hip.hpp index ad9010fc0f..ae91b2fa29 100644 --- a/src/include/Sequence.hip.hpp +++ b/src/include/Sequence.hip.hpp @@ -2,12 +2,10 @@ #include "constant_integral.hip.hpp" #include "functional.hip.hpp" -struct EmptySequence; - template struct Sequence { - using Type = Sequence; + using Type = Sequence; static constexpr index_t mSize = sizeof...(Is); @@ -72,101 +70,62 @@ struct Sequence return Sequence{}; } - __host__ __device__ constexpr auto Append(EmptySequence) const; - template __host__ __device__ constexpr auto Extract(Number...) const { return Sequence{})...>{}; } - template - struct split_impl + template + __host__ __device__ constexpr auto Extract(Sequence) const { - template - __host__ __device__ constexpr auto operator()(FirstSeq, SecondSeq) const - { - constexpr index_t new_first = FirstSeq{}.PushBack(Number{}); - constexpr index_t new_second = SecondSeq{}.PopFront(); - - static_if<(N > 0)>{}([&](auto fwd) { - return split_impl{}(new_first, fwd(new_second)); - }).else_([&](auto fwd) { return std::make_pair(new_first, fwd(new_second)); }); - } - }; - - // split one sequence to two sequnces: [0, I) and [I, mSize) - // return type is std::pair - template - __host__ __device__ constexpr auto Split(Number) const; - - template - __host__ __device__ constexpr auto Modify(Number, Number) const - { - constexpr auto first_second = Split(Number{}); - - constexpr auto left = first_second.first; - constexpr auto right = first_second.second.PopFront(); - - return left.PushBack(Number{}).Append(right); + return Sequence{})...>{}; } }; -struct EmptySequence +template +struct sequence_merge; + +template +struct sequence_merge, Sequence> { - __host__ __device__ static constexpr index_t GetSize() { return 0; } - - template - __host__ __device__ constexpr auto PushFront(Number) const - { - return Sequence{}; - } - - template - __host__ __device__ constexpr auto PushBack(Number) const - { - return Sequence{}; - } - - template - __host__ __device__ constexpr Seq Append(Seq) const - { - return Seq{}; - } + using Type = Sequence; }; -template -__host__ __device__ constexpr auto Sequence::Append(EmptySequence) const +template +struct increasing_sequence_gen { - return Type{}; -} + static constexpr index_t NSizeLeft = NSize / 2; -// split one sequence to two sequnces: [0, I) and [I, mSize) -// return type is std::pair -template -template -__host__ __device__ constexpr auto Sequence::Split(Number) const + using Type = + sequence_merge::Type, + typename increasing_sequence_gen::Type>; +}; + +template +struct increasing_sequence_gen { - static_assert(I <= GetSize(), "wrong! split position is too high!"); + using Type = Sequence; +}; - static_if<(I == 0)>{}([&](auto fwd) { return std::make_pair(EmptySequence{}, fwd(Type{})); }); +template +struct increasing_sequence_gen +{ + using Type = Sequence<>; +}; - static_if<(I == GetSize())>{}( - [&](auto fwd) { return std::make_pair(Type{}, fwd(EmptySequence{})); }); - - static_if<(I > 0 && I < GetSize())>{}( - [&](auto fwd) { return split_impl{}(EmptySequence{}, fwd(Type{})); }); -} - -#if 0 template -__host__ __device__ auto make_increasing_sequence(Number, Number, Number) +__host__ __device__ constexpr auto + make_increasing_sequence(Number, Number, Number) { - static_assert(IBegin < IEnd, (IEnd - IBegin) % Increment == 0, "wrong!"); + static_assert(IBegin <= IEnd && Increment > 0, "wrong!"); - // not implemented + constexpr index_t NSize = (IEnd - IBegin) / Increment; + + return increasing_sequence_gen{}; } -#endif template __host__ __device__ constexpr auto operator+(Sequence, Sequence) @@ -222,7 +181,7 @@ __host__ __device__ constexpr auto operator-(Sequence, Number) { constexpr auto seq_x = Sequence{}; -#if 0 +#if 0 // doesn't compile static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) { constexpr auto I = decltype(Iter){}; static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow");