mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
Testing Trellis quantization
Using 12 bits per 8 weights I get a better rmse than iq2_xxs. I still need to see how quantizing the group-of-8 scales will affect accuracy. By AVX2 SIMDifying the search for the best code, LLaMA-3.1-8B gets quantized in 130 seconds on the Ryzen-7950X CPU - sluggish but still acceptable.
This commit is contained in:
@@ -292,18 +292,18 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
|
||||
constexpr int kBlockSize = 8;
|
||||
static_assert(kNumVal%8 == 0);
|
||||
auto codes = make_values(kNumVal, kBlockSize);
|
||||
std::vector<float> sumq2(kNumVal);
|
||||
std::vector<float> sumq2i(kNumVal);
|
||||
for (int j = 0; j < kNumVal; ++j) {
|
||||
auto data = codes.data() + kBlockSize*j;
|
||||
float sum = 0; for (int k = 0; k < kBlockSize; ++k) sum += data[k]*data[k];
|
||||
sumq2[j] = sum;
|
||||
sumq2i[j] = sum > 0 ? 1/sum : 0.f;;
|
||||
}
|
||||
int nthread = std::max(1, int(std::thread::hardware_concurrency()/2));
|
||||
int chunk = (nrows + 8*nthread - 1)/(8*nthread);
|
||||
std::mutex mutex;
|
||||
int counter = 0;
|
||||
float mse = 0;
|
||||
auto compute = [&mutex, &counter, &mse, &codes, &sumq2, values, nrows, n_per_row, chunk] () {
|
||||
auto compute = [&mutex, &counter, &mse, &codes, &sumq2i, values, nrows, n_per_row, chunk] () {
|
||||
float lmse = 0;
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
@@ -317,7 +317,9 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
|
||||
#ifdef __AVX2__
|
||||
__m256 vx[kBlockSize/8];
|
||||
__m256 sqx[8];
|
||||
__m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
|
||||
float sx[8];
|
||||
int index[8];
|
||||
#endif
|
||||
for (int row = first; row < last; ++row) {
|
||||
auto xr = values + row*n_per_row;
|
||||
@@ -326,7 +328,10 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
|
||||
auto xb = xr + kBlockSize*ib;
|
||||
#ifdef __AVX2__
|
||||
for (int l = 0; l < kBlockSize/8; ++l) vx[l] = _mm256_loadu_ps(xb+8*l);
|
||||
auto vbest = _mm256_set1_ps(0.f);
|
||||
auto best_index = _mm256_set1_epi32(-1);
|
||||
for (int j = 0; j < kNumVal; j += 8) {
|
||||
auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
sqx[i] = _mm256_setzero_ps();
|
||||
for (int l = 0; l < kBlockSize/8; ++l) {
|
||||
@@ -334,25 +339,40 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
|
||||
sqx[i] = _mm256_fmadd_ps(vx[l], qv, sqx[i]);
|
||||
}
|
||||
}
|
||||
_mm256_storeu_ps(sx, hsum_float_8x8(sqx));
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
if (sumq2[j+i] > 0 && sx[i]*sx[i] > best*sumq2[j+i]) {
|
||||
d = sx[i]/sumq2[j+i]; best = d*sx[i]; jbest = j+i;
|
||||
}
|
||||
}
|
||||
auto sumqx = hsum_float_8x8(sqx);
|
||||
auto score = _mm256_mul_ps(_mm256_mul_ps(sumqx, sumqx), _mm256_loadu_ps(sumq2i.data() + j));
|
||||
auto mask = _mm256_cmp_ps(score, vbest, _CMP_GT_OQ);
|
||||
best_index = _mm256_or_si256(_mm256_and_si256(idx, _mm256_castps_si256(mask)), _mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
|
||||
vbest = _mm256_max_ps(vbest, score);
|
||||
//_mm256_storeu_ps(sx, hsum_float_8x8(sqx));
|
||||
//for (int i = 0; i < 8; ++i) {
|
||||
// if (sx[i]*sx[i]*sumq2i[j+i] > best) {
|
||||
// d = sx[i]*sumq2i[j+i]; best = d*sx[i]; jbest = j+i;
|
||||
// }
|
||||
//}
|
||||
}
|
||||
_mm256_store_ps(sx, vbest);
|
||||
_mm256_store_si256((__m256i *)index, best_index);
|
||||
best = sx[0]; jbest = index[0];
|
||||
for (int j = 1; j < 8; ++j) {
|
||||
if (sx[j] > best) { best = sx[j]; jbest = index[j]; }
|
||||
}
|
||||
auto qv = codes.data() + kBlockSize*jbest;
|
||||
float sumqx = 0;
|
||||
for (int k = 0; k < 8; ++k) sumqx += xb[k]*qv[k];
|
||||
d = sumqx*sumq2i[jbest];
|
||||
#else
|
||||
for (int j = 0; j < kNumVal; ++j) {
|
||||
if (!sumq2[j]) continue;
|
||||
if (!sumq2i[j]) continue;
|
||||
auto qv = codes.data() + kBlockSize*j;
|
||||
float sumqx = 0;
|
||||
for (int k = 0; k < kBlockSize; ++k) sumqx += qv[k]*xb[k];
|
||||
if (sumqx*sumqx > best*sumq2[j]) {
|
||||
d = sumqx/sumq2[j]; best = d*sumqx; jbest = j;
|
||||
if (sumqx*sumqx*sumq2i[j] > best]) {
|
||||
d = sumqx*sumq2i[j]; best = d*sumqx; jbest = j;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
auto qv = codes.data() + kBlockSize*jbest;
|
||||
#endif
|
||||
for (int k = 0; k < kBlockSize; ++k) {
|
||||
float diff = xb[k] - d*qv[k];
|
||||
lmse += diff*diff;
|
||||
|
||||
Reference in New Issue
Block a user