diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 0b459466..f1c45598 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -277,11 +277,20 @@ static inline float hsum_float_4(__m128 x) { static inline float hsum_float_8(__m256 x) { return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); } +static __m256 hsum_float_8x8(__m256 * accm) { + for (int i = 0; i < 4; ++i) { + accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), + _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); + } + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); +} #endif static void analyze_x(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_elements) { constexpr int kNumVal = 1 << 12; constexpr int kBlockSize = 8; + static_assert(kNumVal%8 == 0); auto codes = make_values(kNumVal, kBlockSize); std::vector sumq2(kNumVal); for (int j = 0; j < kNumVal; ++j) { @@ -307,6 +316,8 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float * int last = std::min(first + chunk, nrows); #ifdef __AVX2__ __m256 vx[kBlockSize/8]; + __m256 sqx[8]; + float sx[8]; #endif for (int row = first; row < last; ++row) { auto xr = values + row*n_per_row; @@ -315,16 +326,19 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float * auto xb = xr + kBlockSize*ib; #ifdef __AVX2__ for (int l = 0; l < kBlockSize/8; ++l) vx[l] = _mm256_loadu_ps(xb+8*l); - for (int j = 0; j < kNumVal; ++j) { - if (!sumq2[j]) continue; - auto sx = _mm256_setzero_ps(); - for (int l = 0; l < kBlockSize/8; ++l) { - auto qv = _mm256_loadu_ps(codes.data() + kBlockSize*j + 8*l); - sx = _mm256_fmadd_ps(vx[l], qv, sx); + for (int j = 0; j < kNumVal; j += 8) { + for (int i = 0; i < 8; ++i) { + sqx[i] = _mm256_setzero_ps(); + for (int l = 0; l < kBlockSize/8; ++l) { + auto qv = _mm256_loadu_ps(codes.data() + kBlockSize*(j+i) + 8*l); + sqx[i] = _mm256_fmadd_ps(vx[l], qv, sqx[i]); + } } - float sumqx = hsum_float_8(sx); - if (sumqx*sumqx > best*sumq2[j]) { - d = sumqx/sumq2[j]; best = d*sumqx; jbest = j; + _mm256_storeu_ps(sx, hsum_float_8x8(sqx)); + for (int i = 0; i < 8; ++i) { + if (sumq2[j+i] > 0 && sx[i]*sx[i] > best*sumq2[j+i]) { + d = sx[i]/sumq2[j+i]; best = d*sx[i]; jbest = j+i; + } } } #else