mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-08 07:20:12 +00:00
SVD POC: simdify (AVX2)
This commit is contained in:
@@ -19,6 +19,10 @@
|
||||
#include <mutex>
|
||||
#include <chrono>
|
||||
|
||||
#ifdef __AVX2__
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
@@ -169,42 +173,107 @@ static void test_roundtrip_on_chunk(
|
||||
update_error_stats(chunk_size, input_scratch, output_scratch, stats);
|
||||
}
|
||||
|
||||
static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd) {
|
||||
#ifdef __AVX2__
|
||||
static inline float hsum_float_4(__m128 x) {
|
||||
x = _mm_add_ps(x, _mm_movehl_ps(x, x));
|
||||
x = _mm_add_ss(x, _mm_movehdup_ps(x));
|
||||
return _mm_cvtss_f32(x);
|
||||
}
|
||||
static inline float hsum_float_8(__m256 x) {
|
||||
return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
|
||||
}
|
||||
#endif
|
||||
|
||||
static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd, int nsvd_iter) {
|
||||
constexpr int kNiter = 10;
|
||||
if (nsvd_iter < 1) nsvd_iter = kNiter;
|
||||
auto tim1 = std::chrono::steady_clock::now();
|
||||
int nelem = n_per_row*nrows;
|
||||
double mse = 0;
|
||||
bool use_avx2 = false;
|
||||
#ifdef __AVX2__
|
||||
GGML_ASSERT(n_per_row%64 == 0);
|
||||
use_avx2 = true;
|
||||
#endif
|
||||
for (int j = 0; j < nelem; ++j) {
|
||||
q[j] = b[j] - q[j];
|
||||
mse += q[j]*q[j];
|
||||
}
|
||||
printf("===================== %s(%d x %d, %d): rmse = %g\n", __func__, n_per_row, nrows, nsvd, sqrt(mse/nelem));
|
||||
printf("===================== %s(%d x %d, %d, %d): rmse = %g\n", __func__, n_per_row, nrows, nsvd, use_avx2, sqrt(mse/nelem));
|
||||
float mse_old = mse;
|
||||
std::vector<float> f(n_per_row), g(nrows, 1);
|
||||
for (int isvd = 0; isvd < nsvd; ++isvd) {
|
||||
printf("--- isvd = %d\n", isvd);
|
||||
float norm = 1.f/nrows;
|
||||
for (int iter = 0; iter < kNiter; ++iter) {
|
||||
for (int iter = 0; iter < nsvd_iter; ++iter) {
|
||||
std::memset(f.data(), 0, f.size()*sizeof(float));
|
||||
float sum_f = 0;
|
||||
#ifdef __AVX2__
|
||||
auto vnorm = _mm256_set1_ps(norm);
|
||||
__m256 sums[8] = {};
|
||||
for (int part = 0; part < n_per_row/64; ++part) {
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
__m256 vg = _mm256_set1_ps(g[row]);
|
||||
auto qr = q + row*n_per_row + 64*part;
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
auto vq = _mm256_loadu_ps(qr + 8*k);
|
||||
sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]);
|
||||
}
|
||||
}
|
||||
__m256 tot = _mm256_setzero_ps();
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
sums[k] = _mm256_mul_ps(vnorm, sums[k]);
|
||||
_mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]);
|
||||
tot = _mm256_fmadd_ps(sums[k], sums[k], tot);
|
||||
sums[k] = _mm256_setzero_ps();
|
||||
}
|
||||
sum_f += hsum_float_8(tot);
|
||||
}
|
||||
#else
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
auto qr = q + row*n_per_row;
|
||||
for (int j = 0; j < n_per_row; ++j) f[j] += g[row]*qr[j];
|
||||
}
|
||||
float sum_f = 0;
|
||||
for (int j = 0; j < n_per_row; ++j) { f[j] *= norm; sum_f += f[j]*f[j]; }
|
||||
#endif
|
||||
mse = 0;
|
||||
float sum_g = 0;
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
auto qr = q + row*n_per_row;
|
||||
float sum = 0;
|
||||
#ifdef __AVX2__
|
||||
__m256 vsum = _mm256_setzero_ps();
|
||||
for (int j = 0; j < n_per_row; j += 8) {
|
||||
auto vq = _mm256_loadu_ps(qr + j);
|
||||
auto vf = _mm256_loadu_ps(f.data() + j);
|
||||
vsum = _mm256_fmadd_ps(vq, vf, vsum);
|
||||
}
|
||||
sum = hsum_float_8(vsum);
|
||||
#else
|
||||
for (int j = 0; j < n_per_row; ++j) sum += qr[j]*f[j];
|
||||
#endif
|
||||
g[row] = sum/sum_f;
|
||||
sum_g += g[row]*g[row];
|
||||
#ifdef __AVX2__
|
||||
auto vg = _mm256_set1_ps(g[row]);
|
||||
auto vmse = _mm256_setzero_ps();
|
||||
for (int j = 0; j < n_per_row; j += 8) {
|
||||
auto vq = _mm256_loadu_ps(qr + j);
|
||||
auto vf = _mm256_loadu_ps(f.data() + j);
|
||||
auto vdiff = _mm256_sub_ps(vq, _mm256_mul_ps(vg, vf));
|
||||
vmse = _mm256_fmadd_ps(vdiff, vdiff, vmse);
|
||||
}
|
||||
mse += hsum_float_8(vmse);
|
||||
#else
|
||||
for (int j = 0; j < n_per_row; ++j) {
|
||||
float diff = qr[j] - g[row]*f[j];
|
||||
mse += diff*diff;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
printf(" after %d iterations: %g\n", iter+1, sqrt(mse/nelem));
|
||||
if (mse_old/mse - 1 < 1e-6f) break;
|
||||
mse_old = mse;
|
||||
norm = 1.f/sum_g;
|
||||
}
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
@@ -224,7 +293,7 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns
|
||||
static void test_roundtrip_on_layer(
|
||||
std::string & name, bool print_layer_stats, const ggml_type_traits_t & qfns, bool use_reference,
|
||||
const ggml_tensor * layer, std::vector<float> & input_scratch, std::vector<char> & quantized_scratch,
|
||||
std::vector<float> & output_scratch, error_stats & total_error, int nsvd, int max_thread = 0) {
|
||||
std::vector<float> & output_scratch, error_stats & total_error, int nsvd, int nsvd_iter, int max_thread = 0) {
|
||||
assert(tensor_is_contiguous(layer));
|
||||
error_stats layer_error {};
|
||||
uint64_t nelements = ggml_nelements(layer);
|
||||
@@ -277,7 +346,7 @@ static void test_roundtrip_on_layer(
|
||||
}
|
||||
|
||||
if (nsvd > 0 && layer->ne[0] > 1 && layer->ne[1] > 1 && layer->ne[2] == 1 && layer->ne[3] == 1) {
|
||||
try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd);
|
||||
try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd, nsvd_iter);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,6 +359,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
int max_thread = 0;
|
||||
int nsvd = 0;
|
||||
int nsvd_iter = 0;
|
||||
bool invalid_param = false;
|
||||
std::string arg;
|
||||
for (int i = 1; i < argc; i++) {
|
||||
@@ -312,6 +382,12 @@ int main(int argc, char ** argv) {
|
||||
break;
|
||||
}
|
||||
nsvd = atoi(argv[i]);
|
||||
} else if (arg == "-ni" || arg == "--svd-iterations") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
nsvd_iter = atoi(argv[i]);
|
||||
} else if (arg == "-m" || arg == "--model") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@@ -476,6 +552,7 @@ int main(int argc, char ** argv) {
|
||||
output_scratch,
|
||||
global_stats,
|
||||
nsvd,
|
||||
nsvd_iter,
|
||||
max_thread
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user