This commit is contained in:
Iwan Kawrakow
2025-05-23 16:26:03 +03:00
parent f015390efa
commit 858f2a55a5

View File

@@ -549,11 +549,15 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa
std::mutex mutex; std::mutex mutex;
int counter = 0; int counter = 0;
float mse = 0, mse_q = 0; float mse = 0, mse_q = 0;
auto compute = [&mutex, &counter, &mse, &mse_q, values, nrows, n_per_row, chunk, block_size = kBlockSize] () { auto compute = [&mutex, &counter, &mse, &mse_q, values, nrows, n_per_row, chunk] () {
constexpr int kNumVal = 1 << 15;
constexpr int kBlockSize = 32;
constexpr int kGroupSize = 8;
constexpr int kNg = kBlockSize/kGroupSize;
double lmse = 0, lmse_q = 0; double lmse = 0, lmse_q = 0;
std::vector<float> scales(n_per_row/block_size); std::vector<float> scales(n_per_row/kBlockSize);
std::vector<int> best_idx(n_per_row/kGroupSize); std::vector<int> best_idx(n_per_row/kGroupSize);
std::vector<float> weight(block_size, 1.f); std::vector<float> weight(kBlockSize, 1.f);
int ncluster = clusters.size() / kGroupSize; int ncluster = clusters.size() / kGroupSize;
while (true) { while (true) {
std::unique_lock<std::mutex> lock(mutex); std::unique_lock<std::mutex> lock(mutex);
@@ -575,9 +579,10 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa
float sigma2 = 0; float sigma2 = 0;
for (int j = 0; j < n_per_row; ++j) sigma2 += xr[j]*xr[j]; for (int j = 0; j < n_per_row; ++j) sigma2 += xr[j]*xr[j];
sigma2 /= n_per_row; sigma2 /= n_per_row;
for (int ib = 0; ib < n_per_row/block_size; ++ib) { for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) {
auto xb = xr + block_size*ib; auto xb = xr + kBlockSize*ib;
float d = find_best_scale(block_size, xb, weight.data(), iq4k_values, 5); //for (int i = 0; i < kBlockSize; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i];
float d = find_best_scale(kBlockSize, xb, weight.data(), iq4k_values, 5);
float id = d ? 1/d : 0.f; float id = d ? 1/d : 0.f;
#ifdef __AVX2__ #ifdef __AVX2__
auto vid = _mm256_set1_ps(id); auto vid = _mm256_set1_ps(id);
@@ -662,7 +667,7 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa
} }
float amax_scale = std::abs(scales[0]); float amax_scale = std::abs(scales[0]);
float max_scale = scales[0]; float max_scale = scales[0];
for (int ib = 1; ib < n_per_row/block_size; ++ib) { for (int ib = 1; ib < n_per_row/kBlockSize; ++ib) {
float ax = std::abs(scales[ib]); float ax = std::abs(scales[ib]);
if (ax > amax_scale) { if (ax > amax_scale) {
amax_scale = ax; amax_scale = ax;
@@ -671,10 +676,10 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa
} }
float d = max_scale/scale_values[0]; float d = max_scale/scale_values[0];
float id = d ? 1/d : 0.f; float id = d ? 1/d : 0.f;
for (int ib = 0; ib < n_per_row/block_size; ++ib) { for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) {
int ls = best_index_scale(scale_values, id*scales[ib]); int ls = best_index_scale(scale_values, id*scales[ib]);
float dl = d * scale_values[ls]; float dl = d * scale_values[ls];
auto xb = xr + block_size*ib; auto xb = xr + kBlockSize*ib;
for (int l = 0; l < kNg; ++l) { for (int l = 0; l < kNg; ++l) {
auto q = codes.data() + kGroupSize*best_idx[ib*kNg+l]; auto q = codes.data() + kGroupSize*best_idx[ib*kNg+l];
for (int k = 0; k < kGroupSize; ++k) { for (int k = 0; k < kGroupSize; ++k) {
@@ -688,9 +693,8 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa
} }
} }
}; };
std::vector<std::thread> workers(nthread-1); std::vector<std::thread> workers(nthread);
for (auto& w : workers) w = std::thread(compute); for (auto& w : workers) w = std::thread(compute);
compute();
for (auto& w : workers) w.join(); for (auto& w : workers) w.join();
tot_mse += mse; tot_mse += mse;
tot_mse_q += mse_q; tot_mse_q += mse_q;
@@ -716,15 +720,12 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
std::mutex mutex; std::mutex mutex;
int counter = 0; int counter = 0;
float mse = 0, mse_q = 0; float mse = 0, mse_q = 0;
#ifdef __AVX2__ auto compute = [&mutex, &counter, &mse, &mse_q, &codes, &sumq2i, values, nrows, n_per_row, chunk] () {
__m256 vx[kBlockSize/8]; constexpr int kBlockSize = 8;
auto compute = [&mutex, &counter, &mse, &mse_q, &codes, &sumq2i, values, nrows, n_per_row, chunk, block_size = kBlockSize, &vx] () { constexpr int kNumVal = 1 << 12;
#else
auto compute = [&mutex, &counter, &mse, &mse_q, &codes, &sumq2i, values, nrows, n_per_row, chunk, block_size = kBlockSize] () {
#endif
float lmse = 0, lmse_q = 0; float lmse = 0, lmse_q = 0;
std::vector<float> scales(n_per_row/block_size); std::vector<float> scales(n_per_row/kBlockSize);
std::vector<int> best_idx(n_per_row/block_size); std::vector<int> best_idx(n_per_row/kBlockSize);
while (true) { while (true) {
std::unique_lock<std::mutex> lock(mutex); std::unique_lock<std::mutex> lock(mutex);
int first = counter; counter += chunk; int first = counter; counter += chunk;
@@ -735,6 +736,7 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
lock.unlock(); lock.unlock();
int last = std::min(first + chunk, nrows); int last = std::min(first + chunk, nrows);
#ifdef __AVX2__ #ifdef __AVX2__
__m256 vx[kBlockSize/8];
__m256 sqx[8]; __m256 sqx[8];
__m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); __m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
float sx[8]; float sx[8];
@@ -742,11 +744,11 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
#endif #endif
for (int row = first; row < last; ++row) { for (int row = first; row < last; ++row) {
auto xr = values + row*n_per_row; auto xr = values + row*n_per_row;
for (int ib = 0; ib < n_per_row/block_size; ++ib) { for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) {
float best = 0, d = 0; int jbest = -1; float best = 0, d = 0; int jbest = -1;
auto xb = xr + block_size*ib; auto xb = xr + kBlockSize*ib;
#ifdef __AVX2__ #ifdef __AVX2__
for (int l = 0; l < block_size/8; ++l) { for (int l = 0; l < kBlockSize/8; ++l) {
vx[l] = _mm256_loadu_ps(xb+8*l); vx[l] = _mm256_loadu_ps(xb+8*l);
} }
auto vbest = _mm256_set1_ps(0.f); auto vbest = _mm256_set1_ps(0.f);
@@ -755,8 +757,8 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx); auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx);
for (int i = 0; i < 8; ++i) { for (int i = 0; i < 8; ++i) {
sqx[i] = _mm256_setzero_ps(); sqx[i] = _mm256_setzero_ps();
for (int l = 0; l < block_size/8; ++l) { for (int l = 0; l < kBlockSize/8; ++l) {
auto qv = _mm256_loadu_ps(codes.data() + block_size*(j+i) + 8*l); auto qv = _mm256_loadu_ps(codes.data() + kBlockSize*(j+i) + 8*l);
sqx[i] = _mm256_fmadd_ps(vx[l], qv, sqx[i]); sqx[i] = _mm256_fmadd_ps(vx[l], qv, sqx[i]);
} }
} }
@@ -772,32 +774,32 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
for (int j = 1; j < 8; ++j) { for (int j = 1; j < 8; ++j) {
if (sx[j] > best) { best = sx[j]; jbest = index[j]; } if (sx[j] > best) { best = sx[j]; jbest = index[j]; }
} }
auto qv = codes.data() + block_size*jbest; auto qv = codes.data() + kBlockSize*jbest;
float sumqx = 0; float sumqx = 0;
for (int k = 0; k < block_size; ++k) sumqx += xb[k]*qv[k]; for (int k = 0; k < kBlockSize; ++k) sumqx += xb[k]*qv[k];
d = sumqx*sumq2i[jbest]; d = sumqx*sumq2i[jbest];
#else #else
for (int j = 0; j < kNumVal; ++j) { for (int j = 0; j < kNumVal; ++j) {
if (!sumq2i[j]) continue; if (!sumq2i[j]) continue;
auto qv = codes.data() + block_size*j; auto qv = codes.data() + kBlockSize*j;
float sumqx = 0; float sumqx = 0;
for (int k = 0; k < block_size; ++k) sumqx += qv[k]*xb[k]; for (int k = 0; k < kBlockSize; ++k) sumqx += qv[k]*xb[k];
if (sumqx*sumqx*sumq2i[j] > best) { if (sumqx*sumqx*sumq2i[j] > best) {
d = sumqx*sumq2i[j]; best = d*sumqx; jbest = j; d = sumqx*sumq2i[j]; best = d*sumqx; jbest = j;
} }
} }
auto qv = codes.data() + block_size*jbest; auto qv = codes.data() + kBlockSize*jbest;
#endif #endif
scales[ib] = d; scales[ib] = d;
best_idx[ib] = jbest; best_idx[ib] = jbest;
for (int k = 0; k < block_size; ++k) { for (int k = 0; k < kBlockSize; ++k) {
float diff = xb[k] - d*qv[k]; float diff = xb[k] - d*qv[k];
lmse += diff*diff; lmse += diff*diff;
} }
} }
float amax_scale = std::abs(scales[0]); float amax_scale = std::abs(scales[0]);
float max_scale = scales[0]; float max_scale = scales[0];
for (int ib = 1; ib < n_per_row/block_size; ++ib) { for (int ib = 1; ib < n_per_row/kBlockSize; ++ib) {
float ax = std::abs(scales[ib]); float ax = std::abs(scales[ib]);
if (ax > amax_scale) { if (ax > amax_scale) {
amax_scale = ax; amax_scale = ax;
@@ -806,12 +808,12 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
} }
float d = max_scale/scale_values[0]; float d = max_scale/scale_values[0];
float id = d ? 1/d : 0.f; float id = d ? 1/d : 0.f;
for (int ib = 0; ib < n_per_row/block_size; ++ib) { for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) {
int ls = best_index_scale(scale_values, id*scales[ib]); int ls = best_index_scale(scale_values, id*scales[ib]);
float dl = d * scale_values[ls]; float dl = d * scale_values[ls];
auto xb = xr + block_size*ib; auto xb = xr + kBlockSize*ib;
auto qv = codes.data() + block_size*best_idx[ib]; auto qv = codes.data() + kBlockSize*best_idx[ib];
for (int k = 0; k < block_size; ++k) { for (int k = 0; k < kBlockSize; ++k) {
float diff = xb[k] - dl*qv[k]; float diff = xb[k] - dl*qv[k];
lmse_q += diff*diff; lmse_q += diff*diff;
} }
@@ -819,9 +821,8 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
} }
} }
}; };
std::vector<std::thread> workers(nthread-1); std::vector<std::thread> workers(nthread);
for (auto& w : workers) w = std::thread(compute); for (auto& w : workers) w = std::thread(compute);
compute();
for (auto& w : workers) w.join(); for (auto& w : workers) w.join();
tot_mse += mse; tot_mse += mse;
tot_mse_q += mse_q; tot_mse_q += mse_q;