diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index fc0b9190..641015c7 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2596,7 +2596,9 @@ inline static __m512 ggml_v_silu(__m512 x) { inline static __m512 ggml_v_tanh(__m512 x) { const __m512 one = _mm512_set1_ps(1.0f); const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f))); - return _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); + const __mmask16 mask = _mm512_cmp_ps_mask(x, _mm512_set1_ps(10.f), _CMP_GT_OQ); + const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); + return _mm512_mask_blend_ps(mask, res, one); } inline static __m512 ggml_v_softcap(__m512 x, __m512 s_before, __m512 s_after) { @@ -2611,8 +2613,10 @@ inline static __m512 ggml_v_gelu(__m512 x, __m512 c1, __m512 c2) { __m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one); //__m512 arg = _mm512_add_ps(one, _mm512_mul_ps(_mm512_mul_ps(x, x), c1)); arg = _mm512_mul_ps(arg, _mm512_mul_ps(c2, x)); - __m512 exp_arg = ggml_v_expf(arg); - return _mm512_mul_ps(x, _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one))); + const __mmask16 mask = _mm512_cmp_ps_mask(arg, _mm512_set1_ps(30.f), _CMP_GT_OQ); + const __m512 exp_arg = ggml_v_expf(arg); + const __m512 ratio = _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one)); + return _mm512_mul_ps(x, _mm512_mask_blend_ps(mask, ratio, one)); } #elif defined(__AVX2__) && defined(__FMA__) @@ -2673,14 +2677,17 @@ inline static __m256 ggml_v_silu(__m256 x) { inline static __m256 ggml_v_tanh(__m256 x) { const __m256 one = _mm256_set1_ps(1.0f); const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f))); - return _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one)); + const __m256 res = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one)); + const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ); + return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res)); } inline static __m256 ggml_v_softcap(__m256 x, float s_before, float s_after) { - const __m256 one = _mm256_set1_ps(1.0f); - const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f*s_before))); - const __m256 th = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one)); - return _mm256_mul_ps(th, _mm256_set1_ps(s_after)); + return _mm256_mul_ps(_mm256_set1_ps(s_after), ggml_v_tanh(_mm256_mul_ps(x, _mm256_set1_ps(s_before)))); + //const __m256 one = _mm256_set1_ps(1.0f); + //const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f*s_before))); + //const __m256 th = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one)); + //return _mm256_mul_ps(th, _mm256_set1_ps(s_after)); } inline static __m256 ggml_v_gelu(__m256 x, __m256 c1, __m256 c2) {