use more constexpr

[ROCm/composable_kernel commit: 709f13a6d7]
This commit is contained in:
Chao Liu
2019-06-04 20:00:48 -05:00
parent f11222a3ac
commit c26ccd3d07
11 changed files with 310 additions and 183 deletions

View File

@@ -1,6 +1,6 @@
#pragma once
#include "Sequence.hip.hpp"
#include "functional.hip.hpp"
#include "functional2.hip.hpp"
template <class TData, index_t NSize>
struct Array
@@ -25,14 +25,17 @@ struct Array
template <index_t I>
__host__ __device__ constexpr TData Get(Number<I>) const
{
static_assert(I < NSize, "wrong!");
return mData[I];
}
template <index_t I>
__host__ __device__ constexpr bool Set(Number<I>, TData x)
__host__ __device__ constexpr void Set(Number<I>, TData x)
{
static_assert(I < NSize, "wrong!");
mData[I] = x;
return true; // for constexpr
}
__host__ __device__ constexpr auto PushBack(TData x) const
@@ -59,6 +62,7 @@ __host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
template <class TData, index_t NSize>
__host__ __device__ constexpr auto make_zero_array()
{
#if 0
Array<TData, NSize> a;
static_for<0, NSize, 1>{}([&](auto I) {
@@ -67,6 +71,11 @@ __host__ __device__ constexpr auto make_zero_array()
});
return a;
#else
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::SeqType{};
constexpr auto zero_array = sequence2array(zero_sequence);
return zero_array;
#endif
}
template <class TData, index_t NSize, index_t... IRs>
@@ -85,6 +94,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
return new_array;
}
#if 0
template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
Sequence<IRs...> old2new)
@@ -100,6 +110,45 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
return new_array;
}
#else
template <class TData, index_t NSize, class MapOld2New>
struct reorder_array_given_old2new_impl
{
const Array<TData, NSize>& old_array_ref;
Array<TData, NSize>& new_array_ref;
__host__
__device__ constexpr reorder_array_given_old2new_impl(const Array<TData, NSize>& old_array,
Array<TData, NSize>& new_array)
: old_array_ref(old_array), new_array_ref(new_array)
{
}
template <index_t IOldDim>
__host__ __device__ constexpr void operator()(Number<IOldDim>) const
{
TData old_data = old_array_ref.Get(Number<IOldDim>{});
constexpr index_t INewDim = MapOld2New::Get(Number<IOldDim>{});
new_array_ref.Set(Number<INewDim>{}, old_data);
}
};
template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
Sequence<IRs...> old2new)
{
Array<TData, NSize> new_array;
static_assert(NSize == sizeof...(IRs), "NSize not consistent");
static_for<0, NSize, 1>{}(
reorder_array_given_old2new_impl<TData, NSize, Sequence<IRs...>>(old_array, new_array));
return new_array;
}
#endif
template <class TData, index_t NSize, class ExtractSeq>
__host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)