SVD POC: multi-threading

This commit is contained in:
Iwan Kawrakow
2024-08-14 20:26:40 +03:00
parent 903d389e0f
commit 931e4615df

View File

@@ -184,9 +184,129 @@ static inline float hsum_float_8(__m256 x) {
}
#endif
static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd, int nsvd_iter) {
static void f_helper(int nrows, int stride, float norm, const float * g, const float * q, float * f, float& sum_f) {
#ifdef __AVX2__
auto vnorm = _mm256_set1_ps(norm);
__m256 sums[8] = {};
for (int row = 0; row < nrows; ++row) {
__m256 vg = _mm256_set1_ps(g[row]);
auto qr = q + row*stride;
for (int k = 0; k < 8; ++k) {
auto vq = _mm256_loadu_ps(qr + 8*k);
sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]);
}
}
__m256 tot = _mm256_setzero_ps();
for (int k = 0; k < 8; ++k) {
sums[k] = _mm256_mul_ps(vnorm, sums[k]);
_mm256_storeu_ps(f + 8*k, sums[k]);
tot = _mm256_fmadd_ps(sums[k], sums[k], tot);
sums[k] = _mm256_setzero_ps();
}
sum_f += hsum_float_8(tot);
#else
std::memset(f, 0, 64*sizeof(float));
for (int row = 0; row < nrows; ++row) {
auto qr = q + row*stride;
for (int k = 0; k < 64; ++k) {
f[k] += qr[k]*g[row];
}
}
float s = 0;
for (int k = 0; k < 64; ++k) {
f[k] *= norm;
s += f[k]*f[k];
}
sum_f += s;
#endif
}
static void g_helper(int n_per_row, const float * qr, const float * f, float norm, float& g, float& mse) {
float sum_g = 0;
float sum = 0;
#ifdef __AVX2__
__m256 vsum = _mm256_setzero_ps();
for (int j = 0; j < n_per_row; j += 8) {
auto vq = _mm256_loadu_ps(qr + j);
auto vf = _mm256_loadu_ps(f + j);
vsum = _mm256_fmadd_ps(vq, vf, vsum);
}
sum = hsum_float_8(vsum);
#else
for (int j = 0; j < n_per_row; ++j) sum += qr[j]*f[j];
#endif
g = sum * norm;
#ifdef __AVX2__
auto vg = _mm256_set1_ps(g);
auto vmse = _mm256_setzero_ps();
for (int j = 0; j < n_per_row; j += 8) {
auto vq = _mm256_loadu_ps(qr + j);
auto vf = _mm256_loadu_ps(f + j);
auto vdiff = _mm256_sub_ps(vq, _mm256_mul_ps(vg, vf));
vmse = _mm256_fmadd_ps(vdiff, vdiff, vmse);
}
mse += hsum_float_8(vmse);
#else
for (int j = 0; j < n_per_row; ++j) {
float diff = qr[j] - g*f[j];
mse += diff*diff;
}
#endif
}
static void do_svd_iteration(int n_per_row, int nrows, const float * q, float * f, float * g, float& f_norm, float& mse,
std::vector<std::thread>& workers, std::vector<float>& work) {
GGML_ASSERT(n_per_row % 64 == 0);
GGML_ASSERT(nrows % 16 == 0);
GGML_ASSERT(!workers.empty());
if (work.size() < 2*workers.size()) work.resize(2*workers.size());
int nblock = n_per_row/64;
auto compute_f = [&] (int ith) {
float sum_f = 0;
for (int i = ith; i < nblock; i += workers.size()) {
f_helper(nrows, n_per_row, f_norm, g, q + 64*i, f + 64*i, sum_f);
}
work[ith] = sum_f;
};
for (int i = 0; i < int(workers.size())-1; ++i) workers[i] = std::thread(compute_f, i);
compute_f(workers.size()-1);
for (int i = 0; i < int(workers.size())-1; ++i) workers[i].join();
float sum_f = 0; for (int i = 0; i < int(workers.size()); ++i) sum_f += work[i];
float g_norm = 1/sum_f;
nblock = nrows/16;
auto compute_g = [&] (int ith) {
float sum_g = 0, mse = 0;
for (int i = ith; i < nblock; i += workers.size()) {
for (int j = 0; j < 16; ++j) {
g_helper(n_per_row, q + (16*i + j)*n_per_row, f, g_norm, g[16*i + j], mse);
sum_g += g[16*i + j]*g[16*i + j];
}
}
work[2*ith+0] = sum_g;
work[2*ith+1] = mse;
};
for (int i = 0; i < int(workers.size())-1; ++i) workers[i] = std::thread(compute_g, i);
compute_g(workers.size()-1);
for (int i = 0; i < int(workers.size())-1; ++i) workers[i].join();
float sum_g = 0; mse = 0;
for (int i = 0; i < int(workers.size()); ++i) {
sum_g += work[2*i+0];
mse += work[2*i+1];
}
f_norm = 1/sum_g;
}
static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd, int nsvd_iter, int verbosity = 1) {
constexpr int kNiter = 10;
if (nsvd_iter < 1) nsvd_iter = kNiter;
if (nsvd > nrows) nsvd = nrows;
auto tim1 = std::chrono::steady_clock::now();
int nelem = n_per_row*nrows;
double mse = 0;
@@ -199,111 +319,121 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns
q[j] = b[j] - q[j];
mse += q[j]*q[j];
}
printf("===================== %s(%d x %d, %d, %d): rmse = %g\n", __func__, n_per_row, nrows, nsvd, use_avx2, sqrt(mse/nelem));
int nthread = std::max(1, int(std::thread::hardware_concurrency()/2));
std::vector<std::thread> workers(nthread);
std::vector<float> work;
if (verbosity > 0) printf("===================== %s(%d x %d, %d, %d): rmse = %g\n", __func__, n_per_row, nrows, nsvd, use_avx2, sqrt(mse/nelem));
float mse_old = mse;
std::vector<float> f(n_per_row), g(nrows, 1);
for (int isvd = 0; isvd < nsvd; ++isvd) {
printf("--- isvd = %d\n", isvd);
if (verbosity > 1) printf("--- isvd = %d\n", isvd);
float norm = 1.f/nrows;
for (int iter = 0; iter < nsvd_iter; ++iter) {
std::memset(f.data(), 0, f.size()*sizeof(float));
float sum_f = 0;
#ifdef __AVX2__
int part = 0;
#ifdef __AVX512F__
{
auto vnorm = _mm512_set1_ps(norm);
__m512 sums[16] = {};
for (; part < n_per_row/256; ++part) {
for (int row = 0; row < nrows; ++row) {
__m512 vg = _mm512_set1_ps(g[row]);
auto qr = q + row*n_per_row + 256*part;
for (int k = 0; k < 16; ++k) {
auto vq = _mm512_loadu_ps(qr + 16*k);
sums[k] = _mm512_fmadd_ps(vg, vq, sums[k]);
}
}
__m512 tot = _mm512_setzero_ps();
for (int k = 0; k < 16; ++k) {
sums[k] = _mm512_mul_ps(vnorm, sums[k]);
_mm512_storeu_ps(f.data() + 256*part + 16*k, sums[k]);
tot = _mm512_fmadd_ps(sums[k], sums[k], tot);
sums[k] = _mm512_setzero_ps();
}
sum_f += _mm512_reduce_add_ps(tot);
}
part = 4*(n_per_row/256);
}
#endif
if (part < n_per_row/64) {
auto vnorm = _mm256_set1_ps(norm);
__m256 sums[8] = {};
for (; part < n_per_row/64; ++part) {
for (int row = 0; row < nrows; ++row) {
__m256 vg = _mm256_set1_ps(g[row]);
auto qr = q + row*n_per_row + 64*part;
for (int k = 0; k < 8; ++k) {
auto vq = _mm256_loadu_ps(qr + 8*k);
sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]);
}
}
__m256 tot = _mm256_setzero_ps();
for (int k = 0; k < 8; ++k) {
sums[k] = _mm256_mul_ps(vnorm, sums[k]);
_mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]);
tot = _mm256_fmadd_ps(sums[k], sums[k], tot);
sums[k] = _mm256_setzero_ps();
}
sum_f += hsum_float_8(tot);
}
}
#else
for (int row = 0; row < nrows; ++row) {
auto qr = q + row*n_per_row;
for (int j = 0; j < n_per_row; ++j) f[j] += g[row]*qr[j];
}
for (int j = 0; j < n_per_row; ++j) { f[j] *= norm; sum_f += f[j]*f[j]; }
#endif
mse = 0;
float sum_g = 0;
for (int row = 0; row < nrows; ++row) {
auto qr = q + row*n_per_row;
float sum = 0;
#ifdef __AVX2__
__m256 vsum = _mm256_setzero_ps();
for (int j = 0; j < n_per_row; j += 8) {
auto vq = _mm256_loadu_ps(qr + j);
auto vf = _mm256_loadu_ps(f.data() + j);
vsum = _mm256_fmadd_ps(vq, vf, vsum);
}
sum = hsum_float_8(vsum);
#else
for (int j = 0; j < n_per_row; ++j) sum += qr[j]*f[j];
#endif
g[row] = sum/sum_f;
sum_g += g[row]*g[row];
#ifdef __AVX2__
auto vg = _mm256_set1_ps(g[row]);
auto vmse = _mm256_setzero_ps();
for (int j = 0; j < n_per_row; j += 8) {
auto vq = _mm256_loadu_ps(qr + j);
auto vf = _mm256_loadu_ps(f.data() + j);
auto vdiff = _mm256_sub_ps(vq, _mm256_mul_ps(vg, vf));
vmse = _mm256_fmadd_ps(vdiff, vdiff, vmse);
}
mse += hsum_float_8(vmse);
#else
for (int j = 0; j < n_per_row; ++j) {
float diff = qr[j] - g[row]*f[j];
mse += diff*diff;
}
#endif
}
printf(" after %d iterations: %g\n", iter+1, sqrt(mse/nelem));
if (mse_old/mse - 1 < 1e-6f) break;
mse_old = mse;
norm = 1.f/sum_g;
float this_mse = 0;
do_svd_iteration(n_per_row, nrows, q, f.data(), g.data(), norm, this_mse, workers, work);
if (verbosity > 1) printf(" after %d iterations: %g\n", iter+1, sqrt(this_mse/nelem));
if (mse_old/this_mse - 1 < 1e-6f) break;
mse_old = this_mse;
}
// for (int iter = 0; iter < nsvd_iter; ++iter) {
// std::memset(f.data(), 0, f.size()*sizeof(float));
// float sum_f = 0;
//#ifdef __AVX2__
// int part = 0;
//#ifdef __AVX512F__
// {
// auto vnorm = _mm512_set1_ps(norm);
// __m512 sums[16] = {};
// for (; part < n_per_row/256; ++part) {
// for (int row = 0; row < nrows; ++row) {
// __m512 vg = _mm512_set1_ps(g[row]);
// auto qr = q + row*n_per_row + 256*part;
// for (int k = 0; k < 16; ++k) {
// auto vq = _mm512_loadu_ps(qr + 16*k);
// sums[k] = _mm512_fmadd_ps(vg, vq, sums[k]);
// }
// }
// __m512 tot = _mm512_setzero_ps();
// for (int k = 0; k < 16; ++k) {
// sums[k] = _mm512_mul_ps(vnorm, sums[k]);
// _mm512_storeu_ps(f.data() + 256*part + 16*k, sums[k]);
// tot = _mm512_fmadd_ps(sums[k], sums[k], tot);
// sums[k] = _mm512_setzero_ps();
// }
// sum_f += _mm512_reduce_add_ps(tot);
// }
// part = 4*(n_per_row/256);
// }
//#endif
// if (part < n_per_row/64) {
// auto vnorm = _mm256_set1_ps(norm);
// __m256 sums[8] = {};
// for (; part < n_per_row/64; ++part) {
// for (int row = 0; row < nrows; ++row) {
// __m256 vg = _mm256_set1_ps(g[row]);
// auto qr = q + row*n_per_row + 64*part;
// for (int k = 0; k < 8; ++k) {
// auto vq = _mm256_loadu_ps(qr + 8*k);
// sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]);
// }
// }
// __m256 tot = _mm256_setzero_ps();
// for (int k = 0; k < 8; ++k) {
// sums[k] = _mm256_mul_ps(vnorm, sums[k]);
// _mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]);
// tot = _mm256_fmadd_ps(sums[k], sums[k], tot);
// sums[k] = _mm256_setzero_ps();
// }
// sum_f += hsum_float_8(tot);
// }
// }
//#else
// for (int row = 0; row < nrows; ++row) {
// auto qr = q + row*n_per_row;
// for (int j = 0; j < n_per_row; ++j) f[j] += g[row]*qr[j];
// }
// for (int j = 0; j < n_per_row; ++j) { f[j] *= norm; sum_f += f[j]*f[j]; }
//#endif
// mse = 0;
// float sum_g = 0;
// for (int row = 0; row < nrows; ++row) {
// auto qr = q + row*n_per_row;
// float sum = 0;
//#ifdef __AVX2__
// __m256 vsum = _mm256_setzero_ps();
// for (int j = 0; j < n_per_row; j += 8) {
// auto vq = _mm256_loadu_ps(qr + j);
// auto vf = _mm256_loadu_ps(f.data() + j);
// vsum = _mm256_fmadd_ps(vq, vf, vsum);
// }
// sum = hsum_float_8(vsum);
//#else
// for (int j = 0; j < n_per_row; ++j) sum += qr[j]*f[j];
//#endif
// g[row] = sum/sum_f;
// sum_g += g[row]*g[row];
//#ifdef __AVX2__
// auto vg = _mm256_set1_ps(g[row]);
// auto vmse = _mm256_setzero_ps();
// for (int j = 0; j < n_per_row; j += 8) {
// auto vq = _mm256_loadu_ps(qr + j);
// auto vf = _mm256_loadu_ps(f.data() + j);
// auto vdiff = _mm256_sub_ps(vq, _mm256_mul_ps(vg, vf));
// vmse = _mm256_fmadd_ps(vdiff, vdiff, vmse);
// }
// mse += hsum_float_8(vmse);
//#else
// for (int j = 0; j < n_per_row; ++j) {
// float diff = qr[j] - g[row]*f[j];
// mse += diff*diff;
// }
//#endif
// }
// printf(" after %d iterations: %g\n", iter+1, sqrt(mse/nelem));
// if (mse_old/mse - 1 < 1e-6f) break;
// mse_old = mse;
// norm = 1.f/sum_g;
// }
for (int row = 0; row < nrows; ++row) {
auto qr = q + row*n_per_row;
for (int j = 0; j < n_per_row; ++j) {
@@ -313,7 +443,7 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns
}
}
auto tim2 = std::chrono::steady_clock::now();
printf("%s: finished in %g s\n", __func__, 1e-3*std::chrono::duration_cast<std::chrono::milliseconds>(tim2-tim1).count());
if (verbosity > 0) printf("%s: finished in %g s. Final rmse = %g\n", __func__, 1e-3*std::chrono::duration_cast<std::chrono::milliseconds>(tim2-tim1).count(), sqrt(mse_old/nelem));
}
@@ -321,7 +451,7 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns
static void test_roundtrip_on_layer(
std::string & name, bool print_layer_stats, const ggml_type_traits_t & qfns, bool use_reference,
const ggml_tensor * layer, std::vector<float> & input_scratch, std::vector<char> & quantized_scratch,
std::vector<float> & output_scratch, error_stats & total_error, int nsvd, int nsvd_iter, int max_thread = 0) {
std::vector<float> & output_scratch, error_stats & total_error, int nsvd, int nsvd_iter, int verbosity, int max_thread = 0) {
assert(tensor_is_contiguous(layer));
error_stats layer_error {};
uint64_t nelements = ggml_nelements(layer);
@@ -374,7 +504,7 @@ static void test_roundtrip_on_layer(
}
if (nsvd > 0 && layer->ne[0] > 1 && layer->ne[1] > 1 && layer->ne[2] == 1 && layer->ne[3] == 1) {
try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd, nsvd_iter);
try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd, nsvd_iter, verbosity);
}
}
@@ -388,6 +518,7 @@ int main(int argc, char ** argv) {
int max_thread = 0;
int nsvd = 0;
int nsvd_iter = 0;
int verbosity = 1;
bool invalid_param = false;
std::string arg;
for (int i = 1; i < argc; i++) {
@@ -416,6 +547,12 @@ int main(int argc, char ** argv) {
break;
}
nsvd_iter = atoi(argv[i]);
} else if (arg == "-sv" || arg == "--svd-verbosity") {
if (++i >= argc) {
invalid_param = true;
break;
}
verbosity = atoi(argv[i]);
} else if (arg == "-m" || arg == "--model") {
if (++i >= argc) {
invalid_param = true;
@@ -581,6 +718,7 @@ int main(int argc, char ** argv) {
global_stats,
nsvd,
nsvd_iter,
verbosity,
max_thread
);
}