Sae for iq4_nl, iq4_xs

This commit is contained in:
Iwan Kawrakow
2025-03-28 07:20:56 +02:00
parent c8d47fab04
commit b9c25fe753

View File

@@ -14642,6 +14642,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
}
d = sumqx/sumq2;
float best = d*sumqx;
float best_sumqx = sumqx, best_sumq2 = sumq2;
for (int itry = -ntry; itry <= ntry; ++itry) {
id = (itry + values[0])/max;
sumqx = sumq2 = 0;
@@ -14655,8 +14656,67 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
}
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
d = sumqx/sumq2; best = d * sumqx;
best_sumqx = sumqx; best_sumq2 = sumq2;
for (int j = 0; j < block_size; ++j) {
float al = id*xb[j];
Lb[j] = best_index_iq4nl(values, al);
}
}
id = (itry + values[15])/max;
sumqx = sumq2 = 0;
for (int j = 0; j < block_size; ++j) {
float al = id*xb[j];
int l = best_index_iq4nl(values, al);
float q = values[l];
float w = weight[j];
sumqx += w*q*xb[j];
sumq2 += w*q*q;
}
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
d = sumqx/sumq2; best = d * sumqx;
best_sumqx = sumqx; best_sumq2 = sumq2;
for (int j = 0; j < block_size; ++j) {
float al = id*xb[j];
Lb[j] = best_index_iq4nl(values, al);
}
}
}
sumqx = best_sumqx; sumq2 = best_sumq2;
for (int iter = 0; iter < 32*block_size; ++iter) {
float min_step = INFINITY;
int best_j = -1; int dir = 0;
for (int j = 0; j < block_size; ++j) {
float w = weight[j];
float g = d * w * (xb[j] - d*values[Lb[j]]);
if (g > 0 && Lb[j] < 15) {
float step = (values[Lb[j]+1] - values[Lb[j]])/g;
if (step < min_step) {
min_step = step; best_j = j; dir = 1;
}
}
else if (g < 0 && Lb[j] > 0) {
float step = (values[Lb[j]-1] - values[Lb[j]])/g;
if (step < min_step) {
min_step = step; best_j = j; dir = -1;
}
}
}
if (best_j < 0) break;
float new_sumqx = sumqx, new_sumq2 = sumq2;
float w = weight[best_j];
new_sumqx += w*xb[best_j]*(values[Lb[best_j]+dir] - values[Lb[best_j]]);
new_sumq2 += w*(values[Lb[best_j]+dir]*values[Lb[best_j]+dir] - values[Lb[best_j]]*values[Lb[best_j]]);
if (new_sumq2 > 0 && new_sumqx*new_sumqx > best*new_sumq2) {
sumqx = new_sumqx; sumq2 = new_sumq2;
d = sumqx/sumq2; best = d*sumqx;
Lb[best_j] += dir;
}
else {
break;
}
}
scales[ib] = d;
float abs_d = fabsf(d);
if (abs_d > amax_scale) {