add looping Orders into ford and static_ford

This commit is contained in:
Chao Liu
2019-08-06 20:23:11 -05:00
parent 0271338ed4
commit 41cdde99e5
2 changed files with 107 additions and 92 deletions

View File

@@ -24,105 +24,120 @@ struct is_static<Sequence<Is...>> : integral_constant<bool, true>
};
// RemainLengths: Sequence<...>
template <class RemainLengths>
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
struct static_ford_impl
{
// F signature: F(Sequence<...> multi_id)
// CurrentMultiIndex: Sequence<...>
template <class F, class CurrentMultiIndex>
__host__ __device__ constexpr void operator()(F f, CurrentMultiIndex) const
__host__ __device__ constexpr static_ford_impl()
{
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
}
// F signature: F(Sequence<...>)
// CurrentOrderedId: Sequence<...>
template <class F, class CurrentOrderedId>
__host__ __device__ constexpr void operator()(F f, CurrentOrderedId) const
{
static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
static_ford_impl<decltype(RemainLengths::PopFront())>{}(f,
CurrentMultiIndex::PushBack(I));
static_ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
f, CurrentOrderedId::PushBack(I));
});
}
};
template <>
struct static_ford_impl<Sequence<>>
template <class Orders>
struct static_ford_impl<Sequence<>, Orders>
{
// F signature: F(Sequence<...> multi_id)
// CurrentMultiIndex: Sequence<...>
template <class F, class CurrentMultiIndex>
__host__ __device__ constexpr void operator()(F f, CurrentMultiIndex) const
// F signature: F(Sequence<...>)
// OrderedId: Sequence<...>
template <class F, class OrderedId>
__host__ __device__ constexpr void operator()(F f, OrderedId) const
{
f(CurrentMultiIndex{});
// retrive unordered Id
f(OrderedId::ReorderGivenOld2New(Orders{}));
}
};
// Lengths is Sequence<...>
template <class Lengths>
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
struct static_ford
{
// F signature: F(Sequence<...> multi_id)
template <class F>
__host__ __device__ constexpr void operator()(F f) const
__host__ __device__ constexpr static_ford()
{
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
static_ford_impl<Lengths>{}(f, Sequence<>{});
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
}
};
template <index_t RemainDim>
struct ford_impl
{
// F signature: F(Array<...> multi_id)
// CurrentMultiIndex: Array<...>
// RemainLengths: Sequence<...>
template <class F, class CurrentMultiIndex, class RemainLengths>
__host__ __device__ constexpr void
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
{
static_assert(RemainLengths::GetSize() == RemainDim, "wrong!");
static_assert(RemainDim > 1, "wrong!");
constexpr auto next_length = RemainLengths{}.Front();
for(index_t i = 0; i < next_length; ++i)
{
ford_impl<RemainDim - 1>{}(f, current_multi_id.PushBack(i), RemainLengths{}.PopFront());
}
}
};
template <>
struct ford_impl<1>
{
// F signature: F(Array<...> multi_id)
// CurrentMultiIndex: Array<...>
// RemainLengths: Sequence<...>
template <class F, class CurrentMultiIndex, class RemainLengths>
__host__ __device__ constexpr void
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
{
static_assert(RemainLengths::GetSize() == 1, "wrong!");
constexpr index_t last_length = RemainLengths{}.Front();
for(index_t i = 0; i < last_length; ++i)
{
f(current_multi_id.PushBack(i));
}
}
};
// Lengths is Sequence<...>
template <class Lengths>
struct ford
{
// F signature: F(Array<...> multi_id)
// F signature: F(Sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
constexpr index_t first_length = Lengths{}.Front();
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
}
};
for(index_t i = 0; i < first_length; ++i)
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
struct ford_impl
{
__host__ __device__ constexpr ford_impl()
{
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
}
// F signature: F(Array<...> multi_id)
// CurrentOrderdId: Array<...>
template <class F, class CurrentOrderedId>
__host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
{
for(index_t i = 0; i < RemainLengths::Front(); ++i)
{
ford_impl<Lengths::GetSize() - 1>{}(f, Array<index_t, 1>{i}, Lengths{}.PopFront());
ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
f, current_ordered_id.PushBack(i));
}
}
};
template <class Orders>
struct ford_impl<Sequence<>, Orders>
{
// F signature: F(Array<...> multi_id)
// CurrentOrderdId: Array<...>
template <class F, class CurrentOrderedId>
__host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
{
// retrive unordered Id
f(reorder_array_given_old2new(current_ordered_id, Orders{}));
}
};
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which ford will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
struct ford
{
__host__ __device__ constexpr ford()
{
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
}
// F signature: F(Array<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
for(index_t i = 0; i < Lengths::Front(); ++i)
{
ford_impl<decltype(Lengths::PopFront()), Orders>{}(f, Array<index_t, 1>{i});
}
}
};