diff --git a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp index c009a230..3df94088 100644 --- a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp @@ -3719,7 +3719,7 @@ struct DequantizerIQ2KS final : public BaseDequantizer }; struct DequantizerIQ2KL final : public BaseDequantizer { - DequantizerIQ2KL(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), shifts(load_shift()) { load_values(values); } + DequantizerIQ2KL(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), shuff(load_shuffle()), shifts(load_shift()) { load_values(values); } constexpr static int num_blocks() { return 8; } constexpr static bool should_scale_quants() { return false; } @@ -3733,37 +3733,53 @@ struct DequantizerIQ2KL final : public BaseDequantizer int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))}; return scales; } + inline void process_pair(uint8x16_t x, uint8x16_t * val) const { + uint8x16x2_t aux{ vqtbl2q_s8(values[0], x), vqtbl2q_s8(values[1], x) }; + val[0] = vqtbl2q_u8(aux, shuff.val[0]); + val[1] = vqtbl2q_u8(aux, shuff.val[1]); + } inline void prepare(int i, int j) { hbits = j == 0 ? vld1q_u8(x[i].qh) : vshrq_n_u8(hbits, 4); auto lbits = vld1q_u8_x2(x[i].qs+32*j); - bits.b1.val[0] = vandq_u8(lbits.val[0], vdupq_n_u8(0xf)); - bits.b1.val[2] = vshrq_n_u8(lbits.val[0], 4); - bits.b2.val[0] = vandq_u8(lbits.val[1], vdupq_n_u8(0xf)); - bits.b2.val[2] = vshrq_n_u8(lbits.val[1], 4); - bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 4))); - bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 3))); - bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 2))); - bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 1))); - auto t1 = vqtbl2q_s8(values[0], bits.b1.val[0]); - auto t2 = vqtbl2q_s8(values[1], bits.b1.val[0]); - bits.b1.val[0] = vzip1q_s8(t1, t2); - bits.b1.val[1] = vzip2q_s8(t1, t2); - t1 = vqtbl2q_s8(values[0], bits.b1.val[2]); - t2 = vqtbl2q_s8(values[1], bits.b1.val[2]); - bits.b1.val[2] = vzip1q_s8(t1, t2); - bits.b1.val[3] = vzip2q_s8(t1, t2); + uint8x16x4_t aux; + aux.val[0] = vorrq_u8(vandq_u8(m10, vshlq_n_u8(hbits, 4)), vandq_u8(lbits.val[0], vdupq_n_u8(0xf))); + aux.val[1] = vorrq_u8(vandq_u8(m10, vshlq_n_u8(hbits, 3)), vshrq_n_u8(lbits.val[0], 4)); + aux.val[2] = vorrq_u8(vandq_u8(m10, vshlq_n_u8(hbits, 2)), vandq_u8(lbits.val[1], vdupq_n_u8(0xf))); + aux.val[3] = vorrq_u8(vandq_u8(m10, vshlq_n_u8(hbits, 1)), vshrq_n_u8(lbits.val[1], 4)); - t1 = vqtbl2q_s8(values[0], bits.b2.val[0]); - t2 = vqtbl2q_s8(values[1], bits.b2.val[0]); - bits.b2.val[0] = vzip1q_s8(t1, t2); - bits.b2.val[1] = vzip2q_s8(t1, t2); - t1 = vqtbl2q_s8(values[0], bits.b2.val[2]); - t2 = vqtbl2q_s8(values[1], bits.b2.val[2]); - bits.b2.val[2] = vzip1q_s8(t1, t2); - bits.b2.val[3] = vzip2q_s8(t1, t2); + process_pair(aux.val[0], bits.b1.val+0); + process_pair(aux.val[1], bits.b1.val+2); + process_pair(aux.val[2], bits.b2.val+0); + process_pair(aux.val[3], bits.b2.val+2); - hbits = vshrq_n_u8(hbits, 4); + // The compiler crashes the moment I try to use vzip2q_u8!!! + //bits.b1.val[0] = vandq_u8(lbits.val[0], vdupq_n_u8(0xf)); + //bits.b1.val[2] = vshrq_n_u8(lbits.val[0], 4); + //bits.b2.val[0] = vandq_u8(lbits.val[1], vdupq_n_u8(0xf)); + //bits.b2.val[2] = vshrq_n_u8(lbits.val[1], 4); + //bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 4))); + //bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 3))); + //bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 2))); + //bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 1))); + + //auto t1 = vqtbl2q_s8(values[0], bits.b1.val[0]); + //auto t2 = vqtbl2q_s8(values[1], bits.b1.val[0]); + //bits.b1.val[0] = vzip1q_s8(t1, t2); + ////bits.b1.val[1] = vzip2q_u8(t1, t2); + //t1 = vqtbl2q_s8(values[0], bits.b1.val[2]); + //t2 = vqtbl2q_s8(values[1], bits.b1.val[2]); + //bits.b1.val[2] = vzip1q_s8(t1, t2); + ////bits.b1.val[3] = vzip2q_s8(t1, t2); + + //t1 = vqtbl2q_s8(values[0], bits.b2.val[0]); + //t2 = vqtbl2q_s8(values[1], bits.b2.val[0]); + //bits.b2.val[0] = vzip1q_s8(t1, t2); + ////bits.b2.val[1] = vzip2q_s8(t1, t2); + //t1 = vqtbl2q_s8(values[0], bits.b2.val[2]); + //t2 = vqtbl2q_s8(values[1], bits.b2.val[2]); + //bits.b2.val[2] = vzip1q_s8(t1, t2); + ////bits.b2.val[3] = vzip2q_s8(t1, t2); } static inline int16x8_t load_shift() { static const int16_t k_shift[8] = {4, 2, 0, -2, -4, -6, -8, -10}; @@ -3777,10 +3793,17 @@ struct DequantizerIQ2KL final : public BaseDequantizer values[0] = vld1q_s8_x2(k_values+ 0); values[1] = vld1q_s8_x2(k_values+32); } + static uint8x16x2_t load_shuffle() { + static const uint8_t k_shuff[32] = { + 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, + 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31 + }; + return vld1q_u8_x2(k_shuff); + } - Q2bits bits; - //struct { uint8x16x4_t b1, b2; } bits; + struct { uint8x16x4_t b1, b2; } bits; uint8x16_t hbits; + const uint8x16x2_t shuff; const int16x8_t shifts; const uint8x16_t m10 = vdupq_n_u8(0x10); int8x16x2_t values[2];