mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-01 03:41:53 +00:00
iq2_bn: improve on Zen4
We now get PP-512 = 614 t/s up from 542 t/s
This commit is contained in:
@@ -2242,8 +2242,13 @@ struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn> {
|
|||||||
make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2);
|
make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2);
|
||||||
}
|
}
|
||||||
IQK_ALWAYS_INLINE void make2(__m256i q2_1, __m256i * val) const {
|
IQK_ALWAYS_INLINE void make2(__m256i q2_1, __m256i * val) const {
|
||||||
|
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||||
|
val[0] = _mm256_and_si256(q2_1, mask2);
|
||||||
|
val[1] = _mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2);
|
||||||
|
#else
|
||||||
val[0] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
|
val[0] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
|
||||||
val[1] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask3), mf_8);
|
val[1] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask3), mf_8);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
IQK_ALWAYS_INLINE void prepare2(int i, __m256i * val) const {
|
IQK_ALWAYS_INLINE void prepare2(int i, __m256i * val) const {
|
||||||
auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
|
auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
|
||||||
@@ -2276,10 +2281,10 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
|
|||||||
for (int i = 0; i < nb/2; ++i) {
|
for (int i = 0; i < nb/2; ++i) {
|
||||||
deq.prepare4(i, val);
|
deq.prepare4(i, val);
|
||||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||||
acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])),
|
acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], val[0], q8.load_quants(0, i, 0)),
|
||||||
deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]));
|
val[1], q8.load_quants(0, i, 1));
|
||||||
acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])),
|
acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)),
|
||||||
deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]));
|
val[3], q8.load_quants(0, i, 3));
|
||||||
#else
|
#else
|
||||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])),
|
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])),
|
||||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1])));
|
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1])));
|
||||||
@@ -2298,14 +2303,15 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
|
|||||||
for (int i = 0; i < nb/2; ++i) {
|
for (int i = 0; i < nb/2; ++i) {
|
||||||
deq.prepare4(i, val);
|
deq.prepare4(i, val);
|
||||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||||
|
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||||
|
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
|
||||||
|
val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)),
|
||||||
|
val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3));
|
||||||
|
#else
|
||||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]);
|
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]);
|
||||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]);
|
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]);
|
||||||
auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]);
|
auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]);
|
||||||
auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]);
|
auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]);
|
||||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
|
||||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
|
|
||||||
accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
|
|
||||||
#else
|
|
||||||
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
|
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
|
||||||
_mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)),
|
_mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)),
|
||||||
_mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot3), _mm256_maddubs_epi16(deq.m1_8, dot4))));
|
_mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot3), _mm256_maddubs_epi16(deq.m1_8, dot4))));
|
||||||
@@ -2318,11 +2324,12 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
|
|||||||
if (i < nb) {
|
if (i < nb) {
|
||||||
deq.prepare2(i, val);
|
deq.prepare2(i, val);
|
||||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||||
|
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||||
|
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)),
|
||||||
|
val[1], q8.load_quants(iy, i/2, 1));
|
||||||
|
#else
|
||||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
|
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
|
||||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);
|
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);
|
||||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
|
||||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2);
|
|
||||||
#else
|
|
||||||
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)));
|
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)));
|
||||||
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
|
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
|
||||||
#endif
|
#endif
|
||||||
@@ -2332,7 +2339,11 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
|
|||||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||||
auto vd = q8.scale(iy);
|
auto vd = q8.scale(iy);
|
||||||
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
|
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
|
||||||
|
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||||
|
auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
|
||||||
|
#else
|
||||||
auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi));
|
auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi));
|
||||||
|
#endif
|
||||||
info.store(ix, iy, hsum_float_4(sumf));
|
info.store(ix, iy, hsum_float_4(sumf));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user