From 931e4615df80c19b30913b25c0333edab7b94d76 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 14 Aug 2024 20:26:40 +0300 Subject: [PATCH] SVD POC: multi-threading --- examples/quantize-stats/quantize-stats.cpp | 344 +++++++++++++++------ 1 file changed, 241 insertions(+), 103 deletions(-) diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 7848d9ae..86b6fc52 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -184,9 +184,129 @@ static inline float hsum_float_8(__m256 x) { } #endif -static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd, int nsvd_iter) { +static void f_helper(int nrows, int stride, float norm, const float * g, const float * q, float * f, float& sum_f) { +#ifdef __AVX2__ + auto vnorm = _mm256_set1_ps(norm); + __m256 sums[8] = {}; + for (int row = 0; row < nrows; ++row) { + __m256 vg = _mm256_set1_ps(g[row]); + auto qr = q + row*stride; + for (int k = 0; k < 8; ++k) { + auto vq = _mm256_loadu_ps(qr + 8*k); + sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]); + } + } + __m256 tot = _mm256_setzero_ps(); + for (int k = 0; k < 8; ++k) { + sums[k] = _mm256_mul_ps(vnorm, sums[k]); + _mm256_storeu_ps(f + 8*k, sums[k]); + tot = _mm256_fmadd_ps(sums[k], sums[k], tot); + sums[k] = _mm256_setzero_ps(); + } + sum_f += hsum_float_8(tot); +#else + std::memset(f, 0, 64*sizeof(float)); + for (int row = 0; row < nrows; ++row) { + auto qr = q + row*stride; + for (int k = 0; k < 64; ++k) { + f[k] += qr[k]*g[row]; + } + } + float s = 0; + for (int k = 0; k < 64; ++k) { + f[k] *= norm; + s += f[k]*f[k]; + } + sum_f += s; +#endif +} + +static void g_helper(int n_per_row, const float * qr, const float * f, float norm, float& g, float& mse) { + float sum_g = 0; + float sum = 0; +#ifdef __AVX2__ + __m256 vsum = _mm256_setzero_ps(); + for (int j = 0; j < n_per_row; j += 8) { + auto vq = _mm256_loadu_ps(qr + j); + auto vf = _mm256_loadu_ps(f + j); + vsum = _mm256_fmadd_ps(vq, vf, vsum); + } + sum = hsum_float_8(vsum); +#else + for (int j = 0; j < n_per_row; ++j) sum += qr[j]*f[j]; +#endif + g = sum * norm; +#ifdef __AVX2__ + auto vg = _mm256_set1_ps(g); + auto vmse = _mm256_setzero_ps(); + for (int j = 0; j < n_per_row; j += 8) { + auto vq = _mm256_loadu_ps(qr + j); + auto vf = _mm256_loadu_ps(f + j); + auto vdiff = _mm256_sub_ps(vq, _mm256_mul_ps(vg, vf)); + vmse = _mm256_fmadd_ps(vdiff, vdiff, vmse); + } + mse += hsum_float_8(vmse); +#else + for (int j = 0; j < n_per_row; ++j) { + float diff = qr[j] - g*f[j]; + mse += diff*diff; + } +#endif +} + +static void do_svd_iteration(int n_per_row, int nrows, const float * q, float * f, float * g, float& f_norm, float& mse, + std::vector& workers, std::vector& work) { + GGML_ASSERT(n_per_row % 64 == 0); + GGML_ASSERT(nrows % 16 == 0); + GGML_ASSERT(!workers.empty()); + + if (work.size() < 2*workers.size()) work.resize(2*workers.size()); + int nblock = n_per_row/64; + + auto compute_f = [&] (int ith) { + float sum_f = 0; + for (int i = ith; i < nblock; i += workers.size()) { + f_helper(nrows, n_per_row, f_norm, g, q + 64*i, f + 64*i, sum_f); + } + work[ith] = sum_f; + }; + for (int i = 0; i < int(workers.size())-1; ++i) workers[i] = std::thread(compute_f, i); + compute_f(workers.size()-1); + for (int i = 0; i < int(workers.size())-1; ++i) workers[i].join(); + + float sum_f = 0; for (int i = 0; i < int(workers.size()); ++i) sum_f += work[i]; + float g_norm = 1/sum_f; + + nblock = nrows/16; + auto compute_g = [&] (int ith) { + float sum_g = 0, mse = 0; + for (int i = ith; i < nblock; i += workers.size()) { + for (int j = 0; j < 16; ++j) { + g_helper(n_per_row, q + (16*i + j)*n_per_row, f, g_norm, g[16*i + j], mse); + sum_g += g[16*i + j]*g[16*i + j]; + } + } + work[2*ith+0] = sum_g; + work[2*ith+1] = mse; + }; + for (int i = 0; i < int(workers.size())-1; ++i) workers[i] = std::thread(compute_g, i); + compute_g(workers.size()-1); + for (int i = 0; i < int(workers.size())-1; ++i) workers[i].join(); + + float sum_g = 0; mse = 0; + for (int i = 0; i < int(workers.size()); ++i) { + sum_g += work[2*i+0]; + mse += work[2*i+1]; + } + + f_norm = 1/sum_g; + +} + +static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd, int nsvd_iter, int verbosity = 1) { constexpr int kNiter = 10; if (nsvd_iter < 1) nsvd_iter = kNiter; + if (nsvd > nrows) nsvd = nrows; auto tim1 = std::chrono::steady_clock::now(); int nelem = n_per_row*nrows; double mse = 0; @@ -199,111 +319,121 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns q[j] = b[j] - q[j]; mse += q[j]*q[j]; } - printf("===================== %s(%d x %d, %d, %d): rmse = %g\n", __func__, n_per_row, nrows, nsvd, use_avx2, sqrt(mse/nelem)); + int nthread = std::max(1, int(std::thread::hardware_concurrency()/2)); + std::vector workers(nthread); + std::vector work; + if (verbosity > 0) printf("===================== %s(%d x %d, %d, %d): rmse = %g\n", __func__, n_per_row, nrows, nsvd, use_avx2, sqrt(mse/nelem)); float mse_old = mse; std::vector f(n_per_row), g(nrows, 1); for (int isvd = 0; isvd < nsvd; ++isvd) { - printf("--- isvd = %d\n", isvd); + if (verbosity > 1) printf("--- isvd = %d\n", isvd); float norm = 1.f/nrows; for (int iter = 0; iter < nsvd_iter; ++iter) { - std::memset(f.data(), 0, f.size()*sizeof(float)); - float sum_f = 0; -#ifdef __AVX2__ - int part = 0; -#ifdef __AVX512F__ - { - auto vnorm = _mm512_set1_ps(norm); - __m512 sums[16] = {}; - for (; part < n_per_row/256; ++part) { - for (int row = 0; row < nrows; ++row) { - __m512 vg = _mm512_set1_ps(g[row]); - auto qr = q + row*n_per_row + 256*part; - for (int k = 0; k < 16; ++k) { - auto vq = _mm512_loadu_ps(qr + 16*k); - sums[k] = _mm512_fmadd_ps(vg, vq, sums[k]); - } - } - __m512 tot = _mm512_setzero_ps(); - for (int k = 0; k < 16; ++k) { - sums[k] = _mm512_mul_ps(vnorm, sums[k]); - _mm512_storeu_ps(f.data() + 256*part + 16*k, sums[k]); - tot = _mm512_fmadd_ps(sums[k], sums[k], tot); - sums[k] = _mm512_setzero_ps(); - } - sum_f += _mm512_reduce_add_ps(tot); - } - part = 4*(n_per_row/256); - } -#endif - if (part < n_per_row/64) { - auto vnorm = _mm256_set1_ps(norm); - __m256 sums[8] = {}; - for (; part < n_per_row/64; ++part) { - for (int row = 0; row < nrows; ++row) { - __m256 vg = _mm256_set1_ps(g[row]); - auto qr = q + row*n_per_row + 64*part; - for (int k = 0; k < 8; ++k) { - auto vq = _mm256_loadu_ps(qr + 8*k); - sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]); - } - } - __m256 tot = _mm256_setzero_ps(); - for (int k = 0; k < 8; ++k) { - sums[k] = _mm256_mul_ps(vnorm, sums[k]); - _mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]); - tot = _mm256_fmadd_ps(sums[k], sums[k], tot); - sums[k] = _mm256_setzero_ps(); - } - sum_f += hsum_float_8(tot); - } - } -#else - for (int row = 0; row < nrows; ++row) { - auto qr = q + row*n_per_row; - for (int j = 0; j < n_per_row; ++j) f[j] += g[row]*qr[j]; - } - for (int j = 0; j < n_per_row; ++j) { f[j] *= norm; sum_f += f[j]*f[j]; } -#endif - mse = 0; - float sum_g = 0; - for (int row = 0; row < nrows; ++row) { - auto qr = q + row*n_per_row; - float sum = 0; -#ifdef __AVX2__ - __m256 vsum = _mm256_setzero_ps(); - for (int j = 0; j < n_per_row; j += 8) { - auto vq = _mm256_loadu_ps(qr + j); - auto vf = _mm256_loadu_ps(f.data() + j); - vsum = _mm256_fmadd_ps(vq, vf, vsum); - } - sum = hsum_float_8(vsum); -#else - for (int j = 0; j < n_per_row; ++j) sum += qr[j]*f[j]; -#endif - g[row] = sum/sum_f; - sum_g += g[row]*g[row]; -#ifdef __AVX2__ - auto vg = _mm256_set1_ps(g[row]); - auto vmse = _mm256_setzero_ps(); - for (int j = 0; j < n_per_row; j += 8) { - auto vq = _mm256_loadu_ps(qr + j); - auto vf = _mm256_loadu_ps(f.data() + j); - auto vdiff = _mm256_sub_ps(vq, _mm256_mul_ps(vg, vf)); - vmse = _mm256_fmadd_ps(vdiff, vdiff, vmse); - } - mse += hsum_float_8(vmse); -#else - for (int j = 0; j < n_per_row; ++j) { - float diff = qr[j] - g[row]*f[j]; - mse += diff*diff; - } -#endif - } - printf(" after %d iterations: %g\n", iter+1, sqrt(mse/nelem)); - if (mse_old/mse - 1 < 1e-6f) break; - mse_old = mse; - norm = 1.f/sum_g; + float this_mse = 0; + do_svd_iteration(n_per_row, nrows, q, f.data(), g.data(), norm, this_mse, workers, work); + if (verbosity > 1) printf(" after %d iterations: %g\n", iter+1, sqrt(this_mse/nelem)); + if (mse_old/this_mse - 1 < 1e-6f) break; + mse_old = this_mse; } +// for (int iter = 0; iter < nsvd_iter; ++iter) { +// std::memset(f.data(), 0, f.size()*sizeof(float)); +// float sum_f = 0; +//#ifdef __AVX2__ +// int part = 0; +//#ifdef __AVX512F__ +// { +// auto vnorm = _mm512_set1_ps(norm); +// __m512 sums[16] = {}; +// for (; part < n_per_row/256; ++part) { +// for (int row = 0; row < nrows; ++row) { +// __m512 vg = _mm512_set1_ps(g[row]); +// auto qr = q + row*n_per_row + 256*part; +// for (int k = 0; k < 16; ++k) { +// auto vq = _mm512_loadu_ps(qr + 16*k); +// sums[k] = _mm512_fmadd_ps(vg, vq, sums[k]); +// } +// } +// __m512 tot = _mm512_setzero_ps(); +// for (int k = 0; k < 16; ++k) { +// sums[k] = _mm512_mul_ps(vnorm, sums[k]); +// _mm512_storeu_ps(f.data() + 256*part + 16*k, sums[k]); +// tot = _mm512_fmadd_ps(sums[k], sums[k], tot); +// sums[k] = _mm512_setzero_ps(); +// } +// sum_f += _mm512_reduce_add_ps(tot); +// } +// part = 4*(n_per_row/256); +// } +//#endif +// if (part < n_per_row/64) { +// auto vnorm = _mm256_set1_ps(norm); +// __m256 sums[8] = {}; +// for (; part < n_per_row/64; ++part) { +// for (int row = 0; row < nrows; ++row) { +// __m256 vg = _mm256_set1_ps(g[row]); +// auto qr = q + row*n_per_row + 64*part; +// for (int k = 0; k < 8; ++k) { +// auto vq = _mm256_loadu_ps(qr + 8*k); +// sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]); +// } +// } +// __m256 tot = _mm256_setzero_ps(); +// for (int k = 0; k < 8; ++k) { +// sums[k] = _mm256_mul_ps(vnorm, sums[k]); +// _mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]); +// tot = _mm256_fmadd_ps(sums[k], sums[k], tot); +// sums[k] = _mm256_setzero_ps(); +// } +// sum_f += hsum_float_8(tot); +// } +// } +//#else +// for (int row = 0; row < nrows; ++row) { +// auto qr = q + row*n_per_row; +// for (int j = 0; j < n_per_row; ++j) f[j] += g[row]*qr[j]; +// } +// for (int j = 0; j < n_per_row; ++j) { f[j] *= norm; sum_f += f[j]*f[j]; } +//#endif +// mse = 0; +// float sum_g = 0; +// for (int row = 0; row < nrows; ++row) { +// auto qr = q + row*n_per_row; +// float sum = 0; +//#ifdef __AVX2__ +// __m256 vsum = _mm256_setzero_ps(); +// for (int j = 0; j < n_per_row; j += 8) { +// auto vq = _mm256_loadu_ps(qr + j); +// auto vf = _mm256_loadu_ps(f.data() + j); +// vsum = _mm256_fmadd_ps(vq, vf, vsum); +// } +// sum = hsum_float_8(vsum); +//#else +// for (int j = 0; j < n_per_row; ++j) sum += qr[j]*f[j]; +//#endif +// g[row] = sum/sum_f; +// sum_g += g[row]*g[row]; +//#ifdef __AVX2__ +// auto vg = _mm256_set1_ps(g[row]); +// auto vmse = _mm256_setzero_ps(); +// for (int j = 0; j < n_per_row; j += 8) { +// auto vq = _mm256_loadu_ps(qr + j); +// auto vf = _mm256_loadu_ps(f.data() + j); +// auto vdiff = _mm256_sub_ps(vq, _mm256_mul_ps(vg, vf)); +// vmse = _mm256_fmadd_ps(vdiff, vdiff, vmse); +// } +// mse += hsum_float_8(vmse); +//#else +// for (int j = 0; j < n_per_row; ++j) { +// float diff = qr[j] - g[row]*f[j]; +// mse += diff*diff; +// } +//#endif +// } +// printf(" after %d iterations: %g\n", iter+1, sqrt(mse/nelem)); +// if (mse_old/mse - 1 < 1e-6f) break; +// mse_old = mse; +// norm = 1.f/sum_g; +// } for (int row = 0; row < nrows; ++row) { auto qr = q + row*n_per_row; for (int j = 0; j < n_per_row; ++j) { @@ -313,7 +443,7 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns } } auto tim2 = std::chrono::steady_clock::now(); - printf("%s: finished in %g s\n", __func__, 1e-3*std::chrono::duration_cast(tim2-tim1).count()); + if (verbosity > 0) printf("%s: finished in %g s. Final rmse = %g\n", __func__, 1e-3*std::chrono::duration_cast(tim2-tim1).count(), sqrt(mse_old/nelem)); } @@ -321,7 +451,7 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns static void test_roundtrip_on_layer( std::string & name, bool print_layer_stats, const ggml_type_traits_t & qfns, bool use_reference, const ggml_tensor * layer, std::vector & input_scratch, std::vector & quantized_scratch, - std::vector & output_scratch, error_stats & total_error, int nsvd, int nsvd_iter, int max_thread = 0) { + std::vector & output_scratch, error_stats & total_error, int nsvd, int nsvd_iter, int verbosity, int max_thread = 0) { assert(tensor_is_contiguous(layer)); error_stats layer_error {}; uint64_t nelements = ggml_nelements(layer); @@ -374,7 +504,7 @@ static void test_roundtrip_on_layer( } if (nsvd > 0 && layer->ne[0] > 1 && layer->ne[1] > 1 && layer->ne[2] == 1 && layer->ne[3] == 1) { - try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd, nsvd_iter); + try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd, nsvd_iter, verbosity); } } @@ -388,6 +518,7 @@ int main(int argc, char ** argv) { int max_thread = 0; int nsvd = 0; int nsvd_iter = 0; + int verbosity = 1; bool invalid_param = false; std::string arg; for (int i = 1; i < argc; i++) { @@ -416,6 +547,12 @@ int main(int argc, char ** argv) { break; } nsvd_iter = atoi(argv[i]); + } else if (arg == "-sv" || arg == "--svd-verbosity") { + if (++i >= argc) { + invalid_param = true; + break; + } + verbosity = atoi(argv[i]); } else if (arg == "-m" || arg == "--model") { if (++i >= argc) { invalid_param = true; @@ -581,6 +718,7 @@ int main(int argc, char ** argv) { global_stats, nsvd, nsvd_iter, + verbosity, max_thread ); }