diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 609c24a8..030e1da8 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3309,15 +3309,24 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data auto smask = _mm256_set1_epi64x(0x8040201008040201); auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); auto m4 = _mm256_set1_epi8(4); - auto m1 = _mm256_set1_epi16(1); #endif + __m256 acc[nrc_y] = {}; +#ifdef HAVE_FANCY_SIMD __m256i shuffles[2] = { _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100), _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908) }; - auto s_shuffle = _mm_set_epi64x(0x0f0d0b0907050301, 0x0e0c0a0806040200); - __m256 acc[nrc_y] = {}; __m256i isum[2*nrc_y] = {}; +#else + __m256i shuffles[4] = { + MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)), + MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)), + MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)), + MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)), + }; + __m256i isum[nrc_y] = {}; +#endif + auto s_shuffle = _mm_set_epi64x(0x0f0d0b0907050301, 0x0e0c0a0806040200); __m256i qx[4]; union { __m256i vec; uint16_t val[16]; } helper; for (int ix = 0; ix < nrc_x; ix += 4) { @@ -3335,15 +3344,15 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data qx[3] = _mm256_set_epi64x(iq2xs_grid[helper.val[15]], iq2xs_grid[helper.val[14]], iq2xs_grid[helper.val[13]], iq2xs_grid[helper.val[12]]); auto signs16 = _mm256_srli_epi16(val, 9); signs16 = _mm256_xor_si256(signs16, _mm256_slli_epi16(signs16, 1)); - auto signs = _mm_or_si128(_mm256_castsi256_si128(signs16), _mm_slli_epi16(_mm256_extracti128_si256(signs16, 1), 8)); - signs = _mm_shuffle_epi8(signs, s_shuffle); + auto signs128 = _mm_or_si128(_mm256_castsi256_si128(signs16), _mm_slli_epi16(_mm256_extracti128_si256(signs16, 1), 8)); + signs128 = _mm_shuffle_epi8(signs128, s_shuffle); auto scales = _mm_set1_epi32(s32[ib]); scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf)); scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1)); auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7 - __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) }; #ifdef HAVE_FANCY_SIMD - auto mask = (const __mmask32 *)&signs; + __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) }; + auto mask = (const __mmask32 *)&signs128; for (int iy = 0; iy < nrc_y; ++iy) { auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y)); // blocks: 0,0,0,0, 1,1,1,1, row 0 @@ -3369,31 +3378,37 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); shuffle = _mm256_add_epi8(shuffle, m4); auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + __m256i scs[4] = { + _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]), + _mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]), + }; for (int iy = 0; iy < nrc_y; ++iy) { auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); - auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1))); - auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2))); - auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3))); - auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4))); + auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1))); // blocks 4x0, 4x1, row 0 + auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2))); // blocks 4x2, 4x3, row 1 + auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3))); // blocks 4x4, 4x5, row 2 + auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4))); // blocks 4x6, 4x7, row 3 auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1 auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3 auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3 - isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi)); + isum[iy] = _mm256_add_epi32(isum[iy], sumi); } #endif } for (int iy = 0; iy < nrc_y; ++iy) { +#ifdef HAVE_FANCY_SIMD auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]); acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256(); - //acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); - //isum[iy] = _mm256_setzero_si256(); +#else + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); +#endif } } for (int iy = 0; iy < nrc_y; ++iy) { auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum)); - //info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), _mm256_castps256_ps128(acc[iy]))); acc[iy] = _mm256_setzero_ps(); } }