diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 7b499c3b..7848d9ae 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -209,25 +209,53 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns std::memset(f.data(), 0, f.size()*sizeof(float)); float sum_f = 0; #ifdef __AVX2__ - auto vnorm = _mm256_set1_ps(norm); - __m256 sums[8] = {}; - for (int part = 0; part < n_per_row/64; ++part) { - for (int row = 0; row < nrows; ++row) { - __m256 vg = _mm256_set1_ps(g[row]); - auto qr = q + row*n_per_row + 64*part; - for (int k = 0; k < 8; ++k) { - auto vq = _mm256_loadu_ps(qr + 8*k); - sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]); + int part = 0; +#ifdef __AVX512F__ + { + auto vnorm = _mm512_set1_ps(norm); + __m512 sums[16] = {}; + for (; part < n_per_row/256; ++part) { + for (int row = 0; row < nrows; ++row) { + __m512 vg = _mm512_set1_ps(g[row]); + auto qr = q + row*n_per_row + 256*part; + for (int k = 0; k < 16; ++k) { + auto vq = _mm512_loadu_ps(qr + 16*k); + sums[k] = _mm512_fmadd_ps(vg, vq, sums[k]); + } } + __m512 tot = _mm512_setzero_ps(); + for (int k = 0; k < 16; ++k) { + sums[k] = _mm512_mul_ps(vnorm, sums[k]); + _mm512_storeu_ps(f.data() + 256*part + 16*k, sums[k]); + tot = _mm512_fmadd_ps(sums[k], sums[k], tot); + sums[k] = _mm512_setzero_ps(); + } + sum_f += _mm512_reduce_add_ps(tot); } - __m256 tot = _mm256_setzero_ps(); - for (int k = 0; k < 8; ++k) { - sums[k] = _mm256_mul_ps(vnorm, sums[k]); - _mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]); - tot = _mm256_fmadd_ps(sums[k], sums[k], tot); - sums[k] = _mm256_setzero_ps(); + part = 4*(n_per_row/256); + } +#endif + if (part < n_per_row/64) { + auto vnorm = _mm256_set1_ps(norm); + __m256 sums[8] = {}; + for (; part < n_per_row/64; ++part) { + for (int row = 0; row < nrows; ++row) { + __m256 vg = _mm256_set1_ps(g[row]); + auto qr = q + row*n_per_row + 64*part; + for (int k = 0; k < 8; ++k) { + auto vq = _mm256_loadu_ps(qr + 8*k); + sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]); + } + } + __m256 tot = _mm256_setzero_ps(); + for (int k = 0; k < 8; ++k) { + sums[k] = _mm256_mul_ps(vnorm, sums[k]); + _mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]); + tot = _mm256_fmadd_ps(sums[k], sums[k], tot); + sums[k] = _mm256_setzero_ps(); + } + sum_f += hsum_float_8(tot); } - sum_f += hsum_float_8(tot); } #else for (int row = 0; row < nrows; ++row) {