From 215bea5c6ab04a51ea38b827493688468b482fbf Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 13 Nov 2024 16:36:08 +0200 Subject: [PATCH] iq3_kt: small improvements and faster quantization --- ggml/src/iqk/iqk_quantize.cpp | 149 ++++++++++++++++++++++++---------- 1 file changed, 106 insertions(+), 43 deletions(-) diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index e8ae64b7..8e6d5e76 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -3186,6 +3186,8 @@ public: } } + static inline int bin4(float x) { return x < -24.f ? 0 : x < 0.0f ? 1 : x < 24.f ? 2 : 3; } + static inline void set_weights(float sigma2_scale, int nblock, const float * x, const float * imatrix, float * row_weights) { for (int ibl = 0; ibl < nblock; ++ibl) { @@ -3393,52 +3395,60 @@ void QuantizerIQKT::find_best_match(float d, c auto xl = xb + 4*l; auto wl = weight + 4*l; auto vx4 = _mm_loadu_ps(xl); - auto vx_p = _mm256_mul_ps(vid_p, _mm256_set_m128(vx4, vx4)); + auto vx = _mm256_mul_ps(vid_p, _mm256_set_m128(vx4, vx4)); auto vw4 = _mm_loadu_ps(wl); auto vw = _mm256_set_m128(vw4, vw4); - auto vbest = _mm256_set1_ps(INFINITY); - auto best_index = _mm256_set1_epi32(-1); - float best = INFINITY; int jbest = -1; - auto idx = add_idx; - for (int j = 0; j < ncluster; j += 8) { - for (int i = 0; i < 4; ++i) { - auto vq = _mm256_loadu_ps(m_clusters.data() + kGroupSize*(j+2*i)); - auto vdiff = _mm256_sub_ps(vq, vx_p); - //vdiff = _mm256_mul_ps(vdiff, vdiff); - //sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); - //vdiff = _mm256_andnot_ps(sign_bit, vdiff); - vdiff = _mm256_and_ps(sign_bit, vdiff); - sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff))); + int jbest = -1; + if (ncluster == 256) { + _mm256_storeu_ps(sx, vx); + uint8_t u = 0; + for (int k = 0; k < 4; ++k) u |= (bin4(sx[k]) << 2*k); + jbest = u; + } else { + auto vbest = _mm256_set1_ps(INFINITY); + auto best_index = _mm256_set1_epi32(-1); + float best = INFINITY; + auto idx = add_idx; + for (int j = 0; j < ncluster; j += 8) { + for (int i = 0; i < 4; ++i) { + auto vq = _mm256_loadu_ps(m_clusters.data() + kGroupSize*(j+2*i)); + auto vdiff = _mm256_sub_ps(vq, vx); + //vdiff = _mm256_mul_ps(vdiff, vdiff); + //sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); + //vdiff = _mm256_andnot_ps(sign_bit, vdiff); + vdiff = _mm256_and_ps(sign_bit, vdiff); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff))); + } + auto score = hsum_float_4x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + idx = _mm256_add_epi32(idx, add8); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] < best) { best = sx[i]; jbest = index[i]; } } - auto score = hsum_float_4x8(sqx); - auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); - best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), - _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); - vbest = _mm256_min_ps(vbest, score); - idx = _mm256_add_epi32(idx, add8); - } - _mm256_store_ps(sx, vbest); - _mm256_store_si256((__m256i *)index, best_index); - for (int i = 0; i < 8; ++i) { - if (sx[i] < best) { best = sx[i]; jbest = index[i]; } } auto& points = m_in_cluster[jbest]; auto& values = m_c_values[jbest]; GGML_ASSERT(!points.empty() && points.size()%8 == 0); int jbest_cluster = jbest; - vbest = _mm256_set1_ps(INFINITY); - best_index = _mm256_set1_epi32(-1); - best = INFINITY; jbest = -1; - idx = add_idx; + auto vbest = _mm256_set1_ps(INFINITY); + auto best_index = _mm256_set1_epi32(-1); + float best = INFINITY; jbest = -1; + auto idx = add_idx; for (int j = 0; j < int(points.size()); j += 8) { for (int i = 0; i < 4; ++i) { auto vq = _mm256_loadu_ps(values.data() + kGroupSize*(j+2*i)); - auto vdiff = _mm256_sub_ps(vq, vx_p); + auto vdiff = _mm256_sub_ps(vq, vx); //vdiff = _mm256_mul_ps(vdiff, vdiff); - //sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); //vdiff = _mm256_andnot_ps(sign_bit, vdiff); - vdiff = _mm256_and_ps(sign_bit, vdiff); - sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff))); + //vdiff = _mm256_and_ps(sign_bit, vdiff); + //sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff))); } auto score = hsum_float_4x8(sqx); auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); @@ -3574,7 +3584,7 @@ std::vector QuantizerIQKT::cluster_poin std::vector sump(ncluster*ndim); std::vector counts(ncluster); std::vector result(ncluster*ndim); - if (group_size == 8 && ncluster == 256) { + if (ndim == 8 && ncluster == 256) { std::memset(sump.data(), 0, sump.size()*sizeof(float)); std::memset(counts.data(), 0, counts.size()*sizeof(int)); for (int ip = 0; ip < npoint; ++ip) { @@ -3593,6 +3603,52 @@ std::vector QuantizerIQKT::cluster_poin } return result; } + else if (ndim == 4 && ncluster == 256) { + std::memset(sump.data(), 0, sump.size()*sizeof(float)); + std::memset(counts.data(), 0, counts.size()*sizeof(int)); + //printf("%s: simple with group size %d\n", __func__, group_size); + //printf("%s: midpoints = %g, %g, %g, %g\n", __func__, mid[0], mid[1], mid[2], mid[3]); + for (int ip = 0; ip < npoint; ++ip) { + auto vp = points.data() + ndim*ip; + uint8_t u = 0; + for (int k = 0; k < ndim; ++k) u |= (bin4(vp[k]) << 2*k); + ++counts[u]; + for (int k = 0; k < ndim; ++k) sump[ndim*u + k] += vp[k]; + } + int nzero = 0; + for (int ic = 0; ic < ncluster; ++ic) { + if (!counts[ic]) { + ++nzero; + printf("%s: Oops. Cluster %d has no points: ", __func__, ic); + for (int k = 0; k < ndim; ++k) { + int l = (ic >> 2*k) & 3; + printf(" %d", l); + } + printf("\n"); + //GGML_ABORT("fatal error"); + } else { + for (int k = 0; k < ndim; ++k) result[ic*ndim + k] = sump[ic*ndim + k]/counts[ic]; + } + } + if (nzero > 0) printf("%s: %d out of %d clusters dir not have any points\n", __func__, nzero, ncluster); + //counts.resize(ndim*ncluster); + //auto fcounts = (float *)counts.data(); + //std::memset(fcounts, 0, counts.size()*sizeof(float)); + //for (int ip = 0; ip < npoint; ++ip) { + // auto vp = points.data() + ndim*ip; + // uint8_t u = 0; + // for (int k = 0; k < ndim; ++k) u |= (bin4(vp[k]) << 2*k); + // for (int k = 0; k < ndim; ++k) { + // float w = std::abs(vp[k]); + // sump[ndim*u + k] += w*vp[k]; + // fcounts[ndim*u + k] += w; + // } + //} + //for (int ic = 0; ic < ncluster; ++ic) { + // for (int k = 0; k < ndim; ++k) result[ic*ndim + k] = fcounts[ic*ndim + k] > 0 ? sump[ic*ndim + k]/fcounts[ic*ndim + k] : 0.f; + //} + return result; + } std::mt19937 rndm(1234); float scale = 1.f/4294967296.f; for (int i = 0; i < ncluster; ++i) { @@ -3859,7 +3915,8 @@ const QuantizerIQ3KT& iq3kt_quantizer() { static std::mutex mutex; std::lock_guard lock(mutex); static std::unique_ptr quantizer; - if (!quantizer) quantizer = std::make_unique(64, 5); + if (!quantizer) quantizer = std::make_unique(256, 16); + //if (!quantizer) quantizer = std::make_unique(64, 5); return *quantizer; } @@ -3921,6 +3978,14 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f } scales[ib] = 0; if (!amax) continue; + + //quantizer.find_best_match( amax/96.f, xb, weight, best_idx); + //auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx); + //quantizer.find_best_match(-amax/96.f, xb, weight, best_idx); + //auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx); + //scales[ib] = score_p > score_m ? dp : dm; + + //float scale_0 = std::max(80.f, 127.f*amax/amax_row); float scale_0 = std::max(80.f, 127.f*amax/amax_row); float best = 0; for (int itry = -3; itry <= 3; ++itry) { @@ -3938,11 +4003,11 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f } } - for (int j = 0; j < Q::kNg; ++j) { - int jj = ib*Q::kNg + j; - y[ibl].ql[jj] = best_idx[j] & 255; - y[ibl].qh[jj%(kNumGroups/2)] |= ((best_idx[j] >> 8) << 4*(jj/(kNumGroups/2))); - } + //for (int j = 0; j < Q::kNg; ++j) { + // int jj = ib*Q::kNg + j; + // y[ibl].ql[jj] = best_idx[j] & 255; + // y[ibl].qh[jj%(kNumGroups/2)] |= ((best_idx[j] >> 8) << 4*(jj/(kNumGroups/2))); + //} float abs_scale = std::abs(scales[ib]); if (abs_scale > amax_scale) { @@ -3964,13 +4029,10 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f } } - //d *= 1.05f; *dptr = d; for (int iloop = 0; iloop < 1; ++iloop) { - //d *= 1.05f; - float sumqx = 0, sumq2 = 0; for (int ibl = 0; ibl < nblock; ++ibl) { @@ -4010,6 +4072,7 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f if (sumq2 > 0) { d = sumqx/sumq2; *dptr = d; + if (!d) break; } else { break; }