iq2_bn: improve performance on NEON

We now get TG-128 = 100 t/s for Bitnet-3B-1.58b!
This commit is contained in:
Iwan Kawrakow
2024-09-09 06:46:25 +02:00
parent 3487e68cc0
commit c85d7bef55

View File

@@ -6006,10 +6006,10 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
for (int j = 0; j < 2; ++j) {
auto q = q8.load_quants64(0, i, j);
auto q2bits = vld1q_u8(x[2*i+j].qs);
v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
v1.val[0] = vandq_s8(q2bits, mask2);
v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
v1.val[3] = vshrq_n_u8(q2bits, 6);
acc[0] = ggml_vdotq_s32(acc[0], q.val[0], v1.val[0]);
acc[1] = ggml_vdotq_s32(acc[1], q.val[1], v1.val[1]);
acc[2] = ggml_vdotq_s32(acc[2], q.val[2], v1.val[2]);
@@ -6022,15 +6022,15 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0);
for (int i = 0; i < nb/2; ++i) {
auto q2bits = vld1q_u8(x[2*i+0].qs);
v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
v1.val[0] = vandq_s8(q2bits, mask2);
v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
v1.val[3] = vshrq_n_u8(q2bits, 6);
q2bits = vld1q_u8(x[2*i+1].qs);
v2.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
v2.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
v2.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
v2.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
v2.val[0] = vandq_s8(q2bits, mask2);
v2.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
v2.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
v2.val[3] = vshrq_n_u8(q2bits, 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i, 0);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
@@ -6047,10 +6047,10 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
if (i < nb) {
auto q2bits = vld1q_u8(x[i].qs);
int8x16x4_t v1;
v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
v1.val[0] = vandq_s8(q2bits, mask2);
v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
v1.val[3] = vshrq_n_u8(q2bits, 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i/2, 0);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
@@ -6060,7 +6060,7 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
info.store(ix, iy, -vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
}
}
}