mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
adding implicit gemm v3
This commit is contained in:
@@ -12,7 +12,7 @@ struct Array
|
||||
index_t mData[nSize];
|
||||
|
||||
template <class... Xs>
|
||||
__host__ __device__ Array(Xs... xs) : mData{static_cast<TData>(xs)...}
|
||||
__host__ __device__ constexpr Array(Xs... xs) : mData{static_cast<TData>(xs)...}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -37,6 +37,25 @@ struct Array
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
|
||||
{
|
||||
return Array<index_t, sizeof...(Is)>{Is...};
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto make_zero_array()
|
||||
{
|
||||
Array<TData, NSize> a;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
a[i] = static_cast<TData>(0);
|
||||
});
|
||||
|
||||
return a;
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, index_t... IRs>
|
||||
__host__ __device__ auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
|
||||
Sequence<IRs...> new2old)
|
||||
@@ -80,15 +99,14 @@ __host__ __device__ auto extract_array(const Array<TData, NSize>& old_array, Ext
|
||||
|
||||
static_for<0, new_size, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
new_array[i] = old_array[ExtractSeq{}.Get(I)];
|
||||
new_array[i] = old_array[ExtractSeq::Get(I)];
|
||||
});
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto operator+(const Array<TData, NSize>& a,
|
||||
const Array<TData, NSize>& b)
|
||||
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b)
|
||||
{
|
||||
Array<TData, NSize> result;
|
||||
|
||||
@@ -99,3 +117,20 @@ __host__ __device__ constexpr auto operator+(const Array<TData, NSize>& a,
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Array = Array * Sequence
|
||||
template <class TData, index_t NSize, index_t... Is>
|
||||
__host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is...> b)
|
||||
{
|
||||
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
|
||||
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
|
||||
result[i] = a[i] + b.Get(I);
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user