mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
Playing with horizontal sums - matrix times vector
This commit is contained in:
@@ -1286,6 +1286,9 @@ static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const
|
||||
|
||||
__m512i scales[2*k_nx];
|
||||
|
||||
__m256 sums[8];
|
||||
|
||||
int ks = 0;
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
auto accd = _mm512_setzero_ps();
|
||||
@@ -1319,11 +1322,20 @@ static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
|
||||
info.store(ix, 0, _mm512_reduce_add_ps(accd));
|
||||
sums[ks++] = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
|
||||
//info.store(ix, 0, _mm512_reduce_add_ps(accd));
|
||||
} else {
|
||||
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
|
||||
info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
|
||||
sums[ks++] = _mm256_add_ps(accm, sum256);
|
||||
//info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
|
||||
}
|
||||
if (ks == 8) {
|
||||
_mm256_storeu_ps(info.dst_row(0) + ix - 7, hsum_float_8x8(sums));
|
||||
ks = 0;
|
||||
}
|
||||
}
|
||||
if (ks > 0) {
|
||||
for (int ix = 0; ix < ks; ++ix) info.store(ix, 0, hsum_float_8(sums[ix]));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user