iq4_kss: slightly better quantization

This commit is contained in:
Iwan Kawrakow
2024-10-15 13:49:51 +03:00
parent b159b2b113
commit b68c2cb0e0

View File

@@ -2893,10 +2893,236 @@ uint16_t prune_iq4ks(uint16_t v, const int8_t * values, const float * x, const f
q4[jbest] = bestq;
return (q4[0] | (q4[1] << 4) | (q4[2] << 8) | (q4[3] << 12));
}
static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
float * all_scales, float * weight,
const int8_t * values,
const float * quant_weights,
const uint16_t * table,
const int ntry) {
constexpr int super_block_size = 256;
constexpr int block_size = 32;
float * dptr = (float *)cy;
*dptr = 0;
block_iq4_kss * y = (block_iq4_kss *)(dptr + 1);
const int8_t * shifted_values = values + 16;
uint16_t vps[block_size/2], vms[block_size/2], vs[block_size/2];
//uint8_t Lp[block_size], Lm[block_size], L[block_size];
float xv[4], wv[4];
float amax_scale = 0;
for (int ibl = 0; ibl < n_per_row/super_block_size; ++ibl) {
memset(&y[ibl], 0, sizeof(block_iq4_kss));
const float * xbl = x + ibl*super_block_size;
auto scales = all_scales + ibl*(super_block_size/block_size);
float sigma2 = 0;
for (int j = 0; j < super_block_size; ++j) sigma2 += xbl[j]*xbl[j];
sigma2 *= 2.f/super_block_size;
for (int ib = 0; ib < super_block_size/block_size; ++ib) {
const float * xb = xbl + ib*block_size;
if (quant_weights) {
const float * qw = quant_weights + ibl*super_block_size + ib*block_size;
for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
} else {
for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
}
float amax = 0, max = 0;
for (int j = 0; j < block_size; ++j) {
float ax = fabsf(xb[j]);
if (ax > amax) {
amax = ax; max = xb[j];
}
}
if (!amax) {
scales[ib] = 0;
continue;
}
float best = 0;
bool is_shifted = false;
float d = -max/iq4k_values[0];
std::memset(vs, 0, block_size);
for (int itry = -ntry; itry <= ntry; ++itry) {
float id = (itry + values[0])/max;
float sumqx_p = 0, sumq2_p = 0;
float sumqx_m = 0, sumq2_m = 0;
//for (int j = 0; j < block_size; ++j) {
// float al = id*xb[j];
// Lp[j] = best_index_iq4nl(values, al);
// Lm[j] = best_index_iq4nl(values, -al);
//}
float this_d = 1/id;
for (int k = 0; k < block_size/4; ++k) {
xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2];
wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2];
uint16_t vp = 0, vm = 0;
for (int j = 0; j < 4; ++j) {
float al = id*xv[j];
vp |= (best_index_iq4nl(values, al) << 4*j);
vm |= (best_index_iq4nl(values, -al) << 4*j);
}
vp = prune_iq4ks(vp, values, xv, wv, this_d);
vm = prune_iq4ks(vm, values, xv, wv, this_d);
for (int j = 0; j < 4; ++j) {
float w = wv[j];
float q = values[(vp >> 4*j) & 0xf];
sumqx_p += w*q*xv[j];
sumq2_p += w*q*q;
q = values[(vm >> 4*j) & 0xf];
sumqx_m += w*q*xv[j];
sumq2_m += w*q*q;
}
vps[k] = vp;
vms[k] = vm;
//uint16_t vp = Lp[2*k+0] | (Lp[2*k+0+block_size/2] << 4) | (Lp[2*k+1] << 8) | (Lp[2*k+1+block_size/2] << 12);
//uint16_t vm = Lm[2*k+0] | (Lm[2*k+0+block_size/2] << 4) | (Lm[2*k+1] << 8) | (Lm[2*k+1+block_size/2] << 12);
//auto vps = prune_iq4ks(vp, values, xv, wv, this_d);
//auto vms = prune_iq4ks(vm, values, xv, wv, -this_d);
//Lp[2*k+0] = (vps >> 0) & 0xf; Lp[2*k+0+block_size/2] = (vps >> 4) & 0xf;
//Lp[2*k+1] = (vps >> 8) & 0xf; Lp[2*k+1+block_size/2] = (vps >> 12) & 0xf;
//Lm[2*k+0] = (vms >> 0) & 0xf; Lm[2*k+0+block_size/2] = (vms >> 4) & 0xf;
//Lm[2*k+1] = (vms >> 8) & 0xf; Lm[2*k+1+block_size/2] = (vms >> 12) & 0xf;
}
//for (int j = 0; j < block_size; ++j) {
// float w = weight[j];
// float q = values[Lp[j]];
// sumqx_p += w*q*xb[j];
// sumq2_p += w*q*q;
// q = values[Lm[j]];
// sumqx_m += w*q*xb[j];
// sumq2_m += w*q*q;
//}
bool copy_p = false, copy_m = false;
if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; copy_p = true;
}
if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; copy_m = true;
}
if (copy_m) {
std::memcpy(vs, vms, block_size);
} else if (copy_p) {
std::memcpy(vs, vps, block_size);
}
id = (itry + shifted_values[0])/max;
this_d = 1/id;
sumqx_p = sumq2_p = 0;
sumqx_m = sumq2_m = 0;
for (int k = 0; k < block_size/4; ++k) {
xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2];
wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2];
uint16_t vp = 0, vm = 0;
for (int j = 0; j < 4; ++j) {
float al = id*xv[j];
vp |= (best_index_iq4nl(shifted_values, al) << 4*j);
vm |= (best_index_iq4nl(shifted_values, -al) << 4*j);
}
vp = prune_iq4ks(vp, shifted_values, xv, wv, this_d);
vm = prune_iq4ks(vm, shifted_values, xv, wv, this_d);
for (int j = 0; j < 4; ++j) {
float w = wv[j];
float q = shifted_values[(vp >> 4*j) & 0xf];
sumqx_p += w*q*xv[j];
sumq2_p += w*q*q;
q = shifted_values[(vm >> 4*j) & 0xf];
sumqx_m += w*q*xv[j];
sumq2_m += w*q*q;
}
vps[k] = vp;
vms[k] = vm;
}
copy_p = copy_m = false;
if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; copy_p = true;
}
if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; copy_m = true;
}
if (copy_m) {
std::memcpy(vs, vms, block_size);
} else if (copy_p) {
std::memcpy(vs, vps, block_size);
}
//for (int j = 0; j < block_size; ++j) {
// float al = id*xb[j];
// Lp[j] = best_index_iq4nl(shifted_values, al);
// Lm[j] = best_index_iq4nl(shifted_values, -al);
//}
//for (int k = 0; k < block_size/4; ++k) {
// xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2];
// wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2];
// uint16_t vp = Lp[2*k+0] | (Lp[2*k+0+block_size/2] << 4) | (Lp[2*k+1] << 8) | (Lp[2*k+1+block_size/2] << 12);
// uint16_t vm = Lm[2*k+0] | (Lm[2*k+0+block_size/2] << 4) | (Lm[2*k+1] << 8) | (Lm[2*k+1+block_size/2] << 12);
// auto vps = prune_iq4ks(vp, shifted_values, xv, wv, this_d);
// auto vms = prune_iq4ks(vm, shifted_values, xv, wv, -this_d);
// Lp[2*k+0] = (vps >> 0) & 0xf; Lp[2*k+0+block_size/2] = (vps >> 4) & 0xf;
// Lp[2*k+1] = (vps >> 8) & 0xf; Lp[2*k+1+block_size/2] = (vps >> 12) & 0xf;
// Lm[2*k+0] = (vms >> 0) & 0xf; Lm[2*k+0+block_size/2] = (vms >> 4) & 0xf;
// Lm[2*k+1] = (vms >> 8) & 0xf; Lm[2*k+1+block_size/2] = (vms >> 12) & 0xf;
//}
//for (int j = 0; j < block_size; ++j) {
// float w = weight[j];
// float q = shifted_values[Lp[j]];
// sumqx_p += w*q*xb[j];
// sumq2_p += w*q*q;
// q = shifted_values[Lm[j]];
// sumqx_m += w*q*xb[j];
// sumq2_m += w*q*q;
//}
//copy_p = copy_m = false;
//if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
// d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; copy_p = true;
//}
//if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
// d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; copy_m = true;
//}
//if (copy_m) {
// std::memcpy(L, Lm, block_size);
//} else if (copy_p) {
// std::memcpy(L, Lp, block_size);
//}
}
for (int k = 0; k < block_size/8; ++k) {
auto v1 = table[vs[2*k+0] & 0x7fff];
auto v2 = table[vs[2*k+1] & 0x7fff];
y[ibl].qs[(block_size/8)*ib + k] = v1 | (v2 << 15);
}
//for (int k = 0; k < block_size/8; ++k) {
// uint16_t v1 = Lp[4*k+0] | (Lp[4*k+0+block_size/2] << 4) | (Lp[4*k+1] << 8) | (Lp[4*k+1+block_size/2] << 12);
// uint16_t v2 = Lp[4*k+2] | (Lp[4*k+2+block_size/2] << 4) | (Lp[4*k+3] << 8) | (Lp[4*k+3+block_size/2] << 12);
// v1 = table[v1 & 0x7fff];
// v2 = table[v2 & 0x7fff];
// y[ibl].qs[(block_size/8)*ib + k] = v1 | (v2 << 15);
//}
if (is_shifted) y[ibl].qs[(block_size/8)*ib] |= (1u << 30);
scales[ib] = d;
amax_scale = std::max(amax_scale, std::abs(d));
}
}
float d = amax_scale/127;
*dptr = d;
if (!d) return;
float id = 1/d;
for (int ibl = 0; ibl < n_per_row/super_block_size; ++ibl) {
auto scales = all_scales + (super_block_size/block_size)*ibl;
for (int ib = 0; ib < super_block_size/block_size; ++ib) {
int l = nearest_int(0.5f*(id*scales[ib]+127.f));
l = std::max(0, std::min(127, l)) << 1;
for (int k = 0; k < 4; ++k) {
y[ibl].qs[4*ib+k] |= (((l >> 2*k) & 3) << 30);
}
}
}
}
void prune_iq4ks_to_iq4kss(int n_per_row, const uint16_t * table, const char * cx, const float * x, char *cy,
const float * quant_weights, float * weight, float * all_scales) {
constexpr int kBlockSize = 32;
float xv[4], w[4];
float xv[4], wv[4];
uint16_t vps[kBlockSize/4];
const float * dptr_ks = (const float *)cx;
const float d_ks = *dptr_ks;
@@ -2925,14 +3151,14 @@ void prune_iq4ks_to_iq4kss(int n_per_row, const uint16_t * table, const char * c
float dl = d_ks * ((iq4ks[ibl].scales[ib] & 254) - 127);
float sumqx = 0, sumq2 = 0;
for (int k = 0; k < kBlockSize/4; ++k) {
xv[0] = xb[2*k+0]; xv[1] = xb[2*k+kBlockSize/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+kBlockSize/2];
w[0] = weight[2*k+0]; w[1] = weight[2*k+kBlockSize/2]; w[2] = weight[2*k+1]; w[3] = weight[2*k+1+kBlockSize/2];
auto vp = prune_iq4ks(q4[k], values, xv, w, dl);
xv[0] = xb[2*k+0]; xv[1] = xb[2*k+kBlockSize/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+kBlockSize/2];
wv[0] = weight[2*k+0]; wv[1] = weight[2*k+kBlockSize/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+kBlockSize/2];
auto vp = prune_iq4ks(q4[k], values, xv, wv, dl);
vps[k] = table[vp & 0x7fff];
for (int j = 0; j < 4; ++j) {
float q = values[(vp >> 4*j) & 0xf];
sumqx += w[j]*q*xv[j];
sumq2 += w[j]*q*q;
sumqx += wv[j]*q*xv[j];
sumq2 += wv[j]*q*q;
}
}
for (int k = 0; k < kBlockSize/8; ++k) {
@@ -2974,8 +3200,7 @@ size_t quantize_iq4_kss(const float * src, void * dst, int64_t nrows, int64_t n_
auto qrow = (char *)dst;
auto table = scramble_table();
for (int row = 0; row < nrows; ++row) {
quantize_row_iq4_k_impl_bs128(QK_K, kBlockSize, n_per_row, src, work.data(), all_scales.data(), weight, iq4k_values, imatrix, 7);
prune_iq4ks_to_iq4kss(n_per_row, table, work.data(), src, qrow, imatrix, weight, all_scales.data());
quantize_row_iq4_kss_impl(n_per_row, src, qrow, all_scales.data(), weight, iq4k_values, imatrix, table, 7);
src += n_per_row;
qrow += row_size;
}