mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
iq1_tn: improve Zen4
PP-512 goes to 485 t/s up from 352. With FA we get 545 t/s up from 380. TG-128 @ 1 thread goes to 12.4 t/s up from 10.4. However, we seem to have a bottleneck somewhere as TG saturates at 8 threads.
This commit is contained in:
@@ -2122,8 +2122,8 @@ struct DequantizerIQ1BN {
|
||||
auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3);
|
||||
auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8);
|
||||
v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8);
|
||||
v1 = _mm256_permutex2var_epi8(val1, bmask, val2);
|
||||
v2 = _mm256_permutex2var_epi8(val3, bmask, val4);
|
||||
#else
|
||||
v1 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216);
|
||||
v2 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216);
|
||||
@@ -2163,12 +2163,8 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0]);
|
||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]);
|
||||
auto dot3 = _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2]);
|
||||
auto dot4 = _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]);
|
||||
acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, deq.m1_8, dot1), deq.m1_8, dot2);
|
||||
acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, deq.m1_8, dot3), deq.m1_8, dot4);
|
||||
acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1));
|
||||
acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3));
|
||||
#else
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
|
||||
@@ -2191,12 +2187,11 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]);
|
||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]);
|
||||
auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]);
|
||||
auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]);
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
|
||||
accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
|
||||
val[0], q8.load_quants(iy, i, 0)),
|
||||
val[1], q8.load_quants(iy, i, 1)),
|
||||
val[2], q8.load_quants(iy, i, 2)),
|
||||
val[3], q8.load_quants(iy, i, 3));
|
||||
#else
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1)));
|
||||
@@ -2213,9 +2208,8 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
deq.prepare_iq1bn_quants(x + i, val[0], val[1]);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
|
||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2);
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
|
||||
val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1));
|
||||
#else
|
||||
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 1))));
|
||||
|
||||
Reference in New Issue
Block a user