From c247793798d2487faeed5eb1162562e98c353bc5 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 6 Dec 2024 10:26:02 +0200 Subject: [PATCH] iq2_bn_r4: better AVX2 As we don't have enough vector registers on AVX2, it is better to do two passes per row needing only half of the accumulator registers that way. With this, we now beat iq2_bn PP also on AVX2 by a small margin. --- ggml/src/iqk/iqk_mul_mat.cpp | 60 ++++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index c8721a8b..b6ff7ab7 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -2126,8 +2126,63 @@ static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, con auto m3 = _mm256_set1_epi8(0x3); auto m1 = _mm256_set1_epi16(1); int nb = n / QK_IQ1BN; - __m256i acc[2*nrc_y] = {}; __m256i qx[4]; + if constexpr (nrc_y > 4) { + __m256i acc[nrc_y] = {}; + __m128 sum4[nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = _mm_loadu_ps(dptr); + const uint8_t * iq2l = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+0); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf1 = _mm256_cvtepi32_ps(acc[iy]); + auto s4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); + s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), s4); + sum4[iy] = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), s4); + acc[iy] = _mm256_setzero_si256(); + } + for (int ib = 0; ib < nb; ++ib) { + auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+1); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf1 = _mm256_cvtepi32_ps(acc[iy]); + auto s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4[iy]); + s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), s4); + info.store(ix, iy, s4); + acc[iy] = _mm256_setzero_si256(); + } + } + } else { + __m256i acc[2*nrc_y] = {}; for (int ix = 0; ix < nrc_x; ix += 4) { const float * dptr = (const float *)((const char *)vx + ix*bx); auto dl = _mm_loadu_ps(dptr); @@ -2164,8 +2219,6 @@ static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, con auto dy = q8.scale(iy); auto sumf1 = _mm256_cvtepi32_ps(acc[2*iy+0]); auto sumf2 = _mm256_cvtepi32_ps(acc[2*iy+1]); - //auto sumf1 = _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, acc[2*iy+0])); - //auto sumf2 = _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, acc[2*iy+1])); auto sum4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4); sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); @@ -2175,6 +2228,7 @@ static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, con acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_si256(); } } + } } #ifdef HAVE_FANCY_SIMD