mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
iq2_bn_r4: better AVX2
As we don't have enough vector registers on AVX2, it is better to do two passes per row needing only half of the accumulator registers that way. With this, we now beat iq2_bn PP also on AVX2 by a small margin.
This commit is contained in:
@@ -2126,8 +2126,63 @@ static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, con
|
||||
auto m3 = _mm256_set1_epi8(0x3);
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
int nb = n / QK_IQ1BN;
|
||||
__m256i acc[2*nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
if constexpr (nrc_y > 4) {
|
||||
__m256i acc[nrc_y] = {};
|
||||
__m128 sum4[nrc_y];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const float * dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto dl = _mm_loadu_ps(dptr);
|
||||
const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
|
||||
qx[0] = _mm256_and_si256(bits, m3);
|
||||
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
|
||||
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants(iy, 2*ib+0);
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dy = q8.scale(iy);
|
||||
auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
|
||||
auto s4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
|
||||
s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), s4);
|
||||
sum4[iy] = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), s4);
|
||||
acc[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
|
||||
qx[0] = _mm256_and_si256(bits, m3);
|
||||
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
|
||||
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants(iy, 2*ib+1);
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dy = q8.scale(iy);
|
||||
auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
|
||||
auto s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4[iy]);
|
||||
s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), s4);
|
||||
info.store(ix, iy, s4);
|
||||
acc[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
__m256i acc[2*nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const float * dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto dl = _mm_loadu_ps(dptr);
|
||||
@@ -2164,8 +2219,6 @@ static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, con
|
||||
auto dy = q8.scale(iy);
|
||||
auto sumf1 = _mm256_cvtepi32_ps(acc[2*iy+0]);
|
||||
auto sumf2 = _mm256_cvtepi32_ps(acc[2*iy+1]);
|
||||
//auto sumf1 = _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, acc[2*iy+0]));
|
||||
//auto sumf2 = _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, acc[2*iy+1]));
|
||||
auto sum4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
|
||||
sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
|
||||
sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
|
||||
@@ -2175,6 +2228,7 @@ static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, con
|
||||
acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
|
||||
Reference in New Issue
Block a user