mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-08 07:20:12 +00:00
SVD POC: sprinkle some AVX512
This commit is contained in:
@@ -209,25 +209,53 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns
|
||||
std::memset(f.data(), 0, f.size()*sizeof(float));
|
||||
float sum_f = 0;
|
||||
#ifdef __AVX2__
|
||||
auto vnorm = _mm256_set1_ps(norm);
|
||||
__m256 sums[8] = {};
|
||||
for (int part = 0; part < n_per_row/64; ++part) {
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
__m256 vg = _mm256_set1_ps(g[row]);
|
||||
auto qr = q + row*n_per_row + 64*part;
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
auto vq = _mm256_loadu_ps(qr + 8*k);
|
||||
sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]);
|
||||
int part = 0;
|
||||
#ifdef __AVX512F__
|
||||
{
|
||||
auto vnorm = _mm512_set1_ps(norm);
|
||||
__m512 sums[16] = {};
|
||||
for (; part < n_per_row/256; ++part) {
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
__m512 vg = _mm512_set1_ps(g[row]);
|
||||
auto qr = q + row*n_per_row + 256*part;
|
||||
for (int k = 0; k < 16; ++k) {
|
||||
auto vq = _mm512_loadu_ps(qr + 16*k);
|
||||
sums[k] = _mm512_fmadd_ps(vg, vq, sums[k]);
|
||||
}
|
||||
}
|
||||
__m512 tot = _mm512_setzero_ps();
|
||||
for (int k = 0; k < 16; ++k) {
|
||||
sums[k] = _mm512_mul_ps(vnorm, sums[k]);
|
||||
_mm512_storeu_ps(f.data() + 256*part + 16*k, sums[k]);
|
||||
tot = _mm512_fmadd_ps(sums[k], sums[k], tot);
|
||||
sums[k] = _mm512_setzero_ps();
|
||||
}
|
||||
sum_f += _mm512_reduce_add_ps(tot);
|
||||
}
|
||||
__m256 tot = _mm256_setzero_ps();
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
sums[k] = _mm256_mul_ps(vnorm, sums[k]);
|
||||
_mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]);
|
||||
tot = _mm256_fmadd_ps(sums[k], sums[k], tot);
|
||||
sums[k] = _mm256_setzero_ps();
|
||||
part = 4*(n_per_row/256);
|
||||
}
|
||||
#endif
|
||||
if (part < n_per_row/64) {
|
||||
auto vnorm = _mm256_set1_ps(norm);
|
||||
__m256 sums[8] = {};
|
||||
for (; part < n_per_row/64; ++part) {
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
__m256 vg = _mm256_set1_ps(g[row]);
|
||||
auto qr = q + row*n_per_row + 64*part;
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
auto vq = _mm256_loadu_ps(qr + 8*k);
|
||||
sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]);
|
||||
}
|
||||
}
|
||||
__m256 tot = _mm256_setzero_ps();
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
sums[k] = _mm256_mul_ps(vnorm, sums[k]);
|
||||
_mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]);
|
||||
tot = _mm256_fmadd_ps(sums[k], sums[k], tot);
|
||||
sums[k] = _mm256_setzero_ps();
|
||||
}
|
||||
sum_f += hsum_float_8(tot);
|
||||
}
|
||||
sum_f += hsum_float_8(tot);
|
||||
}
|
||||
#else
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
|
||||
Reference in New Issue
Block a user