iqk_soft_max

With this ggml_mul_mat_ext, he hit PP-512 = 209 t/s (iq1_bn) and
PP-512 = 246 t/s (iq2_bn) on the M2 Max CPU.
On the Ryzen-7950X we are at PP-512 = 447 t/s (iq1_bn, 32 threads)
and PP-512 = 530 t/s (iq2_bn, 16 threads).
This commit is contained in:
Iwan Kawrakow
2024-07-22 16:34:42 +02:00
parent 412bc31c75
commit 86d94862ae
3 changed files with 308 additions and 1 deletions

6
ggml.c
View File

@@ -13920,6 +13920,8 @@ static void ggml_compute_forward_soft_max_f32(
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
const size_t data_size = use_f16 ? sizeof(ggml_fp16_t) : sizeof(float);
for (int i1 = ir0; i1 < ir1; i1++) {
// ALiBi
const uint32_t h = (i1/ne01)%ne02; // head
@@ -13929,6 +13931,10 @@ static void ggml_compute_forward_soft_max_f32(
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
// broadcast the mask across rows
const char * mask = src1 ? (const char *) src1->data + (i1%ne01)*ne00*data_size : NULL;
if (iqk_soft_max(nc, sp, dp, wp, mask, scale, slope, use_f16)) {
continue;
}
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;

View File

@@ -4233,7 +4233,7 @@ template <int nrc> struct QF16 final : public QF16Base {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)(cx + iy*bx);
}
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4(y[iy] + k_step*i); }
IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4(y[iy] + 4*i); }
const __fp16 * y[nrc_y];
};
@@ -4656,3 +4656,303 @@ bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const
}
#endif
namespace {
// copied from ggml.c, which in turn was
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
// numbers above 88.38 will flush to infinity
// numbers beneath -103.97 will flush to zero
#if defined(__AVX512F__) && defined(__AVX512DQ__)
inline __m512 v_expf(__m512 x) {
const __m512 r = _mm512_set1_ps(0x1.8p23f);
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
const __m512 n = _mm512_sub_ps(z, r);
const __m512 b =
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
const __mmask16 d =
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
const __m512 u = _mm512_mul_ps(b, b);
const __m512 j = _mm512_fmadd_ps(
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
_mm512_set1_ps(0x1.573e2ep-5f)),
u,
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
_mm512_set1_ps(0x1.fffdb6p-2f))),
u,
_mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
const __m512 res = _mm512_scalef_ps(j, n);
if (_mm512_kortestz(d, d))
return res;
const __m512 zero = _mm512_setzero_ps();
const __m512 alt = _mm512_mask_blend_ps(
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
return _mm512_mask_blend_ps(d, res, alt);
}
#endif
#if defined(__AVX2__) && defined(__FMA__)
inline __m256 v_expf(__m256 x) {
const __m256 r = _mm256_set1_ps(0x1.8p23f);
const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
const __m256 n = _mm256_sub_ps(z, r);
const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
_mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
const __m256 k = _mm256_castsi256_ps(
_mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
const __m256i c = _mm256_castps_si256(
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
_mm256_set1_ps(126), _CMP_GT_OQ));
const __m256 u = _mm256_mul_ps(b, b);
const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
_mm256_set1_ps(0x1.573e2ep-5f)), u,
_mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
_mm256_set1_ps(0x1.fffdb6p-2f))),
u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
return _mm256_fmadd_ps(j, k, k);
const __m256i g = _mm256_and_si256(
_mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
_mm256_set1_epi32(0x82000000u));
const __m256 s1 =
_mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
const __m256i d = _mm256_castps_si256(
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
_mm256_set1_ps(192), _CMP_GT_OQ));
return _mm256_or_ps(
_mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
_mm256_andnot_ps(
_mm256_castsi256_ps(d),
_mm256_or_ps(
_mm256_and_ps(_mm256_castsi256_ps(c),
_mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
_mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
}
#endif
#ifdef __ARM_NEON
inline float32x4_t v_expf(float32x4_t x) {
const float32x4_t r = vdupq_n_f32(0x1.8p23f);
const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
const float32x4_t n = vsubq_f32(z, r);
const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
vdupq_n_f32(0x1.7f7d1cp-20f));
const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
const float32x4_t u = vmulq_f32(b, b);
const float32x4_t j = vfmaq_f32(
vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
return vfmaq_f32(k, j, k);
const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
}
#endif
#ifdef __AVX2__
inline float reduce_max(__m256 vmax) {
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(vmax, 1), _mm256_castps256_ps128(vmax));
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
return _mm_cvtss_f32(max4);
}
#endif
}
#ifdef __ARM_NEON
bool iqk_soft_max(int n, const float * sp, float * dp, float * wp, const char * bias, float scale, float slope, bool bias_is_f16) {
(void)wp;
const float32x4_t vscale = vdupq_n_f32(scale);
const float * xb = sp;
float max = -INFINITY;
if (bias) {
if (!bias_is_f16) {
const float * zb = (const float *)bias;
const float32x4_t vslope = vdupq_n_f32(slope);
float32x4_t vmax = vdupq_n_f32(-INFINITY);
int i = 0;
for (; i + 3 < n; i += 4) {
float32x4_t vx = vld1q_f32(xb + i);
float32x4_t vz = vld1q_f32(zb + i);
float32x4_t vy = vfmaq_f32(vmulq_f32(vslope, vz), vscale, vx);
vst1q_f32(dp + i, vy);
vmax = vmaxq_f32(vmax, vy);
}
max = vmaxvq_f32(vmax);
for (; i < n; ++i) {
dp[i] = scale*xb[i] + slope*zb[i];
max = MAX(max, dp[i]);
}
} else {
const __fp16 * zb = (const __fp16 *)bias;
float32x4_t vslope = vdupq_n_f32(slope);
float32x4_t vmax = vdupq_n_f32(-INFINITY);
int i = 0;
for (; i + 3 < n; i += 4) {
float32x4_t vx = vld1q_f32(xb + i);
float32x4_t vz = vcvt_f32_f16(vld1_f16(zb + i));
float32x4_t vy = vfmaq_f32(vmulq_f32(vslope, vz), vscale, vx);
vst1q_f32(dp + i, vy);
vmax = vmaxq_f32(vmax, vy);
}
max = vmaxvq_f32(vmax);
for (; i < n; ++i) {
dp[i] = scale*xb[i] + slope*GGML_FP16_TO_FP32(zb[i]);
max = MAX(max, dp[i]);
}
}
} else {
int i = 0;
float32x4_t vmax = vdupq_n_f32(-INFINITY);
for (; i + 3 < n; i += 4) {
float32x4_t vx = vld1q_f32(xb + i);
float32x4_t vy = vmulq_f32(vscale, vx);
vst1q_f32(dp + i, vy);
vmax = vmaxq_f32(vmax, vy);
}
max = vmaxvq_f32(vmax);
for (; i < n; ++i) {
dp[i] = scale*xb[i];
max = MAX(max, dp[i]);
}
}
float32x4_t vsum = vdupq_n_f32(0.f);
float32x4_t vmax = vdupq_n_f32(-max);
int i = 0;
for (; i + 3 < n; i += 4) {
float32x4_t v = vld1q_f32(dp + i);
v = v_expf(vaddq_f32(v, vmax));
vst1q_f32(dp + i, v);
vsum = vaddq_f32(vsum, v);
}
float sum = vaddvq_f32(vsum);
for (; i < n; ++i) {
float v = expf(dp[i] - max);
dp[i] = v;
sum += v;
}
float norm = 1.f/sum;
float32x4_t vnorm = vdupq_n_f32(norm);
i = 0;
for (; i + 3 < n; i += 4) {
float32x4_t v = vld1q_f32(dp + i);
v = vmulq_f32(v, vnorm);
vst1q_f32(dp + i, v);
}
for (; i < n; ++i) {
dp[i] *= norm;
}
return true;
}
#elif defined __AVX2__ && defined __FMA__
bool iqk_soft_max(int n, const float * sp, float * dp, float * wp, const char * bias, float scale, float slope, bool bias_is_f16) {
(void)wp;
const __m256 vscale = _mm256_set1_ps(scale);
const float * xb = sp;
float max = -INFINITY;
if (bias) {
const __m256 vslope = _mm256_set1_ps(slope);
__m256 vmax = _mm256_set1_ps(-INFINITY);
if (!bias_is_f16) {
const float * zb = (const float *)bias;
int i = 0;
for (; i + 7 < n; i += 8) {
__m256 vx = _mm256_loadu_ps(xb + i);
__m256 vz = _mm256_loadu_ps(zb + i);
__m256 vy = _mm256_fmadd_ps(vscale, vx, _mm256_mul_ps(vslope, vz));
_mm256_storeu_ps(dp + i, vy);
vmax = _mm256_max_ps(vmax, vy);
}
max = reduce_max(vmax);
for (; i < n; ++i) {
dp[i] = scale*xb[i] + slope*zb[i];
max = MAX(max, dp[i]);
}
} else {
const ggml_fp16_t * zb = (const ggml_fp16_t *)bias;
int i = 0;
for (; i + 7 < n; i += 8) {
__m256 vx = _mm256_loadu_ps(xb + i);
__m256 vz = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(zb + i)));
__m256 vy = _mm256_fmadd_ps(vscale, vx, _mm256_mul_ps(vslope, vz));
_mm256_storeu_ps(dp + i, vy);
vmax = _mm256_max_ps(vmax, vy);
}
max = reduce_max(vmax);
for (; i < n; ++i) {
dp[i] = scale*xb[i] + slope*GGML_FP16_TO_FP32(zb[i]);
max = MAX(max, dp[i]);
}
}
} else {
int i = 0;
__m256 vmax = _mm256_set1_ps(-INFINITY);
for (; i + 7 < n; i += 8) {
__m256 vx = _mm256_loadu_ps(xb + i);
__m256 vy = _mm256_mul_ps(vscale, vx);
_mm256_storeu_ps(dp + i, vy);
vmax = _mm256_max_ps(vmax, vy);
}
max = reduce_max(vmax);
for (; i < n; ++i) {
dp[i] = scale*xb[i];
max = MAX(max, dp[i]);
}
}
__m256 vsum = _mm256_setzero_ps();
__m256 vmax = _mm256_set1_ps(-max);
int i = 0;
#if defined __AVX512F__ && defined __AVX512DQ__
__m512 vsum512 = _mm512_setzero_ps();
__m512 vmax512 = _mm512_set1_ps(-max);
for (; i + 15 < n; i += 16) {
auto v = _mm512_loadu_ps(dp + i);
v = v_expf(_mm512_add_ps(v, vmax512));
_mm512_storeu_ps(dp + i, v);
vsum512 = _mm512_add_ps(vsum512, v);
}
vsum = _mm256_add_ps(_mm512_castps512_ps256(vsum512), _mm512_extractf32x8_ps(vsum512, 1));
#endif
for (; i + 7 < n; i += 8) {
__m256 v = _mm256_loadu_ps(dp + i);
v = v_expf(_mm256_add_ps(v, vmax));
_mm256_storeu_ps(dp + i, v);
vsum = _mm256_add_ps(vsum, v);
}
float sum = hsum_float_8(vsum);
for (; i < n; ++i) {
float v = expf(dp[i] - max);
dp[i] = v;
sum += v;
}
float norm = 1.f/sum;
__m256 vnorm = _mm256_set1_ps(norm);
i = 0;
for (; i + 7 < n; i += 8) {
__m256 v = _mm256_loadu_ps(dp + i);
v = _mm256_mul_ps(v, vnorm);
_mm256_storeu_ps(dp + i, v);
}
for (; i < n; ++i) {
dp[i] *= norm;
}
return true;
}
#else
bool iqk_soft_max(int, const float *, float *, float *, const char *, float, float, bool) {
return false;
}
#endif

View File

@@ -22,6 +22,7 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
int typeB, const void * B, long strideB,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
bool iqk_soft_max(int nc, const float * sp, float * dp, float * wp, const char * bias, float scale, float slope, bool bias_is_f16);
#ifdef __cplusplus
}