From 766fa600c802dc8ec1cd4e50848047f7e2394d0a Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 6 Nov 2024 16:24:17 +0200 Subject: [PATCH] WIP - try larger blocks With blocks of 32 and 16 bits per groups of 8 the brute force seach becomes prohibitive in terms of CPU time (30+ minutes for 8B LLaMA after SIMDifying with AVX2). The trick is to group the points in clusters, find the nearest cluster, and only search within the cluster. --- examples/quantize-stats/quantize-stats.cpp | 251 +++++++++++++++++++-- 1 file changed, 233 insertions(+), 18 deletions(-) diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 28103350..312e3af5 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -426,24 +427,180 @@ static float find_best_scale(int block_size, const float * xb, const float * wei return d; } +static std::vector cluster_points(const std::vector& points, int ndim, int ncluster, int niter) { + if (points.size() % ndim != 0) { + printf("%s: bad input\n", __func__); return {}; + } + int npoint = points.size() / ndim; + if (npoint < 2*ncluster) { + printf("%s: bad input\n", __func__); return {}; + } + std::vector> range(ndim, std::make_pair(INFINITY, -INFINITY)); + double Fo = 0; + for (int i = 0; i < npoint; ++i) { + auto v = points.data() + i*ndim; + for (int k = 0; k < ndim; ++k) { + Fo += v[k]*v[k]; + range[k].first = std::min(range[k].first, v[k]); + range[k].second = std::max(range[k].second, v[k]); + } + } + printf("%s (ndim = %d, npoint = %d): Fo = %g\n", __func__, ndim, npoint, Fo/points.size()); + std::mt19937 rndm(1234); + float scale = 1.f/4294967296.f; + std::vector result(ncluster*ndim); + for (int i = 0; i < ncluster; ++i) { + auto v = result.data() + i*ndim; + for (int k = 0; k < ndim; ++k) v[k] = range[k].first + (range[k].second - range[k].first)*scale*rndm(); + } + std::vector sump(ncluster*ndim); + std::vector counts(ncluster); + std::vector which_cluster(npoint, -1); + double Flast = Fo; + for (int iter = 0; iter < niter; ++iter) { + std::memset(sump.data(), 0, sump.size()*sizeof(float)); + std::memset(counts.data(), 0, counts.size()*sizeof(int)); + int nchanged = 0; + double F = 0; + for (int ip = 0; ip < npoint; ++ip) { + auto vp = points.data() + ndim*ip; + float best = INFINITY; int ibest = -1; + for (int ic = 0; ic < ncluster; ++ic) { + auto vc = result.data() + ndim*ic; + float dist2 = 0; + for (int k = 0; k < ndim; ++k) { + float d = vp[k] - vc[k]; dist2 += d*d; + } + if (dist2 < best) { + best = dist2; ibest = ic; + } + } + if (ibest < 0) { printf("Oops.\n"); exit(1); } + F += best; + if (which_cluster[ip] != ibest) ++nchanged; + which_cluster[ip] = ibest; + ++counts[ibest]; + auto vc = sump.data() + ndim*ibest; + for (int k = 0; k < ndim; ++k) vc[k] += vp[k]; + } + if (nchanged == 0) break; + for (int ic = 0; ic < ncluster; ++ic) { + float norm = counts[ic] > 0 ? 1.f/counts[ic] : 0.f; + auto vc = sump.data() + ndim*ic; + auto r = result.data() + ndim*ic; + for (int k = 0; k < ndim; ++k) r[k] = vc[k]*norm; + } + printf("%s(iteration %d): F = %g, nchanged = %d\n", __func__, iter+1, F/points.size(), nchanged); + if (iter > 1 && Flast/F - 1 < 1e-6) break; + Flast = F; + } + return result; +} + static void analyze_x_v2(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_mse_q, float& tot_elements) { - constexpr int kNumVal = 1 << 15; + constexpr int kNumVal = 1 << 16; constexpr int kBlockSize = 32; constexpr int kGroupSize = 8; constexpr int kNg = kBlockSize/kGroupSize; constexpr int kSuperBlockSize = 256; static_assert(kNumVal%8 == 0); - auto codes = make_values(kNumVal, kGroupSize, 31.75f); + static std::vector codes, clusters; + static std::vector> p_in_cluster; + if (codes.empty()) { + codes = make_values(kNumVal, kGroupSize, 31.75f); + clusters = cluster_points(codes, kGroupSize, kNumVal/1024, 200); + if (clusters.empty()) { printf("Oops\n"); exit(1); } + int ncluster = clusters.size()/kGroupSize; + p_in_cluster.resize(ncluster); + std::vector which_cluster(4*kNumVal); + GGML_ASSERT(ncluster%8 == 0); + for (int ip = 0; ip < kNumVal; ++ip) { + auto vp = codes.data() + ip*kGroupSize; + float best[4] = {INFINITY, INFINITY, INFINITY, INFINITY}; + int ibest[4] = {-1, -1, -1, -1}; + for (int ic = 0; ic < ncluster; ++ic) { + auto vc = clusters.data() + ic*kGroupSize; + float dist2 = 0; + for (int k = 0; k < kGroupSize; ++k) { + float d = vp[k] - vc[k]; dist2 += d*d; + } + if (dist2 < best[0]) { + best[3] = best[2]; ibest[3] = ibest[2]; + best[2] = best[1]; ibest[2] = ibest[1]; + best[1] = best[0]; ibest[1] = ibest[0]; + best[0] = dist2; ibest[0] = ic; + } + else if (dist2 < best[1]) { + best[3] = best[2]; ibest[3] = ibest[2]; + best[2] = best[1]; ibest[2] = ibest[1]; + best[1] = dist2; ibest[1] = ic; + } + else if (dist2 < best[2]) { + best[3] = best[2]; ibest[3] = ibest[2]; + best[2] = dist2; ibest[2] = ic; + } + else if (dist2 < best[3]) { + best[3] = dist2; ibest[3] = ic; + } + } + GGML_ASSERT(ibest[0] >= 0 && ibest[1] >= 0 && ibest[2] >= 0 && ibest[3] >= 0); + p_in_cluster[ibest[0]].push_back(ip); + p_in_cluster[ibest[1]].push_back(ip); + p_in_cluster[ibest[2]].push_back(ip); + p_in_cluster[ibest[3]].push_back(ip); + std::memcpy(which_cluster.data() + 4*ip, ibest, 4*sizeof(int)); + } + std::vector> extra; + extra.reserve(kNumVal); + for (int ic = 0; ic < ncluster; ++ic) { + auto& points = p_in_cluster[ic]; + if (!points.empty() && points.size()%8 == 0) continue; + extra.clear(); + auto vc = clusters.data() + ic*kGroupSize; + for (int ip = 0; ip < kNumVal; ++ip) { + if (which_cluster[4*ip] == ic || which_cluster[4*ip+1] == ic || which_cluster[4*ip+2] == ic || which_cluster[4*ip+3] == ic) continue; + auto vp = codes.data() + ip*kGroupSize; + float dist2 = 0; + for (int k = 0; k < kGroupSize; ++k) { + float d = vp[k] - vc[k]; dist2 += d*d; + } + extra.push_back(std::make_pair(dist2, ip)); + } + std::sort(extra.begin(), extra.end()); + int nadd = 8*((points.size()+7)/8) - points.size(); + for (int i = 0; i < nadd; ++i) points.push_back(extra[i].second); + GGML_ASSERT(points.size()%8 == 0); + } + auto min = p_in_cluster.front().size(), max = p_in_cluster.front().size(); + int nzero = 0; + for (auto& points : p_in_cluster) { + min = std::min(min, points.size()); + max = std::max(max, points.size()); + if (points.empty()) ++nzero; + } + printf("%s: prepared %d clusters\n", __func__, ncluster); + printf(" min number of points in a cluster: %d\n", int(min)); + printf(" max number of points in a cluster: %d\n", int(max)); + if (nzero > 0) { + printf(" there are %d empty clusters\n", nzero); + for (auto& points : p_in_cluster) { + if (!points.empty()) continue; + points.reserve(kNumVal); + for (int j = 0; j < kNumVal; ++j) points.push_back(j); // i.e., if we end iup picking an empty cluster, we just check all points + } + } + } int nthread = std::max(1, int(std::thread::hardware_concurrency()/2)); int chunk = (nrows + 8*nthread - 1)/(8*nthread); std::mutex mutex; int counter = 0; float mse = 0, mse_q = 0; - auto compute = [&mutex, &counter, &mse, &mse_q, &codes, values, nrows, n_per_row, chunk] () { - float lmse = 0, lmse_q = 0; + auto compute = [&mutex, &counter, &mse, &mse_q, values, nrows, n_per_row, chunk] () { + double lmse = 0, lmse_q = 0; std::vector scales(n_per_row/kBlockSize); std::vector best_idx(n_per_row/kGroupSize); std::vector weight(kBlockSize, 1.f); + int ncluster = clusters.size() / kGroupSize; while (true) { std::unique_lock lock(mutex); int first = counter; counter += chunk; @@ -464,17 +621,7 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa float sigma2 = 0; for (int j = 0; j < n_per_row; ++j) sigma2 += xr[j]*xr[j]; sigma2 /= n_per_row; - //int last_ibl = -1; for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) { - //int ibl = ib*kBlockSize/kSuperBlockSize; - //if (ibl != last_ibl) { - // auto xbl = xr + ibl*kSuperBlockSize; - // int n = kSuperBlockSize*(ibl + 1) <= n_per_row ? kSuperBlockSize : n_per_row - ibl*kSuperBlockSize; - // float sumx2 = 0; - // for (int i = 0; i < n; ++i) sumx2 += xbl[i]*xbl[i]; - // sigma2 = sumx2/n; - // last_ibl = ibl; - //} auto xb = xr + kBlockSize*ib; for (int i = 0; i < kBlockSize; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i]; float d = find_best_scale(kBlockSize, xb, weight.data(), iq4k_values, 5); @@ -482,15 +629,17 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa #ifdef __AVX2__ auto vid = _mm256_set1_ps(id); for (int l = 0; l < kNg; ++l) { - auto vx = _mm256_mul_ps(vid, _mm256_loadu_ps(xb+8*l)); - auto vw = _mm256_loadu_ps(weight.data() + 8*l); + auto xl = xb + 8*l; + auto wl = weight.data() + 8*l; + auto vx = _mm256_mul_ps(vid, _mm256_loadu_ps(xl)); + auto vw = _mm256_loadu_ps(wl); auto vbest = _mm256_set1_ps(INFINITY); auto best_index = _mm256_set1_epi32(-1); float best = INFINITY; int jbest = -1; - for (int j = 0; j < kNumVal; j += 8) { + for (int j = 0; j < ncluster; j += 8) { auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx); for (int i = 0; i < 8; ++i) { - auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*(j+i)); + auto vq = _mm256_loadu_ps(clusters.data() + kGroupSize*(j+i)); auto vdiff = _mm256_sub_ps(vq, vx); sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); } @@ -505,6 +654,72 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa for (int i = 0; i < 8; ++i) { if (sx[i] < best) { best = sx[i]; jbest = index[i]; } } + auto& points = p_in_cluster[jbest]; + if (points.empty()) { + printf("Oops: empty cluster %d\n", jbest); + auto vc = clusters.data() + kGroupSize*jbest; + printf("Cluster:\n"); + for (int j = 0; j < kGroupSize; ++j) printf("%d %g %g\n", j, vc[j], xl[j]); + GGML_ASSERT(false); + } + int jbest_cluster = jbest; + vbest = _mm256_set1_ps(INFINITY); + best_index = _mm256_set1_epi32(-1); + best = INFINITY; jbest = -1; + for (int j = 0; j < int(points.size()); j += 8) { + auto idx = _mm256_loadu_si256((const __m256i*)(points.data() + j)); + for (int i = 0; i < 8; ++i) { + auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*points[j+i]); + auto vdiff = _mm256_sub_ps(vq, vx); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); + } + auto score = hsum_float_8x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] < best) { best = sx[i]; jbest = index[i]; } + } + //int jbest_cluster = jbest; + //best = INFINITY; jbest = -1; + //for (auto ip : points) { + // auto vc = codes.data() + ip*kGroupSize; + // float diff2 = 0; + // for (int k = 0; k < kGroupSize; ++k) { + // float delta = d*vc[k] - xl[k]; + // diff2 += wl[k]*delta*delta; + // } + // if (diff2 < best) { + // best = diff2; jbest = ip; + // } + //} + if (jbest < 0) { + printf("Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size())); + GGML_ASSERT(false); + } + GGML_ASSERT(jbest >= 0); + //for (int j = 0; j < kNumVal; j += 8) { + // auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx); + // for (int i = 0; i < 8; ++i) { + // auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*(j+i)); + // auto vdiff = _mm256_sub_ps(vq, vx); + // sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); + // } + // auto score = hsum_float_8x8(sqx); + // auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + // best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + // _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + // vbest = _mm256_min_ps(vbest, score); + //} + //_mm256_store_ps(sx, vbest); + //_mm256_store_si256((__m256i *)index, best_index); + //for (int i = 0; i < 8; ++i) { + // if (sx[i] < best) { best = sx[i]; jbest = index[i]; } + //} best_idx[ib*kNg + l] = jbest; } auto vqx = _mm256_setzero_ps();