use more constexpr for Array

[ROCm/composable_kernel commit: 0a386c46a9]
This commit is contained in:
Chao Liu
2019-06-06 19:26:08 -05:00
parent 07ad21131b
commit d141b32528
7 changed files with 228 additions and 202 deletions

View File

@@ -34,14 +34,6 @@ struct Array
__host__ __device__ TData& operator()(index_t i) { return mData[i]; }
template <index_t I>
__host__ __device__ constexpr TData Get(Number<I>) const
{
static_assert(I < NSize, "wrong!");
return mData[I];
}
template <index_t I>
__host__ __device__ constexpr void Set(Number<I>, TData x)
{
@@ -50,16 +42,33 @@ struct Array
mData[I] = x;
}
__host__ __device__ constexpr void Set(index_t I, TData x) { mData[I] = x; }
struct lambda_PushBack // emulate constexpr lambda
{
const Array<TData, NSize>& old_array;
Array<TData, NSize + 1>& new_array;
__host__ __device__ constexpr lambda_PushBack(const Array<TData, NSize>& old_array_,
Array<TData, NSize + 1>& new_array_)
: old_array(old_array_), new_array(new_array_)
{
}
template <index_t I>
__host__ __device__ constexpr void operator()(Number<I>) const
{
new_array.Set(Number<I>{}, old_array[I]);
}
};
__host__ __device__ constexpr auto PushBack(TData x) const
{
Array<TData, NSize + 1> new_array;
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
new_array(i) = mData[i];
});
static_for<0, NSize, 1>{}(lambda_PushBack(*this, new_array));
new_array(NSize) = x;
new_array.Set(Number<NSize>{}, x);
return new_array;
}
@@ -81,18 +90,13 @@ __host__ __device__ constexpr auto make_zero_array()
template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
Sequence<IRs...> new2old)
Sequence<IRs...> /*new2old*/)
{
Array<TData, NSize> new_array;
static_assert(NSize == sizeof...(IRs), "NSize not consistent");
static_for<0, NSize, 1>{}([&](auto IDim) {
constexpr index_t idim = IDim.Get();
new_array[idim] = old_array[new2old.Get(IDim)];
});
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
return new_array;
return Array<TData, NSize>{old_array.mSize[IRs]...};
}
template <class TData, index_t NSize, class MapOld2New>
@@ -120,12 +124,14 @@ struct lambda_reorder_array_given_old2new
template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
Sequence<IRs...> old2new)
Sequence<IRs...> /*old2new*/)
{
Array<TData, NSize> new_array;
static_assert(NSize == sizeof...(IRs), "NSize not consistent");
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
static_for<0, NSize, 1>{}(
lambda_reorder_array_given_old2new<TData, NSize, Sequence<IRs...>>(old_array, new_array));
@@ -141,25 +147,44 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_
static_assert(new_size <= NSize, "wrong! too many extract");
static_for<0, new_size, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
new_array(i) = old_array[ExtractSeq::Get(I)];
});
static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::Get(I)]; });
return new_array;
}
template <class F, class X, class Y, class Z> // emulate constepxr lambda for array math
struct lambda_array_math
{
const F& f;
const X& x;
const Y& y;
Z& z;
__host__ __device__ constexpr lambda_array_math(const F& f_, const X& x_, const Y& y_, Z& z_)
: f(f_), x(x_), y(y_), z(z_)
{
}
template <index_t IDim_>
__host__ __device__ constexpr void operator()(Number<IDim_>) const
{
constexpr auto IDim = Number<IDim_>{};
z.Set(IDim, f(x[IDim], y[IDim]));
}
};
// Array = Array + Array
template <class TData, index_t NSize>
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b)
{
Array<TData, NSize> result;
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
auto f = mod_conv::plus<index_t>{};
result(i) = a[i] + b[i];
});
static_for<0, NSize, 1>{}(
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
f, a, b, result));
return result;
}
@@ -170,11 +195,11 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
{
Array<TData, NSize> result;
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
auto f = mod_conv::minus<index_t>{};
result(i) = a[i] - b[i];
});
static_for<0, NSize, 1>{}(
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
f, a, b, result));
return result;
}
@@ -187,11 +212,11 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is.
Array<TData, NSize> result;
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
auto f = mod_conv::plus<index_t>{};
result(i) = a[i] + b.Get(I);
});
static_for<0, NSize, 1>{}(
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
f, a, b, result));
return result;
}
@@ -204,11 +229,11 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is.
Array<TData, NSize> result;
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
auto f = mod_conv::minus<index_t>{};
result(i) = a[i] - b.Get(I);
});
static_for<0, NSize, 1>{}(
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
f, a, b, result));
return result;
}
@@ -221,11 +246,11 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
Array<TData, NSize> result;
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
auto f = mod_conv::multiplies<index_t>{};
result(i) = a[i] * b.Get(I);
});
static_for<0, NSize, 1>{}(
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
f, a, b, result));
return result;
}
@@ -238,11 +263,11 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi
Array<TData, NSize> result;
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
auto f = mod_conv::minus<index_t>{};
result(i) = a.Get(I) - b[i];
});
static_for<0, NSize, 1>{}(
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
f, a, b, result));
return result;
}
@@ -255,10 +280,7 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
static_assert(NSize > 0, "wrong");
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
result = f(result, a[i]);
});
static_for<0, NSize, 1>{}([&](auto I) { result = f(result, a[I]); });
return result;
}