This commit is contained in:
Iwan Kawrakow
2024-11-05 13:47:38 +02:00
parent c578478911
commit afe9db7143

View File

@@ -277,11 +277,20 @@ static inline float hsum_float_4(__m128 x) {
static inline float hsum_float_8(__m256 x) {
return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
}
static __m256 hsum_float_8x8(__m256 * accm) {
for (int i = 0; i < 4; ++i) {
accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
_mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
}
for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
}
#endif
static void analyze_x(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_elements) {
constexpr int kNumVal = 1 << 12;
constexpr int kBlockSize = 8;
static_assert(kNumVal%8 == 0);
auto codes = make_values(kNumVal, kBlockSize);
std::vector<float> sumq2(kNumVal);
for (int j = 0; j < kNumVal; ++j) {
@@ -307,6 +316,8 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
int last = std::min(first + chunk, nrows);
#ifdef __AVX2__
__m256 vx[kBlockSize/8];
__m256 sqx[8];
float sx[8];
#endif
for (int row = first; row < last; ++row) {
auto xr = values + row*n_per_row;
@@ -315,16 +326,19 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
auto xb = xr + kBlockSize*ib;
#ifdef __AVX2__
for (int l = 0; l < kBlockSize/8; ++l) vx[l] = _mm256_loadu_ps(xb+8*l);
for (int j = 0; j < kNumVal; ++j) {
if (!sumq2[j]) continue;
auto sx = _mm256_setzero_ps();
for (int l = 0; l < kBlockSize/8; ++l) {
auto qv = _mm256_loadu_ps(codes.data() + kBlockSize*j + 8*l);
sx = _mm256_fmadd_ps(vx[l], qv, sx);
for (int j = 0; j < kNumVal; j += 8) {
for (int i = 0; i < 8; ++i) {
sqx[i] = _mm256_setzero_ps();
for (int l = 0; l < kBlockSize/8; ++l) {
auto qv = _mm256_loadu_ps(codes.data() + kBlockSize*(j+i) + 8*l);
sqx[i] = _mm256_fmadd_ps(vx[l], qv, sqx[i]);
}
}
float sumqx = hsum_float_8(sx);
if (sumqx*sumqx > best*sumq2[j]) {
d = sumqx/sumq2[j]; best = d*sumqx; jbest = j;
_mm256_storeu_ps(sx, hsum_float_8x8(sqx));
for (int i = 0; i < 8; ++i) {
if (sumq2[j+i] > 0 && sx[i]*sx[i] > best*sumq2[j+i]) {
d = sx[i]/sumq2[j+i]; best = d*sx[i]; jbest = j+i;
}
}
}
#else