diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index a1476e27..b5e3cba3 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -2242,13 +2242,8 @@ struct DequantizeIQ2BN final : public BaseDequantizer { make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2); } IQK_ALWAYS_INLINE void make2(__m256i q2_1, __m256i * val) const { -#if defined __AVX512VNNI__ && defined __AVX512VL__ val[0] = _mm256_and_si256(q2_1, mask2); val[1] = _mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2); -#else - val[0] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8); - val[1] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask3), mf_8); -#endif } IQK_ALWAYS_INLINE void prepare2(int i, __m256i * val) const { auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs); @@ -2286,10 +2281,10 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3)); #else - auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]))); - auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]))); + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)), + _mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3))); acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1)); acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2)); #endif @@ -2308,14 +2303,12 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)), val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3)); #else - auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]); - auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]); - auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]); - auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]); auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16( - _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)), - _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot3), _mm256_maddubs_epi16(deq.m1_8, dot4)))); - accd[iy] = i > 0 ? _mm256_add_epi32(dot, accd[iy]) : dot; + _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1))), + _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)), + _mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3))))); + accd[iy] = _mm256_add_epi32(dot, accd[iy]); #endif } } @@ -2328,10 +2321,9 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1)); #else - auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]); - auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]); - dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2))); - accd[iy] = _mm256_add_epi32(dot1, accd[iy]); + auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 0)))); + accd[iy] = _mm256_add_epi32(dot, accd[iy]); #endif } } @@ -2339,11 +2331,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const for (int iy = 0; iy < nrc_y; ++iy) { auto vd = q8.scale(iy); auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1)); -#if defined __AVX512VNNI__ && defined __AVX512VL__ auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy)); -#else - auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi)); -#endif info.store(ix, iy, hsum_float_4(sumf)); } }