iq2_kt: SOTA

We arrive at
PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.1642
PPL(LLaMA-2-7B,            4096) = 6.3920
This commit is contained in:
Iwan Kawrakow
2024-11-13 07:27:15 +02:00
parent de7fe92833
commit 200a19f18f

View File

@@ -3454,6 +3454,7 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
float sx[8];
int index[8];
auto vid = _mm256_set1_ps(id);
auto add8 = _mm256_set1_epi32(8);
for (int l = 0; l < kNg; ++l) {
auto xl = xb + 8*l;
auto wl = weight + 8*l;
@@ -3462,6 +3463,7 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
auto vbest = _mm256_set1_ps(INFINITY);
auto best_index = _mm256_set1_epi32(-1);
float best = INFINITY; int jbest = -1;
auto idx = add_idx;
for (int j = 0; j < ncluster; j += 8) {
auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx);
for (int i = 0; i < 8; ++i) {
@@ -3474,6 +3476,7 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
_mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
vbest = _mm256_min_ps(vbest, score);
idx = _mm256_add_epi32(idx, add8);
}
_mm256_store_ps(sx, vbest);
_mm256_store_si256((__m256i *)index, best_index);
@@ -3487,8 +3490,7 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
vbest = _mm256_set1_ps(INFINITY);
best_index = _mm256_set1_epi32(-1);
best = INFINITY; jbest = -1;
auto idx = add_idx;
auto add8 = _mm256_set1_epi32(8);
idx = add_idx;
for (int j = 0; j < int(points.size()); j += 8) {
//auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx);
for (int i = 0; i < 8; ++i) {