diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 3764ddc6..e6a6f5ec 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6246,7 +6246,7 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI GGML_ASSERT(nrc_x%8 == 0); GGML_ASSERT(n%32 == 0); __m256i qx[2]; - __m256i acc[nrc_y] = {}; + __m256i acc[2*nrc_y] = {}; float dy[nrc_y]; #ifdef HAVE_FANCY_SIMD int32_t sy[nrc_y]; @@ -6279,10 +6279,10 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI for (int iy = 0; iy < nrc_y; ++iy) { for (int j = 0; j < 2; ++j) { #ifdef HAVE_FANCY_SIMD - acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j)); + acc[2*iy+j] = _mm256_dpbusd_epi32(acc[2*iy+j], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j)); #else auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j])); - acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot)); + acc[2*iy+j] = _mm256_add_epi32(acc[2*iy+j], _mm256_madd_epi16(m1, dot)); #endif } } @@ -6296,21 +6296,21 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI #endif for (int iy = 0; iy < nrc_y; ++iy) { #ifdef HAVE_FANCY_SIMD - acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i)); + acc[2*iy] = _mm256_dpbusd_epi32(acc[2*iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i)); #else auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0])); - acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot)); + acc[2*iy] = _mm256_add_epi32(acc[2*iy], _mm256_madd_epi16(m1, dot)); #endif } } for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi = hsum_i32_8(acc[iy]); + auto sumi = hsum_i32_8(_mm256_add_epi32(acc[2*iy], acc[2*iy+1])); #ifdef HAVE_FANCY_SIMD info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy])); #else info.store(ix, iy, dx[0]*dy[iy]*sumi); #endif - acc[iy] = _mm256_setzero_si256(); + acc[2*iy] = acc[2*iy+1] = _mm256_setzero_si256(); } } }