diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 21fbe5e5..24cd057e 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -1226,69 +1226,41 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c } float d = ntry > 0 ? -max/iq3nl_values[0] : max/iq3nl_values[0]; float id = 1/d; - float sumqx_p = 0, sumq2_p = 0; - float sumqx_m = 0, sumq2_m = 0; - for (int j = 0; j < 16; ++j) { - float w = weight[j]; - float al = id*xb[j]; - int l = best_index_iq3nl(iq3nl_values, al); - float q = iq3nl_values[l]; - sumqx_p += w*q*xb[j]; - sumq2_p += w*q*q; - l = best_index_iq3nl(iq3nl_values, -al); - q = iq3nl_values[l]; - sumqx_m += w*q*xb[j]; - sumq2_m += w*q*q; - } - d = sumqx_p/sumq2_p; - float best = d*sumqx_p; - if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { - d = sumqx_m/sumq2_m; best = d*sumqx_m; - } + float best = 0; + auto check_one = [&best, &d, xb, weight] (float id, const int8_t * values) { + float sumqx_p = 0, sumq2_p = 0; + float sumqx_m = 0, sumq2_m = 0; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq3nl(values, al); + float q = values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq3nl(values, -al); + q = values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + bool result = false; + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d*sumqx_p; result = true; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d*sumqx_m; result = true; + } + return result; + }; bool is_shifted = false; + check_one(id, iq3nl_values); + if (check_one(id, shifted_values)) is_shifted = true; for (int itry = -ntry; itry <= ntry; ++itry) { - id = (itry + iq3nl_values[0])/max; - sumqx_p = sumq2_p = 0; - sumqx_m = sumq2_m = 0; - for (int j = 0; j < 16; ++j) { - float w = weight[j]; - float al = id*xb[j]; - int l = best_index_iq3nl(iq3nl_values, al); - float q = iq3nl_values[l]; - sumqx_p += w*q*xb[j]; - sumq2_p += w*q*q; - l = best_index_iq3nl(iq3nl_values, -al); - q = iq3nl_values[l]; - sumqx_m += w*q*xb[j]; - sumq2_m += w*q*q; - } - if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { - d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; - } - if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { - d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; - } - id = (itry + shifted_values[0])/max; - sumqx_p = sumq2_p = 0; - sumqx_m = sumq2_m = 0; - for (int j = 0; j < 16; ++j) { - float w = weight[j]; - float al = id*xb[j]; - int l = best_index_iq3nl(shifted_values, al); - float q = shifted_values[l]; - sumqx_p += w*q*xb[j]; - sumq2_p += w*q*q; - l = best_index_iq3nl(shifted_values, -al); - q = shifted_values[l]; - sumqx_m += w*q*xb[j]; - sumq2_m += w*q*q; - } - if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { - d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; - } - if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { - d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; - } + if (check_one((itry + iq3nl_values[0])/max, iq3nl_values )) is_shifted = false; + if (check_one((itry + shifted_values[0])/max, shifted_values)) is_shifted = true; + if (check_one((itry + iq3nl_values[7])/max, iq3nl_values )) is_shifted = false; + if (check_one((itry + shifted_values[7])/max, shifted_values)) is_shifted = true; + if (check_one((itry + iq3nl_values[1])/max, iq3nl_values )) is_shifted = false; + if (check_one((itry + shifted_values[1])/max, shifted_values)) is_shifted = true; } if (d) { const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values;