Simdify swiglu_oai

Turning it off for now as performance becomes more variable,
so perhaps I'm running into thermal trottling imore often
because of making the CPU work too hard.
This commit is contained in:
Iwan Kawrakow
2025-08-12 19:40:35 +03:00
parent 8bd983300c
commit 2ac615507f
2 changed files with 55 additions and 1 deletions

View File

@@ -1000,6 +1000,31 @@ constexpr float k_swiglu_oai_alpha = 1.702f;
constexpr float k_swiglu_oai_limit = 7.f;
void MulMat::swiglu_oai(int n, const float * x, float * y) {
// int i = 0;
//#if defined __AVX512F__ && defined __AVX512DQ__
// {
// auto max = _mm512_set1_ps(k_swiglu_oai_limit);
// auto alpha = _mm512_set1_ps(-k_swiglu_oai_alpha);
// for (; i + 15 < n; i += 16) {
// auto xc = v_clamp_max(_mm512_loadu_ps(x + i), max);
// _mm512_storeu_ps(y + i, v_silu_oai(xc, alpha));
// }
// }
//#endif
//#if defined __AVX2__ && defined __FMA__
// if (i + 7 < n) {
// auto max = _mm256_set1_ps(k_swiglu_oai_limit);
// auto alpha = _mm256_set1_ps(-k_swiglu_oai_alpha);
// for (; i + 7 < n; i += 8) {
// auto xc = v_clamp_max(_mm256_loadu_ps(x + i), max);
// _mm256_storeu_ps(y + i, v_silu_oai(xc, alpha));
// }
// }
//#endif
// for (; i < n; ++i) {
// auto xi = std::min(x[i], k_swiglu_oai_limit);
// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
// }
for (int i = 0; i < n; ++i) {
auto xi = std::min(x[i], k_swiglu_oai_limit);
y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));

View File

@@ -61,6 +61,13 @@ static inline float32x4_t v_silu(float32x4_t x) {
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
return vdivq_f32(x, one_plus_exp_neg_x);
}
static inline float32x4_t v_silu_oai(float32x4_t x, float32x4_t alpha) {
const float32x4_t one = vdupq_n_f32(1.0f);
const float32x4_t neg_x = vmulq_f32(alpha, x);
const float32x4_t exp_neg_x = v_expf(neg_x);
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
return vdivq_f32(x, one_plus_exp_neg_x);
}
static inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) {
const float32x4_t one = vdupq_n_f32(1.0f);
float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x));
@@ -131,6 +138,17 @@ static inline __m512 v_silu(__m512 x) {
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
return _mm512_div_ps(x, one_plus_exp_neg_x);
}
static inline __m512 v_silu_oai(__m512 x, __m512 alpha) {
const __m512 one = _mm512_set1_ps(1);
const __m512 neg_x = _mm512_mul_ps(alpha, x);
const __m512 exp_neg_x = v_expf(neg_x);
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
return _mm512_div_ps(x, one_plus_exp_neg_x);
}
static inline __m512 v_clamp_max(__m512 x, __m512 max) {
auto mask = _mm512_cmp_ps_mask(x, max, _CMP_GT_OQ);
return _mm512_mask_blend_ps(mask, x, max);
}
#endif // __AVX512__
#if defined(__AVX2__) && defined(__FMA__)
@@ -195,12 +213,23 @@ static inline __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) {
}
static inline __m256 v_silu(__m256 x) {
const __m256 one = _mm256_set1_ps(1);
const __m256 zero = _mm256_setzero_ps();
const __m256 zero = _mm256_setzero_ps();
const __m256 neg_x = _mm256_sub_ps(zero, x);
const __m256 exp_neg_x = v_expf(neg_x);
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
return _mm256_div_ps(x, one_plus_exp_neg_x);
}
static inline __m256 v_silu_oai(__m256 x, __m256 alpha) {
const __m256 one = _mm256_set1_ps(1);
const __m256 neg_x = _mm256_mul_ps(alpha, x);
const __m256 exp_neg_x = v_expf(neg_x);
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
return _mm256_div_ps(x, one_plus_exp_neg_x);
}
static inline __m256 v_clamp_max(__m256 x, __m256 max) {
auto mask = _mm256_cmp_ps(x, max, _CMP_GT_OQ);
return _mm256_or_ps(_mm256_and_ps(mask, max), _mm256_andnot_ps(mask, x));
}
#endif // __AVX2__