From b3dfe9984beb907f6d5ec68ae1071b1d505dd6e9 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 7 Nov 2024 08:38:20 +0200 Subject: [PATCH] iq2_kt - even better Re-quantize after determining block scales (at the epxense of much longer quantization time). --- ggml/src/iqk/iqk_quantize.cpp | 60 +++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 45f4d32b..e19a3801 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -3153,7 +3153,7 @@ public: constexpr static bool kVerbose = false; QuantizerIQ2KT(); - //const float * values() const { return m_values.data(); } + const float * values() const { return m_values.data(); } inline void find_best_match(float d, const float * xb, const float * weight, int * best_idx) const; inline float find_best_scale(const float * xb, const float * weight, const int * best_idx) const; @@ -3188,6 +3188,9 @@ QuantizerIQ2KT::QuantizerIQ2KT() { set_values(i, data, kScale); data += kGroupSize; } + // Make 128 clusters. + // Note: we get a slightly better result by using 64 clusters + // at the expense of almost doubling the quantization time. m_clusters = cluster_points(m_values, kNumVal/512, 200); GGML_ASSERT(!m_clusters.empty()); m_in_cluster = finalize_clusters(m_values, m_clusters); @@ -3447,6 +3450,8 @@ const QuantizerIQ2KT& iq2kt_quantizer() { } void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales) { + constexpr float kSigmaScale = 2.0f; + static_assert(QuantizerIQ2KT::kNumVal%8 == 0); float * dptr = (float *)vy; @@ -3471,7 +3476,7 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f const float * xbl = x + ibl*QuantizerIQ2KT::kSuperBlockSize; float sumx2 = 0; for (int j = 0; j < QuantizerIQ2KT::kSuperBlockSize; ++j) sumx2 += xbl[j]*xbl[j]; - const float sigma2 = 1.5f*sumx2/QuantizerIQ2KT::kSuperBlockSize; + const float sigma2 = kSigmaScale*sumx2/QuantizerIQ2KT::kSuperBlockSize; auto scales = all_scales + ibl*QuantizerIQ2KT::kNblock; @@ -3506,7 +3511,7 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f float d = max_scale/iq4k_values[0]; float id = d ? 1/d : 0.f; - *dptr = d; + //*dptr = d; for (int ibl = 0; ibl < nblock; ++ibl) { auto scales = all_scales + ibl*QuantizerIQ2KT::kNblock; for (int ib = 0; ib < QuantizerIQ2KT::kNblock/2; ++ib) { @@ -3516,6 +3521,55 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f } } + d *= 1.05f; + *dptr = d; + + for (int iloop = 0; iloop < 2; ++iloop) { + + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + + auto qs = (uint16_t *)y[ibl].ql; + const float * xbl = x + ibl*QuantizerIQ2KT::kSuperBlockSize; + float sumx2 = 0; + for (int j = 0; j < QuantizerIQ2KT::kSuperBlockSize; ++j) sumx2 += xbl[j]*xbl[j]; + const float sigma2 = kSigmaScale*sumx2/QuantizerIQ2KT::kSuperBlockSize; + + for (int ib = 0; ib < QuantizerIQ2KT::kNblock; ++ib) { + const float * xb = xbl + QuantizerIQ2KT::kBlockSize*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QuantizerIQ2KT::kSuperBlockSize + ib*QuantizerIQ2KT::kBlockSize; + for (int j = 0; j < QuantizerIQ2KT::kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < QuantizerIQ2KT::kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + int ls = iq4k_values[(y[ibl].scales[ib%(QuantizerIQ2KT::kNblock/2)] >> 4*(ib/(QuantizerIQ2KT::kNblock/2))) & 0xf]; + float dl = d*ls; + quantizer.find_best_match(dl, xb, weight, best_idx); + + for (int j = 0; j < QuantizerIQ2KT::kNg; ++j) { + qs[j] = best_idx[j]; + auto xl = xb + QuantizerIQ2KT::kGroupSize*j; + auto wl = weight + QuantizerIQ2KT::kGroupSize*j; + auto ql = quantizer.values() + best_idx[j]*QuantizerIQ2KT::kGroupSize; + for (int k = 0; k < QuantizerIQ2KT::kGroupSize; ++k) { + float q = ql[k]*ls; + sumqx += wl[k]*xl[k]*q; + sumq2 += wl[k]*q*q; + } + } + qs += QuantizerIQ2KT::kNg; + } + } + if (sumq2 > 0) { + d = sumqx/sumq2; + *dptr = d; + } else { + break; + } + + } + } }