mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-28 17:14:17 +00:00
Also apply to iq2_tn
This commit is contained in:
@@ -1198,8 +1198,8 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D
|
||||
sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[3], q8q);
|
||||
// The scale is supposed to be per per tensor, so we can use the same scale
|
||||
auto vd = _mm512_set1_ps(d*q8.scale(iy, i));
|
||||
accd[2*iy+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
|
||||
accd[2*iy+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
|
||||
accd[iy+ 0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[iy+ 0]);
|
||||
accd[iy+nrc_y] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[iy+nrc_y]);
|
||||
// Leaving this here just in case ternary models start using per row scales
|
||||
//accd[2*iy+0] = _mm512_fmadd_ps(_mm512_set1_ps(deq1.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
|
||||
//accd[2*iy+1] = _mm512_fmadd_ps(_mm512_set1_ps(deq2.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
|
||||
@@ -1207,9 +1207,21 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D
|
||||
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix+0, iy, _mm512_reduce_add_ps(accd[2*iy+0]));
|
||||
info.store(ix+1, iy, _mm512_reduce_add_ps(accd[2*iy+1]));
|
||||
if constexpr (nrc_y == 8) {
|
||||
__m256 sums[8];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
sums[iy] = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
|
||||
}
|
||||
store_8(ix+0, sums, info);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
sums[iy] = _mm256_add_ps(_mm512_castps512_ps256(accd[iy+nrc_y]), _mm512_extractf32x8_ps(accd[iy+nrc_y], 1));
|
||||
}
|
||||
store_8(ix+1, sums, info);
|
||||
} else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix+0, iy, _mm512_reduce_add_ps(accd[iy+ 0]));
|
||||
info.store(ix+1, iy, _mm512_reduce_add_ps(accd[iy+nrc_y]));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user