From 86d94862ae066dbbe63f1773ffd7d94f17c7fa46 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 22 Jul 2024 16:34:42 +0200 Subject: [PATCH] iqk_soft_max With this ggml_mul_mat_ext, he hit PP-512 = 209 t/s (iq1_bn) and PP-512 = 246 t/s (iq2_bn) on the M2 Max CPU. On the Ryzen-7950X we are at PP-512 = 447 t/s (iq1_bn, 32 threads) and PP-512 = 530 t/s (iq2_bn, 16 threads). --- ggml.c | 6 + iqk_mul_mat.cpp | 302 +++++++++++++++++++++++++++++++++++++++++++++++- iqk_mul_mat.h | 1 + 3 files changed, 308 insertions(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 776f1d43..a57d7844 100644 --- a/ggml.c +++ b/ggml.c @@ -13920,6 +13920,8 @@ static void ggml_compute_forward_soft_max_f32( const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const size_t data_size = use_f16 ? sizeof(ggml_fp16_t) : sizeof(float); + for (int i1 = ir0; i1 < ir1; i1++) { // ALiBi const uint32_t h = (i1/ne01)%ne02; // head @@ -13929,6 +13931,10 @@ static void ggml_compute_forward_soft_max_f32( float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows + const char * mask = src1 ? (const char *) src1->data + (i1%ne01)*ne00*data_size : NULL; + if (iqk_soft_max(nc, sp, dp, wp, mask, scale, slope, use_f16)) { + continue; + } ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index b8ca8c84..aed3ecca 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -4233,7 +4233,7 @@ template struct QF16 final : public QF16Base { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)(cx + iy*bx); } IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); } - IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4(y[iy] + k_step*i); } + IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4(y[iy] + 4*i); } const __fp16 * y[nrc_y]; }; @@ -4656,3 +4656,303 @@ bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const } #endif + +namespace { +// copied from ggml.c, which in turn was +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +#if defined(__AVX512F__) && defined(__AVX512DQ__) +inline __m512 v_expf(__m512 x) { + const __m512 r = _mm512_set1_ps(0x1.8p23f); + const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); + const __m512 n = _mm512_sub_ps(z, r); + const __m512 b = + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); + const __mmask16 d = + _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); + const __m512 u = _mm512_mul_ps(b, b); + const __m512 j = _mm512_fmadd_ps( + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, + _mm512_set1_ps(0x1.573e2ep-5f)), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, + _mm512_set1_ps(0x1.fffdb6p-2f))), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F))); + const __m512 res = _mm512_scalef_ps(j, n); + if (_mm512_kortestz(d, d)) + return res; + const __m512 zero = _mm512_setzero_ps(); + const __m512 alt = _mm512_mask_blend_ps( + _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); + return _mm512_mask_blend_ps(d, res, alt); +} +#endif +#if defined(__AVX2__) && defined(__FMA__) +inline __m256 v_expf(__m256 x) { + const __m256 r = _mm256_set1_ps(0x1.8p23f); + const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r); + const __m256 n = _mm256_sub_ps(z, r); + const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f), + _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x)); + const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23); + const __m256 k = _mm256_castsi256_ps( + _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1)))); + const __m256i c = _mm256_castps_si256( + _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), + _mm256_set1_ps(126), _CMP_GT_OQ)); + const __m256 u = _mm256_mul_ps(b, b); + const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b, + _mm256_set1_ps(0x1.573e2ep-5f)), u, + _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b, + _mm256_set1_ps(0x1.fffdb6p-2f))), + u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b)); + if (!_mm256_movemask_ps(_mm256_castsi256_ps(c))) + return _mm256_fmadd_ps(j, k, k); + const __m256i g = _mm256_and_si256( + _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)), + _mm256_set1_epi32(0x82000000u)); + const __m256 s1 = + _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u))); + const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g)); + const __m256i d = _mm256_castps_si256( + _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), + _mm256_set1_ps(192), _CMP_GT_OQ)); + return _mm256_or_ps( + _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)), + _mm256_andnot_ps( + _mm256_castsi256_ps(d), + _mm256_or_ps( + _mm256_and_ps(_mm256_castsi256_ps(c), + _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)), + _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k))))); +} +#endif +#ifdef __ARM_NEON +inline float32x4_t v_expf(float32x4_t x) { + const float32x4_t r = vdupq_n_f32(0x1.8p23f); + const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f)); + const float32x4_t n = vsubq_f32(z, r); + const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n, + vdupq_n_f32(0x1.7f7d1cp-20f)); + const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23); + const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1)))); + const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126)); + const float32x4_t u = vmulq_f32(b, b); + const float32x4_t j = vfmaq_f32( + vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b), + vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b), + vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u); + if (!vpaddd_u64(vreinterpretq_u64_u32(c))) + return vfmaq_f32(k, j, k); + const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000)); + const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000))); + const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d)); + return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1), + vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); +} +#endif +#ifdef __AVX2__ +inline float reduce_max(__m256 vmax) { + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(vmax, 1), _mm256_castps256_ps128(vmax)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + return _mm_cvtss_f32(max4); +} +#endif +} + +#ifdef __ARM_NEON +bool iqk_soft_max(int n, const float * sp, float * dp, float * wp, const char * bias, float scale, float slope, bool bias_is_f16) { + (void)wp; + const float32x4_t vscale = vdupq_n_f32(scale); + const float * xb = sp; + float max = -INFINITY; + if (bias) { + if (!bias_is_f16) { + const float * zb = (const float *)bias; + const float32x4_t vslope = vdupq_n_f32(slope); + float32x4_t vmax = vdupq_n_f32(-INFINITY); + int i = 0; + for (; i + 3 < n; i += 4) { + float32x4_t vx = vld1q_f32(xb + i); + float32x4_t vz = vld1q_f32(zb + i); + float32x4_t vy = vfmaq_f32(vmulq_f32(vslope, vz), vscale, vx); + vst1q_f32(dp + i, vy); + vmax = vmaxq_f32(vmax, vy); + } + max = vmaxvq_f32(vmax); + for (; i < n; ++i) { + dp[i] = scale*xb[i] + slope*zb[i]; + max = MAX(max, dp[i]); + } + } else { + const __fp16 * zb = (const __fp16 *)bias; + float32x4_t vslope = vdupq_n_f32(slope); + float32x4_t vmax = vdupq_n_f32(-INFINITY); + int i = 0; + for (; i + 3 < n; i += 4) { + float32x4_t vx = vld1q_f32(xb + i); + float32x4_t vz = vcvt_f32_f16(vld1_f16(zb + i)); + float32x4_t vy = vfmaq_f32(vmulq_f32(vslope, vz), vscale, vx); + vst1q_f32(dp + i, vy); + vmax = vmaxq_f32(vmax, vy); + } + max = vmaxvq_f32(vmax); + for (; i < n; ++i) { + dp[i] = scale*xb[i] + slope*GGML_FP16_TO_FP32(zb[i]); + max = MAX(max, dp[i]); + } + } + } else { + int i = 0; + float32x4_t vmax = vdupq_n_f32(-INFINITY); + for (; i + 3 < n; i += 4) { + float32x4_t vx = vld1q_f32(xb + i); + float32x4_t vy = vmulq_f32(vscale, vx); + vst1q_f32(dp + i, vy); + vmax = vmaxq_f32(vmax, vy); + } + max = vmaxvq_f32(vmax); + for (; i < n; ++i) { + dp[i] = scale*xb[i]; + max = MAX(max, dp[i]); + } + } + + float32x4_t vsum = vdupq_n_f32(0.f); + float32x4_t vmax = vdupq_n_f32(-max); + int i = 0; + for (; i + 3 < n; i += 4) { + float32x4_t v = vld1q_f32(dp + i); + v = v_expf(vaddq_f32(v, vmax)); + vst1q_f32(dp + i, v); + vsum = vaddq_f32(vsum, v); + } + float sum = vaddvq_f32(vsum); + for (; i < n; ++i) { + float v = expf(dp[i] - max); + dp[i] = v; + sum += v; + } + float norm = 1.f/sum; + float32x4_t vnorm = vdupq_n_f32(norm); + i = 0; + for (; i + 3 < n; i += 4) { + float32x4_t v = vld1q_f32(dp + i); + v = vmulq_f32(v, vnorm); + vst1q_f32(dp + i, v); + } + for (; i < n; ++i) { + dp[i] *= norm; + } + return true; +} + +#elif defined __AVX2__ && defined __FMA__ + +bool iqk_soft_max(int n, const float * sp, float * dp, float * wp, const char * bias, float scale, float slope, bool bias_is_f16) { + (void)wp; + const __m256 vscale = _mm256_set1_ps(scale); + const float * xb = sp; + float max = -INFINITY; + if (bias) { + const __m256 vslope = _mm256_set1_ps(slope); + __m256 vmax = _mm256_set1_ps(-INFINITY); + if (!bias_is_f16) { + const float * zb = (const float *)bias; + int i = 0; + for (; i + 7 < n; i += 8) { + __m256 vx = _mm256_loadu_ps(xb + i); + __m256 vz = _mm256_loadu_ps(zb + i); + __m256 vy = _mm256_fmadd_ps(vscale, vx, _mm256_mul_ps(vslope, vz)); + _mm256_storeu_ps(dp + i, vy); + vmax = _mm256_max_ps(vmax, vy); + } + max = reduce_max(vmax); + for (; i < n; ++i) { + dp[i] = scale*xb[i] + slope*zb[i]; + max = MAX(max, dp[i]); + } + } else { + const ggml_fp16_t * zb = (const ggml_fp16_t *)bias; + int i = 0; + for (; i + 7 < n; i += 8) { + __m256 vx = _mm256_loadu_ps(xb + i); + __m256 vz = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(zb + i))); + __m256 vy = _mm256_fmadd_ps(vscale, vx, _mm256_mul_ps(vslope, vz)); + _mm256_storeu_ps(dp + i, vy); + vmax = _mm256_max_ps(vmax, vy); + } + max = reduce_max(vmax); + for (; i < n; ++i) { + dp[i] = scale*xb[i] + slope*GGML_FP16_TO_FP32(zb[i]); + max = MAX(max, dp[i]); + } + } + } else { + int i = 0; + __m256 vmax = _mm256_set1_ps(-INFINITY); + for (; i + 7 < n; i += 8) { + __m256 vx = _mm256_loadu_ps(xb + i); + __m256 vy = _mm256_mul_ps(vscale, vx); + _mm256_storeu_ps(dp + i, vy); + vmax = _mm256_max_ps(vmax, vy); + } + max = reduce_max(vmax); + for (; i < n; ++i) { + dp[i] = scale*xb[i]; + max = MAX(max, dp[i]); + } + } + + __m256 vsum = _mm256_setzero_ps(); + __m256 vmax = _mm256_set1_ps(-max); + int i = 0; +#if defined __AVX512F__ && defined __AVX512DQ__ + __m512 vsum512 = _mm512_setzero_ps(); + __m512 vmax512 = _mm512_set1_ps(-max); + for (; i + 15 < n; i += 16) { + auto v = _mm512_loadu_ps(dp + i); + v = v_expf(_mm512_add_ps(v, vmax512)); + _mm512_storeu_ps(dp + i, v); + vsum512 = _mm512_add_ps(vsum512, v); + } + vsum = _mm256_add_ps(_mm512_castps512_ps256(vsum512), _mm512_extractf32x8_ps(vsum512, 1)); +#endif + for (; i + 7 < n; i += 8) { + __m256 v = _mm256_loadu_ps(dp + i); + v = v_expf(_mm256_add_ps(v, vmax)); + _mm256_storeu_ps(dp + i, v); + vsum = _mm256_add_ps(vsum, v); + } + float sum = hsum_float_8(vsum); + for (; i < n; ++i) { + float v = expf(dp[i] - max); + dp[i] = v; + sum += v; + } + float norm = 1.f/sum; + __m256 vnorm = _mm256_set1_ps(norm); + i = 0; + for (; i + 7 < n; i += 8) { + __m256 v = _mm256_loadu_ps(dp + i); + v = _mm256_mul_ps(v, vnorm); + _mm256_storeu_ps(dp + i, v); + } + for (; i < n; ++i) { + dp[i] *= norm; + } + return true; +} + +#else + +bool iqk_soft_max(int, const float *, float *, float *, const char *, float, float, bool) { + return false; +} + +#endif diff --git a/iqk_mul_mat.h b/iqk_mul_mat.h index cb28808d..8daa28cc 100644 --- a/iqk_mul_mat.h +++ b/iqk_mul_mat.h @@ -22,6 +22,7 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); +bool iqk_soft_max(int nc, const float * sp, float * dp, float * wp, const char * bias, float scale, float slope, bool bias_is_f16); #ifdef __cplusplus }