mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 23:24:13 +00:00
WIP - try larger blocks
With blocks of 32 and 16 bits per groups of 8 the brute force seach becomes prohibitive in terms of CPU time (30+ minutes for 8B LLaMA after SIMDifying with AVX2). The trick is to group the points in clusters, find the nearest cluster, and only search within the cluster.
This commit is contained in:
@@ -22,6 +22,7 @@
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <array>
|
||||
#include <random>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
@@ -426,24 +427,180 @@ static float find_best_scale(int block_size, const float * xb, const float * wei
|
||||
return d;
|
||||
}
|
||||
|
||||
static std::vector<float> cluster_points(const std::vector<float>& points, int ndim, int ncluster, int niter) {
|
||||
if (points.size() % ndim != 0) {
|
||||
printf("%s: bad input\n", __func__); return {};
|
||||
}
|
||||
int npoint = points.size() / ndim;
|
||||
if (npoint < 2*ncluster) {
|
||||
printf("%s: bad input\n", __func__); return {};
|
||||
}
|
||||
std::vector<std::pair<float, float>> range(ndim, std::make_pair(INFINITY, -INFINITY));
|
||||
double Fo = 0;
|
||||
for (int i = 0; i < npoint; ++i) {
|
||||
auto v = points.data() + i*ndim;
|
||||
for (int k = 0; k < ndim; ++k) {
|
||||
Fo += v[k]*v[k];
|
||||
range[k].first = std::min(range[k].first, v[k]);
|
||||
range[k].second = std::max(range[k].second, v[k]);
|
||||
}
|
||||
}
|
||||
printf("%s (ndim = %d, npoint = %d): Fo = %g\n", __func__, ndim, npoint, Fo/points.size());
|
||||
std::mt19937 rndm(1234);
|
||||
float scale = 1.f/4294967296.f;
|
||||
std::vector<float> result(ncluster*ndim);
|
||||
for (int i = 0; i < ncluster; ++i) {
|
||||
auto v = result.data() + i*ndim;
|
||||
for (int k = 0; k < ndim; ++k) v[k] = range[k].first + (range[k].second - range[k].first)*scale*rndm();
|
||||
}
|
||||
std::vector<float> sump(ncluster*ndim);
|
||||
std::vector<int> counts(ncluster);
|
||||
std::vector<int> which_cluster(npoint, -1);
|
||||
double Flast = Fo;
|
||||
for (int iter = 0; iter < niter; ++iter) {
|
||||
std::memset(sump.data(), 0, sump.size()*sizeof(float));
|
||||
std::memset(counts.data(), 0, counts.size()*sizeof(int));
|
||||
int nchanged = 0;
|
||||
double F = 0;
|
||||
for (int ip = 0; ip < npoint; ++ip) {
|
||||
auto vp = points.data() + ndim*ip;
|
||||
float best = INFINITY; int ibest = -1;
|
||||
for (int ic = 0; ic < ncluster; ++ic) {
|
||||
auto vc = result.data() + ndim*ic;
|
||||
float dist2 = 0;
|
||||
for (int k = 0; k < ndim; ++k) {
|
||||
float d = vp[k] - vc[k]; dist2 += d*d;
|
||||
}
|
||||
if (dist2 < best) {
|
||||
best = dist2; ibest = ic;
|
||||
}
|
||||
}
|
||||
if (ibest < 0) { printf("Oops.\n"); exit(1); }
|
||||
F += best;
|
||||
if (which_cluster[ip] != ibest) ++nchanged;
|
||||
which_cluster[ip] = ibest;
|
||||
++counts[ibest];
|
||||
auto vc = sump.data() + ndim*ibest;
|
||||
for (int k = 0; k < ndim; ++k) vc[k] += vp[k];
|
||||
}
|
||||
if (nchanged == 0) break;
|
||||
for (int ic = 0; ic < ncluster; ++ic) {
|
||||
float norm = counts[ic] > 0 ? 1.f/counts[ic] : 0.f;
|
||||
auto vc = sump.data() + ndim*ic;
|
||||
auto r = result.data() + ndim*ic;
|
||||
for (int k = 0; k < ndim; ++k) r[k] = vc[k]*norm;
|
||||
}
|
||||
printf("%s(iteration %d): F = %g, nchanged = %d\n", __func__, iter+1, F/points.size(), nchanged);
|
||||
if (iter > 1 && Flast/F - 1 < 1e-6) break;
|
||||
Flast = F;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static void analyze_x_v2(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_mse_q, float& tot_elements) {
|
||||
constexpr int kNumVal = 1 << 15;
|
||||
constexpr int kNumVal = 1 << 16;
|
||||
constexpr int kBlockSize = 32;
|
||||
constexpr int kGroupSize = 8;
|
||||
constexpr int kNg = kBlockSize/kGroupSize;
|
||||
constexpr int kSuperBlockSize = 256;
|
||||
static_assert(kNumVal%8 == 0);
|
||||
auto codes = make_values(kNumVal, kGroupSize, 31.75f);
|
||||
static std::vector<float> codes, clusters;
|
||||
static std::vector<std::vector<int>> p_in_cluster;
|
||||
if (codes.empty()) {
|
||||
codes = make_values(kNumVal, kGroupSize, 31.75f);
|
||||
clusters = cluster_points(codes, kGroupSize, kNumVal/1024, 200);
|
||||
if (clusters.empty()) { printf("Oops\n"); exit(1); }
|
||||
int ncluster = clusters.size()/kGroupSize;
|
||||
p_in_cluster.resize(ncluster);
|
||||
std::vector<int> which_cluster(4*kNumVal);
|
||||
GGML_ASSERT(ncluster%8 == 0);
|
||||
for (int ip = 0; ip < kNumVal; ++ip) {
|
||||
auto vp = codes.data() + ip*kGroupSize;
|
||||
float best[4] = {INFINITY, INFINITY, INFINITY, INFINITY};
|
||||
int ibest[4] = {-1, -1, -1, -1};
|
||||
for (int ic = 0; ic < ncluster; ++ic) {
|
||||
auto vc = clusters.data() + ic*kGroupSize;
|
||||
float dist2 = 0;
|
||||
for (int k = 0; k < kGroupSize; ++k) {
|
||||
float d = vp[k] - vc[k]; dist2 += d*d;
|
||||
}
|
||||
if (dist2 < best[0]) {
|
||||
best[3] = best[2]; ibest[3] = ibest[2];
|
||||
best[2] = best[1]; ibest[2] = ibest[1];
|
||||
best[1] = best[0]; ibest[1] = ibest[0];
|
||||
best[0] = dist2; ibest[0] = ic;
|
||||
}
|
||||
else if (dist2 < best[1]) {
|
||||
best[3] = best[2]; ibest[3] = ibest[2];
|
||||
best[2] = best[1]; ibest[2] = ibest[1];
|
||||
best[1] = dist2; ibest[1] = ic;
|
||||
}
|
||||
else if (dist2 < best[2]) {
|
||||
best[3] = best[2]; ibest[3] = ibest[2];
|
||||
best[2] = dist2; ibest[2] = ic;
|
||||
}
|
||||
else if (dist2 < best[3]) {
|
||||
best[3] = dist2; ibest[3] = ic;
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(ibest[0] >= 0 && ibest[1] >= 0 && ibest[2] >= 0 && ibest[3] >= 0);
|
||||
p_in_cluster[ibest[0]].push_back(ip);
|
||||
p_in_cluster[ibest[1]].push_back(ip);
|
||||
p_in_cluster[ibest[2]].push_back(ip);
|
||||
p_in_cluster[ibest[3]].push_back(ip);
|
||||
std::memcpy(which_cluster.data() + 4*ip, ibest, 4*sizeof(int));
|
||||
}
|
||||
std::vector<std::pair<float, int>> extra;
|
||||
extra.reserve(kNumVal);
|
||||
for (int ic = 0; ic < ncluster; ++ic) {
|
||||
auto& points = p_in_cluster[ic];
|
||||
if (!points.empty() && points.size()%8 == 0) continue;
|
||||
extra.clear();
|
||||
auto vc = clusters.data() + ic*kGroupSize;
|
||||
for (int ip = 0; ip < kNumVal; ++ip) {
|
||||
if (which_cluster[4*ip] == ic || which_cluster[4*ip+1] == ic || which_cluster[4*ip+2] == ic || which_cluster[4*ip+3] == ic) continue;
|
||||
auto vp = codes.data() + ip*kGroupSize;
|
||||
float dist2 = 0;
|
||||
for (int k = 0; k < kGroupSize; ++k) {
|
||||
float d = vp[k] - vc[k]; dist2 += d*d;
|
||||
}
|
||||
extra.push_back(std::make_pair(dist2, ip));
|
||||
}
|
||||
std::sort(extra.begin(), extra.end());
|
||||
int nadd = 8*((points.size()+7)/8) - points.size();
|
||||
for (int i = 0; i < nadd; ++i) points.push_back(extra[i].second);
|
||||
GGML_ASSERT(points.size()%8 == 0);
|
||||
}
|
||||
auto min = p_in_cluster.front().size(), max = p_in_cluster.front().size();
|
||||
int nzero = 0;
|
||||
for (auto& points : p_in_cluster) {
|
||||
min = std::min(min, points.size());
|
||||
max = std::max(max, points.size());
|
||||
if (points.empty()) ++nzero;
|
||||
}
|
||||
printf("%s: prepared %d clusters\n", __func__, ncluster);
|
||||
printf(" min number of points in a cluster: %d\n", int(min));
|
||||
printf(" max number of points in a cluster: %d\n", int(max));
|
||||
if (nzero > 0) {
|
||||
printf(" there are %d empty clusters\n", nzero);
|
||||
for (auto& points : p_in_cluster) {
|
||||
if (!points.empty()) continue;
|
||||
points.reserve(kNumVal);
|
||||
for (int j = 0; j < kNumVal; ++j) points.push_back(j); // i.e., if we end iup picking an empty cluster, we just check all points
|
||||
}
|
||||
}
|
||||
}
|
||||
int nthread = std::max(1, int(std::thread::hardware_concurrency()/2));
|
||||
int chunk = (nrows + 8*nthread - 1)/(8*nthread);
|
||||
std::mutex mutex;
|
||||
int counter = 0;
|
||||
float mse = 0, mse_q = 0;
|
||||
auto compute = [&mutex, &counter, &mse, &mse_q, &codes, values, nrows, n_per_row, chunk] () {
|
||||
float lmse = 0, lmse_q = 0;
|
||||
auto compute = [&mutex, &counter, &mse, &mse_q, values, nrows, n_per_row, chunk] () {
|
||||
double lmse = 0, lmse_q = 0;
|
||||
std::vector<float> scales(n_per_row/kBlockSize);
|
||||
std::vector<int> best_idx(n_per_row/kGroupSize);
|
||||
std::vector<float> weight(kBlockSize, 1.f);
|
||||
int ncluster = clusters.size() / kGroupSize;
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
int first = counter; counter += chunk;
|
||||
@@ -464,17 +621,7 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa
|
||||
float sigma2 = 0;
|
||||
for (int j = 0; j < n_per_row; ++j) sigma2 += xr[j]*xr[j];
|
||||
sigma2 /= n_per_row;
|
||||
//int last_ibl = -1;
|
||||
for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) {
|
||||
//int ibl = ib*kBlockSize/kSuperBlockSize;
|
||||
//if (ibl != last_ibl) {
|
||||
// auto xbl = xr + ibl*kSuperBlockSize;
|
||||
// int n = kSuperBlockSize*(ibl + 1) <= n_per_row ? kSuperBlockSize : n_per_row - ibl*kSuperBlockSize;
|
||||
// float sumx2 = 0;
|
||||
// for (int i = 0; i < n; ++i) sumx2 += xbl[i]*xbl[i];
|
||||
// sigma2 = sumx2/n;
|
||||
// last_ibl = ibl;
|
||||
//}
|
||||
auto xb = xr + kBlockSize*ib;
|
||||
for (int i = 0; i < kBlockSize; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i];
|
||||
float d = find_best_scale(kBlockSize, xb, weight.data(), iq4k_values, 5);
|
||||
@@ -482,15 +629,17 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa
|
||||
#ifdef __AVX2__
|
||||
auto vid = _mm256_set1_ps(id);
|
||||
for (int l = 0; l < kNg; ++l) {
|
||||
auto vx = _mm256_mul_ps(vid, _mm256_loadu_ps(xb+8*l));
|
||||
auto vw = _mm256_loadu_ps(weight.data() + 8*l);
|
||||
auto xl = xb + 8*l;
|
||||
auto wl = weight.data() + 8*l;
|
||||
auto vx = _mm256_mul_ps(vid, _mm256_loadu_ps(xl));
|
||||
auto vw = _mm256_loadu_ps(wl);
|
||||
auto vbest = _mm256_set1_ps(INFINITY);
|
||||
auto best_index = _mm256_set1_epi32(-1);
|
||||
float best = INFINITY; int jbest = -1;
|
||||
for (int j = 0; j < kNumVal; j += 8) {
|
||||
for (int j = 0; j < ncluster; j += 8) {
|
||||
auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*(j+i));
|
||||
auto vq = _mm256_loadu_ps(clusters.data() + kGroupSize*(j+i));
|
||||
auto vdiff = _mm256_sub_ps(vq, vx);
|
||||
sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff));
|
||||
}
|
||||
@@ -505,6 +654,72 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
if (sx[i] < best) { best = sx[i]; jbest = index[i]; }
|
||||
}
|
||||
auto& points = p_in_cluster[jbest];
|
||||
if (points.empty()) {
|
||||
printf("Oops: empty cluster %d\n", jbest);
|
||||
auto vc = clusters.data() + kGroupSize*jbest;
|
||||
printf("Cluster:\n");
|
||||
for (int j = 0; j < kGroupSize; ++j) printf("%d %g %g\n", j, vc[j], xl[j]);
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
int jbest_cluster = jbest;
|
||||
vbest = _mm256_set1_ps(INFINITY);
|
||||
best_index = _mm256_set1_epi32(-1);
|
||||
best = INFINITY; jbest = -1;
|
||||
for (int j = 0; j < int(points.size()); j += 8) {
|
||||
auto idx = _mm256_loadu_si256((const __m256i*)(points.data() + j));
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*points[j+i]);
|
||||
auto vdiff = _mm256_sub_ps(vq, vx);
|
||||
sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff));
|
||||
}
|
||||
auto score = hsum_float_8x8(sqx);
|
||||
auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ);
|
||||
best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
|
||||
_mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
|
||||
vbest = _mm256_min_ps(vbest, score);
|
||||
}
|
||||
_mm256_store_ps(sx, vbest);
|
||||
_mm256_store_si256((__m256i *)index, best_index);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
if (sx[i] < best) { best = sx[i]; jbest = index[i]; }
|
||||
}
|
||||
//int jbest_cluster = jbest;
|
||||
//best = INFINITY; jbest = -1;
|
||||
//for (auto ip : points) {
|
||||
// auto vc = codes.data() + ip*kGroupSize;
|
||||
// float diff2 = 0;
|
||||
// for (int k = 0; k < kGroupSize; ++k) {
|
||||
// float delta = d*vc[k] - xl[k];
|
||||
// diff2 += wl[k]*delta*delta;
|
||||
// }
|
||||
// if (diff2 < best) {
|
||||
// best = diff2; jbest = ip;
|
||||
// }
|
||||
//}
|
||||
if (jbest < 0) {
|
||||
printf("Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size()));
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
GGML_ASSERT(jbest >= 0);
|
||||
//for (int j = 0; j < kNumVal; j += 8) {
|
||||
// auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx);
|
||||
// for (int i = 0; i < 8; ++i) {
|
||||
// auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*(j+i));
|
||||
// auto vdiff = _mm256_sub_ps(vq, vx);
|
||||
// sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff));
|
||||
// }
|
||||
// auto score = hsum_float_8x8(sqx);
|
||||
// auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ);
|
||||
// best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
|
||||
// _mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
|
||||
// vbest = _mm256_min_ps(vbest, score);
|
||||
//}
|
||||
//_mm256_store_ps(sx, vbest);
|
||||
//_mm256_store_si256((__m256i *)index, best_index);
|
||||
//for (int i = 0; i < 8; ++i) {
|
||||
// if (sx[i] < best) { best = sx[i]; jbest = index[i]; }
|
||||
//}
|
||||
best_idx[ib*kNg + l] = jbest;
|
||||
}
|
||||
auto vqx = _mm256_setzero_ps();
|
||||
|
||||
Reference in New Issue
Block a user