From b8966277c0e3ecc8f1cceefd2ff423921ec282fa Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 30 Jan 2025 18:29:04 +0200 Subject: [PATCH] Make q5,6_0_r4, iq4_nl_e4 work with row size that are not a multiple of 128 also on NEON. --- ggml/src/iqk/iqk_mul_mat.cpp | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index e8385212..f633229d 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -12215,7 +12215,6 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, Q8 q8(info); Dequantizer deq(vx, bx); int nb = n / QK4_NL; - GGML_ASSERT(nb%4 == 0); int8x16_t qx[8]; float d8[4*nrc_y]; float32x4_t acc[nrc_y] = {}; @@ -12226,7 +12225,7 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); } for (int k = 0; k < 4; ++k) { - auto scales = deq.prepare(ib4, k, qx); + auto scales = deq.prepare(4*ib4+k, qx); for (int iy = 0; iy < nrc_y; ++iy) { auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); auto sumi = interleaved_dotq(qx, y); @@ -12235,6 +12234,16 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = deq.prepare(ib, qx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_0 *)q8.y[iy]; + auto y = vld1q_s8_x2(qy[ib].qs); + auto sumi = interleaved_dotq(qx, y); + auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d))); + acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); + } + } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, deq.result(acc[iy])); acc[iy] = vdupq_n_f32(0.f); @@ -12292,9 +12301,9 @@ void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, struct IQ4_NL_R4_Dequantizer { IQ4_NL_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx), values(vld1q_s8(iq4k_values)) {} inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); } - inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d)); - auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs); + inline float32x4_t prepare(int ib, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ib].d)); + auto bits = vld1q_u8_x4(iq4[ib].qs); prepare_iq4_nl_quants(values, m4, bits, qx); return scales; } @@ -12370,10 +12379,10 @@ struct Q4_0_R8_Dequantizer { struct Q5_0_R4_Dequantizer { Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); } - inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[4*ib4+k].d)); - auto lbits = vld1q_u8_x4(iq5[4*ib4+k].qs); - auto hbits = vld1q_u8(iq5[4*ib4+k].qh); + inline float32x4_t prepare(int ib, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ib].d)); + auto lbits = vld1q_u8_x4(iq5[ib].qs); + auto hbits = vld1q_u8(iq5[ib].qh); qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3 qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19 qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7 @@ -12399,10 +12408,10 @@ struct Q5_0_R4_Dequantizer { struct Q6_0_R4_Dequantizer { Q6_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} inline void new_row(int ix) { iq6 = (const block_q6_0_r4 *)(cx + ix*bx); } - inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[4*ib4+k].d)); - auto lbits = vld1q_u8_x4(iq6[4*ib4+k].qs); - auto hbits = vld1q_u8_x2(iq6[4*ib4+k].qh); + inline float32x4_t prepare(int ib, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ib].d)); + auto lbits = vld1q_u8_x4(iq6[ib].qs); + auto hbits = vld1q_u8_x2(iq6[ib].qh); qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3 qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19 qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7