From b1956cd12270ace3055cf4bafc184a362f8ef70f Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 11 Jul 2025 15:50:47 +0200 Subject: [PATCH] iq2_kl: WIP NEON The compiler started crashing!!! --- ggml/src/iqk/iqk_gemm_iqk_quants.cpp | 72 ++++++++++++++++++++++++++++ ggml/src/iqk/iqk_mul_mat.cpp | 1 + 2 files changed, 73 insertions(+) diff --git a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp index fbd22b68..c009a230 100644 --- a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp @@ -3718,6 +3718,75 @@ struct DequantizerIQ2KS final : public BaseDequantizer }; +struct DequantizerIQ2KL final : public BaseDequantizer { + DequantizerIQ2KL(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), shifts(load_shift()) { load_values(values); } + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) { + uint32_t aux32; std::memcpy(&aux32, x[i].scales_l, 4); + auto scl = vand_u8(vdup_n_u8(0xf), vreinterpret_u8_u32(uint32x2_t{aux32, aux32 >> 4})); + auto sch = vandq_u16(vshlq_u16(vdupq_n_u16(x[i].scales_h), shifts), vdupq_n_u16(0x30)); + auto scales16 = vsubq_s16(vreinterpretq_s16_u16(vorrq_u16(sch, vmovl_u8(scl))), vdupq_n_s16(32)); + int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))}; + return scales; + } + inline void prepare(int i, int j) { + hbits = j == 0 ? vld1q_u8(x[i].qh) : vshrq_n_u8(hbits, 4); + auto lbits = vld1q_u8_x2(x[i].qs+32*j); + bits.b1.val[0] = vandq_u8(lbits.val[0], vdupq_n_u8(0xf)); + bits.b1.val[2] = vshrq_n_u8(lbits.val[0], 4); + bits.b2.val[0] = vandq_u8(lbits.val[1], vdupq_n_u8(0xf)); + bits.b2.val[2] = vshrq_n_u8(lbits.val[1], 4); + bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 4))); + bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 3))); + bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 2))); + bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 1))); + + auto t1 = vqtbl2q_s8(values[0], bits.b1.val[0]); + auto t2 = vqtbl2q_s8(values[1], bits.b1.val[0]); + bits.b1.val[0] = vzip1q_s8(t1, t2); + bits.b1.val[1] = vzip2q_s8(t1, t2); + t1 = vqtbl2q_s8(values[0], bits.b1.val[2]); + t2 = vqtbl2q_s8(values[1], bits.b1.val[2]); + bits.b1.val[2] = vzip1q_s8(t1, t2); + bits.b1.val[3] = vzip2q_s8(t1, t2); + + t1 = vqtbl2q_s8(values[0], bits.b2.val[0]); + t2 = vqtbl2q_s8(values[1], bits.b2.val[0]); + bits.b2.val[0] = vzip1q_s8(t1, t2); + bits.b2.val[1] = vzip2q_s8(t1, t2); + t1 = vqtbl2q_s8(values[0], bits.b2.val[2]); + t2 = vqtbl2q_s8(values[1], bits.b2.val[2]); + bits.b2.val[2] = vzip1q_s8(t1, t2); + bits.b2.val[3] = vzip2q_s8(t1, t2); + + hbits = vshrq_n_u8(hbits, 4); + } + static inline int16x8_t load_shift() { + static const int16_t k_shift[8] = {4, 2, 0, -2, -4, -6, -8, -10}; + return vld1q_s16(k_shift); + } + static inline void load_values(int8x16x2_t * values) { + static const int8_t k_values[64] = { + -63, -63, -40, -40, -40, -40, -23, -23, -23, -23, -23, -10, -10, -10, -10, 1, 1, 1, 1, 1, 13, 13, 13, 13, 13, 28, 28, 28, 28, 28, 47, 47, + -23, 13, -63, -10, 13, 47, -40, -23, 1, 13, 28, -63, 1, 13, 47, -23, -10, 1, 13, 28, -40, -23, -10, 1, 13, -63, -23, 1, 28, 47, -23, 13, + }; + values[0] = vld1q_s8_x2(k_values+ 0); + values[1] = vld1q_s8_x2(k_values+32); + } + + Q2bits bits; + //struct { uint8x16x4_t b1, b2; } bits; + uint8x16_t hbits; + const int16x8_t shifts; + const uint8x16_t m10 = vdupq_n_u8(0x10); + int8x16x2_t values[2]; + +}; + template void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -4959,6 +5028,9 @@ bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array