iq4_kss: another small quantization improvement

This commit is contained in:
Iwan Kawrakow
2024-10-15 14:34:11 +03:00
parent b68c2cb0e0
commit bb0e3f957e

View File

@@ -2910,7 +2910,6 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
const int8_t * shifted_values = values + 16;
uint16_t vps[block_size/2], vms[block_size/2], vs[block_size/2];
//uint8_t Lp[block_size], Lm[block_size], L[block_size];
float xv[4], wv[4];
float amax_scale = 0;
@@ -2949,11 +2948,6 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
float id = (itry + values[0])/max;
float sumqx_p = 0, sumq2_p = 0;
float sumqx_m = 0, sumq2_m = 0;
//for (int j = 0; j < block_size; ++j) {
// float al = id*xb[j];
// Lp[j] = best_index_iq4nl(values, al);
// Lm[j] = best_index_iq4nl(values, -al);
//}
float this_d = 1/id;
for (int k = 0; k < block_size/4; ++k) {
xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2];
@@ -2977,24 +2971,7 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
}
vps[k] = vp;
vms[k] = vm;
//uint16_t vp = Lp[2*k+0] | (Lp[2*k+0+block_size/2] << 4) | (Lp[2*k+1] << 8) | (Lp[2*k+1+block_size/2] << 12);
//uint16_t vm = Lm[2*k+0] | (Lm[2*k+0+block_size/2] << 4) | (Lm[2*k+1] << 8) | (Lm[2*k+1+block_size/2] << 12);
//auto vps = prune_iq4ks(vp, values, xv, wv, this_d);
//auto vms = prune_iq4ks(vm, values, xv, wv, -this_d);
//Lp[2*k+0] = (vps >> 0) & 0xf; Lp[2*k+0+block_size/2] = (vps >> 4) & 0xf;
//Lp[2*k+1] = (vps >> 8) & 0xf; Lp[2*k+1+block_size/2] = (vps >> 12) & 0xf;
//Lm[2*k+0] = (vms >> 0) & 0xf; Lm[2*k+0+block_size/2] = (vms >> 4) & 0xf;
//Lm[2*k+1] = (vms >> 8) & 0xf; Lm[2*k+1+block_size/2] = (vms >> 12) & 0xf;
}
//for (int j = 0; j < block_size; ++j) {
// float w = weight[j];
// float q = values[Lp[j]];
// sumqx_p += w*q*xb[j];
// sumq2_p += w*q*q;
// q = values[Lm[j]];
// sumqx_m += w*q*xb[j];
// sumq2_m += w*q*q;
//}
bool copy_p = false, copy_m = false;
if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; copy_p = true;
@@ -3047,58 +3024,7 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
} else if (copy_p) {
std::memcpy(vs, vps, block_size);
}
//for (int j = 0; j < block_size; ++j) {
// float al = id*xb[j];
// Lp[j] = best_index_iq4nl(shifted_values, al);
// Lm[j] = best_index_iq4nl(shifted_values, -al);
//}
//for (int k = 0; k < block_size/4; ++k) {
// xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2];
// wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2];
// uint16_t vp = Lp[2*k+0] | (Lp[2*k+0+block_size/2] << 4) | (Lp[2*k+1] << 8) | (Lp[2*k+1+block_size/2] << 12);
// uint16_t vm = Lm[2*k+0] | (Lm[2*k+0+block_size/2] << 4) | (Lm[2*k+1] << 8) | (Lm[2*k+1+block_size/2] << 12);
// auto vps = prune_iq4ks(vp, shifted_values, xv, wv, this_d);
// auto vms = prune_iq4ks(vm, shifted_values, xv, wv, -this_d);
// Lp[2*k+0] = (vps >> 0) & 0xf; Lp[2*k+0+block_size/2] = (vps >> 4) & 0xf;
// Lp[2*k+1] = (vps >> 8) & 0xf; Lp[2*k+1+block_size/2] = (vps >> 12) & 0xf;
// Lm[2*k+0] = (vms >> 0) & 0xf; Lm[2*k+0+block_size/2] = (vms >> 4) & 0xf;
// Lm[2*k+1] = (vms >> 8) & 0xf; Lm[2*k+1+block_size/2] = (vms >> 12) & 0xf;
//}
//for (int j = 0; j < block_size; ++j) {
// float w = weight[j];
// float q = shifted_values[Lp[j]];
// sumqx_p += w*q*xb[j];
// sumq2_p += w*q*q;
// q = shifted_values[Lm[j]];
// sumqx_m += w*q*xb[j];
// sumq2_m += w*q*q;
//}
//copy_p = copy_m = false;
//if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
// d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; copy_p = true;
//}
//if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
// d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; copy_m = true;
//}
//if (copy_m) {
// std::memcpy(L, Lm, block_size);
//} else if (copy_p) {
// std::memcpy(L, Lp, block_size);
//}
}
for (int k = 0; k < block_size/8; ++k) {
auto v1 = table[vs[2*k+0] & 0x7fff];
auto v2 = table[vs[2*k+1] & 0x7fff];
y[ibl].qs[(block_size/8)*ib + k] = v1 | (v2 << 15);
}
//for (int k = 0; k < block_size/8; ++k) {
// uint16_t v1 = Lp[4*k+0] | (Lp[4*k+0+block_size/2] << 4) | (Lp[4*k+1] << 8) | (Lp[4*k+1+block_size/2] << 12);
// uint16_t v2 = Lp[4*k+2] | (Lp[4*k+2+block_size/2] << 4) | (Lp[4*k+3] << 8) | (Lp[4*k+3+block_size/2] << 12);
// v1 = table[v1 & 0x7fff];
// v2 = table[v2 & 0x7fff];
// y[ibl].qs[(block_size/8)*ib + k] = v1 | (v2 << 15);
//}
if (is_shifted) y[ibl].qs[(block_size/8)*ib] |= (1u << 30);
scales[ib] = d;
amax_scale = std::max(amax_scale, std::abs(d));
}
@@ -3107,16 +3033,79 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
*dptr = d;
if (!d) return;
float id = 1/d;
float sumqx = 0, sumq2 = 0;
for (int ibl = 0; ibl < n_per_row/super_block_size; ++ibl) {
auto scales = all_scales + (super_block_size/block_size)*ibl;
const float * xbl = x + ibl*super_block_size;
float sigma2 = 0;
for (int j = 0; j < super_block_size; ++j) sigma2 += xbl[j]*xbl[j];
sigma2 *= 2.f/super_block_size;
for (int ib = 0; ib < super_block_size/block_size; ++ib) {
const float * xb = xbl + ib*block_size;
if (quant_weights) {
const float * qw = quant_weights + ibl*super_block_size + ib*block_size;
for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
} else {
for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
}
int l = nearest_int(0.5f*(id*scales[ib]+127.f));
l = std::max(0, std::min(127, l)) << 1;
for (int k = 0; k < 4; ++k) {
y[ibl].qs[4*ib+k] |= (((l >> 2*k) & 3) << 30);
l = (std::max(0, std::min(127, l)) << 1) - 127;
if (l) {
float dl = d*l;
float idl = 1/dl;
float mse_p = 0, mse_m = 0;
for (int k = 0; k < block_size/4; ++k) {
xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2];
wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2];
uint16_t vp = 0, vm = 0;
for (int j = 0; j < 4; ++j) {
float al = idl*xv[j];
vp |= (best_index_iq4nl( values, al) << 4*j);
vm |= (best_index_iq4nl(shifted_values, al) << 4*j);
}
vp = prune_iq4ks(vp, values, xv, wv, dl);
vm = prune_iq4ks(vm, shifted_values, xv, wv, dl);
for (int j = 0; j < 4; ++j) {
float w = wv[j];
float q = values[(vp >> 4*j) & 0xf];
mse_p += w*(xv[j] - dl*q)*(xv[j] - dl*q);
q = shifted_values[(vm >> 4*j) & 0xf];
mse_m += w*(xv[j] - dl*q)*(xv[j] - dl*q);
}
vps[k] = vp;
vms[k] = vm;
}
const uint16_t * v = vps;
const int8_t * block_values = values;
if (mse_m < mse_p) {
v = vms;
block_values = values + 16;
}
for (int k = 0; k < block_size/4; ++k) {
xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2];
wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2];
for (int j = 0; j < 4; ++j) {
float q = block_values[(v[k] >> 4*j) & 0xf] * l;
sumqx += wv[j]*q*xv[j];
sumq2 += wv[j]*q*q;
}
}
l += 127;
if (mse_m < mse_p) l |= 1;
for (int k = 0; k < block_size/8; ++k) {
auto v1 = table[v[2*k+0] & 0x7fff];
auto v2 = table[v[2*k+1] & 0x7fff];
y[ibl].qs[(block_size/8)*ib + k] = v1 | (v2 << 15) | (((l >> 2*k) & 3) << 30);
}
} else {
l += 127;
for (int k = 0; k < block_size/8; ++k) {
y[ibl].qs[(block_size/8)*ib + k] |= (((l >> 2*k) & 3) << 30);
}
}
}
}
if (sumq2 > 0) *dptr = sumqx/sumq2;
}
void prune_iq4ks_to_iq4kss(int n_per_row, const uint16_t * table, const char * cx, const float * x, char *cy,