From c85d7bef557200bb43f37bc9564386789f83ce95 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 9 Sep 2024 06:46:25 +0200 Subject: [PATCH] iq2_bn: improve performance on NEON We now get TG-128 = 100 t/s for Bitnet-3B-1.58b! --- ggml/src/iqk/iqk_mul_mat.cpp | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index fce86efd..2a378358 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6006,10 +6006,10 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn for (int j = 0; j < 2; ++j) { auto q = q8.load_quants64(0, i, j); auto q2bits = vld1q_u8(x[2*i+j].qs); - v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); - v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); - v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); - v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + v1.val[0] = vandq_s8(q2bits, mask2); + v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2); + v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2); + v1.val[3] = vshrq_n_u8(q2bits, 6); acc[0] = ggml_vdotq_s32(acc[0], q.val[0], v1.val[0]); acc[1] = ggml_vdotq_s32(acc[1], q.val[1], v1.val[1]); acc[2] = ggml_vdotq_s32(acc[2], q.val[2], v1.val[2]); @@ -6022,15 +6022,15 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0); for (int i = 0; i < nb/2; ++i) { auto q2bits = vld1q_u8(x[2*i+0].qs); - v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); - v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); - v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); - v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + v1.val[0] = vandq_s8(q2bits, mask2); + v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2); + v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2); + v1.val[3] = vshrq_n_u8(q2bits, 6); q2bits = vld1q_u8(x[2*i+1].qs); - v2.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); - v2.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); - v2.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); - v2.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + v2.val[0] = vandq_s8(q2bits, mask2); + v2.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2); + v2.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2); + v2.val[3] = vshrq_n_u8(q2bits, 6); for (int iy = 0; iy < nrc_y; ++iy) { auto q = q8.load_quants(iy, i, 0); accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]); @@ -6047,10 +6047,10 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn if (i < nb) { auto q2bits = vld1q_u8(x[i].qs); int8x16x4_t v1; - v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); - v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); - v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); - v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + v1.val[0] = vandq_s8(q2bits, mask2); + v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2); + v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2); + v1.val[3] = vshrq_n_u8(q2bits, 6); for (int iy = 0; iy < nrc_y; ++iy) { auto q = q8.load_quants(iy, i/2, 0); accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]); @@ -6060,7 +6060,7 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy])))); + info.store(ix, iy, -vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy])))); } } }