SVD POC: sprinkle some AVX512

This commit is contained in:
Iwan Kawrakow
2024-08-14 18:29:31 +03:00
parent 301bcd4d21
commit 903d389e0f

View File

@@ -209,25 +209,53 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns
std::memset(f.data(), 0, f.size()*sizeof(float));
float sum_f = 0;
#ifdef __AVX2__
auto vnorm = _mm256_set1_ps(norm);
__m256 sums[8] = {};
for (int part = 0; part < n_per_row/64; ++part) {
for (int row = 0; row < nrows; ++row) {
__m256 vg = _mm256_set1_ps(g[row]);
auto qr = q + row*n_per_row + 64*part;
for (int k = 0; k < 8; ++k) {
auto vq = _mm256_loadu_ps(qr + 8*k);
sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]);
int part = 0;
#ifdef __AVX512F__
{
auto vnorm = _mm512_set1_ps(norm);
__m512 sums[16] = {};
for (; part < n_per_row/256; ++part) {
for (int row = 0; row < nrows; ++row) {
__m512 vg = _mm512_set1_ps(g[row]);
auto qr = q + row*n_per_row + 256*part;
for (int k = 0; k < 16; ++k) {
auto vq = _mm512_loadu_ps(qr + 16*k);
sums[k] = _mm512_fmadd_ps(vg, vq, sums[k]);
}
}
__m512 tot = _mm512_setzero_ps();
for (int k = 0; k < 16; ++k) {
sums[k] = _mm512_mul_ps(vnorm, sums[k]);
_mm512_storeu_ps(f.data() + 256*part + 16*k, sums[k]);
tot = _mm512_fmadd_ps(sums[k], sums[k], tot);
sums[k] = _mm512_setzero_ps();
}
sum_f += _mm512_reduce_add_ps(tot);
}
__m256 tot = _mm256_setzero_ps();
for (int k = 0; k < 8; ++k) {
sums[k] = _mm256_mul_ps(vnorm, sums[k]);
_mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]);
tot = _mm256_fmadd_ps(sums[k], sums[k], tot);
sums[k] = _mm256_setzero_ps();
part = 4*(n_per_row/256);
}
#endif
if (part < n_per_row/64) {
auto vnorm = _mm256_set1_ps(norm);
__m256 sums[8] = {};
for (; part < n_per_row/64; ++part) {
for (int row = 0; row < nrows; ++row) {
__m256 vg = _mm256_set1_ps(g[row]);
auto qr = q + row*n_per_row + 64*part;
for (int k = 0; k < 8; ++k) {
auto vq = _mm256_loadu_ps(qr + 8*k);
sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]);
}
}
__m256 tot = _mm256_setzero_ps();
for (int k = 0; k < 8; ++k) {
sums[k] = _mm256_mul_ps(vnorm, sums[k]);
_mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]);
tot = _mm256_fmadd_ps(sums[k], sums[k], tot);
sums[k] = _mm256_setzero_ps();
}
sum_f += hsum_float_8(tot);
}
sum_f += hsum_float_8(tot);
}
#else
for (int row = 0; row < nrows; ++row) {