mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-04 19:10:03 +00:00
WIP softmax: ~3% gain on Zen4
This commit is contained in:
184
ggml/src/ggml.c
184
ggml/src/ggml.c
@@ -2278,6 +2278,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_mad_set_f32(const int n, float * z, const float * restrict y, const float * restrict x, const float v) {
|
||||
#if defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F32_STEP - 1));
|
||||
|
||||
GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
|
||||
|
||||
GGML_F32_VEC ax[GGML_F32_ARR];
|
||||
GGML_F32_VEC ay[GGML_F32_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
||||
for (int j = 0; j < GGML_F32_ARR; j++) {
|
||||
ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
|
||||
ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
|
||||
ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
|
||||
|
||||
GGML_F32_VEC_STORE(z + i + j*GGML_F32_EPR, ay[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
z[i] = y[i] + x[i]*v;
|
||||
}
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
z[i] = y[i] + x[i]*v;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
|
||||
#if defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
@@ -2588,6 +2619,27 @@ inline static __m512 ggml_v_expf(__m512 x) {
|
||||
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
|
||||
return _mm512_mask_blend_ps(d, res, alt);
|
||||
}
|
||||
inline static __m512 ggml_v_expf_fast(__m512 x) {
|
||||
const __m512 r = _mm512_set1_ps(0x1.8p23f);
|
||||
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
|
||||
const __m512 n = _mm512_sub_ps(z, r);
|
||||
const __m512 b =
|
||||
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
|
||||
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
|
||||
const __mmask16 d =
|
||||
_mm512_cmp_ps_mask(n, _mm512_set1_ps(-192), _CMP_LE_OQ);
|
||||
const __m512 u = _mm512_mul_ps(b, b);
|
||||
const __m512 j = _mm512_fmadd_ps(
|
||||
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
|
||||
_mm512_set1_ps(0x1.573e2ep-5f)),
|
||||
u,
|
||||
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
|
||||
_mm512_set1_ps(0x1.fffdb6p-2f))),
|
||||
u,
|
||||
_mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
|
||||
const __m512 res = _mm512_scalef_ps(j, n);
|
||||
return _mm512_mask_blend_ps(d, res, _mm512_setzero_ps());
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
inline static __m512 ggml_v_silu(__m512 x) {
|
||||
@@ -14452,6 +14504,31 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||
|
||||
//if (ith == 0) printf("%s: nc = %d, nr = %d, use_f16 = %d, max_bias = %g, src1 = %d\n", __func__, nc, nr, use_f16, max_bias, src1 ? 1 : 0);
|
||||
|
||||
//if (!use_f16 && max_bias <= 0) {
|
||||
// for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
// float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||
// float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||
// if (src1) {
|
||||
// const float * mp_f32 = (const float *)src1->data + (i1%ne01)*ne00;
|
||||
// ggml_vec_mad_set_f32(nc, wp, mp_f32, sp, scale);
|
||||
// } else {
|
||||
// ggml_vec_cpy_f32 (nc, wp, sp);
|
||||
// ggml_vec_scale_f32(nc, wp, scale);
|
||||
// }
|
||||
// float max = -INFINITY;
|
||||
// ggml_vec_max_f32(nc, &max, wp);
|
||||
|
||||
// ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
|
||||
// assert(sum > 0.0);
|
||||
|
||||
// sum = 1.0/sum;
|
||||
// ggml_vec_scale_f32(nc, dp, sum);
|
||||
// }
|
||||
// return;
|
||||
//}
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
// ALiBi
|
||||
const uint32_t h = (i1/ne01)%ne02; // head
|
||||
@@ -14460,6 +14537,113 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||
|
||||
if (src1 && !use_f16 && nc%16 == 0 && max_bias <= 0) {
|
||||
const float * mp_32 = (const float *)src1->data + (i1%ne01)*ne00;
|
||||
__m512 vscale = _mm512_set1_ps(scale);
|
||||
__m512 vmax = _mm512_fmadd_ps(vscale, _mm512_loadu_ps(sp), _mm512_loadu_ps(mp_32));
|
||||
_mm512_storeu_ps(dp, vmax);
|
||||
for (int j = 1; j < nc/16; ++j) {
|
||||
__m512 v = _mm512_fmadd_ps(vscale, _mm512_loadu_ps(sp + 16*j), _mm512_loadu_ps(mp_32 + 16*j));
|
||||
_mm512_storeu_ps(dp + 16*j, v);
|
||||
vmax = _mm512_max_ps(vmax, v);
|
||||
}
|
||||
float max = _mm512_reduce_max_ps(vmax);
|
||||
vmax = _mm512_set1_ps(-max);
|
||||
__m512 vsum = ggml_v_expf(_mm512_add_ps(_mm512_loadu_ps(dp), vmax));
|
||||
_mm512_storeu_ps(dp, vsum);
|
||||
for (int j = 1; j < nc/16; ++j) {
|
||||
__m512 v = ggml_v_expf_fast(_mm512_add_ps(_mm512_loadu_ps(dp + 16*j), vmax));
|
||||
_mm512_storeu_ps(dp + 16*j, v);
|
||||
vsum = _mm512_add_ps(vsum, v);
|
||||
}
|
||||
float sum = _mm512_reduce_add_ps(vsum);
|
||||
//float sum;
|
||||
//if (max < 16.f) {
|
||||
// __m512 vsum = ggml_v_expf(_mm512_loadu_ps(dp));
|
||||
// _mm512_storeu_ps(dp, vsum);
|
||||
// for (int j = 1; j < nc/16; ++j) {
|
||||
// __m512 v = ggml_v_expf_fast(_mm512_loadu_ps(dp + 16*j));
|
||||
// _mm512_storeu_ps(dp + 16*j, v);
|
||||
// vsum = _mm512_add_ps(vsum, v);
|
||||
// }
|
||||
// sum = _mm512_reduce_add_ps(vsum);
|
||||
//} else {
|
||||
// vmax = _mm512_set1_ps(-max);
|
||||
// __m512 vsum = ggml_v_expf(_mm512_add_ps(_mm512_loadu_ps(dp), vmax));
|
||||
// _mm512_storeu_ps(dp, vsum);
|
||||
// for (int j = 1; j < nc/16; ++j) {
|
||||
// __m512 v = ggml_v_expf_fast(_mm512_add_ps(_mm512_loadu_ps(dp + 16*j), vmax));
|
||||
// _mm512_storeu_ps(dp + 16*j, v);
|
||||
// vsum = _mm512_add_ps(vsum, v);
|
||||
// }
|
||||
// sum = _mm512_reduce_add_ps(vsum);
|
||||
//}
|
||||
////GGML_ASSERT(sum > 0);
|
||||
__m512 norm = _mm512_set1_ps(1/sum);
|
||||
for (int j = 0; j < nc/16; ++j) {
|
||||
__m512 v = _mm512_mul_ps(norm, _mm512_loadu_ps(dp + 16*j));
|
||||
_mm512_storeu_ps(dp + 16*j, v);
|
||||
}
|
||||
|
||||
//_mm512_storeu_ps(wp, vmax);
|
||||
//for (int j = 1; j < nc/16; ++j) {
|
||||
// __m512 v = _mm512_fmadd_ps(vscale, _mm512_loadu_ps(sp + 16*j), _mm512_loadu_ps(mp_32 + 16*j));
|
||||
// _mm512_storeu_ps(wp + 16*j, v);
|
||||
// vmax = _mm512_max_ps(vmax, v);
|
||||
//}
|
||||
//float max = _mm512_reduce_max_ps(vmax);
|
||||
////if (max == -INFINITY) printf("Oops: max is -infinity?\n");
|
||||
//vmax = _mm512_set1_ps(-max);
|
||||
//__m512 vsum = ggml_v_expf(_mm512_add_ps(_mm512_loadu_ps(wp), vmax));
|
||||
//_mm512_storeu_ps(wp, vsum);
|
||||
//for (int j = 1; j < nc/16; ++j) {
|
||||
// __m512 v = ggml_v_expf_fast(_mm512_add_ps(_mm512_loadu_ps(wp + 16*j), vmax));
|
||||
// _mm512_storeu_ps(wp + 16*j, v);
|
||||
// vsum = _mm512_add_ps(vsum, v);
|
||||
//}
|
||||
//float sum = _mm512_reduce_add_ps(vsum);
|
||||
////GGML_ASSERT(sum > 0);
|
||||
//__m512 norm = _mm512_set1_ps(1/sum);
|
||||
//for (int j = 0; j < nc/16; ++j) {
|
||||
// __m512 v = _mm512_mul_ps(norm, _mm512_loadu_ps(wp + 16*j));
|
||||
// _mm512_storeu_ps(dp + 16*j, v);
|
||||
//}
|
||||
|
||||
//if (mp_32[0] == -INFINITY) {
|
||||
// memset(sp, 0, nc*sizeof(float));
|
||||
// continue;
|
||||
//}
|
||||
//__m512 vscale = _mm512_set1_ps(scale);
|
||||
//__m512 vmax = _mm512_fmadd_ps(vscale, _mm512_loadu_ps(sp), _mm512_loadu_ps(mp_32));
|
||||
//_mm512_storeu_ps(wp, vmax);
|
||||
//int jj = 1;
|
||||
//for (; jj < nc/16; ++jj) {
|
||||
// if (mp_32[16*jj] == -INFINITY) break;
|
||||
// __m512 v = _mm512_fmadd_ps(vscale, _mm512_loadu_ps(sp + 16*jj), _mm512_loadu_ps(mp_32 + 16*jj));
|
||||
// _mm512_storeu_ps(wp + 16*jj, v);
|
||||
// vmax = _mm512_max_ps(vmax, v);
|
||||
//}
|
||||
//float max = _mm512_reduce_max_ps(vmax);
|
||||
//vmax = _mm512_set1_ps(-max);
|
||||
//__m512 vsum = ggml_v_expf(_mm512_add_ps(_mm512_loadu_ps(wp), vmax));
|
||||
//_mm512_storeu_ps(wp, vsum);
|
||||
//for (int j = 1; j < jj; ++j) {
|
||||
// __m512 v = ggml_v_expf(_mm512_add_ps(_mm512_loadu_ps(wp + 16*j), vmax));
|
||||
// _mm512_storeu_ps(wp + 16*j, v);
|
||||
// vsum = _mm512_add_ps(vsum, v);
|
||||
//}
|
||||
//float sum = _mm512_reduce_add_ps(vsum);
|
||||
//__m512 norm = _mm512_set1_ps(1/sum);
|
||||
//for (int j = 0; j < jj; ++j) {
|
||||
// __m512 v = _mm512_mul_ps(norm, _mm512_loadu_ps(wp + 16*j));
|
||||
// _mm512_storeu_ps(dp + 16*j, v);
|
||||
//}
|
||||
//if (jj < nc/16) {
|
||||
// memset(dp + 16*jj, 0, (nc - 16*jj)*sizeof(float));
|
||||
//}
|
||||
continue;
|
||||
}
|
||||
|
||||
// broadcast the mask across rows
|
||||
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
||||
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
||||
|
||||
Reference in New Issue
Block a user