iq3_kt WIP: slowly improving

PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.8322, which is
starting to be competitive/slightly better than other quants.
This commit is contained in:
Iwan Kawrakow
2024-11-08 18:39:23 +02:00
parent 435eb9bdd3
commit f1fb59b44b

View File

@@ -3145,6 +3145,10 @@ __m256 hsum_float_8x8(__m256 * accm) {
for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
}
__m256 hsum_float_4x8(__m256 * accm) {
for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
}
#endif
template <int block_size, int group_size, int num_bits, int num_clusters>
class QuantizerIQKT {
@@ -3470,64 +3474,87 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
best_idx[l] = jbest;
}
} else {
__m128 sqx[4];
const __m128i add_idx = _mm_set_epi32(3, 2, 1, 0);
const __m128 sign_bit = _mm_set1_ps(-0.f);
float sx[4];
int index[4];
auto vid = _mm_set1_ps(id);
__m256 sqx[4];
const __m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
const __m256 sign_bit = _mm256_set1_ps(-0.f);
float sx[8];
int index[8];
auto vid_p = _mm256_set1_ps(id);
auto vid_m = _mm256_set1_ps(-id);
for (int l = 0; l < kNg; ++l) {
auto xl = xb + 4*l;
auto wl = weight + 4*l;
auto vx = _mm_mul_ps(vid, _mm_loadu_ps(xl));
auto vw = _mm_loadu_ps(wl);
auto vbest = _mm_set1_ps(INFINITY);
auto best_index = _mm_set1_epi32(-1);
auto vx4 = _mm_loadu_ps(xl);
auto vx_p = _mm256_mul_ps(vid_p, _mm256_set_m128(vx4, vx4));
//auto vx_m = _mm256_mul_ps(vid_m, _mm256_set_m128(vx4, vx4));
auto vw4 = _mm_loadu_ps(wl);
auto vw = _mm256_set_m128(vw4, vw4);
auto vbest = _mm256_set1_ps(INFINITY);
auto best_index = _mm256_set1_epi32(-1);
float best = INFINITY; int jbest = -1;
for (int j = 0; j < ncluster; j += 4) {
auto idx = _mm_add_epi32(_mm_set1_epi32(j), add_idx);
for (int j = 0; j < ncluster; j += 8) {
auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx);
for (int i = 0; i < 4; ++i) {
auto vq = _mm_loadu_ps(m_clusters.data() + kGroupSize*(j+i));
auto vdiff = _mm_sub_ps(vq, vx);
auto vq = _mm256_set_m128(_mm_loadu_ps(m_clusters.data() + kGroupSize*(j+i+4)), _mm_loadu_ps(m_clusters.data() + kGroupSize*(j+i)));
auto vdiff = _mm256_sub_ps(vq, vx_p);
//sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff));
vdiff = _mm_andnot_ps(sign_bit, vdiff);
sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, _mm_mul_ps(vdiff, vdiff)));
vdiff = _mm256_andnot_ps(sign_bit, vdiff);
sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff)));
//vdiff = _mm256_sub_ps(vq, vx_m);
////sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff));
//vdiff = _mm256_andnot_ps(sign_bit, vdiff);
//sqx[i+4] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff)));
}
auto score = hsum_float_4x4(sqx);
auto mask = _mm_cmp_ps(score, vbest, _CMP_LT_OQ);
best_index = _mm_or_si128(_mm_and_si128(_mm_castps_si128(mask), idx),
_mm_andnot_si128(_mm_castps_si128(mask), best_index));
vbest = _mm_min_ps(vbest, score);
auto score = hsum_float_4x8(sqx);
auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ);
best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
_mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
vbest = _mm256_min_ps(vbest, score);
//score = hsum_float_4x8(sqx+4);
//mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ);
//best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
// _mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
//vbest = _mm256_min_ps(vbest, score);
}
_mm_store_ps(sx, vbest);
_mm_store_si128((__m128i *)index, best_index);
for (int i = 0; i < 4; ++i) {
_mm256_store_ps(sx, vbest);
_mm256_store_si256((__m256i *)index, best_index);
for (int i = 0; i < 8; ++i) {
if (sx[i] < best) { best = sx[i]; jbest = index[i]; }
}
auto& points = m_in_cluster[jbest];
GGML_ASSERT(!points.empty() && points.size()%4 == 0);
GGML_ASSERT(!points.empty() && points.size()%8 == 0);
int jbest_cluster = jbest;
vbest = _mm_set1_ps(INFINITY);
best_index = _mm_set1_epi32(-1);
vbest = _mm256_set1_ps(INFINITY);
best_index = _mm256_set1_epi32(-1);
best = INFINITY; jbest = -1;
for (int j = 0; j < int(points.size()); j += 4) {
auto idx = _mm_loadu_si128((const __m128i*)(points.data() + j));
for (int j = 0; j < int(points.size()); j += 8) {
auto idx = _mm256_loadu_si256((const __m256i*)(points.data() + j));
for (int i = 0; i < 4; ++i) {
auto vq = _mm_loadu_ps(m_values.data() + kGroupSize*points[j+i]);
auto vdiff = _mm_sub_ps(vq, vx);
auto vq = _mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*points[j+i+4]),
_mm_loadu_ps(m_values.data() + kGroupSize*points[j+i+0]));
auto vdiff = _mm256_sub_ps(vq, vx_p);
//sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff));
vdiff = _mm_andnot_ps(sign_bit, vdiff);
sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, _mm_mul_ps(vdiff, vdiff)));
vdiff = _mm256_andnot_ps(sign_bit, vdiff);
sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff)));
//vdiff = _mm256_sub_ps(vq, vx_m);
////sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff));
//vdiff = _mm256_andnot_ps(sign_bit, vdiff);
//sqx[i+4] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff)));
}
auto score = hsum_float_4x4(sqx);
auto mask = _mm_cmp_ps(score, vbest, _CMP_LT_OQ);
best_index = _mm_or_si128(_mm_and_si128(_mm_castps_si128(mask), idx),
_mm_andnot_si128(_mm_castps_si128(mask), best_index));
vbest = _mm_min_ps(vbest, score);
auto score = hsum_float_4x8(sqx);
auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ);
best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
_mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
vbest = _mm256_min_ps(vbest, score);
//score = hsum_float_4x8(sqx+4);
//mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ);
//best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
// _mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
//vbest = _mm256_min_ps(vbest, score);
}
_mm_store_ps(sx, vbest);
_mm_store_si128((__m128i *)index, best_index);
for (int i = 0; i < 4; ++i) {
_mm256_store_ps(sx, vbest);
_mm256_store_si256((__m256i *)index, best_index);
for (int i = 0; i < 8; ++i) {
if (sx[i] < best) { best = sx[i]; jbest = index[i]; }
}
if (jbest < 0) {
@@ -3962,12 +3989,19 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f
scales[ib] = 0;
if (!amax) continue;
float best = 0;
for (int itry = -5; itry <= 5; ++itry) {
//for (int itry = -5; itry <= 5; ++itry) {
for (int itry = -3; itry <= 3; ++itry) {
quantizer.find_best_match(amax/(96.f + 4.f*itry), xb, weight, best_idx);
auto [d, score] = quantizer.find_best_scale(xb, weight, best_idx);
if (score > best) {
best = score;
scales[ib] = d;
auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx);
if (score_p > best) {
best = score_p;
scales[ib] = dp;
}
quantizer.find_best_match(-amax/(96.f + 4.f*itry), xb, weight, best_idx);
auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx);
if (score_m > best) {
best = score_m;
scales[ib] = dm;
}
}
//float d = amax/96.f;