diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 68d545d0..a744e2f5 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -5011,6 +5011,63 @@ struct DequantizerIQ4KS final : public BaseDequantizer { const int16x8_t m127 = vdupq_n_s16(-127); }; +struct DequantizerIQ4KSS final : public BaseDequantizer { + + DequantizerIQ4KSS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { + (void)q8; + (void)acc; + auto q4bits_1 = vld1q_u16_x4((const uint16_t *)x[i].qs); + q4bits_2 = vld1q_u16_x4((const uint16_t *)x[i].qs + 32); + for (int k = 0; k < 4; ++k) { + aux[k+0] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_1.val[k], m1), shift)); + aux[k+4] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_2.val[k], m1), shift)); + q4bits_1.val[k] = vandq_u16(q4bits_1.val[k], bmask); + q4bits_1.val[k] = veorq_u16(q4bits_1.val[k], vshrq_n_u16(q4bits_1.val[k], 1)); + q4bits_2.val[k] = vandq_u16(q4bits_2.val[k], bmask); + q4bits_2.val[k] = veorq_u16(q4bits_2.val[k], vshrq_n_u16(q4bits_2.val[k], 1)); + } + make_quants(q4bits_1, bits, aux); + auto scales16 = vld1q_s16(aux); + scales16 = vaddq_s16(vandq_s16(scales16, mask), m127); + int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))}; + return scales; + } + inline void make_quants(uint16x8x4_t& q4bits, Q4bits& bits, const int16_t * aux) const { + bits.b1.val[0] = vqtbl1q_s8(values.val[aux[0] & 1], vandq_u8(q4bits.val[0], bits.m4b)); + bits.b1.val[1] = vqtbl1q_s8(values.val[aux[0] & 1], vshrq_n_u8(q4bits.val[0], 4)); + bits.b1.val[2] = vqtbl1q_s8(values.val[aux[1] & 1], vandq_u8(q4bits.val[1], bits.m4b)); + bits.b1.val[3] = vqtbl1q_s8(values.val[aux[1] & 1], vshrq_n_u8(q4bits.val[1], 4)); + bits.b2.val[0] = vqtbl1q_s8(values.val[aux[2] & 1], vandq_u8(q4bits.val[2], bits.m4b)); + bits.b2.val[1] = vqtbl1q_s8(values.val[aux[2] & 1], vshrq_n_u8(q4bits.val[2], 4)); + bits.b2.val[2] = vqtbl1q_s8(values.val[aux[3] & 1], vandq_u8(q4bits.val[3], bits.m4b)); + bits.b2.val[3] = vqtbl1q_s8(values.val[aux[3] & 1], vshrq_n_u8(q4bits.val[3], 4)); + } + inline void prepare([[maybe_unused]] int i, int j) { + if (j == 0) return; + make_quants(q4bits_2, bits, aux+4); + } + static int16x8_t load_shift() { + static const int16_t k_shift[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + return vld1q_s16(k_shift); + } + + Q4bits bits; + const int8x16x2_t values; + const uint16x8_t mask = vdupq_n_s16(254); + const uint16x8_t bmask = vdupq_n_u16(0xfffe); + const uint16x8_t m1 = vdupq_n_u16(1); + const int16x8_t shift = load_shift(); + const int16x8_t m127 = vdupq_n_s16(-127); + uint16x8x4_t q4bits_2; + int16_t aux[8]; +}; + struct DequantizerIQ2KS final : public BaseDequantizer { DequantizerIQ2KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} @@ -6782,6 +6839,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_IQ4_KS: MulMat::set_functions(m); break; + case GGML_TYPE_IQ4_KSS: + MulMat::set_functions(m); + break; case GGML_TYPE_IQ2_KS: MulMat::set_functions(m); break;