From 4213ab1cb36d0f8fcfcae3156d03c8173293a511 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 14 Nov 2024 11:55:55 +0200 Subject: [PATCH] iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 8.9627 PPL(LLaMA-2-7B, 4096) = 6.3825 Quantization is faster too: ~200 seconds for LLaMA-3.1-8B on Ryzen-5975WX. --- ggml/src/iqk/iqk_quantize.cpp | 63 ++++++++++++++++++++++++++++++----- 1 file changed, 54 insertions(+), 9 deletions(-) diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 8e6d5e76..442b5b99 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -3707,7 +3707,8 @@ const QuantizerIQ2KT& iq2kt_quantizer() { return *quantizer; } -void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights) { +void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights, + float * qtmp) { constexpr float kSigmaScale = 2.0f; using Q = QuantizerIQ2KT; @@ -3718,7 +3719,7 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f block_iq2_kt * y = (block_iq2_kt *)(dptr + 1); - int best_idx[Q::kNg]; + int best_idx[2*Q::kNg]; auto& quantizer = iq2kt_quantizer(); @@ -3745,9 +3746,20 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f } quantizer.find_best_match( amax/96.f, xb, weight, best_idx); auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx); - quantizer.find_best_match(-amax/96.f, xb, weight, best_idx); - auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx); - scales[ib] = score_p > score_m ? dp : dm; + quantizer.find_best_match(-amax/96.f, xb, weight, best_idx + Q::kNg); + auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx + Q::kNg); + + auto idx = best_idx; + if (score_p > score_m) scales[ib] = dp; + else { + scales[ib] = dm; idx += Q::kNg; + } + auto qt = qtmp + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + for (int ig = 0; ig < Q::kNg; ++ig) { + auto q = quantizer.values() + idx[ig]*Q::kGroupSize; + for (int j = 0; j < Q::kGroupSize; ++j) qt[j] = q[j]; + qt += Q::kGroupSize; + } float abs_scale = std::abs(scales[ib]); if (abs_scale > amax_scale) { @@ -3758,7 +3770,39 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f } + if (!max_scale) { + *dptr = 0; + return; + } + float d = max_scale/iq4k_values[0]; + float best = 0; + for (int itry = -5; itry <= 5; ++itry) { + float id = (itry + iq4k_values[0])/max_scale; + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * xb = x + ibl*Q::kSuperBlockSize; + const float * qb = qtmp + ibl*Q::kSuperBlockSize; + const float * wb = all_weights + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + for (int ib = 0; ib < Q::kNblock; ++ib) { + int ls = best_index_iq4nl(iq4k_values, id*scales[ib]); + float dl = iq4k_values[ls]; + for (int j = 0; j < Q::kBlockSize; ++j) { + float q = dl*qb[j]; + sumqx += wb[j]*xb[j]*q; + sumq2 += wb[j]*q*q; + } + xb += Q::kBlockSize; + wb += Q::kBlockSize; + qb += Q::kBlockSize; + } + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; + } + } + float id = d ? 1/d : 0.f; for (int ibl = 0; ibl < nblock; ++ibl) { auto scales = all_scales + ibl*Q::kNblock; @@ -3769,12 +3813,12 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f } } + *dptr = d; if (!d) return; - d *= 1.05f; - *dptr = d; + //d *= 1.05f; - for (int iloop = 0; iloop < 2; ++iloop) { + for (int iloop = 0; iloop < 1; ++iloop) { float sumqx = 0, sumq2 = 0; for (int ibl = 0; ibl < nblock; ++ibl) { @@ -3856,9 +3900,10 @@ size_t quantize_iq2_kt(const float * src, void * dst, int64_t nrows, int64_t n_p auto row_size = ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row); std::vector scales(n_per_row/QuantizerIQ2KT::kBlockSize); std::vector weights(n_per_row); + std::vector xtmp(n_per_row); char * qrow = (char *)dst; for (int64_t row = 0; row < nrows; ++row) { - quantize_row_iq2_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data()); + quantize_row_iq2_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), xtmp.data()); src += n_per_row; qrow += row_size; }