diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index a0a08fa4..3764ddc6 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6245,47 +6245,41 @@ template static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); GGML_ASSERT(n%32 == 0); -#ifndef HAVE_FANCY_SIMD + __m256i qx[2]; + __m256i acc[nrc_y] = {}; + float dy[nrc_y]; +#ifdef HAVE_FANCY_SIMD + int32_t sy[nrc_y]; +#else + __m256i sx[2]; auto m1 = _mm256_set1_epi16(1); #endif - __m256i qx[4]; - __m256i sx[4]; - __m256i acc[nrc_y] = {}; - float dy[nrc_y]; const int8_t * q8y[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) { auto dptr = (const float *)info.src1_row(iy); dy[iy] = dptr[0]; +#ifdef HAVE_FANCY_SIMD + auto iptr = (const int32_t *)(dptr+1); + sy[iy] = -127*iptr[0]; +#endif q8y[iy] = (const int8_t *)(dptr + 2); } for (int ix = 0; ix < nrc_x; ++ix) { auto dx = (const float *)((const char *)vx + ix*bx); auto q8x = (const int8_t *)(dx + 2); - for (int i = 0; i < n/128; ++i) { - for (int j = 0; j < 4; ++j) { - qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 4*i + j); - sx[j] = _mm256_sign_epi8(qx[j], qx[j]); - } - for (int iy = 0; iy < nrc_y; ++iy) { - for (int j = 0; j < 4; ++j) { -#ifdef HAVE_FANCY_SIMD - acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 4*i + j), qx[j])); -#else - auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 4*i + j), qx[j])); - acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot)); -#endif - } - } - } - for (int i = 2*(n/128); i < n/64; ++i) { + for (int i = 0; i < n/64; ++i) { for (int j = 0; j < 2; ++j) { +#ifdef HAVE_FANCY_SIMD + qx[j] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + 2*i + j), _mm256_set1_epi8(127)); +#else qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j); sx[j] = _mm256_sign_epi8(qx[j], qx[j]); +#endif } for (int iy = 0; iy < nrc_y; ++iy) { for (int j = 0; j < 2; ++j) { #ifdef HAVE_FANCY_SIMD - acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j])); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j)); #else auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j])); acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot)); @@ -6294,11 +6288,15 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI } } if (int i = 2*(n/64); i < n/32) { +#ifdef HAVE_FANCY_SIMD + qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127)); +#else qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i); sx[0] = _mm256_sign_epi8(qx[0], qx[0]); +#endif for (int iy = 0; iy < nrc_y; ++iy) { #ifdef HAVE_FANCY_SIMD - acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0])); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i)); #else auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0])); acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot)); @@ -6307,7 +6305,11 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI } for (int iy = 0; iy < nrc_y; ++iy) { auto sumi = hsum_i32_8(acc[iy]); - info.store(ix, iy, dx[0]*dy[2*iy+0]*sumi); +#ifdef HAVE_FANCY_SIMD + info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy])); +#else + info.store(ix, iy, dx[0]*dy[iy]*sumi); +#endif acc[iy] = _mm256_setzero_si256(); } }