From 56a6ee26bbe1b6d4dd38e367f54ced5a433b922f Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 5 Feb 2025 10:23:34 +0200 Subject: [PATCH] iq1_s_r4: slightly faster AVX2/Zen4 gemm/gemv --- ggml/src/iqk/iqk_mul_mat.cpp | 39 ++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 828d9270..8ea613c4 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3273,17 +3273,22 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI auto m1 = _mm256_set1_epi16(1); auto ms = _mm_set1_epi16(-32768); float d8[8*nrc_y]; + union { __m256i vec; uint16_t val[16]; } helper; + struct aux_iq1_s_r4 { + uint8_t qs[16]; + uint64_t qh; + }; for (int ix= 0; ix < nrc_x; ix += 4) { auto dptr = (const ggml_half *)((const char *)vx + ix*bx); auto d1 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr)); - auto x = (const block_iq1_s_r4 *)(dptr + 4); + auto x = (const aux_iq1_s_r4 *)(dptr + 4); for (int ib = 0; ib < nb/4; ++ib) { for (int iy = 0; iy < nrc_y; ++iy) { _mm256_storeu_ps(d8 + 8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib].d))); } for (int k = 0; k < 4; ++k) { - const uint64_t * s64 = (const uint64_t *)x[4*ib+k].qh; - auto sas = _mm_set1_epi64x(s64[0]); + auto idxh = _mm256_set1_epi64x(x[4*ib+k].qh); + auto sas = _mm256_castsi256_si128(idxh); auto scales4 = _mm_and_si128(_mm_srli_epi16(sas, 12), _mm_set1_epi16(7)); scales4 = _mm_or_si128(_mm_slli_epi16(scales4, 1), _mm_set1_epi16(1)); auto signs = _mm_or_si128(_mm_cmpeq_epi16(_mm_and_si128(sas, ms), ms), _mm256_castsi256_si128(m1)); @@ -3293,22 +3298,18 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI auto delta = _mm256_set_m128(delta4, delta4); scales4 = _mm_unpacklo_epi16(scales4, scales4); // 0,0, 1,1, 2,2, 3,3 auto scales = MM256_SET_M128I(scales4, scales4); - qx[0] = _mm256_set_epi64x(iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)]); - qx[1] = _mm256_set_epi64x(iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)]); - qx[2] = _mm256_set_epi64x(iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)]); - qx[3] = _mm256_set_epi64x(iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)]); + auto idxl = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)x[4*ib+k].qs)); + idxh = _mm256_sllv_epi64(idxh, _mm256_set_epi64x(0, 2, 5, 8)); + idxh = _mm256_srlv_epi64(idxh, _mm256_set_epi64x(1, 0, 0, 0)); + helper.vec = _mm256_or_si256(idxl, _mm256_and_si256(_mm256_set1_epi16(0x0700), idxh)); + qx[0] = _mm256_set_epi64x(iq1s_grid_us[helper.val[ 9]], iq1s_grid_us[helper.val[ 8]], + iq1s_grid_us[helper.val[ 1]], iq1s_grid_us[helper.val[ 0]]); + qx[1] = _mm256_set_epi64x(iq1s_grid_us[helper.val[13]], iq1s_grid_us[helper.val[12]], + iq1s_grid_us[helper.val[ 5]], iq1s_grid_us[helper.val[ 4]]); + qx[2] = _mm256_set_epi64x(iq1s_grid_us[helper.val[11]], iq1s_grid_us[helper.val[10]], + iq1s_grid_us[helper.val[ 3]], iq1s_grid_us[helper.val[ 2]]); + qx[3] = _mm256_set_epi64x(iq1s_grid_us[helper.val[15]], iq1s_grid_us[helper.val[14]], + iq1s_grid_us[helper.val[ 7]], iq1s_grid_us[helper.val[ 6]]); for (int iy = 0; iy < nrc_y; ++iy) { auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ib].qs + k); #ifdef HAVE_FANCY_SIMD