Also apply to iq2_tn

This commit is contained in:
Iwan Kawrakow
2024-09-17 09:46:21 +03:00
parent 94cdadd559
commit 07b5d73837

View File

@@ -1198,8 +1198,8 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D
sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[3], q8q);
// The scale is supposed to be per per tensor, so we can use the same scale
auto vd = _mm512_set1_ps(d*q8.scale(iy, i));
accd[2*iy+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
accd[2*iy+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
accd[iy+ 0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[iy+ 0]);
accd[iy+nrc_y] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[iy+nrc_y]);
// Leaving this here just in case ternary models start using per row scales
//accd[2*iy+0] = _mm512_fmadd_ps(_mm512_set1_ps(deq1.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
//accd[2*iy+1] = _mm512_fmadd_ps(_mm512_set1_ps(deq2.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
@@ -1207,9 +1207,21 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, _mm512_reduce_add_ps(accd[2*iy+0]));
info.store(ix+1, iy, _mm512_reduce_add_ps(accd[2*iy+1]));
if constexpr (nrc_y == 8) {
__m256 sums[8];
for (int iy = 0; iy < nrc_y; ++iy) {
sums[iy] = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
}
store_8(ix+0, sums, info);
for (int iy = 0; iy < nrc_y; ++iy) {
sums[iy] = _mm256_add_ps(_mm512_castps512_ps256(accd[iy+nrc_y]), _mm512_extractf32x8_ps(accd[iy+nrc_y], 1));
}
store_8(ix+1, sums, info);
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, _mm512_reduce_add_ps(accd[iy+ 0]));
info.store(ix+1, iy, _mm512_reduce_add_ps(accd[iy+nrc_y]));
}
}
}