Adding forgotten file

This commit is contained in:
Iwan Kawrakow
2025-05-19 13:42:20 +03:00
parent 630279cb54
commit fbfe79e2fe

207
ggml/src/iqk/iqk_utils.h Normal file
View File

@@ -0,0 +1,207 @@
#pragma once
#include "iqk_config.h"
#if defined IQK_IMPLEMENT
#include "ggml-impl.h"
#if defined(__ARM_NEON) && defined(__aarch64__)
// copy-pasted from Justine Tunney's contribution to llama.cpp
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
// numbers above 88.38 will flush to infinity
// numbers beneath -103.97 will flush to zero
static inline float32x4_t v_expf(float32x4_t x) {
const float32x4_t r = vdupq_n_f32(0x1.8p23f);
const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
const float32x4_t n = vsubq_f32(z, r);
const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
vdupq_n_f32(0x1.7f7d1cp-20f));
const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
const float32x4_t u = vmulq_f32(b, b);
const float32x4_t j = vfmaq_f32(
vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
return vfmaq_f32(k, j, k);
const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
}
static inline float16x8_t v_expf(float16x8_t x) {
auto val1 = v_expf(vcvt_f32_f16(vget_low_f16(x)));
auto val2 = v_expf(vcvt_f32_f16(vget_high_f16(x)));
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
}
static inline float32x4_t v_tanh(float32x4_t x) {
const float32x4_t one = vdupq_n_f32(1.0f);
const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f));
const float32x4_t exp_two_x = v_expf(two_x);
const uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f));
const float32x4_t res = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask)));
//return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
}
//inline float32x4_t v_tanh(float16x8_t x) {
// auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x)));
// auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x)));
// return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
//}
static inline float32x4_t v_silu(float32x4_t x) {
const float32x4_t one = vdupq_n_f32(1.0f);
const float32x4_t zero = vdupq_n_f32(0.0f);
const float32x4_t neg_x = vsubq_f32(zero, x);
const float32x4_t exp_neg_x = v_expf(neg_x);
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
return vdivq_f32(x, one_plus_exp_neg_x);
}
static inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) {
const float32x4_t one = vdupq_n_f32(1.0f);
float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x));
arg = vmulq_f32(arg, vmulq_f32(x, c2));
float32x4_t exp_arg = v_expf(arg);
float32x4_t gelu = vmulq_f32(x, vdivq_f32(exp_arg, vaddq_f32(exp_arg, one)));
uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f));
return vbslq_f32(mask, x, gelu);
}
#endif // __ARN_NEON
#if defined(__AVX512F__) && defined(__AVX512DQ__)
// copy-pasted from Justine Tunney's contribution to llama.cpp
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
// numbers above 88.38 will flush to infinity
// numbers beneath -103.97 will flush to zero
static inline __m512 v_expf(__m512 x) {
const __m512 r = _mm512_set1_ps(0x1.8p23f);
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
const __m512 n = _mm512_sub_ps(z, r);
const __m512 b =
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
const __mmask16 d =
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
const __m512 u = _mm512_mul_ps(b, b);
const __m512 j = _mm512_fmadd_ps(
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
_mm512_set1_ps(0x1.573e2ep-5f)),
u,
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
_mm512_set1_ps(0x1.fffdb6p-2f))),
u,
_mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
const __m512 res = _mm512_scalef_ps(j, n);
if (_mm512_kortestz(d, d))
return res;
const __m512 zero = _mm512_setzero_ps();
const __m512 alt = _mm512_mask_blend_ps(
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
return _mm512_mask_blend_ps(d, res, alt);
}
static inline __m512 v_tanh(__m512 x) {
const __m512 one = _mm512_set1_ps(1.0f);
const __m512 exp_two_x = v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f)));
const __mmask16 mask = _mm512_cmp_ps_mask(x, _mm512_set1_ps(10.f), _CMP_GT_OQ);
const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
return _mm512_mask_blend_ps(mask, res, one);
}
static inline __m512 v_gelu(__m512 x, __m512 c1, __m512 c2) {
const __m512 one = _mm512_set1_ps(1.0f);
__m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one);
//__m512 arg = _mm512_add_ps(one, _mm512_mul_ps(_mm512_mul_ps(x, x), c1));
arg = _mm512_mul_ps(arg, _mm512_mul_ps(c2, x));
const __mmask16 mask = _mm512_cmp_ps_mask(arg, _mm512_set1_ps(30.f), _CMP_GT_OQ);
const __m512 exp_arg = v_expf(arg);
const __m512 ratio = _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one));
return _mm512_mul_ps(x, _mm512_mask_blend_ps(mask, ratio, one));
}
static inline __m512 v_silu(__m512 x) {
const __m512 one = _mm512_set1_ps(1);
const __m512 zero = _mm512_setzero_ps();
const __m512 neg_x = _mm512_sub_ps(zero, x);
const __m512 exp_neg_x = v_expf(neg_x);
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
return _mm512_div_ps(x, one_plus_exp_neg_x);
}
#endif // __AVX512__
#if defined(__AVX2__) && defined(__FMA__)
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
// numbers above 88.38 will flush to infinity
// numbers beneath -103.97 will flush to zero
static inline __m256 v_expf(__m256 x) {
const __m256 r = _mm256_set1_ps(0x1.8p23f);
const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
const __m256 n = _mm256_sub_ps(z, r);
const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
_mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
const __m256 k = _mm256_castsi256_ps(
_mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
const __m256i c = _mm256_castps_si256(
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
_mm256_set1_ps(126), _CMP_GT_OQ));
const __m256 u = _mm256_mul_ps(b, b);
const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
_mm256_set1_ps(0x1.573e2ep-5f)), u,
_mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
_mm256_set1_ps(0x1.fffdb6p-2f))),
u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
return _mm256_fmadd_ps(j, k, k);
const __m256i g = _mm256_and_si256(
_mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
_mm256_set1_epi32(0x82000000u));
const __m256 s1 =
_mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
const __m256i d = _mm256_castps_si256(
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
_mm256_set1_ps(192), _CMP_GT_OQ));
return _mm256_or_ps(
_mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
_mm256_andnot_ps(
_mm256_castsi256_ps(d),
_mm256_or_ps(
_mm256_and_ps(_mm256_castsi256_ps(c),
_mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
_mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
}
static inline __m256 v_tanh(__m256 x) {
const __m256 one = _mm256_set1_ps(1.0f);
const __m256 exp_two_x = v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f)));
const __m256 res = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ);
return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res));
}
static inline __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) {
const __m256 one = _mm256_set1_ps(1.0f);
const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ);
__m256 arg = _mm256_add_ps(one, _mm256_mul_ps(_mm256_mul_ps(x, x), c1));
arg = _mm256_mul_ps(arg, _mm256_mul_ps(x, c2));
__m256 exp_arg = v_expf(arg);
__m256 gelu = _mm256_mul_ps(x, _mm256_div_ps(exp_arg, _mm256_add_ps(exp_arg, one)));
return _mm256_or_ps(_mm256_and_ps(mask, x), _mm256_andnot_ps(mask, gelu));
}
static inline __m256 v_silu(__m256 x) {
const __m256 one = _mm256_set1_ps(1);
const __m256 zero = _mm256_setzero_ps();
const __m256 neg_x = _mm256_sub_ps(zero, x);
const __m256 exp_neg_x = v_expf(neg_x);
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
return _mm256_div_ps(x, one_plus_exp_neg_x);
}
#endif // __AVX2__
#endif // IQK_IMPLEMENT