Another iq3k improvement

This commit is contained in:
Iwan Kawrakow
2024-11-25 10:11:02 +02:00
parent 55db84400a
commit 85d1011f52

View File

@@ -1226,69 +1226,41 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c
}
float d = ntry > 0 ? -max/iq3nl_values[0] : max/iq3nl_values[0];
float id = 1/d;
float sumqx_p = 0, sumq2_p = 0;
float sumqx_m = 0, sumq2_m = 0;
for (int j = 0; j < 16; ++j) {
float w = weight[j];
float al = id*xb[j];
int l = best_index_iq3nl(iq3nl_values, al);
float q = iq3nl_values[l];
sumqx_p += w*q*xb[j];
sumq2_p += w*q*q;
l = best_index_iq3nl(iq3nl_values, -al);
q = iq3nl_values[l];
sumqx_m += w*q*xb[j];
sumq2_m += w*q*q;
}
d = sumqx_p/sumq2_p;
float best = d*sumqx_p;
if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
d = sumqx_m/sumq2_m; best = d*sumqx_m;
}
float best = 0;
auto check_one = [&best, &d, xb, weight] (float id, const int8_t * values) {
float sumqx_p = 0, sumq2_p = 0;
float sumqx_m = 0, sumq2_m = 0;
for (int j = 0; j < 16; ++j) {
float w = weight[j];
float al = id*xb[j];
int l = best_index_iq3nl(values, al);
float q = values[l];
sumqx_p += w*q*xb[j];
sumq2_p += w*q*q;
l = best_index_iq3nl(values, -al);
q = values[l];
sumqx_m += w*q*xb[j];
sumq2_m += w*q*q;
}
bool result = false;
if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
d = sumqx_p/sumq2_p; best = d*sumqx_p; result = true;
}
if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
d = sumqx_m/sumq2_m; best = d*sumqx_m; result = true;
}
return result;
};
bool is_shifted = false;
check_one(id, iq3nl_values);
if (check_one(id, shifted_values)) is_shifted = true;
for (int itry = -ntry; itry <= ntry; ++itry) {
id = (itry + iq3nl_values[0])/max;
sumqx_p = sumq2_p = 0;
sumqx_m = sumq2_m = 0;
for (int j = 0; j < 16; ++j) {
float w = weight[j];
float al = id*xb[j];
int l = best_index_iq3nl(iq3nl_values, al);
float q = iq3nl_values[l];
sumqx_p += w*q*xb[j];
sumq2_p += w*q*q;
l = best_index_iq3nl(iq3nl_values, -al);
q = iq3nl_values[l];
sumqx_m += w*q*xb[j];
sumq2_m += w*q*q;
}
if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false;
}
if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false;
}
id = (itry + shifted_values[0])/max;
sumqx_p = sumq2_p = 0;
sumqx_m = sumq2_m = 0;
for (int j = 0; j < 16; ++j) {
float w = weight[j];
float al = id*xb[j];
int l = best_index_iq3nl(shifted_values, al);
float q = shifted_values[l];
sumqx_p += w*q*xb[j];
sumq2_p += w*q*q;
l = best_index_iq3nl(shifted_values, -al);
q = shifted_values[l];
sumqx_m += w*q*xb[j];
sumq2_m += w*q*q;
}
if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true;
}
if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true;
}
if (check_one((itry + iq3nl_values[0])/max, iq3nl_values )) is_shifted = false;
if (check_one((itry + shifted_values[0])/max, shifted_values)) is_shifted = true;
if (check_one((itry + iq3nl_values[7])/max, iq3nl_values )) is_shifted = false;
if (check_one((itry + shifted_values[7])/max, shifted_values)) is_shifted = true;
if (check_one((itry + iq3nl_values[1])/max, iq3nl_values )) is_shifted = false;
if (check_one((itry + shifted_values[1])/max, shifted_values)) is_shifted = true;
}
if (d) {
const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values;