From 780929a6d0f614dbb81b1f8b2a89a162cfaf464b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 18 Oct 2024 08:02:22 +0200 Subject: [PATCH] iq4_knn: ARM_NEON Pretty good performance - on M2-Max we get PP-512(LLaMA-3.1-8B) = 89.5 t/s TG-128(LLaMA-3.1-8B) = 27.65 t/s --- ggml/src/iqk/iqk_mul_mat.cpp | 68 ++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index a34eaca2..a6de85a3 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -5696,6 +5696,64 @@ void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info } } +template +void mul_mat_iq4_knn_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8 q8(info); + + const char * qrow = (const char *)vx; + + auto ml = vdupq_n_u8(0xf); + + int8x16_t q4[16]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + const float * dptr = (const float *)qrow; + float d = *dptr; + const int8_t * int_values = (const int8_t *)(dptr + 1); + auto values = vld1q_s8(int_values); + auto x = (const block_iq4_knn *)(int_values + 16); + + float32x4_t acc[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); + + for (int i = 0; i < nb; ++i) { + + auto bits = vld1q_u8_x4(x[i].qs); + for (int k = 0; k < 4; ++k) { + q4[2*k+0] = vqtbl1q_s8(values, vandq_u8(bits.val[k], ml)); + q4[2*k+1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[k], 4)); + } + bits = vld1q_u8_x4(x[i].qs+64); + for (int k = 0; k < 4; ++k) { + q4[2*k+8] = vqtbl1q_s8(values, vandq_u8(bits.val[k], ml)); + q4[2*k+9] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[k], 4)); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi1 = vdupq_n_s32(0); + auto sumi2 = vdupq_n_s32(0); + for (int k = 0; k < 4; ++k) { + auto qy = q8.load_quants_64(iy, i, k); + sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(sumi1, q4[4*k+0], qy.val[0]), q4[4*k+1], qy.val[1]); + sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(sumi2, q4[4*k+2], qy.val[2]), q4[4*k+3], qy.val[3]); + } + acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(vaddq_s32(sumi1, sumi2)), vdupq_n_f32(d*q8.scale(iy, i))); + //acc[iy] = i > 0 ? vmlaq_f32(acc[iy], vcvtq_f32_s32(vaddq_s32(sumi1, sumi2)), vdupq_n_f32(d*q8.scale(iy, i))) : + // vmulq_f32(vcvtq_f32_s32(vaddq_s32(sumi1, sumi2)), vdupq_n_f32(d*q8.scale(iy, i))); + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(acc[iy])); + } + + qrow += bx; + } +} + // =========================================== Legacy quants template @@ -6969,6 +7027,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_IQ4_KSS: MulMat::set_functions(m); break; + case GGML_TYPE_IQ4_KNN: + m.funcs[0] = mul_mat_iq4_knn_q8_K<1>; + m.funcs[1] = mul_mat_iq4_knn_q8_K<2>; + m.funcs[2] = mul_mat_iq4_knn_q8_K<3>; + m.funcs[3] = mul_mat_iq4_knn_q8_K<4>; + m.funcs[4] = mul_mat_iq4_knn_q8_K<5>; + m.funcs[5] = mul_mat_iq4_knn_q8_K<6>; + m.funcs[6] = mul_mat_iq4_knn_q8_K<7>; + m.funcs[7] = mul_mat_iq4_knn_q8_K<8>; + break; case GGML_TYPE_IQ2_KS: MulMat::set_functions(m); break;