Gemm multiple d multiple r (#335)

* Imitate XXX_gemm_multiple_d, add XXX_gemm_multiple_d_multiple_r for gemm + reduction

* Implement run of kernel

* Add example

* Fix parameter of typo

* Rewrite the reduceMax example

* Rewrite the reduceMean + reduceMeanSquare example

* Refine naming

* Refine folder name

* refine naming

* Rewrite the gemm + bias + relu + add + layernorm example

* Rewrite the gemm + layernorm example

* clang-format

* Fix bug if sync lds

* Fix compile error
This commit is contained in:
rocking5566
2022-08-13 14:07:12 +08:00
committed by GitHub
parent cac014f173
commit 6c3c06bf1f
14 changed files with 2940 additions and 902 deletions

View File

@@ -130,6 +130,35 @@ struct AddHardswishAdd
}
};
// C = A * B
// E = C + D0 + D1
struct AddAdd
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const
{
// Only support floating so far
static_assert(is_same<E, half_t>::value || is_same<E, float>::value ||
is_same<E, double>::value,
"Data type is not supported by this operation!");
static_assert(is_same<C, half_t>::value || is_same<C, float>::value ||
is_same<C, double>::value,
"Data type is not supported by this operation!");
static_assert(is_same<D0, half_t>::value || is_same<D0, float>::value ||
is_same<D0, double>::value,
"Data type is not supported by this operation!");
static_assert(is_same<D1, half_t>::value || is_same<D1, float>::value ||
is_same<D1, double>::value,
"Data type is not supported by this operation!");
const C y = c + type_convert<C>(d0) + type_convert<C>(d1);
e = type_convert<E>(y);
}
};
// C = A * B
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu