diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index a96e559c..7b87f49c 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -1259,34 +1259,34 @@ void quantize_row_iq2_ks_impl(const float * x, void * vy, int n_per_row, const f float best = 0, d = 0; bool is_shifted = false; float sumqx, sumq2; - for (int i1 = 0; i1 < kMax_i1; ++i1) { - for (int i2 = i1; i2 < kBlockSize; ++i2) { - for (int i3 = std::max(i2, kMin_i3); i3 < kBlockSize; ++i3) { - sumqx = (sumx[i1] - sumx[ 0])*val[0] + (sumx[i2] - sumx[i1])*val[1] - + (sumx[i3] - sumx[i2])*val[2] + (sumx[kBlockSize] - sumx[i3])*val[3]; - sumq2 = (sumw[i1] - sumw[ 0])*val[0]*val[0] + (sumw[i2] - sumw[i1])*val[1]*val[1] - + (sumw[i3] - sumw[i2])*val[2]*val[2] + (sumw[kBlockSize] - sumw[i3])*val[3]*val[3]; + for (int i1 = 1; i1 < kMax_i1; ++i1) { + for (int i3 = std::max(i1, kMin_i3); i3 < kBlockSize; ++i3) { + float sumqx_1 = (sumx[i1] - sumx[0])*val[0] + (sumx[kBlockSize] - sumx[i3])*val[3]; + float sumq2_1 = (sumw[i1] - sumw[0])*val[0]*val[0] + (sumw[kBlockSize] - sumw[i3])*val[3]*val[3]; + float sumqx_2 = (sumx[i1] - sumx[0])*sval[0] + (sumx[kBlockSize] - sumx[i3])*sval[3]; + float sumq2_2 = (sumw[i1] - sumw[0])*sval[0]*sval[0] + (sumw[kBlockSize] - sumw[i3])*sval[3]*sval[3]; + float sumqx_3 = (sumx[i1] - sumx[0])*val[3] + (sumx[kBlockSize] - sumx[i3])*val[0]; + float sumq2_3 = (sumw[i1] - sumw[0])*val[3]*val[3] + (sumw[kBlockSize] - sumw[i3])*val[0]*val[0]; + float sumqx_4 = (sumx[i1] - sumx[0])*sval[3] + (sumx[kBlockSize] - sumx[i3])*sval[0]; + float sumq2_4 = (sumw[i1] - sumw[0])*sval[3]*sval[3] + (sumw[kBlockSize] - sumw[i3])*sval[0]*sval[0]; + for (int i2 = i1; i2 <= i3; ++i2) { + sumqx = sumqx_1 + (sumx[i2] - sumx[i1])*val[1] + (sumx[i3] - sumx[i2])*val[2]; + sumq2 = sumq2_1 + (sumw[i2] - sumw[i1])*val[1]*val[1] + (sumw[i3] - sumw[i2])*val[2]*val[2]; if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { d = sumqx/sumq2; best = d*sumqx; is_shifted = false; } - sumqx = (sumx[i1] - sumx[ 0])*sval[0] + (sumx[i2] - sumx[i1])*sval[1] - + (sumx[i3] - sumx[i2])*sval[2] + (sumx[kBlockSize] - sumx[i3])*sval[3]; - sumq2 = (sumw[i1] - sumw[ 0])*sval[0]*sval[0] + (sumw[i2] - sumw[i1])*sval[1]*sval[1] - + (sumw[i3] - sumw[i2])*sval[2]*sval[2] + (sumw[kBlockSize] - sumw[i3])*sval[3]*sval[3]; + sumqx = sumqx_2 + (sumx[i2] - sumx[i1])*sval[1] + (sumx[i3] - sumx[i2])*sval[2]; + sumq2 = sumq2_2 + (sumw[i2] - sumw[i1])*sval[1]*sval[1] + (sumw[i3] - sumw[i2])*sval[2]*sval[2]; if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { d = sumqx/sumq2; best = d*sumqx; is_shifted = true; } - sumqx = (sumx[i1] - sumx[ 0])*val[3] + (sumx[i2 ] - sumx[i1])*val[2] - + (sumx[i3] - sumx[i2])*val[1] + (sumx[kBlockSize] - sumx[i3])*val[0]; - sumq2 = (sumw[i1] - sumw[ 0])*val[3]*val[3] + (sumw[i2 ] - sumw[i1])*val[2]*val[2] - + (sumw[i3] - sumw[i2])*val[1]*val[1] + (sumw[kBlockSize] - sumw[i3])*val[0]*val[0]; + sumqx = sumqx_3 + (sumx[i2] - sumx[i1])*val[2] + (sumx[i3] - sumx[i2])*val[1]; + sumq2 = sumq2_3 + (sumw[i2] - sumw[i1])*val[2]*val[2] + (sumw[i3] - sumw[i2])*val[1]*val[1]; if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { d = sumqx/sumq2; best = d*sumqx; is_shifted = false; } - sumqx = (sumx[i1] - sumx[ 0])*sval[3] + (sumx[i2 ] - sumx[i1])*sval[2] - + (sumx[i3] - sumx[i2])*sval[1] + (sumx[kBlockSize] - sumx[i3])*sval[0]; - sumq2 = (sumw[i1] - sumw[ 0])*sval[3]*sval[3] + (sumw[i2 ] - sumw[i1])*sval[2]*sval[2] - + (sumw[i3] - sumw[i2])*sval[1]*sval[1] + (sumw[kBlockSize] - sumw[i3])*sval[0]*sval[0]; + sumqx = sumqx_4 + (sumx[i2] - sumx[i1])*sval[2] + (sumx[i3] - sumx[i2])*sval[1]; + sumq2 = sumq2_4 + (sumw[i2] - sumw[i1])*sval[2]*sval[2] + (sumw[i3] - sumw[i2])*sval[1]*sval[1]; if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { d = sumqx/sumq2; best = d*sumqx; is_shifted = true; }