mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
use more constexpr for Array
[ROCm/composable_kernel commit: 0a386c46a9]
This commit is contained in:
@@ -34,14 +34,6 @@ struct Array
|
||||
|
||||
__host__ __device__ TData& operator()(index_t i) { return mData[i]; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr TData Get(Number<I>) const
|
||||
{
|
||||
static_assert(I < NSize, "wrong!");
|
||||
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr void Set(Number<I>, TData x)
|
||||
{
|
||||
@@ -50,16 +42,33 @@ struct Array
|
||||
mData[I] = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void Set(index_t I, TData x) { mData[I] = x; }
|
||||
|
||||
struct lambda_PushBack // emulate constexpr lambda
|
||||
{
|
||||
const Array<TData, NSize>& old_array;
|
||||
Array<TData, NSize + 1>& new_array;
|
||||
|
||||
__host__ __device__ constexpr lambda_PushBack(const Array<TData, NSize>& old_array_,
|
||||
Array<TData, NSize + 1>& new_array_)
|
||||
: old_array(old_array_), new_array(new_array_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr void operator()(Number<I>) const
|
||||
{
|
||||
new_array.Set(Number<I>{}, old_array[I]);
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ constexpr auto PushBack(TData x) const
|
||||
{
|
||||
Array<TData, NSize + 1> new_array;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
new_array(i) = mData[i];
|
||||
});
|
||||
static_for<0, NSize, 1>{}(lambda_PushBack(*this, new_array));
|
||||
|
||||
new_array(NSize) = x;
|
||||
new_array.Set(Number<NSize>{}, x);
|
||||
|
||||
return new_array;
|
||||
}
|
||||
@@ -81,18 +90,13 @@ __host__ __device__ constexpr auto make_zero_array()
|
||||
|
||||
template <class TData, index_t NSize, index_t... IRs>
|
||||
__host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
|
||||
Sequence<IRs...> new2old)
|
||||
Sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
Array<TData, NSize> new_array;
|
||||
|
||||
static_assert(NSize == sizeof...(IRs), "NSize not consistent");
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim.Get();
|
||||
new_array[idim] = old_array[new2old.Get(IDim)];
|
||||
});
|
||||
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
return new_array;
|
||||
return Array<TData, NSize>{old_array.mSize[IRs]...};
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, class MapOld2New>
|
||||
@@ -120,12 +124,14 @@ struct lambda_reorder_array_given_old2new
|
||||
|
||||
template <class TData, index_t NSize, index_t... IRs>
|
||||
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
|
||||
Sequence<IRs...> old2new)
|
||||
Sequence<IRs...> /*old2new*/)
|
||||
{
|
||||
Array<TData, NSize> new_array;
|
||||
|
||||
static_assert(NSize == sizeof...(IRs), "NSize not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_reorder_array_given_old2new<TData, NSize, Sequence<IRs...>>(old_array, new_array));
|
||||
|
||||
@@ -141,25 +147,44 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_
|
||||
|
||||
static_assert(new_size <= NSize, "wrong! too many extract");
|
||||
|
||||
static_for<0, new_size, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
new_array(i) = old_array[ExtractSeq::Get(I)];
|
||||
});
|
||||
static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::Get(I)]; });
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
template <class F, class X, class Y, class Z> // emulate constepxr lambda for array math
|
||||
struct lambda_array_math
|
||||
{
|
||||
const F& f;
|
||||
const X& x;
|
||||
const Y& y;
|
||||
Z& z;
|
||||
|
||||
__host__ __device__ constexpr lambda_array_math(const F& f_, const X& x_, const Y& y_, Z& z_)
|
||||
: f(f_), x(x_), y(y_), z(z_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IDim_>
|
||||
__host__ __device__ constexpr void operator()(Number<IDim_>) const
|
||||
{
|
||||
constexpr auto IDim = Number<IDim_>{};
|
||||
|
||||
z.Set(IDim, f(x[IDim], y[IDim]));
|
||||
}
|
||||
};
|
||||
|
||||
// Array = Array + Array
|
||||
template <class TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b)
|
||||
{
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
auto f = mod_conv::plus<index_t>{};
|
||||
|
||||
result(i) = a[i] + b[i];
|
||||
});
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -170,11 +195,11 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
|
||||
{
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
auto f = mod_conv::minus<index_t>{};
|
||||
|
||||
result(i) = a[i] - b[i];
|
||||
});
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -187,11 +212,11 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is.
|
||||
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
auto f = mod_conv::plus<index_t>{};
|
||||
|
||||
result(i) = a[i] + b.Get(I);
|
||||
});
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -204,11 +229,11 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is.
|
||||
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
auto f = mod_conv::minus<index_t>{};
|
||||
|
||||
result(i) = a[i] - b.Get(I);
|
||||
});
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -221,11 +246,11 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
|
||||
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
auto f = mod_conv::multiplies<index_t>{};
|
||||
|
||||
result(i) = a[i] * b.Get(I);
|
||||
});
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -238,11 +263,11 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi
|
||||
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
auto f = mod_conv::minus<index_t>{};
|
||||
|
||||
result(i) = a.Get(I) - b[i];
|
||||
});
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -255,10 +280,7 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
|
||||
|
||||
static_assert(NSize > 0, "wrong");
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
result = f(result, a[i]);
|
||||
});
|
||||
static_for<0, NSize, 1>{}([&](auto I) { result = f(result, a[I]); });
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user