iq2_tn: AVX2 PP improvement

We now get PP-512 = 490.73 t/s for TriLM-3.9B on the Ryzen-5975WX.
We have PP-512 = 636.61 t/s for Bintnet-3B quantized with iq2_bn.
Bintnet-3B is actually 3.4B, TriLM-3.9B is 3.99B, so we would
expect 3.43/3.99 * 636 = 546 t/s, so it seems we still have something
that is not quite optimal in iq2_tn.
This commit is contained in:
Iwan Kawrakow
2024-08-06 12:34:44 +03:00
parent 2cc6338670
commit 9780ac4591

View File

@@ -1510,18 +1510,13 @@ IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const Da
deq.new_row(ix);
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
__m256i sumi[nrc_y];
//deq.new_block(i, q8, sumi);
deq.new_block(i);
deq.prepare(i, 0);
for (int iy = 0; iy < nrc_y; ++iy) {
//sumi[iy] = _mm256_sub_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 0)),
// _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 1))), sumi[iy]);
sumi[iy] = _mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 0)),
_mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 1)));
sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[2], q8.load_quants(iy, i, 2)),
@@ -1535,8 +1530,14 @@ IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const Da
_mm256_maddubs_epi16(deq.bits.values[3], q8.load_quants(iy, i, 7))), sumi[iy]);
sumi[iy] = _mm256_sub_epi16(sumi[iy], q8.load_bsums(iy, i));
}
for (int iy = 0; iy < nrc_y; ++iy) {
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy])), accd[iy]);
if (i > 0) {
for (int iy = 0; iy < nrc_y; ++iy) {
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy])), accd[iy]);
}
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
accd[iy] = _mm256_mul_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy])));
}
}
}
@@ -2040,7 +2041,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
_mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)),
_mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot3), _mm256_maddubs_epi16(deq.m1_8, dot4))));
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
accd[iy] = i > 0 ? _mm256_add_epi32(dot, accd[iy]) : dot;
#endif
}
}
@@ -3275,7 +3276,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[2] = mul_mat_iq2tn_q8_K<3>;
mm.funcs[3] = mul_mat_iq2tn_q8_K<4>;
mm.funcs[4] = mul_mat_iq2tn_q8_K<5>;
mm.funcs[5] = mul_mat_iq2tn_q8_K<6>;
//mm.funcs[5] = mul_mat_iq2tn_q8_K<6>;
//mm.funcs[6] = mul_mat_iq2tn_q8_K<7>;
//mm.funcs[7] = mul_mat_iq2tn_q8_K<8>;
#endif