mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 23:24:13 +00:00
Adding forgotten gelu, relu, silu on ARM
This commit is contained in:
@@ -14800,6 +14800,45 @@ inline float32x4_t v_tanh(float16x8_t x) {
|
||||
auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x)));
|
||||
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
|
||||
}
|
||||
inline float32x4_t v_silu(float32x4_t x) {
|
||||
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||
const float32x4_t zero = vdupq_n_f32(0.0f);
|
||||
const float32x4_t neg_x = vsubq_f32(zero, x);
|
||||
const float32x4_t exp_neg_x = v_expf(neg_x);
|
||||
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
|
||||
return vdivq_f32(x, one_plus_exp_neg_x);
|
||||
}
|
||||
inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) {
|
||||
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||
float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x));
|
||||
arg = vmulq_f32(arg, vmulq_f32(x, c2));
|
||||
float32x4_t exp_arg = v_expf(arg);
|
||||
float32x4_t gelu = vmulq_f32(x, vdivq_f32(exp_arg, vaddq_f32(exp_arg, one)));
|
||||
uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f));
|
||||
return vbslq_f32(mask, x, gelu);
|
||||
}
|
||||
|
||||
void MulMat::gelu(int n, const float * x, float * y) {
|
||||
constexpr float GELU_COEF_A = 0.044715f;
|
||||
constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
int i = 0;
|
||||
auto c1 = vdupq_n_f32(GELU_COEF_A);
|
||||
auto c2 = vdupq_n_f32(2.f*SQRT_2_OVER_PI);
|
||||
for (; i + 3 < n; i += 4) {
|
||||
vst1q_f32(y + i, v_gelu(vld1q_f32(x + i), c1, c2));
|
||||
}
|
||||
for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i])));
|
||||
}
|
||||
|
||||
void MulMat::silu(int n, const float * x, float * y) {
|
||||
int i = 0;
|
||||
for (; i + 3 < n; i += 4) vst1q_f32(y + i, v_silu(vld1q_f32(x + i)));
|
||||
for (; i < n; ++i) y[i] = x[i]/(1.0f + expf(-x[i]));
|
||||
}
|
||||
|
||||
void MulMat::relu(int n, const float * x, float * y) {
|
||||
for (int j = 0; j < n; ++j) y[j] = x[j] > 0 ? x[j] : 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
|
||||
Reference in New Issue
Block a user