Adding forgotten gelu, relu, silu on ARM

This commit is contained in:
Iwan Kawrakow
2025-02-23 13:08:32 +02:00
parent a72cd964b0
commit cf7b98db88

View File

@@ -14800,6 +14800,45 @@ inline float32x4_t v_tanh(float16x8_t x) {
auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x)));
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
}
inline float32x4_t v_silu(float32x4_t x) {
const float32x4_t one = vdupq_n_f32(1.0f);
const float32x4_t zero = vdupq_n_f32(0.0f);
const float32x4_t neg_x = vsubq_f32(zero, x);
const float32x4_t exp_neg_x = v_expf(neg_x);
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
return vdivq_f32(x, one_plus_exp_neg_x);
}
inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) {
const float32x4_t one = vdupq_n_f32(1.0f);
float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x));
arg = vmulq_f32(arg, vmulq_f32(x, c2));
float32x4_t exp_arg = v_expf(arg);
float32x4_t gelu = vmulq_f32(x, vdivq_f32(exp_arg, vaddq_f32(exp_arg, one)));
uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f));
return vbslq_f32(mask, x, gelu);
}
void MulMat::gelu(int n, const float * x, float * y) {
constexpr float GELU_COEF_A = 0.044715f;
constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
int i = 0;
auto c1 = vdupq_n_f32(GELU_COEF_A);
auto c2 = vdupq_n_f32(2.f*SQRT_2_OVER_PI);
for (; i + 3 < n; i += 4) {
vst1q_f32(y + i, v_gelu(vld1q_f32(x + i), c1, c2));
}
for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i])));
}
void MulMat::silu(int n, const float * x, float * y) {
int i = 0;
for (; i + 3 < n; i += 4) vst1q_f32(y + i, v_silu(vld1q_f32(x + i)));
for (; i < n; ++i) y[i] = x[i]/(1.0f + expf(-x[i]));
}
void MulMat::relu(int n, const float * x, float * y) {
for (int j = 0; j < n; ++j) y[j] = x[j] > 0 ? x[j] : 0;
}
#endif
#if defined(__AVX512F__) && defined(__AVX512DQ__)