Fix broken matrix x vector product on Zen4

This commit is contained in:
Iwan Kawrakow
2024-12-08 16:23:41 +02:00
parent 5de1cf4885
commit 4dc97b187b

View File

@@ -2922,10 +2922,11 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const
auto m4 = _mm256_set1_epi8(0xf);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
#endif
auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
auto values = MM256_SET_M128I(values128, values128);
//auto values = load_iq4nl_values_256();
#else
auto values = load_iq4nl_values_256();
#endif
int nbl = n / QK_K;
using helper_t = union { __m256i vec; uint32_t val[8]; };
helper_t h;
@@ -2969,7 +2970,6 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
float d8 = q8.scale(iy, ibl);
//float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]);
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
@@ -3049,7 +3049,6 @@ static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
float d8 = q8.scale(iy, ibl);
//float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]);
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, _mm512_set1_ps(d8)), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[2*iy+1]);