From baa9ed4a5e59c267c8faa3eca8605a1792f970df Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 18 Dec 2024 19:14:40 +0100 Subject: [PATCH] Minor --- ggml/src/iqk/iqk_mul_mat.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 88a1e606..f4576319 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -9411,12 +9411,12 @@ void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& i int16x8x4_t iscales; int32x4x4_t scales; float32x4_t acc[nrc_y] = {}; + int32x4_t isum[nrc_y] = {}; for (int ix = 0; ix < nrc_x; ix += 4) { auto dptr = (const float *)((const char *)vx + ix*bx); auto d4 = vld1q_f32(dptr); const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4); for (int ibl = 0; ibl < nbl; ++ibl) { - // TODO: shifts auto sas = vld1q_u8_x2(iq4[ibl].scales); auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254)); iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127)); @@ -9424,9 +9424,10 @@ void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& i scale = vandq_u8(sas.val[1], vdupq_n_u8(254)); iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127)); iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127)); + // Adding the block shifts costs us ~9% in performance drop. + // Is there a better way? sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 2); sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 2); - int32x4_t isum[nrc_y] = {}; { auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0]))); auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0]))); @@ -9464,6 +9465,7 @@ void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& i } for (int iy = 0; iy < nrc_y; ++iy) { acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); } } for (int iy = 0; iy < nrc_y; ++iy) {