mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
iq1_tn: improve AVX2
PP-512 goes to 533 t/s up from 455. TG-128 @ 2 threads goes to 16.6 t/s up from 14.2. However, we seem to have a bottleneck somewhere as TG saturates at 8 threads.
This commit is contained in:
@@ -2078,15 +2078,16 @@ template <int nrc> struct Q8_K64 {
|
||||
Q8_K64(const DataInfo& info) {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const float * dptr = (const float *)info.src1_row(iy);
|
||||
std::memcpy(d + 4*iy, dptr, 4*sizeof(float));
|
||||
y[iy] = (const int8_t *)(dptr + 4);
|
||||
std::memcpy(d + 8*iy, dptr, 8*sizeof(float));
|
||||
y[iy] = (const int8_t *)(dptr + 8);
|
||||
}
|
||||
}
|
||||
|
||||
inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); }
|
||||
inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 4*iy); }
|
||||
inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 8*iy); }
|
||||
inline __m128 minus(int iy) const { return _mm_loadu_ps(d + 8*iy + 4); }
|
||||
|
||||
float d[4*nrc_y];
|
||||
float d[8*nrc_y];
|
||||
const int8_t * y[nrc_y];
|
||||
};
|
||||
|
||||
@@ -2124,8 +2125,8 @@ struct DequantizerIQ1BN {
|
||||
v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8);
|
||||
v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8);
|
||||
#else
|
||||
v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8);
|
||||
v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8);
|
||||
v1 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216);
|
||||
v2 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -2169,10 +2170,10 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, deq.m1_8, dot1), deq.m1_8, dot2);
|
||||
acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, deq.m1_8, dot3), deq.m1_8, dot4);
|
||||
#else
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1])));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3])));
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)),
|
||||
_mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3)));
|
||||
acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1));
|
||||
acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2));
|
||||
#endif
|
||||
@@ -2197,10 +2198,10 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
|
||||
accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
|
||||
#else
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1])));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3])));
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1)));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)),
|
||||
_mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3)));
|
||||
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
|
||||
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
|
||||
#endif
|
||||
@@ -2211,13 +2212,13 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
if (i < nb) {
|
||||
deq.prepare_iq1bn_quants(x + i, val[0], val[1]);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
|
||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2);
|
||||
#else
|
||||
auto dot = _mm256_madd_epi16(m1_16,
|
||||
_mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)));
|
||||
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 1))));
|
||||
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
|
||||
#endif
|
||||
}
|
||||
@@ -2226,7 +2227,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto vd = q8.scale(iy);
|
||||
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
|
||||
auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi));
|
||||
auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
|
||||
if constexpr (is_iq1_tn) {
|
||||
info.store(ix, iy, scale*hsum_float_4(sumf));
|
||||
} else {
|
||||
|
||||
@@ -437,6 +437,9 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
|
||||
vid[i] = _mm_set1_ps(id);
|
||||
}
|
||||
__m128i q[4];
|
||||
__m128i sums = _mm_setzero_si128();
|
||||
__m128i m1_8 = _mm_set1_epi8(1);
|
||||
__m128i m1_16 = _mm_set1_epi16(1);
|
||||
for (int j = 0; j < k; j += 16) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto val = _mm_loadu_ps(x + j + 4*i);
|
||||
@@ -446,9 +449,13 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
|
||||
auto q1 = _mm_packs_epi32(q[0], q[1]);
|
||||
auto q2 = _mm_packs_epi32(q[2], q[3]);
|
||||
auto qi = _mm_packs_epi16(q1, q2);
|
||||
auto aux = _mm_maddubs_epi16(m1_8, qi);
|
||||
sums = _mm_add_epi32(sums, _mm_madd_epi16(m1_16, aux));
|
||||
_mm_storeu_si128((__m128i *)qs, qi);
|
||||
qs += 16;
|
||||
}
|
||||
auto minus = _mm_mul_ps(_mm_loadu_ps(dptr), _mm_cvtepi32_ps(sums));
|
||||
_mm_storeu_ps(dptr + 4, minus);
|
||||
#else
|
||||
float aux[4] = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int j = 0; j < k; j += 16) {
|
||||
|
||||
Reference in New Issue
Block a user