diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h index 4701fa74..620bd7f9 100644 --- a/ggml/src/iqk/iqk_common.h +++ b/ggml/src/iqk/iqk_common.h @@ -526,6 +526,7 @@ struct Q4Bits { #endif #else +// ------------------------------------ __aarch64__ -------------------------------------------------- template struct Q8 { @@ -547,6 +548,214 @@ template struct Q8 { const block_q8 * y[nrc_y]; }; +template +struct BaseDequantizer { + BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {} + inline void new_row(int ix) { + if constexpr (has_row_scale) { + if constexpr (scale_is_f16) { + const ggml_half * dptr = (const ggml_half *)((const char *)vx + ix*bx); + d = GGML_FP16_TO_FP32(*dptr); + x = (const block_q *)(dptr + 1); + } else { + const float * dptr = (const float *)((const char *)vx + ix*bx); + d = *dptr; + x = (const block_q *)(dptr + 1); + } + } else { + x = (const block_q *)((const char *)vx + ix*bx); + } + } + const void * vx; + const block_q * x; + const size_t bx; + const int nrc; + float d; +}; + +struct Q4bits { + const uint8x16_t m4b = vdupq_n_u8(0xf); + uint8x16x4_t b1, b2; + inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const { + b.val[0] = vandq_u8(val[0], m4b); + b.val[2] = vshrq_n_u8(val[0], 4); + b.val[1] = vandq_u8(val[1], m4b); + b.val[3] = vshrq_n_u8(val[1], 4); + } + inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const { + b.val[0] = vandq_u8(val[0], m4b); + b.val[1] = vshrq_n_u8(val[0], 4); + b.val[2] = vandq_u8(val[1], m4b); + b.val[3] = vshrq_n_u8(val[1], 4); + } + inline void prepare(const uint8_t * qs) { + auto q4bits = vld1q_u8_x2(qs); + prepare4(b1, q4bits.val); + q4bits = vld1q_u8_x2(qs+32); + prepare4(b2, q4bits.val); + } + inline void prepare_v2(const uint8_t * qs) { + auto q4bits = vld1q_u8_x4(qs); + prepare4(b1, q4bits.val+0); + prepare4(b2, q4bits.val+2); + } + inline void prepare64(const uint8_t * qs) { + auto q4bits = vld1q_u8_x4(qs); + b1.val[0] = vandq_u8(q4bits.val[0], m4b); + b1.val[1] = vandq_u8(q4bits.val[1], m4b); + b1.val[2] = vandq_u8(q4bits.val[2], m4b); + b1.val[3] = vandq_u8(q4bits.val[3], m4b); + b2.val[0] = vshrq_n_u8(q4bits.val[0], 4); + b2.val[1] = vshrq_n_u8(q4bits.val[1], 4); + b2.val[2] = vshrq_n_u8(q4bits.val[2], 4); + b2.val[3] = vshrq_n_u8(q4bits.val[3], 4); + } + inline void prepare16(const uint8_t * qs) { + auto q4bits = vld1q_u8_x2(qs); + prepare4_16(b1, q4bits.val); + q4bits = vld1q_u8_x2(qs+32); + prepare4_16(b2, q4bits.val); + } + inline void prepare16_v2(const uint8_t * qs) { + auto q4bits = vld1q_u8_x4(qs); + prepare4_16(b1, q4bits.val+0); + prepare4_16(b2, q4bits.val+2); + } +}; + +struct Q2bits { + const uint8x16_t m4b = vdupq_n_u8(0x03); + uint8x16x4_t b1, b2; + inline void prepare(const uint8_t * qs) { + auto q2bits = vld1q_u8_x2(qs); + b1.val[0] = vandq_u8(q2bits.val[0], m4b); + b1.val[1] = vandq_u8(q2bits.val[1], m4b); + + q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); + q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); + b1.val[2] = vandq_u8(q2bits.val[0], m4b); + b1.val[3] = vandq_u8(q2bits.val[1], m4b); + + q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); + q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); + b2.val[0] = vandq_u8(q2bits.val[0], m4b); + b2.val[1] = vandq_u8(q2bits.val[1], m4b); + + q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); + q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); + b2.val[2] = vandq_u8(q2bits.val[0], m4b); + b2.val[3] = vandq_u8(q2bits.val[1], m4b); + } +}; + +template +static inline void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, + const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) { + auto mzero = vdupq_n_s32(0); + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1]); // block 1 + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1]); // block 2 + auto p12 = vpaddq_s32(p1, p2); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]), + vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1]); // block 1 + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]), + vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1]); // block 2 + auto p34 = vpaddq_s32(p3, p4); + + auto pall = vpaddq_s32(p12, p34); + sumi = vmlaq_s32(sumi, scales.val[j], pall); +} + +template +static inline void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, + const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) { + + auto mzero = vdupq_n_s32(0); + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]), + ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1, + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]), + ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4, + auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3 + sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]), + ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5, + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]), + ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7, + auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7 + sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34); +} + +template +static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8 q8(info); + + Dequantizer deq(vx, bx, nrc_y); + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + float32x4_t acc[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); + + for (int i = 0; i < nb; ++i) { + + int32x4_t sumi[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); + + if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) { + deq.process_scales(i, q8, acc); + deq.prepare(i, 0); + deq.compute(q8, i, 0, sumi); + deq.prepare(i, 1); + deq.compute(q8, i, 1, sumi); + } else { + if constexpr (Dequantizer::num_blocks() == 8) { + auto scales = deq.new_block(i, q8, acc); + deq.prepare(i, 0); + for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); + deq.prepare(i, 1); + for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); + } + else if constexpr (Dequantizer::num_blocks() == 16) { + auto scales = deq.new_block(i, q8, acc); + deq.prepare(i, 0); + for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); + deq.prepare(i, 1); + for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); + } + else { + GGML_ASSERT(false); + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(acc[iy])); + } + } +} + + + #endif #endif diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index b7ae38f4..c69fa7cf 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -1781,6 +1781,397 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array +inline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8s = q8.load_bsums8(iy, i); + int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s)); + int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s)); + float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2)); + acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i))); + } +} +template +inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8s = q8.load_bsums(iy, i); + int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0])); + int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0])); + int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1])); + int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1])); + float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4))); + acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i))); + } +} + +struct Scales8 { + uint32_t utmp[4]; + const uint8_t * sc8 = (const uint8_t *)utmp; + template + inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) { + make_q4_scales(x.scales, utmp); + int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8)); + accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin)); + + uint8x8_t scales8 = vld1_u8(sc8); + uint16x8_t scales16 = vmovl_u8(scales8); + int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))), + vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))}; + return scales; + } +}; + +struct DequantizerQ4K final : public BaseDequantizer { + DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return s8.process_scales_mins(x[i], q8, i, acc); + } + inline void prepare(int i, int j) { + if (nrc == 1) bits.prepare_v2(x[i].qs+64*j); + else bits.prepare(x[i].qs+64*j); + } + + Q4bits bits; + Scales8 s8; + +}; + +struct HighBit5 { + const uint8x16_t mhb = vdupq_n_u8(0x10); + uint8x16x2_t bits; + inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) { + b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb)); + b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb)); + b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb)); + b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb)); + + b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb)); + b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb)); + b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb)); + b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb)); + + if (do_shift) { + bits.val[0] = vshrq_n_u8(bits.val[0], 4); + bits.val[1] = vshrq_n_u8(bits.val[1], 4); + } + } +}; + +struct HighBit3 { + const uint8x16_t mhb = vdupq_n_u8(0x04); + uint8x16x2_t bits; + inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) { + b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb)); + b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb)); + b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb)); + b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb)); + + b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb)); + b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb)); + b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb)); + b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb)); + + if (do_shift) { + bits.val[0] = vshrq_n_u8(bits.val[0], 4); + bits.val[1] = vshrq_n_u8(bits.val[1], 4); + } + } +}; + +struct DequantizerQ5K final : public BaseDequantizer { + DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + h.bits = vld1q_u8_x2(x[i].qh); + return s8.process_scales_mins(x[i], q8, i, acc); + } + inline void prepare(int i, int j) { + if (nrc == 1) bits.prepare_v2(x[i].qs+64*j); + else bits.prepare(x[i].qs+64*j); + h.apply(bits.b1, bits.b2, j == 0); + } + + Q4bits bits; + HighBit5 h; + Scales8 s8; + + uint8x16x2_t hbits; + +}; + +inline int32x4x4_t make_wider(const int16x8x2_t& scales16) { + int32x4x4_t scales = { + vmovl_s16(vget_low_s16 (scales16.val[0])), + vmovl_s16(vget_high_s16(scales16.val[0])), + vmovl_s16(vget_low_s16 (scales16.val[1])), + vmovl_s16(vget_high_s16(scales16.val[1])), + }; + return scales; +} + +template +inline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) { + int16x8x2_t scales16; + scales16.val[0] = vmovl_s8(vget_low_s8(scales8)); + scales16.val[1] = vmovl_s8(vget_high_s8(scales8)); + accum_mins_16(scales16, q8, acc, i, c); + return make_wider(scales16); +} + +struct DequantizerQ6K final : public BaseDequantizer { + DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d); + } + inline void prepare(int i, int j) { + + auto hbits = vld1q_u8_x2(x[i].qh + 32*j); + + bits.prepare64(x[i].ql+64*j); + bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb)); + bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb)); + bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb)); + bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb)); + + bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb)); + bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb)); + bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb)); + bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb)); + + } + + Q4bits bits; + + const uint8x16_t mhb = vdupq_n_u8(0x30); + +}; + +struct DequantizerQ3K final : public BaseDequantizer { + DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + h.bits = vld1q_u8_x2(x[i].hmask); + mask = vdupq_n_u8(0x01); + const uint16_t * sc16 = (const uint16_t *)x[i].scales; + uint32_t aux0 = sc16[0] | (sc16[1] << 16); + uint32_t aux1 = sc16[2] | (sc16[3] << 16); + uint32_t aux2 = sc16[4] | (sc16[5] << 16); + aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030); + aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030); + aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030); + aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030); + auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)); + if (nrc > 1) { + return process_scales_mins_16(scales8, q8, acc, i, -4.f*d); + } + int16x8x2_t scales16; + scales16.val[0] = vmovl_s8(vget_low_s8(scales8)); + scales16.val[1] = vmovl_s8(vget_high_s8(scales8)); + return make_wider(scales16); + } + + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + if (nrc > 1) { + h.apply(bits.b1, bits.b2, j == 0); + } else { + auto minus4 = vdupq_n_u8(0xfc); + auto zero = vdupq_n_u8(0); + bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + } + } + + uint32_t aux32[4]; + + Q2bits bits; + + uint8x16_t mask; + HighBit3 h; + +}; + +struct DequantizerQ2K final : public BaseDequantizer { + DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return true; } + + template + inline void process_scales(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + auto scales_and_mins = vld1q_u8(x[i].scales); + auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4)); + int16x8x2_t scales16; + scales16.val[0] = vmovl_s8(vget_low_s8(mins8)); + scales16.val[1] = vmovl_s8(vget_high_s8(mins8)); + accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin)); + + scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf)); + } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + process_scales(i, q8, acc); + int16x8x2_t scales16; + scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8))); + scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8))); + return make_wider(scales16); + } + + template + inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { + auto m1 = vdupq_n_u8(1); + auto shuffle = vdupq_n_u8(8*j); + bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); + + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]), + vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]); + + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]), + vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]); + } + } + + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + } + + uint32_t aux32[4]; + + uint8x16_t scales8; + + Q2bits bits; + +}; + +} + +bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array& kernels, [[maybe_unused]] mul_mat_t& func16) { + + auto etypeA = ggml_type(typeA); + auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32 + : etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8 + : etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV + : GGML_TYPE_Q8_K; + + if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) { + return false; + } + + func16 = nullptr; + + switch (typeA) { + case GGML_TYPE_Q2_K: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ2K, kernels) + break; + case GGML_TYPE_Q3_K: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ3K, kernels) + break; + case GGML_TYPE_Q4_K: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ4K, kernels) + break; + case GGML_TYPE_Q5_K: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ5K, kernels) + break; + case GGML_TYPE_Q6_K: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ6K, kernels) + break; +// case GGML_TYPE_IQ4_XS: +// set_functions(kernels); +// break; +// case GGML_TYPE_Q2_K_R4: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q2_k_r4_q8_k, kernels) +// break; +// case GGML_TYPE_Q3_K_R4: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q3_k_r4_q8_k, kernels) +// break; +// case GGML_TYPE_Q4_K_R4: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q4_k_r4_q8_k, kernels) +// break; +// case GGML_TYPE_Q5_K_R4: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q5_k_r4_q8_k, kernels) +// break; +// case GGML_TYPE_Q6_K_R4: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q6_k_r4_q8_k, kernels) +// break; +// case GGML_TYPE_IQ4_XS_R8: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_xs_r8_q8_k_avx2, kernels) +// break; +// case GGML_TYPE_Q8_K_R8: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_k_r8_q8_k, kernels) +//#ifdef HAVE_FANCY_SIMD +// func16 = mul_mat_q8_k_r8_q8_k<16>; +//#endif +// break; +// case GGML_TYPE_Q8_KV: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_q8_KV, kernels) +//#ifdef HAVE_FANCY_SIMD +// func16 = mul_mat_q8_KV_q8_KV<16>; +//#endif +// break; +// case GGML_TYPE_Q8_KV_R8: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_r8_q8_KV, kernels); +// break; + default: + return false; + } + + return true; + +} + #endif #endif diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index f2d785e9..e271f17d 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -873,195 +873,6 @@ struct Scales8 { } }; -struct Q4bits { - const uint8x16_t m4b = vdupq_n_u8(0xf); - uint8x16x4_t b1, b2; - inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const { - b.val[0] = vandq_u8(val[0], m4b); - b.val[2] = vshrq_n_u8(val[0], 4); - b.val[1] = vandq_u8(val[1], m4b); - b.val[3] = vshrq_n_u8(val[1], 4); - } - inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const { - b.val[0] = vandq_u8(val[0], m4b); - b.val[1] = vshrq_n_u8(val[0], 4); - b.val[2] = vandq_u8(val[1], m4b); - b.val[3] = vshrq_n_u8(val[1], 4); - } - inline void prepare(const uint8_t * qs) { - auto q4bits = vld1q_u8_x2(qs); - prepare4(b1, q4bits.val); - q4bits = vld1q_u8_x2(qs+32); - prepare4(b2, q4bits.val); - } - inline void prepare_v2(const uint8_t * qs) { - auto q4bits = vld1q_u8_x4(qs); - prepare4(b1, q4bits.val+0); - prepare4(b2, q4bits.val+2); - } - inline void prepare64(const uint8_t * qs) { - auto q4bits = vld1q_u8_x4(qs); - b1.val[0] = vandq_u8(q4bits.val[0], m4b); - b1.val[1] = vandq_u8(q4bits.val[1], m4b); - b1.val[2] = vandq_u8(q4bits.val[2], m4b); - b1.val[3] = vandq_u8(q4bits.val[3], m4b); - b2.val[0] = vshrq_n_u8(q4bits.val[0], 4); - b2.val[1] = vshrq_n_u8(q4bits.val[1], 4); - b2.val[2] = vshrq_n_u8(q4bits.val[2], 4); - b2.val[3] = vshrq_n_u8(q4bits.val[3], 4); - } - inline void prepare16(const uint8_t * qs) { - auto q4bits = vld1q_u8_x2(qs); - prepare4_16(b1, q4bits.val); - q4bits = vld1q_u8_x2(qs+32); - prepare4_16(b2, q4bits.val); - } - inline void prepare16_v2(const uint8_t * qs) { - auto q4bits = vld1q_u8_x4(qs); - prepare4_16(b1, q4bits.val+0); - prepare4_16(b2, q4bits.val+2); - } -}; - -struct Q2bits { - const uint8x16_t m4b = vdupq_n_u8(0x03); - uint8x16x4_t b1, b2; - inline void prepare(const uint8_t * qs) { - auto q2bits = vld1q_u8_x2(qs); - b1.val[0] = vandq_u8(q2bits.val[0], m4b); - b1.val[1] = vandq_u8(q2bits.val[1], m4b); - - q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); - q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); - b1.val[2] = vandq_u8(q2bits.val[0], m4b); - b1.val[3] = vandq_u8(q2bits.val[1], m4b); - - q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); - q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); - b2.val[0] = vandq_u8(q2bits.val[0], m4b); - b2.val[1] = vandq_u8(q2bits.val[1], m4b); - - q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); - q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); - b2.val[2] = vandq_u8(q2bits.val[0], m4b); - b2.val[3] = vandq_u8(q2bits.val[1], m4b); - } -}; - -template -struct BaseDequantizer { - BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {} - inline void new_row(int ix) { - if constexpr (has_row_scale) { - if constexpr (scale_is_f16) { - const ggml_half * dptr = (const ggml_half *)((const char *)vx + ix*bx); - d = GGML_FP16_TO_FP32(*dptr); - x = (const block_q *)(dptr + 1); - } else { - const float * dptr = (const float *)((const char *)vx + ix*bx); - d = *dptr; - x = (const block_q *)(dptr + 1); - } - } else { - x = (const block_q *)((const char *)vx + ix*bx); - } - } - const void * vx; - const block_q * x; - const size_t bx; - const int nrc; - float d; -}; - -struct DequantizerQ4K final : public BaseDequantizer { - DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} - - constexpr static int num_blocks() { return 8; } - constexpr static bool should_scale_quants() { return false; } - - template - inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { - d = GGML_FP16_TO_FP32(x[i].d); - return s8.process_scales_mins(x[i], q8, i, acc); - } - inline void prepare(int i, int j) { - if (nrc == 1) bits.prepare_v2(x[i].qs+64*j); - else bits.prepare(x[i].qs+64*j); - } - - Q4bits bits; - Scales8 s8; - -}; - -struct HighBit5 { - const uint8x16_t mhb = vdupq_n_u8(0x10); - uint8x16x2_t bits; - inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) { - b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb)); - b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb)); - b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb)); - b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb)); - - b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb)); - b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb)); - b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb)); - b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb)); - - if (do_shift) { - bits.val[0] = vshrq_n_u8(bits.val[0], 4); - bits.val[1] = vshrq_n_u8(bits.val[1], 4); - } - } -}; - -struct HighBit3 { - const uint8x16_t mhb = vdupq_n_u8(0x04); - uint8x16x2_t bits; - inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) { - b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb)); - b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb)); - b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb)); - b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb)); - - b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb)); - b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb)); - b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb)); - b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb)); - - if (do_shift) { - bits.val[0] = vshrq_n_u8(bits.val[0], 4); - bits.val[1] = vshrq_n_u8(bits.val[1], 4); - } - } -}; - -struct DequantizerQ5K final : public BaseDequantizer { - DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} - - constexpr static int num_blocks() { return 8; } - constexpr static bool should_scale_quants() { return false; } - - template - inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { - d = GGML_FP16_TO_FP32(x[i].d); - h.bits = vld1q_u8_x2(x[i].qh); - return s8.process_scales_mins(x[i], q8, i, acc); - } - inline void prepare(int i, int j) { - if (nrc == 1) bits.prepare_v2(x[i].qs+64*j); - else bits.prepare(x[i].qs+64*j); - h.apply(bits.b1, bits.b2, j == 0); - } - - Q4bits bits; - HighBit5 h; - Scales8 s8; - - uint8x16x2_t hbits; - -}; - inline int32x4x4_t make_wider(const int16x8x2_t& scales16) { int32x4x4_t scales = { vmovl_s16(vget_low_s16 (scales16.val[0])), @@ -1081,171 +892,6 @@ inline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8 return make_wider(scales16); } -struct DequantizerQ6K final : public BaseDequantizer { - DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} - - constexpr static int num_blocks() { return 16; } - constexpr static bool should_scale_quants() { return false; } - - template - inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { - d = GGML_FP16_TO_FP32(x[i].d); - return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d); - } - inline void prepare(int i, int j) { - - auto hbits = vld1q_u8_x2(x[i].qh + 32*j); - - bits.prepare64(x[i].ql+64*j); - bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb)); - bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb)); - bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb)); - bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb)); - - bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb)); - bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb)); - bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb)); - bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb)); - - } - - Q4bits bits; - - const uint8x16_t mhb = vdupq_n_u8(0x30); - -}; - -struct DequantizerQ3K final : public BaseDequantizer { - DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} - - constexpr static int num_blocks() { return 16; } - constexpr static bool should_scale_quants() { return false; } - - template - inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { - d = GGML_FP16_TO_FP32(x[i].d); - h.bits = vld1q_u8_x2(x[i].hmask); - mask = vdupq_n_u8(0x01); - const uint16_t * sc16 = (const uint16_t *)x[i].scales; - uint32_t aux0 = sc16[0] | (sc16[1] << 16); - uint32_t aux1 = sc16[2] | (sc16[3] << 16); - uint32_t aux2 = sc16[4] | (sc16[5] << 16); - aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030); - aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030); - aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030); - aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030); - auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)); - if (nrc > 1) { - return process_scales_mins_16(scales8, q8, acc, i, -4.f*d); - } - int16x8x2_t scales16; - scales16.val[0] = vmovl_s8(vget_low_s8(scales8)); - scales16.val[1] = vmovl_s8(vget_high_s8(scales8)); - return make_wider(scales16); - } - - inline void prepare(int i, int j) { - bits.prepare(x[i].qs+32*j); - if (nrc > 1) { - h.apply(bits.b1, bits.b2, j == 0); - } else { - auto minus4 = vdupq_n_u8(0xfc); - auto zero = vdupq_n_u8(0); - bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); - bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); - mask = vshlq_n_u8(mask, 1); - bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); - bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); - mask = vshlq_n_u8(mask, 1); - bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); - bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); - mask = vshlq_n_u8(mask, 1); - bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); - bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); - mask = vshlq_n_u8(mask, 1); - } - } - - uint32_t aux32[4]; - - Q2bits bits; - - uint8x16_t mask; - HighBit3 h; - -}; - -struct DequantizerQ2K final : public BaseDequantizer { - DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} - - constexpr static int num_blocks() { return 16; } - constexpr static bool should_scale_quants() { return true; } - - template - inline void process_scales(int i, const Q8& q8, float32x4_t * acc) { - d = GGML_FP16_TO_FP32(x[i].d); - auto scales_and_mins = vld1q_u8(x[i].scales); - auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4)); - int16x8x2_t scales16; - scales16.val[0] = vmovl_s8(vget_low_s8(mins8)); - scales16.val[1] = vmovl_s8(vget_high_s8(mins8)); - accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin)); - - scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf)); - } - - template - inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { - process_scales(i, q8, acc); - int16x8x2_t scales16; - scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8))); - scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8))); - return make_wider(scales16); - } - - template - inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { - auto m1 = vdupq_n_u8(1); - auto shuffle = vdupq_n_u8(8*j); - bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); - bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); - bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); - bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); - bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); - bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); - bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); - bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - auto q8b_1 = q8.load_quants(iy, i, 4*j+0); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), - vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); - - auto q8b_2 = q8.load_quants(iy, i, 4*j+1); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), - vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); - - auto q8b_3 = q8.load_quants(iy, i, 4*j+2); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]), - vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]); - - auto q8b_4 = q8.load_quants(iy, i, 4*j+3); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]), - vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]); - } - } - - inline void prepare(int i, int j) { - bits.prepare(x[i].qs+32*j); - } - - uint32_t aux32[4]; - - uint8x16_t scales8; - - Q2bits bits; - -}; - // ============================= i-quants inline int32x4x4_t make_wider_8(const int8x16_t& scales8) { @@ -1969,64 +1615,6 @@ struct DequantizerIQ3S final : public BaseDequantizer { }; -template -void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n % QK_K == 0); - const int nb = n / QK_K; - - Q8 q8(info); - - Dequantizer deq(vx, bx, nrc_y); - - for (int ix = 0; ix < nrc_x; ++ix) { - - deq.new_row(ix); - - float32x4_t acc[nrc_y]; - for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); - - for (int i = 0; i < nb; ++i) { - - int32x4_t sumi[nrc_y]; - for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); - - if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) { - deq.process_scales(i, q8, acc); - deq.prepare(i, 0); - deq.compute(q8, i, 0, sumi); - deq.prepare(i, 1); - deq.compute(q8, i, 1, sumi); - } else { - if constexpr (Dequantizer::num_blocks() == 8) { - auto scales = deq.new_block(i, q8, acc); - deq.prepare(i, 0); - for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); - deq.prepare(i, 1); - for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); - } - else if constexpr (Dequantizer::num_blocks() == 16) { - auto scales = deq.new_block(i, q8, acc); - deq.prepare(i, 0); - for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); - deq.prepare(i, 1); - for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); - } - else { - GGML_ASSERT(false); - } - } - - for (int iy = 0; iy < nrc_y; ++iy) { - acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); - } - } - - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vaddvq_f32(acc[iy])); - } - } -} - // =========================================== Legacy quants template @@ -5095,20 +4683,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { switch (typeA) { case GGML_TYPE_Q2_K: - MulMat::set_functions(m); - break; case GGML_TYPE_Q3_K: - MulMat::set_functions(m); - break; case GGML_TYPE_Q4_K: - MulMat::set_functions(m); - break; case GGML_TYPE_Q5_K: - MulMat::set_functions(m); - break; case GGML_TYPE_Q6_K: - MulMat::set_functions(m); - break; + return iqk_set_kernels_kquants(ne00, typeA, typeB, m.funcs, m.func16); case GGML_TYPE_IQ4_XS: MulMat::set_functions(m); break;