From 4dc97b187b36e4bb06ba4c2bf01db90a3d9f2738 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Dec 2024 16:23:41 +0200 Subject: [PATCH] Fix broken matrix x vector product on Zen4 --- ggml/src/iqk/iqk_mul_mat.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 7e7deb52..a9adedfc 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -2922,10 +2922,11 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const auto m4 = _mm256_set1_epi8(0xf); #ifndef HAVE_FANCY_SIMD auto m1 = _mm256_set1_epi16(1); -#endif auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); auto values = MM256_SET_M128I(values128, values128); - //auto values = load_iq4nl_values_256(); +#else + auto values = load_iq4nl_values_256(); +#endif int nbl = n / QK_K; using helper_t = union { __m256i vec; uint32_t val[8]; }; helper_t h; @@ -2969,7 +2970,6 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); float d8 = q8.scale(iy, ibl); - //float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]); float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]); acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]); @@ -3049,7 +3049,6 @@ static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); float d8 = q8.scale(iy, ibl); - //float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]); float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, _mm512_set1_ps(d8)), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[2*iy+1]);