Testing Trellis quantization

Using 12 bits per 8 weights I get a better rmse than
iq2_xxs. I still need to see how quantizing the group-of-8
scales will affect accuracy. By AVX2 SIMDifying the search
for the best code, LLaMA-3.1-8B gets quantized in 130 seconds
on the Ryzen-7950X CPU - sluggish but still acceptable.
This commit is contained in:
Iwan Kawrakow
2024-11-05 14:11:14 +02:00
parent afe9db7143
commit f21dd3fb15

View File

@@ -292,18 +292,18 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
constexpr int kBlockSize = 8;
static_assert(kNumVal%8 == 0);
auto codes = make_values(kNumVal, kBlockSize);
std::vector<float> sumq2(kNumVal);
std::vector<float> sumq2i(kNumVal);
for (int j = 0; j < kNumVal; ++j) {
auto data = codes.data() + kBlockSize*j;
float sum = 0; for (int k = 0; k < kBlockSize; ++k) sum += data[k]*data[k];
sumq2[j] = sum;
sumq2i[j] = sum > 0 ? 1/sum : 0.f;;
}
int nthread = std::max(1, int(std::thread::hardware_concurrency()/2));
int chunk = (nrows + 8*nthread - 1)/(8*nthread);
std::mutex mutex;
int counter = 0;
float mse = 0;
auto compute = [&mutex, &counter, &mse, &codes, &sumq2, values, nrows, n_per_row, chunk] () {
auto compute = [&mutex, &counter, &mse, &codes, &sumq2i, values, nrows, n_per_row, chunk] () {
float lmse = 0;
while (true) {
std::unique_lock<std::mutex> lock(mutex);
@@ -317,7 +317,9 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
#ifdef __AVX2__
__m256 vx[kBlockSize/8];
__m256 sqx[8];
__m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
float sx[8];
int index[8];
#endif
for (int row = first; row < last; ++row) {
auto xr = values + row*n_per_row;
@@ -326,7 +328,10 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
auto xb = xr + kBlockSize*ib;
#ifdef __AVX2__
for (int l = 0; l < kBlockSize/8; ++l) vx[l] = _mm256_loadu_ps(xb+8*l);
auto vbest = _mm256_set1_ps(0.f);
auto best_index = _mm256_set1_epi32(-1);
for (int j = 0; j < kNumVal; j += 8) {
auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx);
for (int i = 0; i < 8; ++i) {
sqx[i] = _mm256_setzero_ps();
for (int l = 0; l < kBlockSize/8; ++l) {
@@ -334,25 +339,40 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
sqx[i] = _mm256_fmadd_ps(vx[l], qv, sqx[i]);
}
}
_mm256_storeu_ps(sx, hsum_float_8x8(sqx));
for (int i = 0; i < 8; ++i) {
if (sumq2[j+i] > 0 && sx[i]*sx[i] > best*sumq2[j+i]) {
d = sx[i]/sumq2[j+i]; best = d*sx[i]; jbest = j+i;
}
}
auto sumqx = hsum_float_8x8(sqx);
auto score = _mm256_mul_ps(_mm256_mul_ps(sumqx, sumqx), _mm256_loadu_ps(sumq2i.data() + j));
auto mask = _mm256_cmp_ps(score, vbest, _CMP_GT_OQ);
best_index = _mm256_or_si256(_mm256_and_si256(idx, _mm256_castps_si256(mask)), _mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
vbest = _mm256_max_ps(vbest, score);
//_mm256_storeu_ps(sx, hsum_float_8x8(sqx));
//for (int i = 0; i < 8; ++i) {
// if (sx[i]*sx[i]*sumq2i[j+i] > best) {
// d = sx[i]*sumq2i[j+i]; best = d*sx[i]; jbest = j+i;
// }
//}
}
_mm256_store_ps(sx, vbest);
_mm256_store_si256((__m256i *)index, best_index);
best = sx[0]; jbest = index[0];
for (int j = 1; j < 8; ++j) {
if (sx[j] > best) { best = sx[j]; jbest = index[j]; }
}
auto qv = codes.data() + kBlockSize*jbest;
float sumqx = 0;
for (int k = 0; k < 8; ++k) sumqx += xb[k]*qv[k];
d = sumqx*sumq2i[jbest];
#else
for (int j = 0; j < kNumVal; ++j) {
if (!sumq2[j]) continue;
if (!sumq2i[j]) continue;
auto qv = codes.data() + kBlockSize*j;
float sumqx = 0;
for (int k = 0; k < kBlockSize; ++k) sumqx += qv[k]*xb[k];
if (sumqx*sumqx > best*sumq2[j]) {
d = sumqx/sumq2[j]; best = d*sumqx; jbest = j;
if (sumqx*sumqx*sumq2i[j] > best]) {
d = sumqx*sumq2i[j]; best = d*sumqx; jbest = j;
}
}
#endif
auto qv = codes.data() + kBlockSize*jbest;
#endif
for (int k = 0; k < kBlockSize; ++k) {
float diff = xb[k] - d*qv[k];
lmse += diff*diff;