Add multiD Gemm client APIs (#534)

* start add example

* fix config

* fix showinfo bug

* add an elementop

* change to padding

* add xdl example

* change elementwiseop

* add instance

* add instance to profiler

* change file name

* fix deive not support issue

* add client example

* fix client gemm_add_multiply name

* change AddMultiply elementwiseop

* fix elementwiseop

* fix client example

* fix addmultiply op

* fix comments and fun name

Co-authored-by: letaoqin <letaoqin@amd.com>
This commit is contained in:
ltqin
2023-01-19 01:53:56 +08:00
committed by GitHub
parent 00ff30af8c
commit d66421fe34
20 changed files with 2338 additions and 0 deletions

View File

@@ -172,6 +172,42 @@ struct AddAdd
}
};
// C = A * B
// E = (C + D0) x D1
struct AddMultiply
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
const half_t& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (c + d0) * d1;
e = y;
}
template <>
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (type_convert<half_t>(c) + d0) * d1;
e = y;
}
template <>
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const float y = (c + d0) * d1;
e = y;
}
};
// C = A * B
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu