iq2_kt: quantize / dequantize

I now see that I was comparing apples to oranges:
iq2_xxs was using a weight of sigma^2/4 + x^2, while
the Trellis approach wasn't (weight = 1). Once I use the same weight,
iq2_kt is actually slightly worse than iq2_xxs in terms
of rmse, so does not look promising at this point.
Also, once each group of 8 Trellis values no longer has a
constant sum(q^2) that we can precompute, quantization
becomes significantly slower (476 seconds for LLaMA-3.1-8B).
This commit is contained in:
Iwan Kawrakow
2024-11-05 18:50:08 +02:00
parent f1df1b7e15
commit a4f1ac8da4
6 changed files with 281 additions and 2 deletions

View File

@@ -257,6 +257,24 @@ static inline int nearest_int(float fval) {
return (i & 0x007fffff) - 0x00400000;
}
//static void fast_ht(int n, float * values) {
// constexpr float ksqrt2 = 0.707106781f;
// float scale = 1;
// int h = 1;
// while (h < n) {
// for (int i = 0; i < n; i += 2*h) {
// for (int j = i; j < i + h; ++j) {
// float x = values[j], y = values[j + h];
// values[j+0] = x + y;
// values[j+h] = x - y;
// }
// }
// h *= 2;
// scale *= ksqrt2;
// }
// for (int i = 0; i < n; ++i) values[i] *= scale;
//}
static const int8_t scale_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
//static std::vector<float> make_values(int nval, int n_per_val) {
@@ -374,6 +392,7 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
float lmse = 0, lmse_q = 0;
std::vector<float> scales(n_per_row/kBlockSize);
std::vector<int> best_idx(n_per_row/kBlockSize);
//float xtmp[kBlockSize];
while (true) {
std::unique_lock<std::mutex> lock(mutex);
int first = counter; counter += chunk;
@@ -395,8 +414,13 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) {
float best = 0, d = 0; int jbest = -1;
auto xb = xr + kBlockSize*ib;
//std::memcpy(xtmp, xb, kBlockSize*sizeof(float));
//fast_ht(kBlockSize, xtmp);
#ifdef __AVX2__
for (int l = 0; l < kBlockSize/8; ++l) vx[l] = _mm256_loadu_ps(xb+8*l);
for (int l = 0; l < kBlockSize/8; ++l) {
//vx[l] = _mm256_loadu_ps(xtmp+8*l);
vx[l] = _mm256_loadu_ps(xb+8*l);
}
auto vbest = _mm256_set1_ps(0.f);
auto best_index = _mm256_set1_epi32(-1);
for (int j = 0; j < kNumVal; j += 8) {
@@ -422,7 +446,8 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
}
auto qv = codes.data() + kBlockSize*jbest;
float sumqx = 0;
for (int k = 0; k < 8; ++k) sumqx += xb[k]*qv[k];
for (int k = 0; k < kBlockSize; ++k) sumqx += xb[k]*qv[k];
//for (int k = 0; k < kBlockSize; ++k) sumqx += xtmp[k]*qv[k];
d = sumqx*sumq2i[jbest];
#else
for (int j = 0; j < kNumVal; ++j) {
@@ -440,6 +465,7 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
best_idx[ib] = jbest;
for (int k = 0; k < kBlockSize; ++k) {
float diff = xb[k] - d*qv[k];
//float diff = xtmp[k] - d*qv[k];
lmse += diff*diff;
}
}
@@ -458,9 +484,12 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
int ls = best_index_scale(scale_values, id*scales[ib]);
float dl = d * scale_values[ls];
auto xb = xr + kBlockSize*ib;
//std::memcpy(xtmp, xb, kBlockSize*sizeof(float));
//fast_ht(kBlockSize, xtmp);
auto qv = codes.data() + kBlockSize*best_idx[ib];
for (int k = 0; k < kBlockSize; ++k) {
float diff = xb[k] - dl*qv[k];
//float diff = xtmp[k] - dl*qv[k];
lmse_q += diff*diff;
}
}