mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-04 13:30:47 +00:00
SVD POC: multi-threading
This commit is contained in:
@@ -184,9 +184,129 @@ static inline float hsum_float_8(__m256 x) {
|
||||
}
|
||||
#endif
|
||||
|
||||
static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd, int nsvd_iter) {
|
||||
static void f_helper(int nrows, int stride, float norm, const float * g, const float * q, float * f, float& sum_f) {
|
||||
#ifdef __AVX2__
|
||||
auto vnorm = _mm256_set1_ps(norm);
|
||||
__m256 sums[8] = {};
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
__m256 vg = _mm256_set1_ps(g[row]);
|
||||
auto qr = q + row*stride;
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
auto vq = _mm256_loadu_ps(qr + 8*k);
|
||||
sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]);
|
||||
}
|
||||
}
|
||||
__m256 tot = _mm256_setzero_ps();
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
sums[k] = _mm256_mul_ps(vnorm, sums[k]);
|
||||
_mm256_storeu_ps(f + 8*k, sums[k]);
|
||||
tot = _mm256_fmadd_ps(sums[k], sums[k], tot);
|
||||
sums[k] = _mm256_setzero_ps();
|
||||
}
|
||||
sum_f += hsum_float_8(tot);
|
||||
#else
|
||||
std::memset(f, 0, 64*sizeof(float));
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
auto qr = q + row*stride;
|
||||
for (int k = 0; k < 64; ++k) {
|
||||
f[k] += qr[k]*g[row];
|
||||
}
|
||||
}
|
||||
float s = 0;
|
||||
for (int k = 0; k < 64; ++k) {
|
||||
f[k] *= norm;
|
||||
s += f[k]*f[k];
|
||||
}
|
||||
sum_f += s;
|
||||
#endif
|
||||
}
|
||||
|
||||
static void g_helper(int n_per_row, const float * qr, const float * f, float norm, float& g, float& mse) {
|
||||
float sum_g = 0;
|
||||
float sum = 0;
|
||||
#ifdef __AVX2__
|
||||
__m256 vsum = _mm256_setzero_ps();
|
||||
for (int j = 0; j < n_per_row; j += 8) {
|
||||
auto vq = _mm256_loadu_ps(qr + j);
|
||||
auto vf = _mm256_loadu_ps(f + j);
|
||||
vsum = _mm256_fmadd_ps(vq, vf, vsum);
|
||||
}
|
||||
sum = hsum_float_8(vsum);
|
||||
#else
|
||||
for (int j = 0; j < n_per_row; ++j) sum += qr[j]*f[j];
|
||||
#endif
|
||||
g = sum * norm;
|
||||
#ifdef __AVX2__
|
||||
auto vg = _mm256_set1_ps(g);
|
||||
auto vmse = _mm256_setzero_ps();
|
||||
for (int j = 0; j < n_per_row; j += 8) {
|
||||
auto vq = _mm256_loadu_ps(qr + j);
|
||||
auto vf = _mm256_loadu_ps(f + j);
|
||||
auto vdiff = _mm256_sub_ps(vq, _mm256_mul_ps(vg, vf));
|
||||
vmse = _mm256_fmadd_ps(vdiff, vdiff, vmse);
|
||||
}
|
||||
mse += hsum_float_8(vmse);
|
||||
#else
|
||||
for (int j = 0; j < n_per_row; ++j) {
|
||||
float diff = qr[j] - g*f[j];
|
||||
mse += diff*diff;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static void do_svd_iteration(int n_per_row, int nrows, const float * q, float * f, float * g, float& f_norm, float& mse,
|
||||
std::vector<std::thread>& workers, std::vector<float>& work) {
|
||||
GGML_ASSERT(n_per_row % 64 == 0);
|
||||
GGML_ASSERT(nrows % 16 == 0);
|
||||
GGML_ASSERT(!workers.empty());
|
||||
|
||||
if (work.size() < 2*workers.size()) work.resize(2*workers.size());
|
||||
int nblock = n_per_row/64;
|
||||
|
||||
auto compute_f = [&] (int ith) {
|
||||
float sum_f = 0;
|
||||
for (int i = ith; i < nblock; i += workers.size()) {
|
||||
f_helper(nrows, n_per_row, f_norm, g, q + 64*i, f + 64*i, sum_f);
|
||||
}
|
||||
work[ith] = sum_f;
|
||||
};
|
||||
for (int i = 0; i < int(workers.size())-1; ++i) workers[i] = std::thread(compute_f, i);
|
||||
compute_f(workers.size()-1);
|
||||
for (int i = 0; i < int(workers.size())-1; ++i) workers[i].join();
|
||||
|
||||
float sum_f = 0; for (int i = 0; i < int(workers.size()); ++i) sum_f += work[i];
|
||||
float g_norm = 1/sum_f;
|
||||
|
||||
nblock = nrows/16;
|
||||
auto compute_g = [&] (int ith) {
|
||||
float sum_g = 0, mse = 0;
|
||||
for (int i = ith; i < nblock; i += workers.size()) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
g_helper(n_per_row, q + (16*i + j)*n_per_row, f, g_norm, g[16*i + j], mse);
|
||||
sum_g += g[16*i + j]*g[16*i + j];
|
||||
}
|
||||
}
|
||||
work[2*ith+0] = sum_g;
|
||||
work[2*ith+1] = mse;
|
||||
};
|
||||
for (int i = 0; i < int(workers.size())-1; ++i) workers[i] = std::thread(compute_g, i);
|
||||
compute_g(workers.size()-1);
|
||||
for (int i = 0; i < int(workers.size())-1; ++i) workers[i].join();
|
||||
|
||||
float sum_g = 0; mse = 0;
|
||||
for (int i = 0; i < int(workers.size()); ++i) {
|
||||
sum_g += work[2*i+0];
|
||||
mse += work[2*i+1];
|
||||
}
|
||||
|
||||
f_norm = 1/sum_g;
|
||||
|
||||
}
|
||||
|
||||
static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd, int nsvd_iter, int verbosity = 1) {
|
||||
constexpr int kNiter = 10;
|
||||
if (nsvd_iter < 1) nsvd_iter = kNiter;
|
||||
if (nsvd > nrows) nsvd = nrows;
|
||||
auto tim1 = std::chrono::steady_clock::now();
|
||||
int nelem = n_per_row*nrows;
|
||||
double mse = 0;
|
||||
@@ -199,111 +319,121 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns
|
||||
q[j] = b[j] - q[j];
|
||||
mse += q[j]*q[j];
|
||||
}
|
||||
printf("===================== %s(%d x %d, %d, %d): rmse = %g\n", __func__, n_per_row, nrows, nsvd, use_avx2, sqrt(mse/nelem));
|
||||
int nthread = std::max(1, int(std::thread::hardware_concurrency()/2));
|
||||
std::vector<std::thread> workers(nthread);
|
||||
std::vector<float> work;
|
||||
if (verbosity > 0) printf("===================== %s(%d x %d, %d, %d): rmse = %g\n", __func__, n_per_row, nrows, nsvd, use_avx2, sqrt(mse/nelem));
|
||||
float mse_old = mse;
|
||||
std::vector<float> f(n_per_row), g(nrows, 1);
|
||||
for (int isvd = 0; isvd < nsvd; ++isvd) {
|
||||
printf("--- isvd = %d\n", isvd);
|
||||
if (verbosity > 1) printf("--- isvd = %d\n", isvd);
|
||||
float norm = 1.f/nrows;
|
||||
for (int iter = 0; iter < nsvd_iter; ++iter) {
|
||||
std::memset(f.data(), 0, f.size()*sizeof(float));
|
||||
float sum_f = 0;
|
||||
#ifdef __AVX2__
|
||||
int part = 0;
|
||||
#ifdef __AVX512F__
|
||||
{
|
||||
auto vnorm = _mm512_set1_ps(norm);
|
||||
__m512 sums[16] = {};
|
||||
for (; part < n_per_row/256; ++part) {
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
__m512 vg = _mm512_set1_ps(g[row]);
|
||||
auto qr = q + row*n_per_row + 256*part;
|
||||
for (int k = 0; k < 16; ++k) {
|
||||
auto vq = _mm512_loadu_ps(qr + 16*k);
|
||||
sums[k] = _mm512_fmadd_ps(vg, vq, sums[k]);
|
||||
}
|
||||
}
|
||||
__m512 tot = _mm512_setzero_ps();
|
||||
for (int k = 0; k < 16; ++k) {
|
||||
sums[k] = _mm512_mul_ps(vnorm, sums[k]);
|
||||
_mm512_storeu_ps(f.data() + 256*part + 16*k, sums[k]);
|
||||
tot = _mm512_fmadd_ps(sums[k], sums[k], tot);
|
||||
sums[k] = _mm512_setzero_ps();
|
||||
}
|
||||
sum_f += _mm512_reduce_add_ps(tot);
|
||||
}
|
||||
part = 4*(n_per_row/256);
|
||||
}
|
||||
#endif
|
||||
if (part < n_per_row/64) {
|
||||
auto vnorm = _mm256_set1_ps(norm);
|
||||
__m256 sums[8] = {};
|
||||
for (; part < n_per_row/64; ++part) {
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
__m256 vg = _mm256_set1_ps(g[row]);
|
||||
auto qr = q + row*n_per_row + 64*part;
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
auto vq = _mm256_loadu_ps(qr + 8*k);
|
||||
sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]);
|
||||
}
|
||||
}
|
||||
__m256 tot = _mm256_setzero_ps();
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
sums[k] = _mm256_mul_ps(vnorm, sums[k]);
|
||||
_mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]);
|
||||
tot = _mm256_fmadd_ps(sums[k], sums[k], tot);
|
||||
sums[k] = _mm256_setzero_ps();
|
||||
}
|
||||
sum_f += hsum_float_8(tot);
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
auto qr = q + row*n_per_row;
|
||||
for (int j = 0; j < n_per_row; ++j) f[j] += g[row]*qr[j];
|
||||
}
|
||||
for (int j = 0; j < n_per_row; ++j) { f[j] *= norm; sum_f += f[j]*f[j]; }
|
||||
#endif
|
||||
mse = 0;
|
||||
float sum_g = 0;
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
auto qr = q + row*n_per_row;
|
||||
float sum = 0;
|
||||
#ifdef __AVX2__
|
||||
__m256 vsum = _mm256_setzero_ps();
|
||||
for (int j = 0; j < n_per_row; j += 8) {
|
||||
auto vq = _mm256_loadu_ps(qr + j);
|
||||
auto vf = _mm256_loadu_ps(f.data() + j);
|
||||
vsum = _mm256_fmadd_ps(vq, vf, vsum);
|
||||
}
|
||||
sum = hsum_float_8(vsum);
|
||||
#else
|
||||
for (int j = 0; j < n_per_row; ++j) sum += qr[j]*f[j];
|
||||
#endif
|
||||
g[row] = sum/sum_f;
|
||||
sum_g += g[row]*g[row];
|
||||
#ifdef __AVX2__
|
||||
auto vg = _mm256_set1_ps(g[row]);
|
||||
auto vmse = _mm256_setzero_ps();
|
||||
for (int j = 0; j < n_per_row; j += 8) {
|
||||
auto vq = _mm256_loadu_ps(qr + j);
|
||||
auto vf = _mm256_loadu_ps(f.data() + j);
|
||||
auto vdiff = _mm256_sub_ps(vq, _mm256_mul_ps(vg, vf));
|
||||
vmse = _mm256_fmadd_ps(vdiff, vdiff, vmse);
|
||||
}
|
||||
mse += hsum_float_8(vmse);
|
||||
#else
|
||||
for (int j = 0; j < n_per_row; ++j) {
|
||||
float diff = qr[j] - g[row]*f[j];
|
||||
mse += diff*diff;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
printf(" after %d iterations: %g\n", iter+1, sqrt(mse/nelem));
|
||||
if (mse_old/mse - 1 < 1e-6f) break;
|
||||
mse_old = mse;
|
||||
norm = 1.f/sum_g;
|
||||
float this_mse = 0;
|
||||
do_svd_iteration(n_per_row, nrows, q, f.data(), g.data(), norm, this_mse, workers, work);
|
||||
if (verbosity > 1) printf(" after %d iterations: %g\n", iter+1, sqrt(this_mse/nelem));
|
||||
if (mse_old/this_mse - 1 < 1e-6f) break;
|
||||
mse_old = this_mse;
|
||||
}
|
||||
// for (int iter = 0; iter < nsvd_iter; ++iter) {
|
||||
// std::memset(f.data(), 0, f.size()*sizeof(float));
|
||||
// float sum_f = 0;
|
||||
//#ifdef __AVX2__
|
||||
// int part = 0;
|
||||
//#ifdef __AVX512F__
|
||||
// {
|
||||
// auto vnorm = _mm512_set1_ps(norm);
|
||||
// __m512 sums[16] = {};
|
||||
// for (; part < n_per_row/256; ++part) {
|
||||
// for (int row = 0; row < nrows; ++row) {
|
||||
// __m512 vg = _mm512_set1_ps(g[row]);
|
||||
// auto qr = q + row*n_per_row + 256*part;
|
||||
// for (int k = 0; k < 16; ++k) {
|
||||
// auto vq = _mm512_loadu_ps(qr + 16*k);
|
||||
// sums[k] = _mm512_fmadd_ps(vg, vq, sums[k]);
|
||||
// }
|
||||
// }
|
||||
// __m512 tot = _mm512_setzero_ps();
|
||||
// for (int k = 0; k < 16; ++k) {
|
||||
// sums[k] = _mm512_mul_ps(vnorm, sums[k]);
|
||||
// _mm512_storeu_ps(f.data() + 256*part + 16*k, sums[k]);
|
||||
// tot = _mm512_fmadd_ps(sums[k], sums[k], tot);
|
||||
// sums[k] = _mm512_setzero_ps();
|
||||
// }
|
||||
// sum_f += _mm512_reduce_add_ps(tot);
|
||||
// }
|
||||
// part = 4*(n_per_row/256);
|
||||
// }
|
||||
//#endif
|
||||
// if (part < n_per_row/64) {
|
||||
// auto vnorm = _mm256_set1_ps(norm);
|
||||
// __m256 sums[8] = {};
|
||||
// for (; part < n_per_row/64; ++part) {
|
||||
// for (int row = 0; row < nrows; ++row) {
|
||||
// __m256 vg = _mm256_set1_ps(g[row]);
|
||||
// auto qr = q + row*n_per_row + 64*part;
|
||||
// for (int k = 0; k < 8; ++k) {
|
||||
// auto vq = _mm256_loadu_ps(qr + 8*k);
|
||||
// sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]);
|
||||
// }
|
||||
// }
|
||||
// __m256 tot = _mm256_setzero_ps();
|
||||
// for (int k = 0; k < 8; ++k) {
|
||||
// sums[k] = _mm256_mul_ps(vnorm, sums[k]);
|
||||
// _mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]);
|
||||
// tot = _mm256_fmadd_ps(sums[k], sums[k], tot);
|
||||
// sums[k] = _mm256_setzero_ps();
|
||||
// }
|
||||
// sum_f += hsum_float_8(tot);
|
||||
// }
|
||||
// }
|
||||
//#else
|
||||
// for (int row = 0; row < nrows; ++row) {
|
||||
// auto qr = q + row*n_per_row;
|
||||
// for (int j = 0; j < n_per_row; ++j) f[j] += g[row]*qr[j];
|
||||
// }
|
||||
// for (int j = 0; j < n_per_row; ++j) { f[j] *= norm; sum_f += f[j]*f[j]; }
|
||||
//#endif
|
||||
// mse = 0;
|
||||
// float sum_g = 0;
|
||||
// for (int row = 0; row < nrows; ++row) {
|
||||
// auto qr = q + row*n_per_row;
|
||||
// float sum = 0;
|
||||
//#ifdef __AVX2__
|
||||
// __m256 vsum = _mm256_setzero_ps();
|
||||
// for (int j = 0; j < n_per_row; j += 8) {
|
||||
// auto vq = _mm256_loadu_ps(qr + j);
|
||||
// auto vf = _mm256_loadu_ps(f.data() + j);
|
||||
// vsum = _mm256_fmadd_ps(vq, vf, vsum);
|
||||
// }
|
||||
// sum = hsum_float_8(vsum);
|
||||
//#else
|
||||
// for (int j = 0; j < n_per_row; ++j) sum += qr[j]*f[j];
|
||||
//#endif
|
||||
// g[row] = sum/sum_f;
|
||||
// sum_g += g[row]*g[row];
|
||||
//#ifdef __AVX2__
|
||||
// auto vg = _mm256_set1_ps(g[row]);
|
||||
// auto vmse = _mm256_setzero_ps();
|
||||
// for (int j = 0; j < n_per_row; j += 8) {
|
||||
// auto vq = _mm256_loadu_ps(qr + j);
|
||||
// auto vf = _mm256_loadu_ps(f.data() + j);
|
||||
// auto vdiff = _mm256_sub_ps(vq, _mm256_mul_ps(vg, vf));
|
||||
// vmse = _mm256_fmadd_ps(vdiff, vdiff, vmse);
|
||||
// }
|
||||
// mse += hsum_float_8(vmse);
|
||||
//#else
|
||||
// for (int j = 0; j < n_per_row; ++j) {
|
||||
// float diff = qr[j] - g[row]*f[j];
|
||||
// mse += diff*diff;
|
||||
// }
|
||||
//#endif
|
||||
// }
|
||||
// printf(" after %d iterations: %g\n", iter+1, sqrt(mse/nelem));
|
||||
// if (mse_old/mse - 1 < 1e-6f) break;
|
||||
// mse_old = mse;
|
||||
// norm = 1.f/sum_g;
|
||||
// }
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
auto qr = q + row*n_per_row;
|
||||
for (int j = 0; j < n_per_row; ++j) {
|
||||
@@ -313,7 +443,7 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns
|
||||
}
|
||||
}
|
||||
auto tim2 = std::chrono::steady_clock::now();
|
||||
printf("%s: finished in %g s\n", __func__, 1e-3*std::chrono::duration_cast<std::chrono::milliseconds>(tim2-tim1).count());
|
||||
if (verbosity > 0) printf("%s: finished in %g s. Final rmse = %g\n", __func__, 1e-3*std::chrono::duration_cast<std::chrono::milliseconds>(tim2-tim1).count(), sqrt(mse_old/nelem));
|
||||
}
|
||||
|
||||
|
||||
@@ -321,7 +451,7 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns
|
||||
static void test_roundtrip_on_layer(
|
||||
std::string & name, bool print_layer_stats, const ggml_type_traits_t & qfns, bool use_reference,
|
||||
const ggml_tensor * layer, std::vector<float> & input_scratch, std::vector<char> & quantized_scratch,
|
||||
std::vector<float> & output_scratch, error_stats & total_error, int nsvd, int nsvd_iter, int max_thread = 0) {
|
||||
std::vector<float> & output_scratch, error_stats & total_error, int nsvd, int nsvd_iter, int verbosity, int max_thread = 0) {
|
||||
assert(tensor_is_contiguous(layer));
|
||||
error_stats layer_error {};
|
||||
uint64_t nelements = ggml_nelements(layer);
|
||||
@@ -374,7 +504,7 @@ static void test_roundtrip_on_layer(
|
||||
}
|
||||
|
||||
if (nsvd > 0 && layer->ne[0] > 1 && layer->ne[1] > 1 && layer->ne[2] == 1 && layer->ne[3] == 1) {
|
||||
try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd, nsvd_iter);
|
||||
try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd, nsvd_iter, verbosity);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -388,6 +518,7 @@ int main(int argc, char ** argv) {
|
||||
int max_thread = 0;
|
||||
int nsvd = 0;
|
||||
int nsvd_iter = 0;
|
||||
int verbosity = 1;
|
||||
bool invalid_param = false;
|
||||
std::string arg;
|
||||
for (int i = 1; i < argc; i++) {
|
||||
@@ -416,6 +547,12 @@ int main(int argc, char ** argv) {
|
||||
break;
|
||||
}
|
||||
nsvd_iter = atoi(argv[i]);
|
||||
} else if (arg == "-sv" || arg == "--svd-verbosity") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
verbosity = atoi(argv[i]);
|
||||
} else if (arg == "-m" || arg == "--model") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@@ -581,6 +718,7 @@ int main(int argc, char ** argv) {
|
||||
global_stats,
|
||||
nsvd,
|
||||
nsvd_iter,
|
||||
verbosity,
|
||||
max_thread
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user