iq3_k: slightly better quantization

Not much of a difference for most models, but this change
avoids what it looks like a catastrophic failure for DeepSeek-Lite
(PPL is now 7.041 vs 7.314 on main).
This commit is contained in:
Iwan Kawrakow
2025-03-29 09:12:45 +02:00
parent 4819257ce6
commit 56860314c6

View File

@@ -1555,12 +1555,13 @@ inline int best_index_iq3nl(const int8_t * values, float x) {
static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights) {
const int ntry = 5;
constexpr int ntry = 3;
block_iq3_k * y = (block_iq3_k *)vy;
float scales[QK_K/16];
float weight[16];
uint8_t L[16];
const int8_t * shifted_values = iq3nl_values + 8;
@@ -1620,7 +1621,7 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c
}
bool is_shifted = false;
for (int itry = -ntry; itry <= ntry; ++itry) {
id = (itry + iq3nl_values[0])/max;
id = (2*itry + iq3nl_values[0])/max;
sumqx_p = sumq2_p = 0;
sumqx_m = sumq2_m = 0;
for (int j = 0; j < 16; ++j) {
@@ -1641,7 +1642,7 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c
if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false;
}
id = (itry + shifted_values[0])/max;
id = (2*itry + shifted_values[0])/max;
sumqx_p = sumq2_p = 0;
sumqx_m = sumq2_m = 0;
for (int j = 0; j < 16; ++j) {
@@ -1663,20 +1664,55 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c
d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true;
}
}
if (d) {
const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values;
float sumqx = 0, sumq2 = 0;
id = 1/d;
if (!d) {
scales[ib] = 0; continue;
}
const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values;
float sumqx = 0, sumq2 = 0;
id = 1/d;
for (int j = 0; j < 16; ++j) {
float w = weight[j];
float al = id*xb[j];
int l = best_index_iq3nl(block_values, al);
L[j] = l;
float q = block_values[l];
sumqx += w*q*xb[j];
sumq2 += w*q*q;
}
if (sumq2 > 0) d = sumqx/sumq2;
float best_d = d;
for (int iter = 0; iter < 128; ++iter) {
float gmax = 0;
int best_j = -1, dir = 0;
for (int j = 0; j < 16; ++j) {
float w = weight[j];
float al = id*xb[j];
int l = best_index_iq3nl(block_values, al);
float q = block_values[l];
sumqx += w*q*xb[j];
sumq2 += w*q*q;
float g = d * w * (xb[j] - d*block_values[L[j]]);
if (g > 0 && L[j] < 7) {
if (g > gmax) {
gmax = g; best_j = j; dir = 1;
}
}
else if (g < 0 && L[j] > 0) {
if (-g > gmax) {
gmax = -g; best_j = j; dir = -1;
}
}
}
if (sumq2 > 0) d = sumqx/sumq2;
if (best_j < 0) break;
float w = weight[best_j];
sumqx += w*xb[best_j]*(block_values[L[best_j]+dir] - block_values[L[best_j]]);
sumq2 += w*(block_values[L[best_j]+dir]*block_values[L[best_j]+dir] - block_values[L[best_j]]*block_values[L[best_j]]);
L[best_j] += dir;
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
best_d = sumqx/sumq2; best = best_d*sumqx;
}
else if (iter > 8) break;
}
scales[ib] = d;
if (is_shifted) extra |= (1 << ib);