From e528505fc805fe7ac609eddaac0ed60108142e8c Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 6 Aug 2024 06:19:59 +0200 Subject: [PATCH] iq2_tn: NEON For TriLM-3.9B running on the M2-Max we get PP-512 = 193.5 t/s, TG-128 = 75.5 t/s. This is in line with what we have for iq2_bn ant 3.3B Bitnet. --- ggml/src/iqk/iqk_mul_mat.cpp | 211 +++++++++++++++++++++++++---------- 1 file changed, 154 insertions(+), 57 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ffe27db3..3510cbaf 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3876,62 +3876,6 @@ struct DequantizerQ2K final : public BaseDequantizer { float d; }; -struct DequantizerIQ2TN final : public BaseDequantizer { - DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} - - constexpr static int num_blocks() { return 16; } - constexpr static bool should_scale_quants() { return true; } - - template - inline void process_scales(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) { - d = GGML_FP16_TO_FP32(x[i].d); - } - - template - inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { - process_scales(i, q8, acc); - return { vdupq_n_s32(1), vdupq_n_s32(1), vdupq_n_s32(1), vdupq_n_s32(1) }; - } - - template - inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - auto q8b_1 = q8.load_quants(iy, i, 4*j+0); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), - vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); - - auto q8b_2 = q8.load_quants(iy, i, 4*j+1); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), - vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); - - auto q8b_3 = q8.load_quants(iy, i, 4*j+2); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]), - vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]); - - auto q8b_4 = q8.load_quants(iy, i, 4*j+3); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]), - vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]); - } - } - - inline void prepare(int i, int j) { - bits.prepare(x[i].qs+32*j); - auto m1 = vdupq_n_s8(1); - bits.b1.val[0] = vsubq_s8(bits.b1.val[0], m1); - bits.b1.val[1] = vsubq_s8(bits.b1.val[1], m1); - bits.b1.val[2] = vsubq_s8(bits.b1.val[2], m1); - bits.b1.val[3] = vsubq_s8(bits.b1.val[3], m1); - bits.b2.val[0] = vsubq_s8(bits.b2.val[0], m1); - bits.b2.val[1] = vsubq_s8(bits.b2.val[1], m1); - bits.b2.val[2] = vsubq_s8(bits.b2.val[2], m1); - bits.b2.val[3] = vsubq_s8(bits.b2.val[3], m1); - } - - Q2bits bits; - - float d; -}; - // ============================= i-quants inline int32x4x4_t make_wider_8(const int8x16_t& scales8) { @@ -4460,6 +4404,151 @@ struct DequantizerIQ3S final : public BaseDequantizer { }; +struct DequantizerIQ2TN final : public BaseDequantizer { + DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return true; } + + //template + //inline void process_scales(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) { + // d = GGML_FP16_TO_FP32(x[i].d); + //} + + inline void new_block(int i) { + d = GGML_FP16_TO_FP32(x[i].d); + } + + template + inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); + + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]), + vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]); + + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]), + vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]); + } + } + template + inline void compute1(const Q8& q8, int i, int j, int32x4_t * sumi) { + auto q8b_1 = q8.load_quants(0, i, 4*j+0); + sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); + + auto q8b_2 = q8.load_quants(0, i, 4*j+1); + sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); + + q8b_1 = q8.load_quants(0, i, 4*j+2); + sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b2.val[1]), q8b_1.val[1]); + + q8b_2 = q8.load_quants(0, i, 4*j+3); + sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b2.val[3]), q8b_2.val[1]); + } + + IQK_ALWAYS_INLINE void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + auto m1 = vdupq_n_s8(1); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vsubq_s8(bits.b1.val[k], m1); + bits.b2.val[k] = vsubq_s8(bits.b2.val[k], m1); + } + } + + Q2bits bits; + + float d; +}; + +template +void mul_mat_iq2tn_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8 q8(info); + + DequantizerIQ2TN deq(vx, bx, nrc_y); + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + float32x4_t acc[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); + + for (int i = 0; i < nb; ++i) { + + int32x4_t sumi[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); + + //deq.process_scales(i, q8, acc); + deq.new_block(i); + deq.prepare(i, 0); + deq.compute(q8, i, 0, sumi); + deq.prepare(i, 1); + deq.compute(q8, i, 1, sumi); + + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(acc[iy])); + } + } +} +void mul_mat_iq2tn_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8<1, block_q8_K> q8(info); + + DequantizerIQ2TN deq(vx, bx, 1); + + auto m1 = vdup_n_s16(-1); + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + float32x4_t acc[2] = {}; + + for (int i = 0; i < nb; ++i) { + + int32x4_t sumi[2] = {}; + deq.new_block(i); + auto bsums = q8.load_bsums(0, i); + bsums.val[0] = vaddq_s32(bsums.val[0], bsums.val[1]); + sumi[0] = vmlal_s16(sumi[0], vget_low_s16 (bsums.val[0]), m1); + sumi[1] = vmlal_s16(sumi[1], vget_high_s16(bsums.val[0]), m1); + deq.bits.prepare(deq.x[i].qs); + deq.compute1(q8, i, 0, sumi); + deq.bits.prepare(deq.x[i].qs+32); + deq.compute1(q8, i, 1, sumi); + + auto vd = vdupq_n_f32(deq.d*q8.scale(0, i)); + acc[0] = vmlaq_f32(acc[0], vcvtq_f32_s32(sumi[0]), vd); + acc[1] = vmlaq_f32(acc[1], vcvtq_f32_s32(sumi[1]), vd); + + } + + acc[0] = vaddq_f32(acc[0], acc[1]); + info.store(ix, 0, vaddvq_f32(acc[0])); + } +} + template void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { @@ -5450,7 +5539,15 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { MulMat::set_functions(m); break; case GGML_TYPE_IQ2_TN: - MulMat::set_functions(m); + //MulMat::set_functions(m); + m.funcs[0] = mul_mat_iq2tn_K_q8_K_1; + m.funcs[1] = mul_mat_iq2tn_K_q8_K_T<2>; + m.funcs[2] = mul_mat_iq2tn_K_q8_K_T<3>; + m.funcs[3] = mul_mat_iq2tn_K_q8_K_T<4>; + m.funcs[4] = mul_mat_iq2tn_K_q8_K_T<5>; + m.funcs[5] = mul_mat_iq2tn_K_q8_K_T<6>; + m.funcs[6] = mul_mat_iq2tn_K_q8_K_T<7>; + m.funcs[7] = mul_mat_iq2tn_K_q8_K_T<8>; break; case GGML_TYPE_Q3_K: MulMat::set_functions(m);