mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
adding implicit gemm v4 (nchw, kcyx)
This commit is contained in:
@@ -105,6 +105,7 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_
|
||||
return new_array;
|
||||
}
|
||||
|
||||
// Array = Array + Array
|
||||
template <class TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b)
|
||||
{
|
||||
@@ -119,6 +120,55 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData,
|
||||
return result;
|
||||
}
|
||||
|
||||
// Array = Array - Array
|
||||
template <class TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, NSize> b)
|
||||
{
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
|
||||
result[i] = a[i] - b[i];
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Array = Array + Sequence
|
||||
template <class TData, index_t NSize, index_t... Is>
|
||||
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is...> b)
|
||||
{
|
||||
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
|
||||
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
|
||||
result[i] = a[i] + b.Get(I);
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Array = Array - Sequence
|
||||
template <class TData, index_t NSize, index_t... Is>
|
||||
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is...> b)
|
||||
{
|
||||
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
|
||||
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
|
||||
result[i] = a[i] - b.Get(I);
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Array = Array * Sequence
|
||||
template <class TData, index_t NSize, index_t... Is>
|
||||
__host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is...> b)
|
||||
@@ -136,15 +186,119 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, class F>
|
||||
__host__ __device__ constexpr TData reduce_on_array(Array<TData, NSize> a, F f)
|
||||
// Array = Sequence - Array
|
||||
template <class TData, index_t NSize, index_t... Is>
|
||||
__host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSize> b)
|
||||
{
|
||||
TData result = a[0];
|
||||
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
|
||||
|
||||
static_for<1, NSize, 1>{}([&](auto I) {
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
|
||||
result[i] = a.Get(I) - b[i];
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, class Reduce>
|
||||
__host__ __device__ constexpr TData
|
||||
accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
|
||||
{
|
||||
TData result = init;
|
||||
|
||||
static_assert(NSize > 0, "wrong");
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
result = f(result, a[i]);
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T, index_t NSize>
|
||||
__host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
|
||||
{
|
||||
constexpr index_t nsize = a.GetSize();
|
||||
|
||||
static_assert(nsize > 0 && nsize <= 10, "wrong!");
|
||||
|
||||
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); });
|
||||
|
||||
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); });
|
||||
|
||||
static_if<nsize == 3>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); });
|
||||
|
||||
static_if<nsize == 4>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); });
|
||||
|
||||
static_if<nsize == 5>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
|
||||
});
|
||||
|
||||
static_if<nsize == 6>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
|
||||
});
|
||||
|
||||
static_if<nsize == 7>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6]);
|
||||
});
|
||||
|
||||
static_if<nsize == 8>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6],
|
||||
a[7]);
|
||||
});
|
||||
|
||||
static_if<nsize == 9>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6],
|
||||
a[7],
|
||||
a[8]);
|
||||
});
|
||||
|
||||
static_if<nsize == 10>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6],
|
||||
a[7],
|
||||
a[8],
|
||||
a[9]);
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user