mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
Also apply to iq2_tn
This commit is contained in:
@@ -1198,8 +1198,8 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D
|
|||||||
sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[3], q8q);
|
sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[3], q8q);
|
||||||
// The scale is supposed to be per per tensor, so we can use the same scale
|
// The scale is supposed to be per per tensor, so we can use the same scale
|
||||||
auto vd = _mm512_set1_ps(d*q8.scale(iy, i));
|
auto vd = _mm512_set1_ps(d*q8.scale(iy, i));
|
||||||
accd[2*iy+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
|
accd[iy+ 0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[iy+ 0]);
|
||||||
accd[2*iy+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
|
accd[iy+nrc_y] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[iy+nrc_y]);
|
||||||
// Leaving this here just in case ternary models start using per row scales
|
// Leaving this here just in case ternary models start using per row scales
|
||||||
//accd[2*iy+0] = _mm512_fmadd_ps(_mm512_set1_ps(deq1.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
|
//accd[2*iy+0] = _mm512_fmadd_ps(_mm512_set1_ps(deq1.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
|
||||||
//accd[2*iy+1] = _mm512_fmadd_ps(_mm512_set1_ps(deq2.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
|
//accd[2*iy+1] = _mm512_fmadd_ps(_mm512_set1_ps(deq2.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
|
||||||
@@ -1207,9 +1207,21 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
if constexpr (nrc_y == 8) {
|
||||||
info.store(ix+0, iy, _mm512_reduce_add_ps(accd[2*iy+0]));
|
__m256 sums[8];
|
||||||
info.store(ix+1, iy, _mm512_reduce_add_ps(accd[2*iy+1]));
|
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||||
|
sums[iy] = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
|
||||||
|
}
|
||||||
|
store_8(ix+0, sums, info);
|
||||||
|
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||||
|
sums[iy] = _mm256_add_ps(_mm512_castps512_ps256(accd[iy+nrc_y]), _mm512_extractf32x8_ps(accd[iy+nrc_y], 1));
|
||||||
|
}
|
||||||
|
store_8(ix+1, sums, info);
|
||||||
|
} else {
|
||||||
|
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||||
|
info.store(ix+0, iy, _mm512_reduce_add_ps(accd[iy+ 0]));
|
||||||
|
info.store(ix+1, iy, _mm512_reduce_add_ps(accd[iy+nrc_y]));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user