mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Add GemmAddSoftmaxGemm support for MSFT ORT (instances and client API) (#576)
* add instance for gemm bias softmax gemm * add client example * change CGridDesc_G_M_N to CGridDesc_G_M_O * add gridwise * change c grid name * device add d0s data * fix 08 client_example * add example 47_fused_attention * example output correct * add d0 to example * add d0 element op * rechange instance code * change Acc0ElementwiseOperation to C0DEElementwiseOperation * change example name * update instance for cdeelementwiseop * add bhalf_t ScaleAdd * add test * not surport geem1 bias * remove some ignore * fix test bug
This commit is contained in:
@@ -49,6 +49,14 @@ struct Add
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x1_tmp = ck::type_convert<float>(x1);
|
||||
y = x0 + x1_tmp;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
|
||||
@@ -67,6 +75,30 @@ struct Add
|
||||
};
|
||||
};
|
||||
|
||||
struct ScaleAdd
|
||||
{
|
||||
__host__ __device__ ScaleAdd(float scale) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void
|
||||
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = scale_ * x0 + ck::type_convert<float>(x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void
|
||||
operator()<float, float, bhalf_t>(float& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
y = scale_ * x0 + ck::type_convert<float>(x1);
|
||||
};
|
||||
|
||||
float scale_;
|
||||
};
|
||||
|
||||
struct Subtract
|
||||
{
|
||||
template <typename T>
|
||||
|
||||
Reference in New Issue
Block a user