refactor ConstantTensorDescriptor and functional

This commit is contained in:
Chao Liu
2019-04-16 17:36:18 -05:00
parent a2cf803c7e
commit 17f3d2d4bc
22 changed files with 390 additions and 276 deletions

View File

@@ -1,26 +1,41 @@
#pragma once
#include "constant_integral.hip.hpp"
template <index_t NLoop>
struct static_loop_n
template <index_t Iter, index_t Remaining, index_t Increment>
struct static_for_impl
{
template <class F>
__host__ __device__ void operator()(F f) const
{
static_assert(NLoop > 1, "out-of-range");
static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
static_assert(Increment <= Remaining, "will go out-of-range");
f(Number<NLoop - 1>{});
static_loop_n<NLoop - 1>{}(f);
f(Number<Iter>{});
static_for_impl<Iter + Increment, Remaining - Increment, Increment>{}(f);
}
};
template <>
struct static_loop_n<1>
template <index_t Iter, index_t Increment>
struct static_for_impl<Iter, 0, Increment>
{
template <class F>
__host__ __device__ void operator()(F) const
{
// do nothing
return;
}
};
template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
{
template <class F>
__host__ __device__ void operator()(F f) const
{
f(Number<0>{});
static_assert(NBegin < NEnd, "Wrong! we should have NBegin < NEnd");
static_assert((NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
static_for_impl<NBegin, NEnd - NBegin, Increment>{}(f);
}
};
@@ -54,4 +69,19 @@ __host__ __device__ constexpr auto unpacker(F f)
{
return [=](auto xs_array){ f(xs...); };
}
#endif
#endif
namespace mod_conv {
template <class T>
struct multiplies
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
};
template <class T>
struct plus
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
};
} // namespace mod_conv