SVD POC: simdify (AVX2)

This commit is contained in:
Iwan Kawrakow
2024-08-14 18:12:09 +03:00
parent 4dfc15f92f
commit 301bcd4d21
2 changed files with 89 additions and 6 deletions

View File

@@ -19,6 +19,10 @@
#include <mutex>
#include <chrono>
#ifdef __AVX2__
#include <immintrin.h>
#endif
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
@@ -169,42 +173,107 @@ static void test_roundtrip_on_chunk(
update_error_stats(chunk_size, input_scratch, output_scratch, stats);
}
static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd) {
#ifdef __AVX2__
static inline float hsum_float_4(__m128 x) {
x = _mm_add_ps(x, _mm_movehl_ps(x, x));
x = _mm_add_ss(x, _mm_movehdup_ps(x));
return _mm_cvtss_f32(x);
}
static inline float hsum_float_8(__m256 x) {
return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
}
#endif
static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd, int nsvd_iter) {
constexpr int kNiter = 10;
if (nsvd_iter < 1) nsvd_iter = kNiter;
auto tim1 = std::chrono::steady_clock::now();
int nelem = n_per_row*nrows;
double mse = 0;
bool use_avx2 = false;
#ifdef __AVX2__
GGML_ASSERT(n_per_row%64 == 0);
use_avx2 = true;
#endif
for (int j = 0; j < nelem; ++j) {
q[j] = b[j] - q[j];
mse += q[j]*q[j];
}
printf("===================== %s(%d x %d, %d): rmse = %g\n", __func__, n_per_row, nrows, nsvd, sqrt(mse/nelem));
printf("===================== %s(%d x %d, %d, %d): rmse = %g\n", __func__, n_per_row, nrows, nsvd, use_avx2, sqrt(mse/nelem));
float mse_old = mse;
std::vector<float> f(n_per_row), g(nrows, 1);
for (int isvd = 0; isvd < nsvd; ++isvd) {
printf("--- isvd = %d\n", isvd);
float norm = 1.f/nrows;
for (int iter = 0; iter < kNiter; ++iter) {
for (int iter = 0; iter < nsvd_iter; ++iter) {
std::memset(f.data(), 0, f.size()*sizeof(float));
float sum_f = 0;
#ifdef __AVX2__
auto vnorm = _mm256_set1_ps(norm);
__m256 sums[8] = {};
for (int part = 0; part < n_per_row/64; ++part) {
for (int row = 0; row < nrows; ++row) {
__m256 vg = _mm256_set1_ps(g[row]);
auto qr = q + row*n_per_row + 64*part;
for (int k = 0; k < 8; ++k) {
auto vq = _mm256_loadu_ps(qr + 8*k);
sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]);
}
}
__m256 tot = _mm256_setzero_ps();
for (int k = 0; k < 8; ++k) {
sums[k] = _mm256_mul_ps(vnorm, sums[k]);
_mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]);
tot = _mm256_fmadd_ps(sums[k], sums[k], tot);
sums[k] = _mm256_setzero_ps();
}
sum_f += hsum_float_8(tot);
}
#else
for (int row = 0; row < nrows; ++row) {
auto qr = q + row*n_per_row;
for (int j = 0; j < n_per_row; ++j) f[j] += g[row]*qr[j];
}
float sum_f = 0;
for (int j = 0; j < n_per_row; ++j) { f[j] *= norm; sum_f += f[j]*f[j]; }
#endif
mse = 0;
float sum_g = 0;
for (int row = 0; row < nrows; ++row) {
auto qr = q + row*n_per_row;
float sum = 0;
#ifdef __AVX2__
__m256 vsum = _mm256_setzero_ps();
for (int j = 0; j < n_per_row; j += 8) {
auto vq = _mm256_loadu_ps(qr + j);
auto vf = _mm256_loadu_ps(f.data() + j);
vsum = _mm256_fmadd_ps(vq, vf, vsum);
}
sum = hsum_float_8(vsum);
#else
for (int j = 0; j < n_per_row; ++j) sum += qr[j]*f[j];
#endif
g[row] = sum/sum_f;
sum_g += g[row]*g[row];
#ifdef __AVX2__
auto vg = _mm256_set1_ps(g[row]);
auto vmse = _mm256_setzero_ps();
for (int j = 0; j < n_per_row; j += 8) {
auto vq = _mm256_loadu_ps(qr + j);
auto vf = _mm256_loadu_ps(f.data() + j);
auto vdiff = _mm256_sub_ps(vq, _mm256_mul_ps(vg, vf));
vmse = _mm256_fmadd_ps(vdiff, vdiff, vmse);
}
mse += hsum_float_8(vmse);
#else
for (int j = 0; j < n_per_row; ++j) {
float diff = qr[j] - g[row]*f[j];
mse += diff*diff;
}
#endif
}
printf(" after %d iterations: %g\n", iter+1, sqrt(mse/nelem));
if (mse_old/mse - 1 < 1e-6f) break;
mse_old = mse;
norm = 1.f/sum_g;
}
for (int row = 0; row < nrows; ++row) {
@@ -224,7 +293,7 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns
static void test_roundtrip_on_layer(
std::string & name, bool print_layer_stats, const ggml_type_traits_t & qfns, bool use_reference,
const ggml_tensor * layer, std::vector<float> & input_scratch, std::vector<char> & quantized_scratch,
std::vector<float> & output_scratch, error_stats & total_error, int nsvd, int max_thread = 0) {
std::vector<float> & output_scratch, error_stats & total_error, int nsvd, int nsvd_iter, int max_thread = 0) {
assert(tensor_is_contiguous(layer));
error_stats layer_error {};
uint64_t nelements = ggml_nelements(layer);
@@ -277,7 +346,7 @@ static void test_roundtrip_on_layer(
}
if (nsvd > 0 && layer->ne[0] > 1 && layer->ne[1] > 1 && layer->ne[2] == 1 && layer->ne[3] == 1) {
try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd);
try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd, nsvd_iter);
}
}
@@ -290,6 +359,7 @@ int main(int argc, char ** argv) {
int max_thread = 0;
int nsvd = 0;
int nsvd_iter = 0;
bool invalid_param = false;
std::string arg;
for (int i = 1; i < argc; i++) {
@@ -312,6 +382,12 @@ int main(int argc, char ** argv) {
break;
}
nsvd = atoi(argv[i]);
} else if (arg == "-ni" || arg == "--svd-iterations") {
if (++i >= argc) {
invalid_param = true;
break;
}
nsvd_iter = atoi(argv[i]);
} else if (arg == "-m" || arg == "--model") {
if (++i >= argc) {
invalid_param = true;
@@ -476,6 +552,7 @@ int main(int argc, char ** argv) {
output_scratch,
global_stats,
nsvd,
nsvd_iter,
max_thread
);
}