From 9dd452d4cb8fe9f0925d1fd9384f8ed31df5c7fd Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 16 May 2025 10:15:17 +0300 Subject: [PATCH] Fix iq5_ks on NEON --- ggml/src/iqk/iqk_mul_mat.cpp | 39 +++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index d4705c3e..7d7ae798 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -11207,7 +11207,8 @@ struct DequantizerIQ4KS final : public BaseDequantizer { }; struct DequantizerIQ5KS final : public BaseDequantizer { - DequantizerIQ5KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq5nl_values)) {} + DequantizerIQ5KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), + values(vld1q_s8_x4(iq5nl_values)) {} constexpr static int num_blocks() { return 8; } constexpr static bool should_scale_quants() { return false; } @@ -11216,7 +11217,11 @@ struct DequantizerIQ5KS final : public BaseDequantizer { inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { (void)q8; (void)acc; - auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(vld1_u8(x[i].scales)), mask)), m127); + auto sas8 = vld1_u8(x[i].scales); + auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(sas8), mask)), m127); + hbits = vld1q_u8_x2(x[i].qh); + sas = vcombine_u8(sas8, sas8); + sas = vshlq_n_u8(vandq_u8(sas, vdupq_n_u8(1)), 5); int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))}; return scales; } @@ -11226,27 +11231,29 @@ struct DequantizerIQ5KS final : public BaseDequantizer { if (j == 1) { for (int k = 0; k < 2; ++k) hbits.val[k] = vshrq_n_u8(hbits.val[k], 4); } - bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm)); - bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm)); - bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm)); - bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm)); - bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm)); - bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm)); - bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm)); - bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm)); - for (int k = 0; k < 4; ++k) { - bits.b1.val[k] = vqtbl2q_s8(values, bits.b1.val[k]); - bits.b2.val[k] = vqtbl2q_s8(values, bits.b2.val[k]); - } + auto shift = vdupq_n_u8((x[i].scales[4*j+0] & 1) << 5); + bits.b1.val[0] = vaddq_u8(shift, vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm))); + bits.b1.val[1] = vaddq_u8(shift, vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm))); + shift = vdupq_n_u8((x[i].scales[4*j+1] & 1) << 5); + bits.b1.val[2] = vaddq_u8(shift, vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm))); + bits.b1.val[3] = vaddq_u8(shift, vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm))); + for (int k = 0; k < 4; ++k) bits.b1.val[k] = vqtbl4q_s8(values, bits.b1.val[k]); + shift = vdupq_n_u8((x[i].scales[4*j+2] & 1) << 5); + bits.b2.val[0] = vaddq_u8(shift, vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm))); + bits.b2.val[1] = vaddq_u8(shift, vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm))); + shift = vdupq_n_u8((x[i].scales[4*j+3] & 1) << 5); + bits.b2.val[2] = vaddq_u8(shift, vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm))); + bits.b2.val[3] = vaddq_u8(shift, vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm))); + for (int k = 0; k < 4; ++k) bits.b2.val[k] = vqtbl4q_s8(values, bits.b2.val[k]); } Q4bits bits; - const int8x16x2_t values; - const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}); + const int8x16x4_t values; const uint8x16_t hm = vdupq_n_u8(0x10); const uint16x8_t mask = vdupq_n_u16(254); const int16x8_t m127 = vdupq_n_s16(-127); uint8x16x2_t hbits; + uint8x16_t sas; };