[ROCm/composable_kernel commit: 7a89684f92]
This commit is contained in:
Chao Liu
2019-06-06 16:50:35 -05:00
parent 92ae4c49bc
commit 5e7dff691c
26 changed files with 299 additions and 517 deletions

View File

@@ -18,9 +18,21 @@ struct Array
__host__ __device__ constexpr index_t GetSize() const { return NSize; }
template <index_t I>
__host__ __device__ constexpr TData operator[](Number<I>) const
{
return mData[I];
}
__host__ __device__ constexpr TData operator[](index_t i) const { return mData[i]; }
__host__ __device__ TData& operator[](index_t i) { return mData[i]; }
template <index_t I>
__host__ __device__ TData& operator()(Number<I>)
{
return mData[I];
}
__host__ __device__ TData& operator()(index_t i) { return mData[i]; }
template <index_t I>
__host__ __device__ constexpr TData Get(Number<I>) const
@@ -44,10 +56,10 @@ struct Array
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
new_array[i] = mData[i];
new_array(i) = mData[i];
});
new_array[NSize] = x;
new_array(NSize) = x;
return new_array;
}
@@ -62,20 +74,9 @@ __host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
template <class TData, index_t NSize>
__host__ __device__ constexpr auto make_zero_array()
{
#if 0
Array<TData, NSize> a;
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
a[i] = static_cast<TData>(0);
});
return a;
#else
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::SeqType{};
constexpr auto zero_array = sequence2array(zero_sequence);
return zero_array;
#endif
}
template <class TData, index_t NSize, index_t... IRs>
@@ -94,44 +95,26 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
return new_array;
}
#if 0
template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
Sequence<IRs...> old2new)
{
Array<TData, NSize> new_array;
static_assert(NSize == sizeof...(IRs), "NSize not consistent");
static_for<0, NSize, 1>{}([&](auto IDim) {
constexpr index_t idim = IDim.Get();
new_array[old2new.Get(IDim)] = old_array[idim];
});
return new_array;
}
#else
template <class TData, index_t NSize, class MapOld2New>
struct reorder_array_given_old2new_impl
struct lambda_reorder_array_given_old2new
{
const Array<TData, NSize>& old_array_ref;
Array<TData, NSize>& new_array_ref;
const Array<TData, NSize>& old_array;
Array<TData, NSize>& new_array;
__host__
__device__ constexpr reorder_array_given_old2new_impl(const Array<TData, NSize>& old_array,
Array<TData, NSize>& new_array)
: old_array_ref(old_array), new_array_ref(new_array)
__host__ __device__ constexpr lambda_reorder_array_given_old2new(
const Array<TData, NSize>& old_array_, Array<TData, NSize>& new_array_)
: old_array(old_array_), new_array(new_array_)
{
}
template <index_t IOldDim>
__host__ __device__ constexpr void operator()(Number<IOldDim>) const
{
TData old_data = old_array_ref.Get(Number<IOldDim>{});
TData old_data = old_array[IOldDim];
constexpr index_t INewDim = MapOld2New::Get(Number<IOldDim>{});
new_array_ref.Set(Number<INewDim>{}, old_data);
new_array.Set(Number<INewDim>{}, old_data);
}
};
@@ -144,11 +127,10 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
static_assert(NSize == sizeof...(IRs), "NSize not consistent");
static_for<0, NSize, 1>{}(
reorder_array_given_old2new_impl<TData, NSize, Sequence<IRs...>>(old_array, new_array));
lambda_reorder_array_given_old2new<TData, NSize, Sequence<IRs...>>(old_array, new_array));
return new_array;
}
#endif
template <class TData, index_t NSize, class ExtractSeq>
__host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)
@@ -161,7 +143,7 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_
static_for<0, new_size, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
new_array[i] = old_array[ExtractSeq::Get(I)];
new_array(i) = old_array[ExtractSeq::Get(I)];
});
return new_array;
@@ -176,7 +158,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData,
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
result[i] = a[i] + b[i];
result(i) = a[i] + b[i];
});
return result;
@@ -191,7 +173,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
result[i] = a[i] - b[i];
result(i) = a[i] - b[i];
});
return result;
@@ -208,7 +190,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is.
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
result[i] = a[i] + b.Get(I);
result(i) = a[i] + b.Get(I);
});
return result;
@@ -225,7 +207,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is.
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
result[i] = a[i] - b.Get(I);
result(i) = a[i] - b.Get(I);
});
return result;
@@ -242,7 +224,7 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
result[i] = a[i] * b.Get(I);
result(i) = a[i] * b.Get(I);
});
return result;
@@ -259,7 +241,7 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi
static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get();
result[i] = a.Get(I) - b[i];
result(i) = a.Get(I) - b[i];
});
return result;