From 2b8a231d87c94a1e983cf7dbd64ef410cfd24c3a Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 19 May 2025 07:51:28 +0300 Subject: [PATCH] Refactor iqk: factor out repacked legacy quants (NEON) --- ggml/src/iqk/iqk_common.h | 51 +++ ggml/src/iqk/iqk_gemm_legacy_quants.cpp | 302 +++++++++++++++-- ggml/src/iqk/iqk_mul_mat.cpp | 417 +----------------------- 3 files changed, 339 insertions(+), 431 deletions(-) diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h index 8b9db67d..6feeff1a 100644 --- a/ggml/src/iqk/iqk_common.h +++ b/ggml/src/iqk/iqk_common.h @@ -771,7 +771,58 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf } } +static IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16x2_t& y) { + auto sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); + sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0); + sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1); + sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1); + sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2); + sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2); + sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3); + sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); + return sumi; +} +static IQK_ALWAYS_INLINE int32x4x2_t interleaved_dotq_b16(const int8x16_t * qx, const int8x16x2_t& y) { + int32x4x2_t sumi = { vdupq_n_s32(0), vdupq_n_s32(0) }; + sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[0], y.val[0], 0); + sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[1], y.val[1], 0); + sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[2], y.val[0], 1); + sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[3], y.val[1], 1); + sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[4], y.val[0], 2); + sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[5], y.val[1], 2); + sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[6], y.val[0], 3); + sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[7], y.val[1], 3); + return sumi; +} + +static IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16_t& y) { + auto sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, qx[0], y, 0); + sumi = vdotq_laneq_s32(sumi, qx[1], y, 1); + sumi = vdotq_laneq_s32(sumi, qx[2], y, 2); + sumi = vdotq_laneq_s32(sumi, qx[3], y, 3); + return sumi; +} + +static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) { + qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19 + qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7 + qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23 + qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11 + qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27 + qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15 + qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31 +} + +static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x2_t& bits, int8x16_t * qx) { + qx[0] = vqtbl1q_s8(values, vandq_u8( bits.val[0], m4)); + qx[1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); + qx[2] = vqtbl1q_s8(values, vandq_u8( bits.val[1], m4)); + qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); +} #endif diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp index 40055e48..64ae0c2f 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp @@ -2310,6 +2310,273 @@ static void mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInf mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x); } +template +void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + Dequantizer deq(vx, bx); + int nb = n / QK4_NL; + int8x16_t qx[8]; + float d8[4*nrc_y]; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + deq.new_row(ix); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); + } + for (int k = 0; k < 4; ++k) { + auto scales = deq.prepare(4*ib4+k, qx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); + auto sumi = interleaved_dotq(qx, y); + auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k])); + acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); + } + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = deq.prepare(ib, qx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_0 *)q8.y[iy]; + auto y = vld1q_s8_x2(qy[ib].qs); + auto sumi = interleaved_dotq(qx, y); + auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d))); + acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, deq.result(acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + Dequantizer deq(vx, bx); + int nb = n / QK4_NL; + int8x16_t qx[16]; + float d8[4*nrc_y]; + float32x4_t acc[2*nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 8) { + deq.new_row(ix); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); + } + for (int k = 0; k < 4; ++k) { + auto scales = deq.prepare(ib4, k, qx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); + auto sumi1 = interleaved_dotq(qx+0, y); + auto sumi2 = interleaved_dotq(qx+8, y); + auto dy = vdupq_n_f32(d8[4*iy+k]); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1)); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2)); + } + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = deq.prepare(ib, 0, qx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_0 *)q8.y[iy]; + auto y = vld1q_s8_x2(qy[ib].qs); + auto sumi1 = interleaved_dotq(qx+0, y); + auto sumi2 = interleaved_dotq(qx+8, y); + auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1)); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2)); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix+0, iy, deq.result(acc[2*iy+0])); + info.store(ix+4, iy, deq.result(acc[2*iy+1])); + acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f); + } + } +} + +struct IQ4_NL_R4_Dequantizer { + IQ4_NL_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx), values(vld1q_s8(iq4k_values)) {} + inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); } + inline float32x4_t prepare(int ib, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ib].d)); + auto bits = vld1q_u8_x4(iq4[ib].qs); + prepare_iq4_nl_quants(values, m4, bits, qx); + return scales; + } + inline float32x4_t result(float32x4_t acc) const { + return acc; + } + + const char * cx; + const size_t bx; + const block_iq4_nl_r4 * iq4; + const uint8x16_t m4 = vdupq_n_u8(0x0f); + const int8x16_t values; +}; + +struct Q4_0_R8_Dequantizer { + Q4_0_R8_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} + inline void new_row(int ix) { iq4 = (const block_iq4_nl_r8 *)(cx + ix*bx); } + inline float32x4x2_t prepare(int ib4, int k, int8x16_t * qx) const { + auto scales16 = vld1q_f16((const float16_t *)iq4[4*ib4+k].d); + float32x4x2_t scales = { vcvt_f32_f16(vget_low_f16(scales16)), vcvt_f32_f16(vget_high_f16(scales16)) }; + for (int j = 0; j < 4; ++j) { + auto bits = vld1q_u8_x2(iq4[4*ib4+k].qs + 32*j); + bits.val[0] = veorq_u8(m88, bits.val[0]); + bits.val[1] = veorq_u8(m88, bits.val[1]); + qx[2*j+0] = vshlq_n_u8(bits.val[0], 4); + qx[2*j+1] = vandq_u8(bits.val[0], m4); + qx[2*j+8] = vshlq_n_u8(bits.val[1], 4); + qx[2*j+9] = vandq_u8(bits.val[1], m4); + } + return scales; + } + inline float32x4_t result(float32x4_t acc) const { + return vmulq_f32(norm, acc); + } + + const char * cx; + const size_t bx; + const block_iq4_nl_r8 * iq4; + const uint8x16_t m4 = vdupq_n_u8(0xf0); + const uint8x16_t m88 = vdupq_n_u8(0x88); + const float32x4_t norm = vdupq_n_f32(1.f/16); +}; + +struct Q5_0_R4_Dequantizer { + Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} + inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); } + inline float32x4_t prepare(int ib, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ib].d)); + auto lbits = vld1q_u8_x4(iq5[ib].qs); + auto hbits = vld1q_u8(iq5[ib].qh); + qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3 + qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19 + qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7 + qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits, 1), m5), m16); // 20..23 + qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits, m5), m16); // 8..11 + qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(vshrq_n_u8(hbits, 1), m5), m16); // 24..27 + qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits, 2), m5), m16); // 12..15 + qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits, 3), m5), m16); // 28..31 + return scales; + } + inline float32x4_t result(float32x4_t acc) const { + return acc; + } + + const char * cx; + const size_t bx; + const block_q5_0_r4 * iq5; + const uint8x16_t m4 = vdupq_n_u8(0x0f); + const uint8x16_t m5 = vdupq_n_u8(0x10); + const int8x16_t m16 = vdupq_n_s8(-16); +}; + +struct Q6_0_R4_Dequantizer { + Q6_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} + inline void new_row(int ix) { iq6 = (const block_q6_0_r4 *)(cx + ix*bx); } + inline float32x4_t prepare(int ib, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ib].d)); + auto lbits = vld1q_u8_x4(iq6[ib].qs); + auto hbits = vld1q_u8_x2(iq6[ib].qh); + qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3 + qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19 + qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7 + qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 2), m6), m32); // 20..23 + qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits.val[0], m6), m32); // 8..11 + qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(hbits.val[1], m6), m32); // 24..27 + qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits.val[0], 2), m6), m32); // 12..15 + qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits.val[1], 2), m6), m32); // 28..31 + return scales; + } + inline float32x4_t result(float32x4_t acc) const { + return acc; + } + + const char * cx; + const size_t bx; + const block_q6_0_r4 * iq6; + const uint8x16_t m4 = vdupq_n_u8(0x0f); + const uint8x16_t m6 = vdupq_n_u8(0x30); + const int8x16_t m32 = vdupq_n_s8(-32); +}; + +inline void qx_0_q8_0_dot(const int8x16_t * qx, const int8_t * qy, int32x4_t& sumi1, int32x4_t& sumi2) { + auto y = vld1q_s8_x2(qy); + sumi1 = sumi2 = vdupq_n_s32(0); + sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0); + sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0); + sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1); + sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1); + sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2); + sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2); + sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3); + sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3); +} + +template +void mul_mat_q8_0_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + int nb = n / QK8_0; + float32x4_t acc[2*nrc_y] = {}; + int8x16_t qx[16]; + float d8[4*nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); + } + for (int k = 0; k < 4; ++k) { + auto scales16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d); + auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16)); + auto scales2 = vcvt_f32_f16(vget_high_f16(scales16)); + for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j); + int32x4_t sumi1, sumi2; + for (int iy = 0; iy < nrc_y; ++iy) { + qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2); + auto dy = vdupq_n_f32(d8[4*iy+k]); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1)); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2)); + } + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d); + auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16)); + auto scales2 = vcvt_f32_f16(vget_high_f16(scales16)); + for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j); + int32x4_t sumi1, sumi2; + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_0 *)q8.y[iy]; + qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2); + auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1)); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2)); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix+0, iy, acc[2*iy+0]); + info.store(ix+4, iy, acc[2*iy+1]); + acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f); + } + } +} + } bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16) { @@ -2344,29 +2611,26 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array; -//#endif -// break; -// case GGML_TYPE_Q5_0_R4: -// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q5_0_r4_q8_2, kernels) -// break; -// case GGML_TYPE_Q6_0_R4: -// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q6_0_r4_q8_2, kernels) -// break; -// case GGML_TYPE_Q8_0_R8: -// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_0_r8_q8_2, kernels) -// break; -// case GGML_TYPE_IQ4_NL_R4: -// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_nl_r4_q8_2, kernels) -// break; + case GGML_TYPE_Q4_0_R8: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r8_q8_0, Q4_0_R8_Dequantizer, kernels); + break; + case GGML_TYPE_Q5_0_R4: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r4_q8_0, Q5_0_R4_Dequantizer, kernels); + break; + case GGML_TYPE_Q6_0_R4: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r4_q8_0, Q6_0_R4_Dequantizer, kernels); + break; + case GGML_TYPE_Q8_0_R8: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_0_r8_q8_0, kernels); + break; + case GGML_TYPE_IQ4_NL_R4: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer, kernels); + break; default: return false; } - return ggml_type(typeB) == expected_typeB; + return true; } #endif diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 163ac526..6aacfb78 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -938,59 +938,6 @@ template struct Q8_16 { const int8_t * y[nrc_y]; }; -IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16x2_t& y) { - auto sumi = vdupq_n_s32(0); - sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); - sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0); - sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1); - sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1); - sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2); - sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2); - sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3); - sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); - return sumi; -} - -IQK_ALWAYS_INLINE int32x4x2_t interleaved_dotq_b16(const int8x16_t * qx, const int8x16x2_t& y) { - int32x4x2_t sumi = { vdupq_n_s32(0), vdupq_n_s32(0) }; - sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[0], y.val[0], 0); - sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[1], y.val[1], 0); - sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[2], y.val[0], 1); - sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[3], y.val[1], 1); - sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[4], y.val[0], 2); - sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[5], y.val[1], 2); - sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[6], y.val[0], 3); - sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[7], y.val[1], 3); - return sumi; -} - -IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16_t& y) { - auto sumi = vdupq_n_s32(0); - sumi = vdotq_laneq_s32(sumi, qx[0], y, 0); - sumi = vdotq_laneq_s32(sumi, qx[1], y, 1); - sumi = vdotq_laneq_s32(sumi, qx[2], y, 2); - sumi = vdotq_laneq_s32(sumi, qx[3], y, 3); - return sumi; -} - -IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) { - qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows - qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19 - qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7 - qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23 - qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11 - qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27 - qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15 - qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31 -} - -IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x2_t& bits, int8x16_t * qx) { - qx[0] = vqtbl1q_s8(values, vandq_u8( bits.val[0], m4)); - qx[1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); - qx[2] = vqtbl1q_s8(values, vandq_u8( bits.val[1], m4)); - qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); -} - template void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -2573,345 +2520,6 @@ void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& i } } -void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); - Q8<1, block_q8_0_x4> q8(info); - auto m4 = vdupq_n_u8(0xf); - auto values = vld1q_s8(iq4k_values); - int nb = n / QK4_NL; - GGML_ASSERT(nb%4 == 0); - int8x16_t qx[8]; - for (int ix = 0; ix < nrc_x; ix += 4) { - auto acc = vdupq_n_f32(0.f); - const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - auto y1 = vld1q_s8_x4(q8.y[0][ib4].qs); - auto y2 = vld1q_s8_x4(q8.y[0][ib4].qs+64); - for (int k = 0; k < 4; ++k) { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d)); - auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[0][ib4].d[k]))); - auto sumi = vdupq_n_s32(0); - const auto yval = k < 2 ? y1.val + 2*k : y2.val + 2*(k-2); - auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs); - qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows - qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19 - sumi = vdotq_laneq_s32(sumi, qx[0], yval[0], 0); - sumi = vdotq_laneq_s32(sumi, qx[1], yval[1], 0); - qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7 - qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23 - sumi = vdotq_laneq_s32(sumi, qx[2], yval[0], 1); - sumi = vdotq_laneq_s32(sumi, qx[3], yval[1], 1); - qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11 - qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27 - sumi = vdotq_laneq_s32(sumi, qx[4], yval[0], 2); - sumi = vdotq_laneq_s32(sumi, qx[5], yval[1], 2); - qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15 - qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31 - sumi = vdotq_laneq_s32(sumi, qx[6], yval[0], 3); - sumi = vdotq_laneq_s32(sumi, qx[7], yval[1], 3); - acc = vfmaq_f32(acc, d4d8, vcvtq_f32_s32(sumi)); - } - } - info.store(ix, 0, acc); - } -} - -template -void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); - Dequantizer deq(vx, bx); - int nb = n / QK4_NL; - int8x16_t qx[8]; - float d8[4*nrc_y]; - float32x4_t acc[nrc_y] = {}; - for (int ix = 0; ix < nrc_x; ix += 4) { - deq.new_row(ix); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - for (int iy = 0; iy < nrc_y; ++iy) { - vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); - } - for (int k = 0; k < 4; ++k) { - auto scales = deq.prepare(4*ib4+k, qx); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); - auto sumi = interleaved_dotq(qx, y); - auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k])); - acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); - } - } - } - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto scales = deq.prepare(ib, qx); - for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_0 *)q8.y[iy]; - auto y = vld1q_s8_x2(qy[ib].qs); - auto sumi = interleaved_dotq(qx, y); - auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d))); - acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, deq.result(acc[iy])); - acc[iy] = vdupq_n_f32(0.f); - } - } -} - -template -void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); - Dequantizer deq(vx, bx); - int nb = n / QK4_NL; - int8x16_t qx[16]; - float d8[4*nrc_y]; - float32x4_t acc[2*nrc_y] = {}; - for (int ix = 0; ix < nrc_x; ix += 8) { - deq.new_row(ix); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - for (int iy = 0; iy < nrc_y; ++iy) { - vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); - } - for (int k = 0; k < 4; ++k) { - auto scales = deq.prepare(ib4, k, qx); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); - auto sumi1 = interleaved_dotq(qx+0, y); - auto sumi2 = interleaved_dotq(qx+8, y); - auto dy = vdupq_n_f32(d8[4*iy+k]); - acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1)); - acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2)); - } - } - } - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto scales = deq.prepare(ib, 0, qx); - for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_0 *)q8.y[iy]; - auto y = vld1q_s8_x2(qy[ib].qs); - auto sumi1 = interleaved_dotq(qx+0, y); - auto sumi2 = interleaved_dotq(qx+8, y); - auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d)); - acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1)); - acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2)); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix+0, iy, deq.result(acc[2*iy+0])); - info.store(ix+4, iy, deq.result(acc[2*iy+1])); - acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f); - } - } -} - -struct IQ4_NL_R4_Dequantizer { - IQ4_NL_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx), values(vld1q_s8(iq4k_values)) {} - inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); } - inline float32x4_t prepare(int ib, int8x16_t * qx) const { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ib].d)); - auto bits = vld1q_u8_x4(iq4[ib].qs); - prepare_iq4_nl_quants(values, m4, bits, qx); - return scales; - } - inline float32x4_t result(float32x4_t acc) const { - return acc; - } - - const char * cx; - const size_t bx; - const block_iq4_nl_r4 * iq4; - const uint8x16_t m4 = vdupq_n_u8(0x0f); - const int8x16_t values; -}; - -struct Q4_0_R4_Dequantizer { - Q4_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} - inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); } - inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d)); - auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs); - for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]); - qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows - qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19 - qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7 - qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23 - qx[4] = vandq_u8(bits.val[0], m4); // 8..11 - qx[5] = vandq_u8(bits.val[1], m4); // 24..27 - qx[6] = vandq_u8(bits.val[2], m4); // 12..15 - qx[7] = vandq_u8(bits.val[3], m4); // 28..31 - return scales; - } - inline float32x4_t result(float32x4_t acc) const { - return vmulq_f32(norm, acc); - } - - const char * cx; - const size_t bx; - const block_iq4_nl_r4 * iq4; - const uint8x16_t m4 = vdupq_n_u8(0xf0); - const uint8x16_t m88 = vdupq_n_u8(0x88); - const float32x4_t norm = vdupq_n_f32(1.f/16); -}; - -struct Q4_0_R8_Dequantizer { - Q4_0_R8_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} - inline void new_row(int ix) { iq4 = (const block_iq4_nl_r8 *)(cx + ix*bx); } - inline float32x4x2_t prepare(int ib4, int k, int8x16_t * qx) const { - auto scales16 = vld1q_f16((const float16_t *)iq4[4*ib4+k].d); - float32x4x2_t scales = { vcvt_f32_f16(vget_low_f16(scales16)), vcvt_f32_f16(vget_high_f16(scales16)) }; - for (int j = 0; j < 4; ++j) { - auto bits = vld1q_u8_x2(iq4[4*ib4+k].qs + 32*j); - bits.val[0] = veorq_u8(m88, bits.val[0]); - bits.val[1] = veorq_u8(m88, bits.val[1]); - qx[2*j+0] = vshlq_n_u8(bits.val[0], 4); - qx[2*j+1] = vandq_u8(bits.val[0], m4); - qx[2*j+8] = vshlq_n_u8(bits.val[1], 4); - qx[2*j+9] = vandq_u8(bits.val[1], m4); - } - return scales; - } - inline float32x4_t result(float32x4_t acc) const { - return vmulq_f32(norm, acc); - } - - const char * cx; - const size_t bx; - const block_iq4_nl_r8 * iq4; - const uint8x16_t m4 = vdupq_n_u8(0xf0); - const uint8x16_t m88 = vdupq_n_u8(0x88); - const float32x4_t norm = vdupq_n_f32(1.f/16); -}; - -struct Q5_0_R4_Dequantizer { - Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} - inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); } - inline float32x4_t prepare(int ib, int8x16_t * qx) const { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ib].d)); - auto lbits = vld1q_u8_x4(iq5[ib].qs); - auto hbits = vld1q_u8(iq5[ib].qh); - qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3 - qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19 - qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7 - qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits, 1), m5), m16); // 20..23 - qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits, m5), m16); // 8..11 - qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(vshrq_n_u8(hbits, 1), m5), m16); // 24..27 - qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits, 2), m5), m16); // 12..15 - qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits, 3), m5), m16); // 28..31 - return scales; - } - inline float32x4_t result(float32x4_t acc) const { - return acc; - } - - const char * cx; - const size_t bx; - const block_q5_0_r4 * iq5; - const uint8x16_t m4 = vdupq_n_u8(0x0f); - const uint8x16_t m5 = vdupq_n_u8(0x10); - const int8x16_t m16 = vdupq_n_s8(-16); -}; - -struct Q6_0_R4_Dequantizer { - Q6_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} - inline void new_row(int ix) { iq6 = (const block_q6_0_r4 *)(cx + ix*bx); } - inline float32x4_t prepare(int ib, int8x16_t * qx) const { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ib].d)); - auto lbits = vld1q_u8_x4(iq6[ib].qs); - auto hbits = vld1q_u8_x2(iq6[ib].qh); - qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3 - qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19 - qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7 - qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 2), m6), m32); // 20..23 - qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits.val[0], m6), m32); // 8..11 - qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(hbits.val[1], m6), m32); // 24..27 - qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits.val[0], 2), m6), m32); // 12..15 - qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits.val[1], 2), m6), m32); // 28..31 - return scales; - } - inline float32x4_t result(float32x4_t acc) const { - return acc; - } - - const char * cx; - const size_t bx; - const block_q6_0_r4 * iq6; - const uint8x16_t m4 = vdupq_n_u8(0x0f); - const uint8x16_t m6 = vdupq_n_u8(0x30); - const int8x16_t m32 = vdupq_n_s8(-32); -}; - -inline void qx_0_q8_0_dot(const int8x16_t * qx, const int8_t * qy, int32x4_t& sumi1, int32x4_t& sumi2) { - auto y = vld1q_s8_x2(qy); - sumi1 = sumi2 = vdupq_n_s32(0); - sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0); - sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0); - sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1); - sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1); - sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2); - sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2); - sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3); - sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3); - sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0); - sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0); - sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1); - sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1); - sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2); - sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2); - sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3); - sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3); -} - -template -void mul_mat_q8_0_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); - int nb = n / QK8_0; - float32x4_t acc[2*nrc_y] = {}; - int8x16_t qx[16]; - float d8[4*nrc_y]; - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - for (int iy = 0; iy < nrc_y; ++iy) { - vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); - } - for (int k = 0; k < 4; ++k) { - auto scales16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d); - auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16)); - auto scales2 = vcvt_f32_f16(vget_high_f16(scales16)); - for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j); - int32x4_t sumi1, sumi2; - for (int iy = 0; iy < nrc_y; ++iy) { - qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2); - auto dy = vdupq_n_f32(d8[4*iy+k]); - acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1)); - acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2)); - } - } - } - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d); - auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16)); - auto scales2 = vcvt_f32_f16(vget_high_f16(scales16)); - for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j); - int32x4_t sumi1, sumi2; - for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_0 *)q8.y[iy]; - qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2); - auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d)); - acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1)); - acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2)); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix+0, iy, acc[2*iy+0]); - info.store(ix+4, iy, acc[2*iy+1]); - acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f); - } - } -} - #define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \ m.funcs[0] = func;\ m.funcs[1] = func;\ @@ -2975,11 +2583,12 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: - return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16); + case GGML_TYPE_Q4_0_R8: + case GGML_TYPE_Q5_0_R4: + case GGML_TYPE_Q6_0_R4: + case GGML_TYPE_Q8_0_R8: case GGML_TYPE_IQ4_NL_R4: - SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer); - expected_Btype = GGML_TYPE_Q8_0_X4; - break; + return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16); case GGML_TYPE_IQ4_XS_R8: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r8_q8_k); expected_Btype = GGML_TYPE_Q8_K32; @@ -3074,22 +2683,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq5_ks_r4_q8_k_neon); expected_Btype = GGML_TYPE_Q8_K; break; - case GGML_TYPE_Q4_0_R8: - SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r8_q8_0, Q4_0_R8_Dequantizer); - expected_Btype = GGML_TYPE_Q8_0_X4; - break; - case GGML_TYPE_Q5_0_R4: - SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q5_0_R4_Dequantizer); - expected_Btype = GGML_TYPE_Q8_0_X4; - break; - case GGML_TYPE_Q6_0_R4: - SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q6_0_R4_Dequantizer); - expected_Btype = GGML_TYPE_Q8_0_X4; - break; - case GGML_TYPE_Q8_0_R8: - SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_0_r8_q8_0); - expected_Btype = GGML_TYPE_Q8_0_X4; - break; default: return false; }