support swiglu activaion and use rcpf to accelerate silu

This commit is contained in:
Feng Shijie
2025-08-26 12:32:29 +00:00
parent d05eed931d
commit 65b702454c
8 changed files with 376 additions and 350 deletions

View File

@@ -863,7 +863,14 @@ struct Silu
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck_tile::exp(-x)));
if constexpr(std::is_same_v<T, float>)
{
y = x * __builtin_amdgcn_rcpf(one + ck_tile::exp(-x));
}
else
{
y = x * (one / (one + ck_tile::exp(-x)));
}
};
template <>
@@ -1218,7 +1225,7 @@ struct Swish
struct SoftRelu
{
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
SoftRelu(float alpha = 1.f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1237,7 +1244,7 @@ struct SoftRelu
struct Power
{
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
: alpha_(alpha), beta_(beta), gamma_(gamma) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1259,7 +1266,7 @@ struct Power
struct ClippedRelu
{
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1278,7 +1285,7 @@ struct ClippedRelu
struct LeakyRelu
{
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
LeakyRelu(float alpha = 0.01f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1295,7 +1302,7 @@ struct LeakyRelu
struct Elu
{
Elu(float alpha = 1.f) : alpha_(alpha){};
Elu(float alpha = 1.f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1312,7 +1319,7 @@ struct Elu
struct Logistic
{
Logistic(float alpha = 1.f) : alpha_(alpha){};
Logistic(float alpha = 1.f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const