This commit is contained in:
Chao Liu
2019-05-31 01:25:14 -05:00
parent 9ef124cc15
commit 3a6044aa84
6 changed files with 135 additions and 60 deletions

View File

@@ -1,41 +1,33 @@
#pragma once
#include "Sequence.hip.hpp"
template <index_t RemainDim>
// RemainLengths: Sequence<...>
template <class RemainLengths>
struct static_ford_impl
{
// F signature: F(Sequence<...> multi_id)
// CurrentMultiIndex: Sequence<...>
// RemainLengths: Sequence<...>
template <class F, class CurrentMultiIndex, class RemainLengths>
__host__ __device__ void operator()(F f, CurrentMultiIndex, RemainLengths) const
template <class F, class CurrentMultiIndex>
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
{
static_assert(RemainLengths::GetSize() == RemainDim, "wrong!");
static_assert(RemainDim > 1, "wrong!");
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
constexpr auto next_length = RemainLengths{}.Front();
static_for<0, next_length, 1>{}([=](auto I) {
static_ford_impl<RemainDim - 1>{}(
f, CurrentMultiIndex{}.PushBack(I), RemainLengths{}.PopFront());
static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
static_ford_impl<decltype(RemainLengths::PopFront())>{}(f,
CurrentMultiIndex::PushBack(I));
});
}
};
template <>
struct static_ford_impl<1>
struct static_ford_impl<Sequence<>>
{
// F signature: F(Sequence<Is...> multi_id)
// F signature: F(Sequence<...> multi_id)
// CurrentMultiIndex: Sequence<...>
// RemainLengths: Sequence<...>
template <class F, class CurrentMultiIndex, class RemainLengths>
__host__ __device__ void operator()(F f, CurrentMultiIndex, RemainLengths) const
template <class F, class CurrentMultiIndex>
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
{
static_assert(RemainLengths::GetSize() == 1, "wrong!");
constexpr index_t last_length = RemainLengths{}.Front();
static_for<0, last_length, 1>{}([=](auto I) { f(CurrentMultiIndex{}.PushBack(I)); });
f(CurrentMultiIndex{});
}
};
@@ -43,16 +35,13 @@ struct static_ford_impl<1>
template <class Lengths>
struct static_ford
{
// F signature: F(Sequence<Is...> multi_id)
// F signature: F(Sequence<...> multi_id)
template <class F>
__host__ __device__ void operator()(F f) const
{
constexpr index_t first_length = Lengths{}.Front();
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
static_for<0, first_length, 1>{}([=](auto I) {
static_ford_impl<Lengths::GetSize() - 1>{}(
f, Sequence<I.Get()>{}, Lengths{}.PopFront());
});
static_ford_impl<Lengths>{}(f, Sequence<>{});
}
};