diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index fffb3ab2..db83b841 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -4481,28 +4481,31 @@ void mul_mat_iq2tn_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& i Q8 q8(info); DequantizerIQ2TN deq(vx, bx, nrc_y); + float32x4_t acc[nrc_y]; for (int ix = 0; ix < nrc_x; ++ix) { deq.new_row(ix); - float32x4_t acc[nrc_y]; - for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); - for (int i = 0; i < nb; ++i) { int32x4_t sumi[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); - //deq.process_scales(i, q8, acc); deq.new_block(i); deq.prepare(i, 0); deq.compute(q8, i, 0, sumi); deq.prepare(i, 1); deq.compute(q8, i, 1, sumi); - for (int iy = 0; iy < nrc_y; ++iy) { - acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + if (i > 0) { + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + } + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vmulq_f32(vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + } } } @@ -4520,11 +4523,11 @@ void mul_mat_iq2tn_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& i DequantizerIQ2TN deq(vx, bx, 1); auto m1 = vdup_n_s16(-1); + float32x4_t acc[2]; for (int ix = 0; ix < nrc_x; ++ix) { deq.new_row(ix); - float32x4_t acc[2] = {}; for (int i = 0; i < nb; ++i) { @@ -4540,8 +4543,13 @@ void mul_mat_iq2tn_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& i deq.compute1(q8, i, 1, sumi); auto vd = vdupq_n_f32(deq.d*q8.scale(0, i)); - acc[0] = vmlaq_f32(acc[0], vcvtq_f32_s32(sumi[0]), vd); - acc[1] = vmlaq_f32(acc[1], vcvtq_f32_s32(sumi[1]), vd); + if (i > 0) { + acc[0] = vmlaq_f32(acc[0], vcvtq_f32_s32(sumi[0]), vd); + acc[1] = vmlaq_f32(acc[1], vcvtq_f32_s32(sumi[1]), vd); + } else { + acc[0] = vmulq_f32(vcvtq_f32_s32(sumi[0]), vd); + acc[1] = vmulq_f32(vcvtq_f32_s32(sumi[1]), vd); + } }