diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 1245f4a3..fce86efd 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -5848,16 +5848,17 @@ template struct Q8_K64 { Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) { auto dptr = (const float *)info.src1_row(iy); - std::memcpy(d + 4*iy, dptr, 4*sizeof(float)); - y[iy] = (const int8_t *)(dptr + 4); + std::memcpy(d + 8*iy, dptr, 8*sizeof(float)); + y[iy] = (const int8_t *)(dptr + 8); } } inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy] + 128*i + 64*j); } inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy] + 128*i + 32*j); } - inline float32x4_t scale(int iy) const { return vld1q_f32(d + 4*iy); } + inline float32x4_t scale(int iy) const { return vld1q_f32(d + 8*iy); } + inline float32x4_t minus(int iy) const { return vld1q_f32(d + 8*iy + 4); } - float d[4*nrc_y]; + float d[8*nrc_y]; const int8_t * y[nrc_y]; }; @@ -5889,6 +5890,14 @@ struct DequantizerIQ1BN { v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1); } } + + IQK_ALWAYS_INLINE void prepare_iq1bn_quants_nosub(const block_iq1_bn * x, int8x16x4_t& v) const { + auto data = vld1q_u8((const uint8_t *)x); + for (int k = 0; k < 4; ++k) { + auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]); + v.val[k] = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6)); + } + } }; template @@ -5916,10 +5925,10 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn if constexpr (nrc_y == 1) { int32x4_t acc[4] = {}; for (int i = 0; i < nb/2; ++i) { - deq.prepare_iq1bn_quants(x+2*i+0, v1); + deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1); auto q = q8.load_quants64(0, i, 0); for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]); - deq.prepare_iq1bn_quants(x+2*i+1, v2); + deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2); q = q8.load_quants64(0, i, 1); for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v2.val[j]); } @@ -5931,8 +5940,8 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn for (int i = 0; i < nb/2; ++i) { - deq.prepare_iq1bn_quants(x+2*i+0, v1); - deq.prepare_iq1bn_quants(x+2*i+1, v2); + deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1); + deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2); for (int iy = 0; iy < nrc_y; ++iy) { auto q = q8.load_quants(iy, i, 0); @@ -5948,7 +5957,7 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } int i = 2*(nb/2); if (i < nb) { - deq.prepare_iq1bn_quants(x+i, v1); + deq.prepare_iq1bn_quants_nosub(x+i, v1); if constexpr (nrc_y == 1) { auto q = q8.load_quants(0, i/2, 0); for (int j = 0; j < 4; ++j) { @@ -5966,9 +5975,9 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn for (int iy = 0; iy < nrc_y; ++iy) { if constexpr (is_iq1_tn) { - info.store(ix, iy, scale * vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy])))); + info.store(ix, iy, -scale * vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy])))); } else { - info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy])))); + info.store(ix, iy, -vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy])))); } } diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 64815e92..7ca1759d 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -379,8 +379,10 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) { + GGML_ASSERT(k >= 8*QK_IQ1BN); + float * dptr = (float *)y; - auto qs = (int8_t *)(dptr + 4); + auto qs = (int8_t *)(dptr + 8); #ifdef __ARM_NEON static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}; auto shuffle = vld1q_u8(k_shuffle); @@ -399,16 +401,22 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) { vid[i] = vdupq_n_f32(id); } int8x16x4_t q; + int32x4_t qsum = {}; + const int8x16_t m1 = vdupq_n_s8(1); for (int j = 0; j < k; j += 16) { for (int i = 0; i < 4; ++i) { auto val = vld1q_f32(x + j + 4*i); val = vmulq_f32(vid[i], val); - q.val[i] = vreinterpretq_s8_s32(vcvtnq_s32_f32(val)); + auto ival = vcvtnq_s32_f32(val); + q.val[i] = vreinterpretq_s8_s32(ival); } auto qi = vqtbl4q_s8(q, shuffle); + qsum = ggml_vdotq_s32(qsum, qi, m1); vst1q_s8(qs, qi); qs += 16; } + auto sumf = vmulq_f32(vld1q_f32(dptr), vcvtq_f32_s32(qsum)); + vst1q_f32(dptr + 4, sumf); #elif defined __AVX__ __m128 max[4] = {}; __m128 sign_bit = _mm_set1_ps(-0.f); @@ -455,11 +463,16 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) { dptr[i] = aux[i]/127; aux[i] = dptr[i] > 0 ? 1/dptr[i] : 0.f; } + int32_t sum[4] = {}; for (int j = 0; j < k; j += 16) { for (int i = 0; i < 4; ++i) { - for (int l = 0; l < 4; ++l) qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]); + for (int l = 0; l < 4; ++l) { + qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]); + sum[i] += qs[j+4*i+l]; + } } } + for (int i = 0; i < 4; ++i) dptr[4+i] = dptr[i]*sum[i]; #endif }