diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 2a378358..33af4a6e 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -2078,15 +2078,16 @@ template struct Q8_K64 { Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) { const float * dptr = (const float *)info.src1_row(iy); - std::memcpy(d + 4*iy, dptr, 4*sizeof(float)); - y[iy] = (const int8_t *)(dptr + 4); + std::memcpy(d + 8*iy, dptr, 8*sizeof(float)); + y[iy] = (const int8_t *)(dptr + 8); } } inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); } - inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 4*iy); } + inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 8*iy); } + inline __m128 minus(int iy) const { return _mm_loadu_ps(d + 8*iy + 4); } - float d[4*nrc_y]; + float d[8*nrc_y]; const int8_t * y[nrc_y]; }; @@ -2124,8 +2125,8 @@ struct DequantizerIQ1BN { v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8); v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8); #else - v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8); - v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8); + v1 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216); + v2 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216); #endif } @@ -2169,10 +2170,10 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, deq.m1_8, dot1), deq.m1_8, dot2); acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, deq.m1_8, dot3), deq.m1_8, dot4); #else - auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]))); - auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]))); + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)), + _mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3))); acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1)); acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2)); #endif @@ -2197,10 +2198,10 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4); #else - auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]))); - auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]))); + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)), + _mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3))); dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2)); accd[iy] = _mm256_add_epi32(dot1, accd[iy]); #endif @@ -2211,13 +2212,13 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const if (i < nb) { deq.prepare_iq1bn_quants(x + i, val[0], val[1]); for (int iy = 0; iy < nrc_y; ++iy) { +#if defined __AVX512VNNI__ && defined __AVX512VL__ auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]); auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]); -#if defined __AVX512VNNI__ && defined __AVX512VL__ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2); #else - auto dot = _mm256_madd_epi16(m1_16, - _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2))); + auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 1)))); accd[iy] = _mm256_add_epi32(dot, accd[iy]); #endif } @@ -2226,7 +2227,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const for (int iy = 0; iy < nrc_y; ++iy) { auto vd = q8.scale(iy); auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1)); - auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi)); + auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy)); if constexpr (is_iq1_tn) { info.store(ix, iy, scale*hsum_float_4(sumf)); } else { diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 7ca1759d..9b39a490 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -437,6 +437,9 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) { vid[i] = _mm_set1_ps(id); } __m128i q[4]; + __m128i sums = _mm_setzero_si128(); + __m128i m1_8 = _mm_set1_epi8(1); + __m128i m1_16 = _mm_set1_epi16(1); for (int j = 0; j < k; j += 16) { for (int i = 0; i < 4; ++i) { auto val = _mm_loadu_ps(x + j + 4*i); @@ -446,9 +449,13 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) { auto q1 = _mm_packs_epi32(q[0], q[1]); auto q2 = _mm_packs_epi32(q[2], q[3]); auto qi = _mm_packs_epi16(q1, q2); + auto aux = _mm_maddubs_epi16(m1_8, qi); + sums = _mm_add_epi32(sums, _mm_madd_epi16(m1_16, aux)); _mm_storeu_si128((__m128i *)qs, qi); qs += 16; } + auto minus = _mm_mul_ps(_mm_loadu_ps(dptr), _mm_cvtepi32_ps(sums)); + _mm_storeu_ps(dptr + 4, minus); #else float aux[4] = {0.f, 0.f, 0.f, 0.f}; for (int j = 0; j < k; j += 16) {