From 6dec39627ccb2120ae816a5e503bfbcb9251c367 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 4 Dec 2024 15:05:55 +0100 Subject: [PATCH] DRY --- ggml/src/iqk/iqk_mul_mat.cpp | 62 ++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 34 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5befef63..faa4cab7 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -7253,6 +7253,30 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } } +IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16x2_t& y) { + auto sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); + sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0); + sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1); + sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1); + sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2); + sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2); + sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3); + sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); + return sumi; +} + +IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) { + qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19 + qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7 + qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23 + qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11 + qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27 + qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15 + qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31 +} + template void mul_mat_iq4_nl_x4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -7269,25 +7293,10 @@ void mul_mat_iq4_nl_x4_q8_0(int n, const void * vx, size_t bx, const DataInfo& i for (int k = 0; k < 4; ++k) { auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d)); auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs); - qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows - qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19 - qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7 - qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23 - qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11 - qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27 - qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15 - qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31 + prepare_iq4_nl_quants(values, m4, bits, qx); for (int iy = 0; iy < nrc_y; ++iy) { auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); - auto sumi = vdupq_n_s32(0); - sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); - sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0); - sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1); - sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1); - sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2); - sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2); - sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3); - sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); + auto sumi = interleaved_dotq(qx, y); auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); } @@ -7323,25 +7332,10 @@ void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& i auto sl32 = vmovl_s16(vget_low_s16(sl16)); auto scales = vmulq_f32(d4, vcvtq_f32_s32(sl32)); auto bits = vld1q_u8_x4(iq4[ibl].qs + 64*ib); - qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows - qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19 - qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7 - qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23 - qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11 - qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27 - qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15 - qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31 + prepare_iq4_nl_quants(values, m4, bits, qx); for (int iy = 0; iy < nrc_y; ++iy) { auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+32*ib); - auto sumi = vdupq_n_s32(0); - sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); - sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0); - sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1); - sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1); - sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2); - sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2); - sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3); - sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); + auto sumi = interleaved_dotq(qx, y); auto d4d8 = vmulq_f32(scales, vdupq_n_f32(q8.scale(iy, ibl))); acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); }