This commit is contained in:
Iwan Kawrakow
2024-11-11 18:51:50 +02:00
parent 1d6ca83203
commit 21ee589996

View File

@@ -3170,6 +3170,7 @@ public:
inline void find_best_match(float d, const float * xb, const float * weight, int * best_idx) const;
inline void find_best_match(const float * xb, const float * weight, int * best_idx) const;
inline std::pair<float, float> find_best_scale(const float * xb, const float * weight, const int * best_idx) const;
inline float find_best_inverse_scale(const float * xb, const float * weight, const int * best_idx) const;
static inline void set_values(uint32_t i, float * result, float scale) {
constexpr uint32_t ka = 89226354;
@@ -3262,6 +3263,39 @@ std::pair<float, float> QuantizerIQKT<block_size, group_size, num_bits, num_clus
return sumq2 > 0 ? std::make_pair(sumqx/sumq2, sumqx*sumqx/sumq2) : std::make_pair(0.f, 0.f);
}
template <int block_size, int group_size, int num_bits, int num_clusters>
float QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_inverse_scale(
const float * xb, const float * weight, const int * best_idx) const {
float sumqx = 0, sumx2 = 0;
#ifdef __AVX2__
auto vqx = _mm256_setzero_ps();
auto vx2 = _mm256_setzero_ps();
for (int l = 0; l < kBlockSize; l += 8) {
auto vx = _mm256_loadu_ps(xb+l);
auto vw = _mm256_loadu_ps(weight+l);
auto vq = kGroupSize == 8 ? _mm256_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize]) :
_mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+1]),
_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+0]));
auto vxw = _mm256_mul_ps(vx, vw);
vx2 = _mm256_fmadd_ps(vxw, vx, vx2);
vqx = _mm256_fmadd_ps(vxw, vq, vqx);
}
sumqx = hsum_float_8(vqx);
sumx2 = hsum_float_8(vx2);
#else
for (int l = 0; l < kNg; ++l) {
auto xl = xb + kGroupSize*l;
auto wl = weight + kGroupSize*l;
auto ql = m_values.data() + kGroupSize*best_idx[l];
for (int k = 0; k < kGroupSize; ++k) {
sumqx += wl[k]*ql[k]*xl[k];
sumx2 += wl[k]*xl[k]*xl[k];
}
}
#endif
return sumx2 > 0 ? sumqx/sumx2 : 0.f;
}
template <int block_size, int group_size, int num_bits, int num_clusters>
void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_match(const float * xb, const float * weight, int * best_idx) const {
int ncluster = m_clusters.size()/kGroupSize;
@@ -3483,7 +3517,8 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
} else {
__m256 sqx[4];
const __m256i add_idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
const __m256 sign_bit = _mm256_set1_ps(-0.f);
//const __m256 sign_bit = _mm256_set1_ps(-0.f);
const __m256 sign_bit = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffff));
float sx[8];
int index[8];
auto vid_p = _mm256_set1_ps(id);
@@ -3505,7 +3540,8 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
auto vdiff = _mm256_sub_ps(vq, vx_p);
//vdiff = _mm256_mul_ps(vdiff, vdiff);
//sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff));
vdiff = _mm256_andnot_ps(sign_bit, vdiff);
//vdiff = _mm256_andnot_ps(sign_bit, vdiff);
vdiff = _mm256_and_ps(sign_bit, vdiff);
sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff)));
}
auto score = hsum_float_4x8(sqx);
@@ -3534,7 +3570,8 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
auto vdiff = _mm256_sub_ps(vq, vx_p);
//vdiff = _mm256_mul_ps(vdiff, vdiff);
//sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff));
vdiff = _mm256_andnot_ps(sign_bit, vdiff);
//vdiff = _mm256_andnot_ps(sign_bit, vdiff);
vdiff = _mm256_and_ps(sign_bit, vdiff);
sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff)));
}
auto score = hsum_float_4x8(sqx);
@@ -4168,7 +4205,7 @@ void vec_dot_iq3_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx
namespace{
using QuantizerIQ4KT = QuantizerIQKT<64, 4, 16, 128>;
using QuantizerIQ4KT = QuantizerIQKT<64, 4, 16, 512>;
const QuantizerIQ4KT& iq4kt_quantizer() {
static std::mutex mutex;
@@ -4238,21 +4275,46 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f
ql += Q::kNg;
continue;
}
//float scale_0 = 127.f*amax/amax_row;
//float scale_0 = std::max(64.f, 127.f*amax/amax_row);
float best = 0;
float scale_0 = std::max(92.f, 127.f*amax/amax_row);
//float scale_0 = row_scale;
quantizer.find_best_match( amax/scale_0, xb, weight, best_idx);
auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx);
quantizer.find_best_match(-amax/scale_0, xb, weight, best_idx+Q::kNg);
auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx+Q::kNg);
if (score_p > score_m) {
scales[ib] = dp;
for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j];
} else {
scales[ib] = dm;
for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j+Q::kNg];
//float scale_0 = 96.f;
for (int itry = -2; itry <= 2; ++itry) {
quantizer.find_best_match( amax/(8.f*itry + scale_0), xb, weight, best_idx);
auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx);
if (score_p > best) {
best = score_p; scales[ib] = dp;
for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j];
}
quantizer.find_best_match(-amax/(8.f*itry + scale_0), xb, weight, best_idx);
auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx);
if (score_m > best) {
best = score_m; scales[ib] = dm;
for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j];
}
}
//for (int j = 0; j < Q::kNg; ++j) best_idx[j] = ql[j];
//auto inv_scale = quantizer.find_best_inverse_scale(xb, weight, best_idx);
//if (inv_scale) {
// quantizer.find_best_match(1/inv_scale, xb, weight, best_idx);
// auto [d, score] = quantizer.find_best_scale(xb, weight, best_idx);
// if (score > best) {
// if (score > 1.02f*best) printf("New best match: %g vs %g, score is %g vs %g\n", d, scales[ib], score, best);
// scales[ib] = d;
// for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j];
// }
//}
////float scale_0 = row_scale;
//quantizer.find_best_match( amax/scale_0, xb, weight, best_idx);
//auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx);
//quantizer.find_best_match(-amax/scale_0, xb, weight, best_idx+Q::kNg);
//auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx+Q::kNg);
//if (score_p > score_m) {
// scales[ib] = dp;
// for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j];
//} else {
// scales[ib] = dm;
// for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j+Q::kNg];
//}
//float mse = 0;
//for (int j = 0; j < Q::kNg; ++j) {
@@ -4317,6 +4379,9 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f
int ls = y[ibl].scales[ib];
float dl = d*ls;
quantizer.find_best_match(dl, xb, weight, best_idx);
float dnew = quantizer.find_best_scale(xb, weight, best_idx).first;
ls = std::max(-128, std::min(127, nearest_int(dnew/d)));
y[ibl].scales[ib] = ls;
for (int j = 0; j < Q::kNg; ++j) {
qs[j] = best_idx[j];