diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 641015c7..9b877bab 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2525,15 +2525,19 @@ inline static float32x4_t ggml_v_tanh(float32x4_t x) { const float32x4_t one = vdupq_n_f32(1.0f); const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f)); const float32x4_t exp_two_x = ggml_v_expf(two_x); - return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); + const uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f)); + const float32x4_t res = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); + return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask))); + //return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); } inline static float32x4_t ggml_v_softcap(float32x4_t x, float32x4_t s_before, float32x4_t s_after) { - const float32x4_t one = vdupq_n_f32(1.0f); - const float32x4_t two_x = vmulq_f32(x, s_before); - const float32x4_t exp_two_x = ggml_v_expf(two_x); - const float32x4_t th = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); - return vmulq_f32(th, s_after); + return vmulq_f32(s_after, ggml_v_tanh(vmulq_f32(x, s_before))); + //const float32x4_t one = vdupq_n_f32(1.0f); + //const float32x4_t two_x = vmulq_f32(x, s_before); + //const float32x4_t exp_two_x = ggml_v_expf(two_x); + //const float32x4_t th = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); + //return vmulq_f32(th, s_after); } @@ -2844,7 +2848,7 @@ static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s _mm_storeu_ps(x + i, ggml_v_softcap(_mm_loadu_ps(x + i), s_before, s_after)); } #elif defined(__ARM_NEON) && defined(__aarch64__) - float32x4_t vs_before = vdupq_n_f32(2.f*s_before); + float32x4_t vs_before = vdupq_n_f32(s_before); float32x4_t vs_after = vdupq_n_f32(s_after); for (; i + 3 < n; i += 4) { vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), vs_before, vs_after));