diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index f1c45598..3166c6dc 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -292,18 +292,18 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float * constexpr int kBlockSize = 8; static_assert(kNumVal%8 == 0); auto codes = make_values(kNumVal, kBlockSize); - std::vector sumq2(kNumVal); + std::vector sumq2i(kNumVal); for (int j = 0; j < kNumVal; ++j) { auto data = codes.data() + kBlockSize*j; float sum = 0; for (int k = 0; k < kBlockSize; ++k) sum += data[k]*data[k]; - sumq2[j] = sum; + sumq2i[j] = sum > 0 ? 1/sum : 0.f;; } int nthread = std::max(1, int(std::thread::hardware_concurrency()/2)); int chunk = (nrows + 8*nthread - 1)/(8*nthread); std::mutex mutex; int counter = 0; float mse = 0; - auto compute = [&mutex, &counter, &mse, &codes, &sumq2, values, nrows, n_per_row, chunk] () { + auto compute = [&mutex, &counter, &mse, &codes, &sumq2i, values, nrows, n_per_row, chunk] () { float lmse = 0; while (true) { std::unique_lock lock(mutex); @@ -317,7 +317,9 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float * #ifdef __AVX2__ __m256 vx[kBlockSize/8]; __m256 sqx[8]; + __m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); float sx[8]; + int index[8]; #endif for (int row = first; row < last; ++row) { auto xr = values + row*n_per_row; @@ -326,7 +328,10 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float * auto xb = xr + kBlockSize*ib; #ifdef __AVX2__ for (int l = 0; l < kBlockSize/8; ++l) vx[l] = _mm256_loadu_ps(xb+8*l); + auto vbest = _mm256_set1_ps(0.f); + auto best_index = _mm256_set1_epi32(-1); for (int j = 0; j < kNumVal; j += 8) { + auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx); for (int i = 0; i < 8; ++i) { sqx[i] = _mm256_setzero_ps(); for (int l = 0; l < kBlockSize/8; ++l) { @@ -334,25 +339,40 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float * sqx[i] = _mm256_fmadd_ps(vx[l], qv, sqx[i]); } } - _mm256_storeu_ps(sx, hsum_float_8x8(sqx)); - for (int i = 0; i < 8; ++i) { - if (sumq2[j+i] > 0 && sx[i]*sx[i] > best*sumq2[j+i]) { - d = sx[i]/sumq2[j+i]; best = d*sx[i]; jbest = j+i; - } - } + auto sumqx = hsum_float_8x8(sqx); + auto score = _mm256_mul_ps(_mm256_mul_ps(sumqx, sumqx), _mm256_loadu_ps(sumq2i.data() + j)); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_GT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(idx, _mm256_castps_si256(mask)), _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_max_ps(vbest, score); + //_mm256_storeu_ps(sx, hsum_float_8x8(sqx)); + //for (int i = 0; i < 8; ++i) { + // if (sx[i]*sx[i]*sumq2i[j+i] > best) { + // d = sx[i]*sumq2i[j+i]; best = d*sx[i]; jbest = j+i; + // } + //} } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + best = sx[0]; jbest = index[0]; + for (int j = 1; j < 8; ++j) { + if (sx[j] > best) { best = sx[j]; jbest = index[j]; } + } + auto qv = codes.data() + kBlockSize*jbest; + float sumqx = 0; + for (int k = 0; k < 8; ++k) sumqx += xb[k]*qv[k]; + d = sumqx*sumq2i[jbest]; #else for (int j = 0; j < kNumVal; ++j) { - if (!sumq2[j]) continue; + if (!sumq2i[j]) continue; auto qv = codes.data() + kBlockSize*j; float sumqx = 0; for (int k = 0; k < kBlockSize; ++k) sumqx += qv[k]*xb[k]; - if (sumqx*sumqx > best*sumq2[j]) { - d = sumqx/sumq2[j]; best = d*sumqx; jbest = j; + if (sumqx*sumqx*sumq2i[j] > best]) { + d = sumqx*sumq2i[j]; best = d*sumqx; jbest = j; } } -#endif auto qv = codes.data() + kBlockSize*jbest; +#endif for (int k = 0; k < kBlockSize; ++k) { float diff = xb[k] - d*qv[k]; lmse += diff*diff;