mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Add multiD Gemm client APIs (#534)
* start add example * fix config * fix showinfo bug * add an elementop * change to padding * add xdl example * change elementwiseop * add instance * add instance to profiler * change file name * fix deive not support issue * add client example * fix client gemm_add_multiply name * change AddMultiply elementwiseop * fix elementwiseop * fix client example * fix addmultiply op * fix comments and fun name Co-authored-by: letaoqin <letaoqin@amd.com>
This commit is contained in:
@@ -172,6 +172,42 @@ struct AddAdd
|
||||
}
|
||||
};
|
||||
|
||||
// C = A * B
|
||||
// E = (C + D0) x D1
|
||||
struct AddMultiply
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
|
||||
const half_t& c,
|
||||
const half_t& d0,
|
||||
const half_t& d1) const
|
||||
{
|
||||
const half_t y = (c + d0) * d1;
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
|
||||
const float& c,
|
||||
const half_t& d0,
|
||||
const half_t& d1) const
|
||||
{
|
||||
const half_t y = (type_convert<half_t>(c) + d0) * d1;
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
|
||||
const float& c,
|
||||
const half_t& d0,
|
||||
const half_t& d1) const
|
||||
{
|
||||
const float y = (c + d0) * d1;
|
||||
e = y;
|
||||
}
|
||||
};
|
||||
|
||||
// C = A * B
|
||||
// E = FastGelu(C + D0 + D1)
|
||||
struct AddAddFastGelu
|
||||
|
||||
Reference in New Issue
Block a user