softcap, tanh: avoid NaNs for large arguments (AVX2, AVX512)

Not that I have encountered this in practice, but just to be sure.
This does it for AVX512 and AVX2, still need a guard for ARM_NEON.
This commit is contained in:
Iwan Kawrakow
2024-08-20 14:42:48 +03:00
parent d50f4f9439
commit ad456dc25b

View File

@@ -2596,7 +2596,9 @@ inline static __m512 ggml_v_silu(__m512 x) {
inline static __m512 ggml_v_tanh(__m512 x) {
const __m512 one = _mm512_set1_ps(1.0f);
const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f)));
return _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
const __mmask16 mask = _mm512_cmp_ps_mask(x, _mm512_set1_ps(10.f), _CMP_GT_OQ);
const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
return _mm512_mask_blend_ps(mask, res, one);
}
inline static __m512 ggml_v_softcap(__m512 x, __m512 s_before, __m512 s_after) {
@@ -2611,8 +2613,10 @@ inline static __m512 ggml_v_gelu(__m512 x, __m512 c1, __m512 c2) {
__m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one);
//__m512 arg = _mm512_add_ps(one, _mm512_mul_ps(_mm512_mul_ps(x, x), c1));
arg = _mm512_mul_ps(arg, _mm512_mul_ps(c2, x));
__m512 exp_arg = ggml_v_expf(arg);
return _mm512_mul_ps(x, _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one)));
const __mmask16 mask = _mm512_cmp_ps_mask(arg, _mm512_set1_ps(30.f), _CMP_GT_OQ);
const __m512 exp_arg = ggml_v_expf(arg);
const __m512 ratio = _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one));
return _mm512_mul_ps(x, _mm512_mask_blend_ps(mask, ratio, one));
}
#elif defined(__AVX2__) && defined(__FMA__)
@@ -2673,14 +2677,17 @@ inline static __m256 ggml_v_silu(__m256 x) {
inline static __m256 ggml_v_tanh(__m256 x) {
const __m256 one = _mm256_set1_ps(1.0f);
const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f)));
return _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
const __m256 res = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ);
return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res));
}
inline static __m256 ggml_v_softcap(__m256 x, float s_before, float s_after) {
const __m256 one = _mm256_set1_ps(1.0f);
const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f*s_before)));
const __m256 th = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
return _mm256_mul_ps(th, _mm256_set1_ps(s_after));
return _mm256_mul_ps(_mm256_set1_ps(s_after), ggml_v_tanh(_mm256_mul_ps(x, _mm256_set1_ps(s_before))));
//const __m256 one = _mm256_set1_ps(1.0f);
//const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f*s_before)));
//const __m256 th = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
//return _mm256_mul_ps(th, _mm256_set1_ps(s_after));
}
inline static __m256 ggml_v_gelu(__m256 x, __m256 c1, __m256 c2) {