From d89c88e8df227eb8a852c3677ab2ec00f3e7792d Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 28 Jul 2024 08:36:20 +0200 Subject: [PATCH] iq4_k: NEON implementation For LLaMA-3.1-8B we get PP-512 = 60.7 t/s, TG-128 = 25.0 t/s on the M2-Max. TG is on par with q4_K_S, PP is ~10% slower. --- ggml/src/iqk/iqk_mul_mat.cpp | 75 ++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 1b6ad44c..1fe0af74 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3463,6 +3463,78 @@ struct DequantizerQ2K final : public BaseDequantizer { // ============================= i-quants +inline int32x4x4_t make_wider_8(const int8x16_t& scales8) { + int16x8x2_t scales16{vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8))}; + return make_wider(scales16); +} + +struct Scale16Extra { + template + static inline int32x4x4_t new_block(int i, float d, uint16_t extra, uint8_t val, + const uint8_t * scales_l, const uint8_t * scales_h, const Q8& q8, float32x4_t * acc) { + uint8x8_t aux = vld1_u8(scales_l); + uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf)); + const uint32_t * aux32 = (const uint32_t *)scales_h; + uint32x4_t sch_32 = {aux32[0] << 4, aux32[0] << 2, aux32[0], aux32[0] >> 2}; + uint8x16_t sch8 = vandq_u8(vreinterpretq_u8_u32(sch_32), vdupq_n_u8(0x30)); + int8x16_t scales8 = vorrq_u8(scl8, vqtbl1q_u8(sch8, vreinterpretq_u8_u32(hshuff))); + scales8 = vaddq_s8(vqtbl1q_s8(scales8, vreinterpretq_u8_u32(hshuff)), vdupq_n_s8(-32)); + return new_block(i, d, extra, val, scales8, q8, acc); + } + inline static uint8x16_t get_extra(uint16_t extra) { + uint8x16_t e8 = vreinterpretq_u8_u16(vdupq_n_u16(extra)); + e8 = vceqq_u8(vandq_u8(e8, emask), emask); + return vqtbl1q_u8(e8, eshuff); + } + template + static inline int32x4x4_t new_block(int i, float d, uint16_t extra, uint8_t val, + const int8x16_t& scales8, const Q8& q8, float32x4_t * acc) { + uint8x16_t e8 = vreinterpretq_u8_u16(vdupq_n_u16(extra)); + e8 = vceqq_u8(vandq_u8(e8, emask), emask); + e8 = vqtbl1q_u8(vandq_u8(e8, vdupq_n_u8(val)), eshuff); + int16x8x2_t extra16 = {vmull_s8(vget_low_s8 (e8), vget_low_s8 (scales8)), + vmull_s8(vget_high_s8(e8), vget_high_s8(scales8))}; + accum_mins_16(extra16, q8, acc, i, d); + return make_wider_8(scales8); + } + + constexpr static uint32x4_t hshuff = {0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}; + constexpr static uint32x4_t emask = {0x02020101, 0x08080404, 0x20201010, 0x80804040}; + constexpr static uint32x4_t eshuff = {0x06040200, 0x0e0c0a08, 0x07050301, 0x0f0d0b09}; +}; + +// Note: on ARM_NEON we cannot use the values shifted into the uint8_t range because +// the ARM_NEON only has vdotq_s32 or vdotq_u32, where both operands need to +// be signed or unsigned. As the Q8_K quants are signed, we need to have the +// iq4_s quants also signed. We can only use unsigned values in k-quants +// because they are all within the valid int8_t range. +struct DequantizerIQ4K final : public BaseDequantizer { + DequantizerIQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8(iq4k_values)) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + inline void new_row(int ix) { x = (const block_iq4_k *)((const char *)vx + bx*ix); } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return Scale16Extra::new_block(i, d, x[i].extra, 4, x[i].scales_l, x[i].scales_h, q8, acc); + } + inline void prepare(int i, int j) { + bits.prepare16(x[i].qs+64*j); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]); + bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]); + } + } + + Q4bits bits; + const int16x8_t values; + + float d; +}; + struct DequantizerIQ4XS final : public BaseDequantizer { static int8x16_t load_values() { @@ -4789,6 +4861,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_IQ4_XS: MulMat::set_functions(m); break; + case GGML_TYPE_IQ4_K: + MulMat::set_functions(m); + break; case GGML_TYPE_IQ2_XXS: MulMat::set_functions(m); break;