Playing with horizontal sums - matrix times vector

This commit is contained in:
Iwan Kawrakow
2024-09-17 09:14:19 +03:00
parent 9790b502e6
commit 94cdadd559

View File

@@ -1286,6 +1286,9 @@ static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const
__m512i scales[2*k_nx];
__m256 sums[8];
int ks = 0;
for (int ix = 0; ix < nrc_x; ++ix) {
auto accd = _mm512_setzero_ps();
@@ -1319,11 +1322,20 @@ static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const
}
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
info.store(ix, 0, _mm512_reduce_add_ps(accd));
sums[ks++] = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
//info.store(ix, 0, _mm512_reduce_add_ps(accd));
} else {
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
sums[ks++] = _mm256_add_ps(accm, sum256);
//info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
}
if (ks == 8) {
_mm256_storeu_ps(info.dst_row(0) + ix - 7, hsum_float_8x8(sums));
ks = 0;
}
}
if (ks > 0) {
for (int ix = 0; ix < ks; ++ix) info.store(ix, 0, hsum_float_8(sums[ix]));
}
}