This commit is contained in:
Iwan Kawrakow
2024-11-05 13:47:38 +02:00
parent c578478911
commit afe9db7143

View File

@@ -277,11 +277,20 @@ static inline float hsum_float_4(__m128 x) {
static inline float hsum_float_8(__m256 x) { static inline float hsum_float_8(__m256 x) {
return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
} }
static __m256 hsum_float_8x8(__m256 * accm) {
for (int i = 0; i < 4; ++i) {
accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
_mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
}
for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
}
#endif #endif
static void analyze_x(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_elements) { static void analyze_x(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_elements) {
constexpr int kNumVal = 1 << 12; constexpr int kNumVal = 1 << 12;
constexpr int kBlockSize = 8; constexpr int kBlockSize = 8;
static_assert(kNumVal%8 == 0);
auto codes = make_values(kNumVal, kBlockSize); auto codes = make_values(kNumVal, kBlockSize);
std::vector<float> sumq2(kNumVal); std::vector<float> sumq2(kNumVal);
for (int j = 0; j < kNumVal; ++j) { for (int j = 0; j < kNumVal; ++j) {
@@ -307,6 +316,8 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
int last = std::min(first + chunk, nrows); int last = std::min(first + chunk, nrows);
#ifdef __AVX2__ #ifdef __AVX2__
__m256 vx[kBlockSize/8]; __m256 vx[kBlockSize/8];
__m256 sqx[8];
float sx[8];
#endif #endif
for (int row = first; row < last; ++row) { for (int row = first; row < last; ++row) {
auto xr = values + row*n_per_row; auto xr = values + row*n_per_row;
@@ -315,16 +326,19 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
auto xb = xr + kBlockSize*ib; auto xb = xr + kBlockSize*ib;
#ifdef __AVX2__ #ifdef __AVX2__
for (int l = 0; l < kBlockSize/8; ++l) vx[l] = _mm256_loadu_ps(xb+8*l); for (int l = 0; l < kBlockSize/8; ++l) vx[l] = _mm256_loadu_ps(xb+8*l);
for (int j = 0; j < kNumVal; ++j) { for (int j = 0; j < kNumVal; j += 8) {
if (!sumq2[j]) continue; for (int i = 0; i < 8; ++i) {
auto sx = _mm256_setzero_ps(); sqx[i] = _mm256_setzero_ps();
for (int l = 0; l < kBlockSize/8; ++l) { for (int l = 0; l < kBlockSize/8; ++l) {
auto qv = _mm256_loadu_ps(codes.data() + kBlockSize*j + 8*l); auto qv = _mm256_loadu_ps(codes.data() + kBlockSize*(j+i) + 8*l);
sx = _mm256_fmadd_ps(vx[l], qv, sx); sqx[i] = _mm256_fmadd_ps(vx[l], qv, sqx[i]);
}
} }
float sumqx = hsum_float_8(sx); _mm256_storeu_ps(sx, hsum_float_8x8(sqx));
if (sumqx*sumqx > best*sumq2[j]) { for (int i = 0; i < 8; ++i) {
d = sumqx/sumq2[j]; best = d*sumqx; jbest = j; if (sumq2[j+i] > 0 && sx[i]*sx[i] > best*sumq2[j+i]) {
d = sx[i]/sumq2[j+i]; best = d*sx[i]; jbest = j+i;
}
} }
} }
#else #else