iq2_xs_r4: AVX2

This commit is contained in:
Iwan Kawrakow
2024-12-20 18:06:02 +02:00
parent 51e5a87382
commit d8a274a18f

View File

@@ -3309,15 +3309,24 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data
auto smask = _mm256_set1_epi64x(0x8040201008040201);
auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
auto m4 = _mm256_set1_epi8(4);
auto m1 = _mm256_set1_epi16(1);
#endif
__m256 acc[nrc_y] = {};
#ifdef HAVE_FANCY_SIMD
__m256i shuffles[2] = {
_mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100),
_mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
};
auto s_shuffle = _mm_set_epi64x(0x0f0d0b0907050301, 0x0e0c0a0806040200);
__m256 acc[nrc_y] = {};
__m256i isum[2*nrc_y] = {};
#else
__m256i shuffles[4] = {
MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)),
MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)),
MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)),
MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)),
};
__m256i isum[nrc_y] = {};
#endif
auto s_shuffle = _mm_set_epi64x(0x0f0d0b0907050301, 0x0e0c0a0806040200);
__m256i qx[4];
union { __m256i vec; uint16_t val[16]; } helper;
for (int ix = 0; ix < nrc_x; ix += 4) {
@@ -3335,15 +3344,15 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data
qx[3] = _mm256_set_epi64x(iq2xs_grid[helper.val[15]], iq2xs_grid[helper.val[14]], iq2xs_grid[helper.val[13]], iq2xs_grid[helper.val[12]]);
auto signs16 = _mm256_srli_epi16(val, 9);
signs16 = _mm256_xor_si256(signs16, _mm256_slli_epi16(signs16, 1));
auto signs = _mm_or_si128(_mm256_castsi256_si128(signs16), _mm_slli_epi16(_mm256_extracti128_si256(signs16, 1), 8));
signs = _mm_shuffle_epi8(signs, s_shuffle);
auto signs128 = _mm_or_si128(_mm256_castsi256_si128(signs16), _mm_slli_epi16(_mm256_extracti128_si256(signs16, 1), 8));
signs128 = _mm_shuffle_epi8(signs128, s_shuffle);
auto scales = _mm_set1_epi32(s32[ib]);
scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf));
scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1));
auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7
__m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) };
#ifdef HAVE_FANCY_SIMD
auto mask = (const __mmask32 *)&signs;
__m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) };
auto mask = (const __mmask32 *)&signs128;
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y)); // blocks: 0,0,0,0, 1,1,1,1, row 0
@@ -3369,31 +3378,37 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data
auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
shuffle = _mm256_add_epi8(shuffle, m4);
auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
__m256i scs[4] = {
_mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]),
_mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]),
};
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1)));
auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2)));
auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3)));
auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4)));
auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1))); // blocks 4x0, 4x1, row 0
auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2))); // blocks 4x2, 4x3, row 1
auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3))); // blocks 4x4, 4x5, row 2
auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4))); // blocks 4x6, 4x7, row 3
auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi));
isum[iy] = _mm256_add_epi32(isum[iy], sumi);
}
#endif
}
for (int iy = 0; iy < nrc_y; ++iy) {
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]);
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256();
//acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
//isum[iy] = _mm256_setzero_si256();
#else
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
isum[iy] = _mm256_setzero_si256();
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum));
//info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), _mm256_castps256_ps128(acc[iy])));
acc[iy] = _mm256_setzero_ps();
}
}