diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5eea36c0..ffe27db3 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1484,13 +1484,6 @@ struct DequantizerQ6K final : public BaseDequantizer { struct DequantizerIQ2TN final : public BaseDequantizer { DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} - //template - //inline void new_block(int i, const Q8& q8, __m256i * sumi) { - // d = GGML_FP16_TO_FP32(x[i].d); - // for (int iy = 0; iy < Q8::nrc_y; ++iy) { - // sumi[iy] = q8.load_bsums(iy, i); - // } - //} inline void new_block(int i) { d = GGML_FP16_TO_FP32(x[i].d); } @@ -3883,6 +3876,62 @@ struct DequantizerQ2K final : public BaseDequantizer { float d; }; +struct DequantizerIQ2TN final : public BaseDequantizer { + DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return true; } + + template + inline void process_scales(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + process_scales(i, q8, acc); + return { vdupq_n_s32(1), vdupq_n_s32(1), vdupq_n_s32(1), vdupq_n_s32(1) }; + } + + template + inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); + + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]), + vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]); + + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]), + vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]); + } + } + + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + auto m1 = vdupq_n_s8(1); + bits.b1.val[0] = vsubq_s8(bits.b1.val[0], m1); + bits.b1.val[1] = vsubq_s8(bits.b1.val[1], m1); + bits.b1.val[2] = vsubq_s8(bits.b1.val[2], m1); + bits.b1.val[3] = vsubq_s8(bits.b1.val[3], m1); + bits.b2.val[0] = vsubq_s8(bits.b2.val[0], m1); + bits.b2.val[1] = vsubq_s8(bits.b2.val[1], m1); + bits.b2.val[2] = vsubq_s8(bits.b2.val[2], m1); + bits.b2.val[3] = vsubq_s8(bits.b2.val[3], m1); + } + + Q2bits bits; + + float d; +}; + // ============================= i-quants inline int32x4x4_t make_wider_8(const int8x16_t& scales8) { @@ -5400,6 +5449,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_Q2_K: MulMat::set_functions(m); break; + case GGML_TYPE_IQ2_TN: + MulMat::set_functions(m); + break; case GGML_TYPE_Q3_K: MulMat::set_functions(m); break;