From 4102aa998c2bc55423e9957ad70bee662e6f0b1c Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 09:17:32 +0300 Subject: [PATCH] New iq4_kt: slightly faster NEON --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 48 ++++++++++++++++++------------ 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 4e06b1b8..0a17fe83 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1330,7 +1330,7 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& union { uint32x4x2_t vec; uint32_t val[8]; } o_helper; - constexpr int k_acc = nrc_y; + constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y; float32x4_t accd[k_acc]; @@ -1339,11 +1339,11 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& y[iy] = (const block_q8_0_x4 *)info.src1_row(iy); } - uint32_t values[64]; - int8x16x2_t xv[4]; + uint32_t values[16]; + int8x16x2_t xv[8]; int32x4x4_t dot; - auto compute_dot = [&dot, &xv] (const int8_t * y) { + auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) { for (int k = 0; k < 4; ++k) { auto yv = vld1q_s8_x2(y + 32*k); dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]); @@ -1379,27 +1379,37 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& for (int j = 0; j < 4; ++j) { const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); - values[8*ib+2*j+ 0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; - values[8*ib+2*j+ 1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; - values[8*ib+2*j+32] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; - values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; + values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; + values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; + values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; + values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; } + xv[ib+0] = trellis.next32(values+0); + xv[ib+4] = trellis.next32(values+8); } - for (int i128 = 0; i128 < 2; ++i128) { - for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k); - for (int iy = 0; iy < nrc_y; ++iy) { - const block_q8_0_x4& yb = y[iy][2*i+i128]; - auto dy = vmulq_f32(scales.val[i128], vcvt_f32_f16(vld1_f16((const float16_t *)yb.d))); - //auto dy = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vld1_u16((const uint16_t *)yb.d)), 16)); - //dy = vmulq_f32(scales.val[i128], dy); - auto sumi = compute_dot(yb.qs); - accd[iy] = vfmaq_f32(accd[iy], dy, vcvtq_f32_s32(sumi)); + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_0_x4& ybl = y[iy][2*i+0]; + const block_q8_0_x4& ybh = y[iy][2*i+1]; + auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d))); + auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d))); + auto sumil = compute_dot(ybl.qs, xv+0); + auto sumih = compute_dot(ybh.qs, xv+4); + if constexpr (nrc_y == 1) { + accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil)); + accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih)); + } else { + accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil)); + accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih)); } } } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vaddvq_f32(accd[iy])); + if constexpr (nrc_y == 1) { + info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1]))); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(accd[iy])); + } } } }