mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 23:54:10 +00:00
softcap, tanh: avoid NaNs for large arguments (AVX2, AVX512)
Not that I have encountered this in practice, but just to be sure. This does it for AVX512 and AVX2, still need a guard for ARM_NEON.
This commit is contained in:
@@ -2596,7 +2596,9 @@ inline static __m512 ggml_v_silu(__m512 x) {
|
||||
inline static __m512 ggml_v_tanh(__m512 x) {
|
||||
const __m512 one = _mm512_set1_ps(1.0f);
|
||||
const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f)));
|
||||
return _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
|
||||
const __mmask16 mask = _mm512_cmp_ps_mask(x, _mm512_set1_ps(10.f), _CMP_GT_OQ);
|
||||
const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
|
||||
return _mm512_mask_blend_ps(mask, res, one);
|
||||
}
|
||||
|
||||
inline static __m512 ggml_v_softcap(__m512 x, __m512 s_before, __m512 s_after) {
|
||||
@@ -2611,8 +2613,10 @@ inline static __m512 ggml_v_gelu(__m512 x, __m512 c1, __m512 c2) {
|
||||
__m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one);
|
||||
//__m512 arg = _mm512_add_ps(one, _mm512_mul_ps(_mm512_mul_ps(x, x), c1));
|
||||
arg = _mm512_mul_ps(arg, _mm512_mul_ps(c2, x));
|
||||
__m512 exp_arg = ggml_v_expf(arg);
|
||||
return _mm512_mul_ps(x, _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one)));
|
||||
const __mmask16 mask = _mm512_cmp_ps_mask(arg, _mm512_set1_ps(30.f), _CMP_GT_OQ);
|
||||
const __m512 exp_arg = ggml_v_expf(arg);
|
||||
const __m512 ratio = _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one));
|
||||
return _mm512_mul_ps(x, _mm512_mask_blend_ps(mask, ratio, one));
|
||||
}
|
||||
|
||||
#elif defined(__AVX2__) && defined(__FMA__)
|
||||
@@ -2673,14 +2677,17 @@ inline static __m256 ggml_v_silu(__m256 x) {
|
||||
inline static __m256 ggml_v_tanh(__m256 x) {
|
||||
const __m256 one = _mm256_set1_ps(1.0f);
|
||||
const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f)));
|
||||
return _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
|
||||
const __m256 res = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
|
||||
const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ);
|
||||
return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res));
|
||||
}
|
||||
|
||||
inline static __m256 ggml_v_softcap(__m256 x, float s_before, float s_after) {
|
||||
const __m256 one = _mm256_set1_ps(1.0f);
|
||||
const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f*s_before)));
|
||||
const __m256 th = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
|
||||
return _mm256_mul_ps(th, _mm256_set1_ps(s_after));
|
||||
return _mm256_mul_ps(_mm256_set1_ps(s_after), ggml_v_tanh(_mm256_mul_ps(x, _mm256_set1_ps(s_before))));
|
||||
//const __m256 one = _mm256_set1_ps(1.0f);
|
||||
//const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f*s_before)));
|
||||
//const __m256 th = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
|
||||
//return _mm256_mul_ps(th, _mm256_set1_ps(s_after));
|
||||
}
|
||||
|
||||
inline static __m256 ggml_v_gelu(__m256 x, __m256 c1, __m256 c2) {
|
||||
|
||||
Reference in New Issue
Block a user