mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
iq4_kss: another small quantization improvement
This commit is contained in:
@@ -2910,7 +2910,6 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
|
||||
const int8_t * shifted_values = values + 16;
|
||||
|
||||
uint16_t vps[block_size/2], vms[block_size/2], vs[block_size/2];
|
||||
//uint8_t Lp[block_size], Lm[block_size], L[block_size];
|
||||
float xv[4], wv[4];
|
||||
|
||||
float amax_scale = 0;
|
||||
@@ -2949,11 +2948,6 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
|
||||
float id = (itry + values[0])/max;
|
||||
float sumqx_p = 0, sumq2_p = 0;
|
||||
float sumqx_m = 0, sumq2_m = 0;
|
||||
//for (int j = 0; j < block_size; ++j) {
|
||||
// float al = id*xb[j];
|
||||
// Lp[j] = best_index_iq4nl(values, al);
|
||||
// Lm[j] = best_index_iq4nl(values, -al);
|
||||
//}
|
||||
float this_d = 1/id;
|
||||
for (int k = 0; k < block_size/4; ++k) {
|
||||
xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2];
|
||||
@@ -2977,24 +2971,7 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
|
||||
}
|
||||
vps[k] = vp;
|
||||
vms[k] = vm;
|
||||
//uint16_t vp = Lp[2*k+0] | (Lp[2*k+0+block_size/2] << 4) | (Lp[2*k+1] << 8) | (Lp[2*k+1+block_size/2] << 12);
|
||||
//uint16_t vm = Lm[2*k+0] | (Lm[2*k+0+block_size/2] << 4) | (Lm[2*k+1] << 8) | (Lm[2*k+1+block_size/2] << 12);
|
||||
//auto vps = prune_iq4ks(vp, values, xv, wv, this_d);
|
||||
//auto vms = prune_iq4ks(vm, values, xv, wv, -this_d);
|
||||
//Lp[2*k+0] = (vps >> 0) & 0xf; Lp[2*k+0+block_size/2] = (vps >> 4) & 0xf;
|
||||
//Lp[2*k+1] = (vps >> 8) & 0xf; Lp[2*k+1+block_size/2] = (vps >> 12) & 0xf;
|
||||
//Lm[2*k+0] = (vms >> 0) & 0xf; Lm[2*k+0+block_size/2] = (vms >> 4) & 0xf;
|
||||
//Lm[2*k+1] = (vms >> 8) & 0xf; Lm[2*k+1+block_size/2] = (vms >> 12) & 0xf;
|
||||
}
|
||||
//for (int j = 0; j < block_size; ++j) {
|
||||
// float w = weight[j];
|
||||
// float q = values[Lp[j]];
|
||||
// sumqx_p += w*q*xb[j];
|
||||
// sumq2_p += w*q*q;
|
||||
// q = values[Lm[j]];
|
||||
// sumqx_m += w*q*xb[j];
|
||||
// sumq2_m += w*q*q;
|
||||
//}
|
||||
bool copy_p = false, copy_m = false;
|
||||
if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
|
||||
d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; copy_p = true;
|
||||
@@ -3047,58 +3024,7 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
|
||||
} else if (copy_p) {
|
||||
std::memcpy(vs, vps, block_size);
|
||||
}
|
||||
//for (int j = 0; j < block_size; ++j) {
|
||||
// float al = id*xb[j];
|
||||
// Lp[j] = best_index_iq4nl(shifted_values, al);
|
||||
// Lm[j] = best_index_iq4nl(shifted_values, -al);
|
||||
//}
|
||||
//for (int k = 0; k < block_size/4; ++k) {
|
||||
// xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2];
|
||||
// wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2];
|
||||
// uint16_t vp = Lp[2*k+0] | (Lp[2*k+0+block_size/2] << 4) | (Lp[2*k+1] << 8) | (Lp[2*k+1+block_size/2] << 12);
|
||||
// uint16_t vm = Lm[2*k+0] | (Lm[2*k+0+block_size/2] << 4) | (Lm[2*k+1] << 8) | (Lm[2*k+1+block_size/2] << 12);
|
||||
// auto vps = prune_iq4ks(vp, shifted_values, xv, wv, this_d);
|
||||
// auto vms = prune_iq4ks(vm, shifted_values, xv, wv, -this_d);
|
||||
// Lp[2*k+0] = (vps >> 0) & 0xf; Lp[2*k+0+block_size/2] = (vps >> 4) & 0xf;
|
||||
// Lp[2*k+1] = (vps >> 8) & 0xf; Lp[2*k+1+block_size/2] = (vps >> 12) & 0xf;
|
||||
// Lm[2*k+0] = (vms >> 0) & 0xf; Lm[2*k+0+block_size/2] = (vms >> 4) & 0xf;
|
||||
// Lm[2*k+1] = (vms >> 8) & 0xf; Lm[2*k+1+block_size/2] = (vms >> 12) & 0xf;
|
||||
//}
|
||||
//for (int j = 0; j < block_size; ++j) {
|
||||
// float w = weight[j];
|
||||
// float q = shifted_values[Lp[j]];
|
||||
// sumqx_p += w*q*xb[j];
|
||||
// sumq2_p += w*q*q;
|
||||
// q = shifted_values[Lm[j]];
|
||||
// sumqx_m += w*q*xb[j];
|
||||
// sumq2_m += w*q*q;
|
||||
//}
|
||||
//copy_p = copy_m = false;
|
||||
//if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
|
||||
// d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; copy_p = true;
|
||||
//}
|
||||
//if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
|
||||
// d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; copy_m = true;
|
||||
//}
|
||||
//if (copy_m) {
|
||||
// std::memcpy(L, Lm, block_size);
|
||||
//} else if (copy_p) {
|
||||
// std::memcpy(L, Lp, block_size);
|
||||
//}
|
||||
}
|
||||
for (int k = 0; k < block_size/8; ++k) {
|
||||
auto v1 = table[vs[2*k+0] & 0x7fff];
|
||||
auto v2 = table[vs[2*k+1] & 0x7fff];
|
||||
y[ibl].qs[(block_size/8)*ib + k] = v1 | (v2 << 15);
|
||||
}
|
||||
//for (int k = 0; k < block_size/8; ++k) {
|
||||
// uint16_t v1 = Lp[4*k+0] | (Lp[4*k+0+block_size/2] << 4) | (Lp[4*k+1] << 8) | (Lp[4*k+1+block_size/2] << 12);
|
||||
// uint16_t v2 = Lp[4*k+2] | (Lp[4*k+2+block_size/2] << 4) | (Lp[4*k+3] << 8) | (Lp[4*k+3+block_size/2] << 12);
|
||||
// v1 = table[v1 & 0x7fff];
|
||||
// v2 = table[v2 & 0x7fff];
|
||||
// y[ibl].qs[(block_size/8)*ib + k] = v1 | (v2 << 15);
|
||||
//}
|
||||
if (is_shifted) y[ibl].qs[(block_size/8)*ib] |= (1u << 30);
|
||||
scales[ib] = d;
|
||||
amax_scale = std::max(amax_scale, std::abs(d));
|
||||
}
|
||||
@@ -3107,16 +3033,79 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
|
||||
*dptr = d;
|
||||
if (!d) return;
|
||||
float id = 1/d;
|
||||
float sumqx = 0, sumq2 = 0;
|
||||
for (int ibl = 0; ibl < n_per_row/super_block_size; ++ibl) {
|
||||
auto scales = all_scales + (super_block_size/block_size)*ibl;
|
||||
const float * xbl = x + ibl*super_block_size;
|
||||
float sigma2 = 0;
|
||||
for (int j = 0; j < super_block_size; ++j) sigma2 += xbl[j]*xbl[j];
|
||||
sigma2 *= 2.f/super_block_size;
|
||||
for (int ib = 0; ib < super_block_size/block_size; ++ib) {
|
||||
const float * xb = xbl + ib*block_size;
|
||||
if (quant_weights) {
|
||||
const float * qw = quant_weights + ibl*super_block_size + ib*block_size;
|
||||
for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
|
||||
} else {
|
||||
for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
|
||||
}
|
||||
int l = nearest_int(0.5f*(id*scales[ib]+127.f));
|
||||
l = std::max(0, std::min(127, l)) << 1;
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
y[ibl].qs[4*ib+k] |= (((l >> 2*k) & 3) << 30);
|
||||
l = (std::max(0, std::min(127, l)) << 1) - 127;
|
||||
if (l) {
|
||||
float dl = d*l;
|
||||
float idl = 1/dl;
|
||||
float mse_p = 0, mse_m = 0;
|
||||
for (int k = 0; k < block_size/4; ++k) {
|
||||
xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2];
|
||||
wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2];
|
||||
uint16_t vp = 0, vm = 0;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float al = idl*xv[j];
|
||||
vp |= (best_index_iq4nl( values, al) << 4*j);
|
||||
vm |= (best_index_iq4nl(shifted_values, al) << 4*j);
|
||||
}
|
||||
vp = prune_iq4ks(vp, values, xv, wv, dl);
|
||||
vm = prune_iq4ks(vm, shifted_values, xv, wv, dl);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float w = wv[j];
|
||||
float q = values[(vp >> 4*j) & 0xf];
|
||||
mse_p += w*(xv[j] - dl*q)*(xv[j] - dl*q);
|
||||
q = shifted_values[(vm >> 4*j) & 0xf];
|
||||
mse_m += w*(xv[j] - dl*q)*(xv[j] - dl*q);
|
||||
}
|
||||
vps[k] = vp;
|
||||
vms[k] = vm;
|
||||
}
|
||||
const uint16_t * v = vps;
|
||||
const int8_t * block_values = values;
|
||||
if (mse_m < mse_p) {
|
||||
v = vms;
|
||||
block_values = values + 16;
|
||||
}
|
||||
for (int k = 0; k < block_size/4; ++k) {
|
||||
xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2];
|
||||
wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2];
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float q = block_values[(v[k] >> 4*j) & 0xf] * l;
|
||||
sumqx += wv[j]*q*xv[j];
|
||||
sumq2 += wv[j]*q*q;
|
||||
}
|
||||
}
|
||||
l += 127;
|
||||
if (mse_m < mse_p) l |= 1;
|
||||
for (int k = 0; k < block_size/8; ++k) {
|
||||
auto v1 = table[v[2*k+0] & 0x7fff];
|
||||
auto v2 = table[v[2*k+1] & 0x7fff];
|
||||
y[ibl].qs[(block_size/8)*ib + k] = v1 | (v2 << 15) | (((l >> 2*k) & 3) << 30);
|
||||
}
|
||||
} else {
|
||||
l += 127;
|
||||
for (int k = 0; k < block_size/8; ++k) {
|
||||
y[ibl].qs[(block_size/8)*ib + k] |= (((l >> 2*k) & 3) << 30);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (sumq2 > 0) *dptr = sumqx/sumq2;
|
||||
}
|
||||
|
||||
void prune_iq4ks_to_iq4kss(int n_per_row, const uint16_t * table, const char * cx, const float * x, char *cy,
|
||||
|
||||
Reference in New Issue
Block a user