From 9790b502e61a4e9110bf2bee8bc7b7a13c0ec064 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 17 Sep 2024 08:35:28 +0300 Subject: [PATCH] Playing with horizontal sums --- ggml/src/iqk/iqk_mul_mat.cpp | 104 +++++++++++++++++++++++++++++------ 1 file changed, 87 insertions(+), 17 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 7543d895..5d02773a 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -260,6 +260,19 @@ inline float hmax_float_8(__m256 x) { max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4)); return _mm_cvtss_f32(max4); } +IQK_ALWAYS_INLINE __m256 hsum_float_8x8(__m256 * accm) { + for (int i = 0; i < 4; ++i) { + accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), + _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); + } + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); +} +IQK_ALWAYS_INLINE void store_8(int ix, __m256 * accm, const DataInfo& info) { + union { __m256 vec; float val[8]; } h; + h.vec = hsum_float_8x8(accm); + for (int iy = 0; iy < 8; ++iy) info.store(ix, iy, h.val[iy]); +} #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) @@ -1128,9 +1141,17 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); - info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); + if constexpr (nrc_y == 8) { + for (int iy = 0; iy < nrc_y; ++iy) { + accm[iy] = _mm256_add_ps(accm[iy], _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1))); + } + store_8(ix, accm, info); + } + else { + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); + info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); + } } } @@ -1230,9 +1251,18 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); - info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); + if constexpr (nrc_y == 8) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); + accm[iy] = _mm256_add_ps(accm[iy], sum256); + } + store_8(ix, accm, info); + } + else { + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); + info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); + } } } @@ -1833,8 +1863,12 @@ IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const Da } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, hsum_float_8(accd[iy])); + if constexpr (nrc_y == 8) { + store_8(ix, accd, info); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } } } @@ -1877,10 +1911,13 @@ static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, hsum_float_8(accd[iy])); + if constexpr (nrc_y == 8) { + store_8(ix, accd, info); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } } - } } @@ -1926,8 +1963,12 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, hsum_float_8(accd[iy])); + if constexpr (nrc_y == 8) { + store_8(ix, accd, info); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } } } @@ -2094,8 +2135,12 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data } } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, hsum_float_8(accd[iy])); + if constexpr (nrc_y == 8) { + store_8(ix, accd, info); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } } } } @@ -2999,10 +3044,17 @@ struct ScaleHelperQ_1 { } }; -struct MinusType0 { +template struct MinusType0 { inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); } inline float compute(float d, int) const { return d; } inline float result(__m256 acc, int) const { return hsum_float_8(acc); } + //inline void store(int ix, __m256 * acc, const DataInfo& info) { + // if constexpr (nrc_y == 8) { + // store_8(ix, acc, info); + // } else { + // for (int iy = 0; iy < nrc_y; ++iy) info.store(ix, iy, hsum_float_8(acc[iy])); + // } + //} }; template struct MinusType1 { @@ -3022,6 +3074,23 @@ template struct MinusType1 { const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); return hsum_float_4(_mm_add_ps(sum, accm[iy])); } + //inline void store(int ix, const __m256 * acc, const DataInfo& info) { + // for (int iy = 0; iy < nrc_y; ++iy) { + // accm[iy] = _mm_add_ps(accm[iy], _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1))); + // } + // if constexpr (nrc_y >= 4) { + // union { __m128 vec; float val[4]; } h; + // for (int i = 0; i < nrc_y/4; ++i) { + // accm[4*i+0] = _mm_add_ps(_mm_unpacklo_ps(accm[4*i+0], accm[4*i+2]), _mm_unpackhi_ps(accm[4*i+0], accm[4*i+2])); + // accm[4*i+1] = _mm_add_ps(_mm_unpacklo_ps(accm[4*i+1], accm[4*i+3]), _mm_unpackhi_ps(accm[4*i+1], accm[4*i+3])); + // h.vec = _mm_add_ps(_mm_unpacklo_ps(accm[4*i+0], accm[4*i+1]), _mm_unpackhi_ps(accm[4*i+0], accm[4*i+1])); + // for (int j = 0; j < 4; ++j) info.store(ix, 4*i+j, h.val[j]); + // } + // for (int iy = 4*(nrc_y/4); iy < nrc_y; ++iy) info.store(ix, iy, hsum_float_4(accm[iy])); + // } else { + // for (int iy = 0; iy < nrc_y; ++iy) info.store(ix, iy, hsum_float_4(accm[iy])); + // } + //} }; template struct AccumT { @@ -3054,6 +3123,7 @@ template struct AccumT { } } } + //accm.store(ix, acc, info); for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, accm.result(acc[iy], iy)); } @@ -3061,7 +3131,7 @@ template struct AccumT { }; template -using AccumType0 = AccumT; +using AccumType0 = AccumT, nrc_y, is_multiple_of_4>; template using AccumType1 = AccumT, nrc_y, is_multiple_of_4>;