mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
support swiglu activaion and use rcpf to accelerate silu
This commit is contained in:
@@ -863,7 +863,14 @@ struct Silu
|
||||
std::is_same_v<T, int32_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
constexpr T one = type_convert<T>(1);
|
||||
y = x * (one / (one + ck_tile::exp(-x)));
|
||||
if constexpr(std::is_same_v<T, float>)
|
||||
{
|
||||
y = x * __builtin_amdgcn_rcpf(one + ck_tile::exp(-x));
|
||||
}
|
||||
else
|
||||
{
|
||||
y = x * (one / (one + ck_tile::exp(-x)));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -1218,7 +1225,7 @@ struct Swish
|
||||
|
||||
struct SoftRelu
|
||||
{
|
||||
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
|
||||
SoftRelu(float alpha = 1.f) : alpha_(alpha) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
@@ -1237,7 +1244,7 @@ struct SoftRelu
|
||||
struct Power
|
||||
{
|
||||
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
|
||||
: alpha_(alpha), beta_(beta), gamma_(gamma){};
|
||||
: alpha_(alpha), beta_(beta), gamma_(gamma) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
@@ -1259,7 +1266,7 @@ struct Power
|
||||
|
||||
struct ClippedRelu
|
||||
{
|
||||
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
|
||||
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
@@ -1278,7 +1285,7 @@ struct ClippedRelu
|
||||
|
||||
struct LeakyRelu
|
||||
{
|
||||
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
|
||||
LeakyRelu(float alpha = 0.01f) : alpha_(alpha) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
@@ -1295,7 +1302,7 @@ struct LeakyRelu
|
||||
|
||||
struct Elu
|
||||
{
|
||||
Elu(float alpha = 1.f) : alpha_(alpha){};
|
||||
Elu(float alpha = 1.f) : alpha_(alpha) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
@@ -1312,7 +1319,7 @@ struct Elu
|
||||
|
||||
struct Logistic
|
||||
{
|
||||
Logistic(float alpha = 1.f) : alpha_(alpha){};
|
||||
Logistic(float alpha = 1.f) : alpha_(alpha) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
|
||||
Reference in New Issue
Block a user