diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index ed628813..4baa11c1 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -3170,6 +3170,7 @@ public: inline void find_best_match(float d, const float * xb, const float * weight, int * best_idx) const; inline void find_best_match(const float * xb, const float * weight, int * best_idx) const; inline std::pair find_best_scale(const float * xb, const float * weight, const int * best_idx) const; + inline float find_best_inverse_scale(const float * xb, const float * weight, const int * best_idx) const; static inline void set_values(uint32_t i, float * result, float scale) { constexpr uint32_t ka = 89226354; @@ -3262,6 +3263,39 @@ std::pair QuantizerIQKT 0 ? std::make_pair(sumqx/sumq2, sumqx*sumqx/sumq2) : std::make_pair(0.f, 0.f); } +template +float QuantizerIQKT::find_best_inverse_scale( + const float * xb, const float * weight, const int * best_idx) const { + float sumqx = 0, sumx2 = 0; +#ifdef __AVX2__ + auto vqx = _mm256_setzero_ps(); + auto vx2 = _mm256_setzero_ps(); + for (int l = 0; l < kBlockSize; l += 8) { + auto vx = _mm256_loadu_ps(xb+l); + auto vw = _mm256_loadu_ps(weight+l); + auto vq = kGroupSize == 8 ? _mm256_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize]) : + _mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+1]), + _mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+0])); + auto vxw = _mm256_mul_ps(vx, vw); + vx2 = _mm256_fmadd_ps(vxw, vx, vx2); + vqx = _mm256_fmadd_ps(vxw, vq, vqx); + } + sumqx = hsum_float_8(vqx); + sumx2 = hsum_float_8(vx2); +#else + for (int l = 0; l < kNg; ++l) { + auto xl = xb + kGroupSize*l; + auto wl = weight + kGroupSize*l; + auto ql = m_values.data() + kGroupSize*best_idx[l]; + for (int k = 0; k < kGroupSize; ++k) { + sumqx += wl[k]*ql[k]*xl[k]; + sumx2 += wl[k]*xl[k]*xl[k]; + } + } +#endif + return sumx2 > 0 ? sumqx/sumx2 : 0.f; +} + template void QuantizerIQKT::find_best_match(const float * xb, const float * weight, int * best_idx) const { int ncluster = m_clusters.size()/kGroupSize; @@ -3483,7 +3517,8 @@ void QuantizerIQKT::find_best_ma } else { __m256 sqx[4]; const __m256i add_idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); - const __m256 sign_bit = _mm256_set1_ps(-0.f); + //const __m256 sign_bit = _mm256_set1_ps(-0.f); + const __m256 sign_bit = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffff)); float sx[8]; int index[8]; auto vid_p = _mm256_set1_ps(id); @@ -3505,7 +3540,8 @@ void QuantizerIQKT::find_best_ma auto vdiff = _mm256_sub_ps(vq, vx_p); //vdiff = _mm256_mul_ps(vdiff, vdiff); //sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); - vdiff = _mm256_andnot_ps(sign_bit, vdiff); + //vdiff = _mm256_andnot_ps(sign_bit, vdiff); + vdiff = _mm256_and_ps(sign_bit, vdiff); sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff))); } auto score = hsum_float_4x8(sqx); @@ -3534,7 +3570,8 @@ void QuantizerIQKT::find_best_ma auto vdiff = _mm256_sub_ps(vq, vx_p); //vdiff = _mm256_mul_ps(vdiff, vdiff); //sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); - vdiff = _mm256_andnot_ps(sign_bit, vdiff); + //vdiff = _mm256_andnot_ps(sign_bit, vdiff); + vdiff = _mm256_and_ps(sign_bit, vdiff); sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff))); } auto score = hsum_float_4x8(sqx); @@ -4168,7 +4205,7 @@ void vec_dot_iq3_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx namespace{ -using QuantizerIQ4KT = QuantizerIQKT<64, 4, 16, 128>; +using QuantizerIQ4KT = QuantizerIQKT<64, 4, 16, 512>; const QuantizerIQ4KT& iq4kt_quantizer() { static std::mutex mutex; @@ -4238,21 +4275,46 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f ql += Q::kNg; continue; } - //float scale_0 = 127.f*amax/amax_row; - //float scale_0 = std::max(64.f, 127.f*amax/amax_row); + float best = 0; float scale_0 = std::max(92.f, 127.f*amax/amax_row); - //float scale_0 = row_scale; - quantizer.find_best_match( amax/scale_0, xb, weight, best_idx); - auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx); - quantizer.find_best_match(-amax/scale_0, xb, weight, best_idx+Q::kNg); - auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx+Q::kNg); - if (score_p > score_m) { - scales[ib] = dp; - for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j]; - } else { - scales[ib] = dm; - for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j+Q::kNg]; + //float scale_0 = 96.f; + for (int itry = -2; itry <= 2; ++itry) { + quantizer.find_best_match( amax/(8.f*itry + scale_0), xb, weight, best_idx); + auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx); + if (score_p > best) { + best = score_p; scales[ib] = dp; + for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j]; + } + quantizer.find_best_match(-amax/(8.f*itry + scale_0), xb, weight, best_idx); + auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx); + if (score_m > best) { + best = score_m; scales[ib] = dm; + for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j]; + } } + //for (int j = 0; j < Q::kNg; ++j) best_idx[j] = ql[j]; + //auto inv_scale = quantizer.find_best_inverse_scale(xb, weight, best_idx); + //if (inv_scale) { + // quantizer.find_best_match(1/inv_scale, xb, weight, best_idx); + // auto [d, score] = quantizer.find_best_scale(xb, weight, best_idx); + // if (score > best) { + // if (score > 1.02f*best) printf("New best match: %g vs %g, score is %g vs %g\n", d, scales[ib], score, best); + // scales[ib] = d; + // for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j]; + // } + //} + ////float scale_0 = row_scale; + //quantizer.find_best_match( amax/scale_0, xb, weight, best_idx); + //auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx); + //quantizer.find_best_match(-amax/scale_0, xb, weight, best_idx+Q::kNg); + //auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx+Q::kNg); + //if (score_p > score_m) { + // scales[ib] = dp; + // for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j]; + //} else { + // scales[ib] = dm; + // for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j+Q::kNg]; + //} //float mse = 0; //for (int j = 0; j < Q::kNg; ++j) { @@ -4317,6 +4379,9 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f int ls = y[ibl].scales[ib]; float dl = d*ls; quantizer.find_best_match(dl, xb, weight, best_idx); + float dnew = quantizer.find_best_scale(xb, weight, best_idx).first; + ls = std::max(-128, std::min(127, nearest_int(dnew/d))); + y[ibl].scales[ib] = ls; for (int j = 0; j < Q::kNg; ++j) { qs[j] = best_idx[j];