From cf7b98db8800e493bf5df301bc73e8403f8fc628 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 23 Feb 2025 13:08:32 +0200 Subject: [PATCH] Adding forgotten gelu, relu, silu on ARM --- ggml/src/iqk/iqk_mul_mat.cpp | 39 ++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 384a0e78..0f7cd1e5 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -14800,6 +14800,45 @@ inline float32x4_t v_tanh(float16x8_t x) { auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x))); return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); } +inline float32x4_t v_silu(float32x4_t x) { + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t zero = vdupq_n_f32(0.0f); + const float32x4_t neg_x = vsubq_f32(zero, x); + const float32x4_t exp_neg_x = v_expf(neg_x); + const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + return vdivq_f32(x, one_plus_exp_neg_x); +} +inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) { + const float32x4_t one = vdupq_n_f32(1.0f); + float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x)); + arg = vmulq_f32(arg, vmulq_f32(x, c2)); + float32x4_t exp_arg = v_expf(arg); + float32x4_t gelu = vmulq_f32(x, vdivq_f32(exp_arg, vaddq_f32(exp_arg, one))); + uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f)); + return vbslq_f32(mask, x, gelu); +} + +void MulMat::gelu(int n, const float * x, float * y) { + constexpr float GELU_COEF_A = 0.044715f; + constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + int i = 0; + auto c1 = vdupq_n_f32(GELU_COEF_A); + auto c2 = vdupq_n_f32(2.f*SQRT_2_OVER_PI); + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, v_gelu(vld1q_f32(x + i), c1, c2)); + } + for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i]))); +} + +void MulMat::silu(int n, const float * x, float * y) { + int i = 0; + for (; i + 3 < n; i += 4) vst1q_f32(y + i, v_silu(vld1q_f32(x + i))); + for (; i < n; ++i) y[i] = x[i]/(1.0f + expf(-x[i])); +} + +void MulMat::relu(int n, const float * x, float * y) { + for (int j = 0; j < n; ++j) y[j] = x[j] > 0 ? x[j] : 0; +} #endif #if defined(__AVX512F__) && defined(__AVX512DQ__)