mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-06 20:10:08 +00:00
bitnet(scale in a separate tensor): more CPU improvements
It seems it is enough to have 4 scales per row for Q8. I get PPL = 8.5470 with this, which is slightly higher than the 8.5430 we get with 1 scale per 128 activations, but still OK, I think. With this, we get the following performance: Systema | quant | PP-512 | TG-128a | quant | PP-512 | TG-12s | M2 Max | iq2bn 229.02 ± 0.37 78.75 ± 0.61 | iq1bn | 146.67 ± 2.85 33.12 ± 0.03 Ryzen7950| iq2bn 379.36 ± 1.03 49.08 ± 0.18 | iq1bn | 247.12 ± 1.53 32.80 ± 0.02 Ryzen5975| iq2bn 465.28 ± 0.57 39.17 ± 0.02 | iq1bn | 325.86 ± 0.46 26.60 ± 0.10
This commit is contained in:
@@ -355,8 +355,8 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
|
||||
}
|
||||
|
||||
void quantize_row_q8_K64_reference(const float * x, block_q8_K64 * y, int64_t k) {
|
||||
assert(k % 64 == 0);
|
||||
const int64_t nb = k / 64;
|
||||
//assert(k % 64 == 0);
|
||||
//const int64_t nb = k / 64;
|
||||
|
||||
// Check if a row-wise scale works. It almost does, PPL is only ~0.02 higher
|
||||
//float amax = 0;
|
||||
@@ -374,50 +374,24 @@ void quantize_row_q8_K64_reference(const float * x, block_q8_K64 * y, int64_t k)
|
||||
// x += 64;
|
||||
//}
|
||||
|
||||
block_q8_K128 * yp = (block_q8_K128 *)y;
|
||||
for (int i = 0; i < nb/2; i++) {
|
||||
float max = 0;
|
||||
float amax = 0;
|
||||
for (int j = 0; j < 128; ++j) {
|
||||
float ax = fabsf(x[j]);
|
||||
if (ax > amax) {
|
||||
amax = ax; max = x[j];
|
||||
float aux[4] = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int j = 0; j < k; j += 16) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
float ax = fabsf(x[j+4*i+l]);
|
||||
aux[i] = std::max(aux[i], ax);
|
||||
}
|
||||
}
|
||||
if (!amax) {
|
||||
yp[i].d = 0;
|
||||
memset(yp[i].qs, 0, 128);
|
||||
x += 128;
|
||||
continue;
|
||||
}
|
||||
const float iscale = -127.f/max;
|
||||
for (int j = 0; j < 128; ++j) {
|
||||
int v = nearest_int(iscale*x[j]);
|
||||
yp[i].qs[j] = MIN(127, v);
|
||||
}
|
||||
yp[i].d = 1/iscale;
|
||||
x += 128;
|
||||
}
|
||||
int i = 2*(nb/2);
|
||||
if (i < nb) {
|
||||
float max = 0;
|
||||
float amax = 0;
|
||||
for (int j = 0; j < 64; ++j) {
|
||||
float ax = fabsf(x[j]);
|
||||
if (ax > amax) {
|
||||
amax = ax; max = x[j];
|
||||
}
|
||||
}
|
||||
if (!amax) {
|
||||
yp[i/2].d = 0;
|
||||
memset(yp[i/2].qs, 0, 64);
|
||||
} else {
|
||||
const float iscale = -127.f/max;
|
||||
for (int j = 0; j < 64; ++j) {
|
||||
int v = nearest_int(iscale*x[j]);
|
||||
yp[i/2].qs[j] = MIN(127, v);
|
||||
}
|
||||
yp[i/2].d = 1/iscale;
|
||||
float * dptr = (float *)y;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
dptr[i] = aux[i]/127;
|
||||
aux[i] = dptr[i] > 0 ? 1/dptr[i] : 0.f;
|
||||
}
|
||||
auto qs = (int8_t *)(dptr + 4);
|
||||
for (int j = 0; j < k; j += 16) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
for (int l = 0; l < 4; ++l) qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
276
iqk_mul_mat.cpp
276
iqk_mul_mat.cpp
@@ -256,6 +256,13 @@ inline float hsum_float_4(__m128 x) {
|
||||
inline float hsum_float_8(__m256 x) {
|
||||
return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
|
||||
}
|
||||
inline int hsum_i32_8(const __m256i a) {
|
||||
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
|
||||
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
|
||||
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
|
||||
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
||||
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
||||
}
|
||||
|
||||
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
||||
|
||||
@@ -1318,12 +1325,19 @@ template <int nrc> struct Q8_K64 {
|
||||
|
||||
constexpr static int nrc_y = nrc;
|
||||
|
||||
Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_K128 *)info.src1_row(iy); }
|
||||
Q8_K64(const DataInfo& info) {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const float * dptr = (const float *)info.src1_row(iy);
|
||||
std::memcpy(d + 4*iy, dptr, 4*sizeof(float));
|
||||
y[iy] = (const int8_t *)(dptr + 4);
|
||||
}
|
||||
}
|
||||
|
||||
inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); }
|
||||
inline float scale(int iy, int i) const { return y[iy][i].d; }
|
||||
inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); }
|
||||
inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 4*iy); }
|
||||
|
||||
const block_q8_K128 * y[nrc_y];
|
||||
float d[4*nrc_y];
|
||||
const int8_t * y[nrc_y];
|
||||
};
|
||||
|
||||
struct DequantizerIQ1BN {
|
||||
@@ -1333,13 +1347,8 @@ struct DequantizerIQ1BN {
|
||||
const __m256i shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
|
||||
const __m256i shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404);
|
||||
const __m256i mask1 = _mm256_set1_epi64x(0x8040201008040201);
|
||||
//__m256i signs[2];
|
||||
|
||||
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) {
|
||||
//auto all_signs = _mm256_set1_epi8(extra);
|
||||
//all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8);
|
||||
//signs[0] = _mm256_shuffle_epi8(all_signs, shuff3);
|
||||
//signs[1] = _mm256_shuffle_epi8(all_signs, shuff4);
|
||||
|
||||
auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)],
|
||||
iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]);
|
||||
@@ -1350,8 +1359,6 @@ struct DequantizerIQ1BN {
|
||||
_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1));
|
||||
v2 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1),
|
||||
_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1));
|
||||
//v1 = _mm256_sign_epi8(v1, signs[0]);
|
||||
//v2 = _mm256_sign_epi8(v2, signs[1]);
|
||||
|
||||
auto all_signs = _mm256_set1_epi8(extra);
|
||||
all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8);
|
||||
@@ -1365,7 +1372,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
const int nb = n / QK_IQ1BN;
|
||||
Q8_K64<nrc_y> q8(info);
|
||||
DequantizerIQ1BN deq;
|
||||
__m256 accd[nrc_y];
|
||||
__m256i accd[nrc_y];
|
||||
__m256i val[4];
|
||||
|
||||
#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
|
||||
@@ -1378,31 +1385,56 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
|
||||
x = (const block_iq1_bn *)((const char *)vx + ix*bx);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
|
||||
deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]);
|
||||
deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
if constexpr (nrc_y == 1) {
|
||||
__m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256();
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]);
|
||||
deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]);
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]);
|
||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]);
|
||||
auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]);
|
||||
auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]);
|
||||
auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
|
||||
_mm256_setzero_si256(), deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]);
|
||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0]);
|
||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]);
|
||||
auto dot3 = _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2]);
|
||||
auto dot4 = _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]);
|
||||
acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, deq.m1_8, dot1), deq.m1_8, dot2);
|
||||
acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, deq.m1_8, dot3), deq.m1_8, dot4);
|
||||
#else
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1])));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3])));
|
||||
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot1), accd[iy]);
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1])));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3])));
|
||||
acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1));
|
||||
acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2));
|
||||
#endif
|
||||
}
|
||||
accd[0] = _mm256_add_epi32(acc1, acc2);
|
||||
}
|
||||
else {
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
|
||||
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
|
||||
deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]);
|
||||
deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]);
|
||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]);
|
||||
auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]);
|
||||
auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]);
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
|
||||
accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
|
||||
#else
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1])));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3])));
|
||||
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
|
||||
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
int i = 2*(nb/2);
|
||||
if (i < nb) {
|
||||
@@ -1411,17 +1443,20 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
|
||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.m1_8, dot1), deq.m1_8, dot2);
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2);
|
||||
#else
|
||||
auto dot = _mm256_madd_epi16(m1_16,
|
||||
_mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)));
|
||||
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
|
||||
#endif
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i/2)), _mm256_cvtepi32_ps(dot), accd[iy]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, hsum_float_8(accd[iy]));
|
||||
auto vd = q8.scale(iy);
|
||||
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
|
||||
auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi));
|
||||
info.store(ix, iy, hsum_float_4(sumf));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1431,7 +1466,7 @@ template <int nrc_y>
|
||||
IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
const int nb = n / QK_IQ1BN;
|
||||
Q8_K64<nrc_y> q8(info);
|
||||
__m256 accd[nrc_y];
|
||||
__m256i accd[nrc_y];
|
||||
|
||||
const auto m1_8 = _mm256_set1_epi8(1);
|
||||
const auto mask2 = _mm256_set1_epi8(3);
|
||||
@@ -1458,14 +1493,13 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, 0, 2), v3);
|
||||
auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, 0, 3), v4);
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
|
||||
_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4);
|
||||
#else
|
||||
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
|
||||
accd[iy] = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
|
||||
_mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)),
|
||||
_mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4))));
|
||||
#endif
|
||||
accd[iy] = _mm256_mul_ps(_mm256_set1_ps(q8.scale(iy, 0)), _mm256_cvtepi32_ps(dot));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1484,14 +1518,14 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), v3);
|
||||
auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), v4);
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
|
||||
_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4);
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
|
||||
accd[iy], m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4);
|
||||
#else
|
||||
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
|
||||
_mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)),
|
||||
_mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4))));
|
||||
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
|
||||
#endif
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]);
|
||||
}
|
||||
}
|
||||
int i = 2*(nb/2);
|
||||
@@ -1504,18 +1538,20 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), v1);
|
||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), v2);
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
dot1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2);
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], m1_8, dot1), m1_8, dot2);
|
||||
#else
|
||||
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
|
||||
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
|
||||
#endif
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i/2)), _mm256_cvtepi32_ps(dot1), accd[iy]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, hsum_float_8(accd[iy]));
|
||||
auto vd = q8.scale(iy);
|
||||
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
|
||||
auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi));
|
||||
info.store(ix, iy, hsum_float_4(sumf));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4149,13 +4185,20 @@ template <int nrc> struct Q8_K64 {
|
||||
|
||||
constexpr static int nrc_y = nrc;
|
||||
|
||||
Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_K128 *)info.src1_row(iy); }
|
||||
Q8_K64(const DataInfo& info) {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dptr = (const float *)info.src1_row(iy);
|
||||
std::memcpy(d + 4*iy, dptr, 4*sizeof(float));
|
||||
y[iy] = (const int8_t *)(dptr + 4);
|
||||
}
|
||||
}
|
||||
|
||||
inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); }
|
||||
inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); }
|
||||
inline float scale(int iy, int i) const { return y[iy][i].d; }
|
||||
inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy] + 128*i + 64*j); }
|
||||
inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy] + 128*i + 32*j); }
|
||||
inline float32x4_t scale(int iy) const { return vld1q_f32(d + 4*iy); }
|
||||
|
||||
const block_q8_K128 * y[nrc_y];
|
||||
float d[4*nrc_y];
|
||||
const int8_t * y[nrc_y];
|
||||
};
|
||||
|
||||
struct DequantizerIQ1BN {
|
||||
@@ -4204,8 +4247,8 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
|
||||
Q8_K64<nrc_y> q8(info);
|
||||
DequantizerIQ1BN deq;
|
||||
|
||||
float32x4_t accd[nrc_y];
|
||||
int8x16x4_t v1, v2;
|
||||
int32x4_t accd[nrc_y];
|
||||
int8x16x4_t v1, v2;
|
||||
|
||||
const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
|
||||
|
||||
@@ -4213,35 +4256,37 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
|
||||
|
||||
x = (const block_iq1_bn *)((const char *)vx + ix*bx);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_f32(0.f);
|
||||
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
if constexpr (nrc_y == 1) {
|
||||
int32x4_t acc[4] = {};
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1);
|
||||
auto q = q8.load_quants64(0, i, 0);
|
||||
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
|
||||
deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v1);
|
||||
q = q8.load_quants64(0, i, 1);
|
||||
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
|
||||
}
|
||||
accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3]));
|
||||
}
|
||||
else {
|
||||
|
||||
deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1);
|
||||
deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v2);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0);
|
||||
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
|
||||
deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1);
|
||||
deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v2);
|
||||
|
||||
int32x4_t sumi1 = vdupq_n_s32(0);
|
||||
int32x4_t sumi2 = vdupq_n_s32(0);
|
||||
if constexpr (nrc_y == 1) {
|
||||
auto q1 = q8.load_quants64(0, i, 0);
|
||||
auto q2 = q8.load_quants64(0, i, 1);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
sumi1 = ggml_vdotq_s32(sumi1, q1.val[j], v1.val[j]);
|
||||
sumi2 = ggml_vdotq_s32(sumi2, q2.val[j], v2.val[j]);
|
||||
}
|
||||
accd[0] = vfmaq_f32(accd[0], vdupq_n_f32(q8.scale(0, i)), vcvtq_f32_s32(vaddq_s32(sumi1, sumi2)));
|
||||
} else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
int32x4_t sumi = vdupq_n_s32(0);
|
||||
auto q = q8.load_quants(iy, i, 0);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]);
|
||||
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
|
||||
q = q8.load_quants(iy, i, 1);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]);
|
||||
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
|
||||
q = q8.load_quants(iy, i, 2);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[0]), q.val[1], v2.val[1]);
|
||||
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]);
|
||||
q = q8.load_quants(iy, i, 3);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[2]), q.val[1], v2.val[3]);
|
||||
accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi));
|
||||
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4250,25 +4295,21 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
|
||||
deq.prepare_iq1bn_quants(x[i].extra, x[i].ql, x[i].qh, v1);
|
||||
if constexpr (nrc_y == 1) {
|
||||
auto q = q8.load_quants(0, i/2, 0);
|
||||
int32x4_t sumi = vdupq_n_s32(0);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
sumi = ggml_vdotq_s32(sumi, q.val[j], v1.val[j]);
|
||||
accd[0] = ggml_vdotq_s32(accd[0], q.val[j], v1.val[j]);
|
||||
}
|
||||
accd[0] = vfmaq_f32(accd[0], vdupq_n_f32(q8.scale(0, i/2)), vcvtq_f32_s32(sumi));
|
||||
} else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
int32x4_t sumi = vdupq_n_s32(0);
|
||||
auto q = q8.load_quants(iy, i/2, 0);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]);
|
||||
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
|
||||
q = q8.load_quants(iy, i/2, 1);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]);
|
||||
accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i/2)), vcvtq_f32_s32(sumi));
|
||||
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, vaddvq_f32(accd[iy]));
|
||||
info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -4280,8 +4321,7 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
|
||||
|
||||
Q8_K64<nrc_y> q8(info);
|
||||
|
||||
float32x4_t accd[nrc_y];
|
||||
int8x16x4_t v1, v2;
|
||||
int32x4_t accd[nrc_y];
|
||||
|
||||
const auto m1 = vdupq_n_u8(1);
|
||||
const auto mask2 = vdupq_n_s8(3);
|
||||
@@ -4290,36 +4330,10 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
|
||||
|
||||
const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx);
|
||||
|
||||
{
|
||||
auto q2bits = vld1q_u8(x[0].qs);
|
||||
v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
|
||||
v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
|
||||
v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
|
||||
v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
|
||||
q2bits = vld1q_u8(x[1].qs);
|
||||
v2.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
|
||||
v2.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
|
||||
v2.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
|
||||
v2.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
int32x4_t sumi = vdupq_n_s32(0);
|
||||
auto q = q8.load_quants(iy, 0, 0);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]);
|
||||
q = q8.load_quants(iy, 0, 1);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]);
|
||||
q = q8.load_quants(iy, 0, 2);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[0]), q.val[1], v2.val[1]);
|
||||
q = q8.load_quants(iy, 0, 3);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[2]), q.val[1], v2.val[3]);
|
||||
accd[iy] = vmulq_f32(vdupq_n_f32(q8.scale(iy, 0)), vcvtq_f32_s32(sumi));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if constexpr (nrc_y == 1) {
|
||||
for (int i = 1; i < nb/2; ++i) {
|
||||
auto sumi1 = vdupq_n_s32(0);
|
||||
auto sumi2 = vdupq_n_s32(0);
|
||||
int8x16x4_t v1;
|
||||
int32x4_t acc[4] = {};
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
auto q = q8.load_quants64(0, i, j);
|
||||
auto q2bits = vld1q_u8(x[2*i+j].qs);
|
||||
@@ -4327,13 +4341,17 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
|
||||
v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
|
||||
v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
|
||||
v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
|
||||
sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(sumi1, q.val[0], v1.val[0]), q.val[1], v1.val[1]);
|
||||
sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(sumi2, q.val[2], v1.val[2]), q.val[3], v1.val[3]);
|
||||
acc[0] = ggml_vdotq_s32(acc[0], q.val[0], v1.val[0]);
|
||||
acc[1] = ggml_vdotq_s32(acc[1], q.val[1], v1.val[1]);
|
||||
acc[2] = ggml_vdotq_s32(acc[2], q.val[2], v1.val[2]);
|
||||
acc[3] = ggml_vdotq_s32(acc[3], q.val[3], v1.val[3]);
|
||||
}
|
||||
accd[0] = vfmaq_f32(accd[0], vdupq_n_f32(q8.scale(0, i)), vcvtq_f32_s32(vaddq_s32(sumi1, sumi2)));
|
||||
}
|
||||
accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3]));
|
||||
} else {
|
||||
for (int i = 1; i < nb/2; ++i) {
|
||||
int8x16x4_t v1, v2;
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0);
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
auto q2bits = vld1q_u8(x[2*i+0].qs);
|
||||
v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
|
||||
v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
|
||||
@@ -4345,40 +4363,36 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
|
||||
v2.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
|
||||
v2.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
int32x4_t sumi = vdupq_n_s32(0);
|
||||
auto q = q8.load_quants(iy, i, 0);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]);
|
||||
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
|
||||
q = q8.load_quants(iy, i, 1);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]);
|
||||
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
|
||||
q = q8.load_quants(iy, i, 2);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[0]), q.val[1], v2.val[1]);
|
||||
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]);
|
||||
q = q8.load_quants(iy, i, 3);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[2]), q.val[1], v2.val[3]);
|
||||
accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi));
|
||||
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]);
|
||||
}
|
||||
}
|
||||
}
|
||||
int i = 2*(nb/2);
|
||||
if (i < nb) {
|
||||
auto q2bits = vld1q_u8(x[2*i+0].qs);
|
||||
auto q2bits = vld1q_u8(x[i].qs);
|
||||
int8x16x4_t v1;
|
||||
v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
|
||||
v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
|
||||
v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
|
||||
v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
int32x4_t sumi = vdupq_n_s32(0);
|
||||
auto q = q8.load_quants(iy, i/2, 0);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]);
|
||||
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
|
||||
q = q8.load_quants(iy, i/2, 1);
|
||||
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]);
|
||||
accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i/2)), vcvtq_f32_s32(sumi));
|
||||
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, vaddvq_f32(accd[iy]));
|
||||
info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user