diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index eb0ca056..041ad165 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1000,6 +1000,31 @@ constexpr float k_swiglu_oai_alpha = 1.702f; constexpr float k_swiglu_oai_limit = 7.f; void MulMat::swiglu_oai(int n, const float * x, float * y) { +// int i = 0; +//#if defined __AVX512F__ && defined __AVX512DQ__ +// { +// auto max = _mm512_set1_ps(k_swiglu_oai_limit); +// auto alpha = _mm512_set1_ps(-k_swiglu_oai_alpha); +// for (; i + 15 < n; i += 16) { +// auto xc = v_clamp_max(_mm512_loadu_ps(x + i), max); +// _mm512_storeu_ps(y + i, v_silu_oai(xc, alpha)); +// } +// } +//#endif +//#if defined __AVX2__ && defined __FMA__ +// if (i + 7 < n) { +// auto max = _mm256_set1_ps(k_swiglu_oai_limit); +// auto alpha = _mm256_set1_ps(-k_swiglu_oai_alpha); +// for (; i + 7 < n; i += 8) { +// auto xc = v_clamp_max(_mm256_loadu_ps(x + i), max); +// _mm256_storeu_ps(y + i, v_silu_oai(xc, alpha)); +// } +// } +//#endif +// for (; i < n; ++i) { +// auto xi = std::min(x[i], k_swiglu_oai_limit); +// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha)); +// } for (int i = 0; i < n; ++i) { auto xi = std::min(x[i], k_swiglu_oai_limit); y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha)); diff --git a/ggml/src/iqk/iqk_utils.h b/ggml/src/iqk/iqk_utils.h index 194bf9b8..435ae4dd 100644 --- a/ggml/src/iqk/iqk_utils.h +++ b/ggml/src/iqk/iqk_utils.h @@ -61,6 +61,13 @@ static inline float32x4_t v_silu(float32x4_t x) { const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); return vdivq_f32(x, one_plus_exp_neg_x); } +static inline float32x4_t v_silu_oai(float32x4_t x, float32x4_t alpha) { + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t neg_x = vmulq_f32(alpha, x); + const float32x4_t exp_neg_x = v_expf(neg_x); + const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + return vdivq_f32(x, one_plus_exp_neg_x); +} static inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) { const float32x4_t one = vdupq_n_f32(1.0f); float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x)); @@ -131,6 +138,17 @@ static inline __m512 v_silu(__m512 x) { const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); return _mm512_div_ps(x, one_plus_exp_neg_x); } +static inline __m512 v_silu_oai(__m512 x, __m512 alpha) { + const __m512 one = _mm512_set1_ps(1); + const __m512 neg_x = _mm512_mul_ps(alpha, x); + const __m512 exp_neg_x = v_expf(neg_x); + const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); + return _mm512_div_ps(x, one_plus_exp_neg_x); +} +static inline __m512 v_clamp_max(__m512 x, __m512 max) { + auto mask = _mm512_cmp_ps_mask(x, max, _CMP_GT_OQ); + return _mm512_mask_blend_ps(mask, x, max); +} #endif // __AVX512__ #if defined(__AVX2__) && defined(__FMA__) @@ -195,12 +213,23 @@ static inline __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) { } static inline __m256 v_silu(__m256 x) { const __m256 one = _mm256_set1_ps(1); - const __m256 zero = _mm256_setzero_ps(); + const __m256 zero = _mm256_setzero_ps(); const __m256 neg_x = _mm256_sub_ps(zero, x); const __m256 exp_neg_x = v_expf(neg_x); const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); return _mm256_div_ps(x, one_plus_exp_neg_x); } +static inline __m256 v_silu_oai(__m256 x, __m256 alpha) { + const __m256 one = _mm256_set1_ps(1); + const __m256 neg_x = _mm256_mul_ps(alpha, x); + const __m256 exp_neg_x = v_expf(neg_x); + const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); + return _mm256_div_ps(x, one_plus_exp_neg_x); +} +static inline __m256 v_clamp_max(__m256 x, __m256 max) { + auto mask = _mm256_cmp_ps(x, max, _CMP_GT_OQ); + return _mm256_or_ps(_mm256_and_ps(mask, max), _mm256_andnot_ps(mask, x)); +} #endif // __AVX2__