iq1_tn: improve Zen4

PP-512 goes to 485 t/s up from 352. With FA we get 545 t/s up from 380.
TG-128 @ 1 thread goes to 12.4 t/s up from 10.4.
However, we seem to have a bottleneck somewhere as
TG saturates at 8 threads.
This commit is contained in:
Iwan Kawrakow
2024-09-09 09:02:33 +03:00
parent 45db1385ef
commit 41c8200d08

View File

@@ -2122,8 +2122,8 @@ struct DequantizerIQ1BN {
auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3);
auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3);
#ifdef HAVE_FANCY_SIMD
v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8);
v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8);
v1 = _mm256_permutex2var_epi8(val1, bmask, val2);
v2 = _mm256_permutex2var_epi8(val3, bmask, val4);
#else
v1 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216);
v2 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216);
@@ -2163,12 +2163,8 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
#if defined __AVX512VNNI__ && defined __AVX512VL__
auto dot1 = _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0]);
auto dot2 = _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]);
auto dot3 = _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2]);
auto dot4 = _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]);
acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, deq.m1_8, dot1), deq.m1_8, dot2);
acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, deq.m1_8, dot3), deq.m1_8, dot4);
acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1));
acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3));
#else
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
_mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
@@ -2191,12 +2187,11 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
for (int iy = 0; iy < nrc_y; ++iy) {
#if defined __AVX512VNNI__ && defined __AVX512VL__
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]);
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]);
auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]);
auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]);
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
val[0], q8.load_quants(iy, i, 0)),
val[1], q8.load_quants(iy, i, 1)),
val[2], q8.load_quants(iy, i, 2)),
val[3], q8.load_quants(iy, i, 3));
#else
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1)));
@@ -2213,9 +2208,8 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
deq.prepare_iq1bn_quants(x + i, val[0], val[1]);
for (int iy = 0; iy < nrc_y; ++iy) {
#if defined __AVX512VNNI__ && defined __AVX512VL__
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2);
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1));
#else
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 1))));