iq1_s_r4: slightly faster AVX2/Zen4 gemm/gemv

This commit is contained in:
Iwan Kawrakow
2025-02-05 10:23:34 +02:00
parent 0467c16a7f
commit 56a6ee26bb

View File

@@ -3273,17 +3273,22 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
auto m1 = _mm256_set1_epi16(1);
auto ms = _mm_set1_epi16(-32768);
float d8[8*nrc_y];
union { __m256i vec; uint16_t val[16]; } helper;
struct aux_iq1_s_r4 {
uint8_t qs[16];
uint64_t qh;
};
for (int ix= 0; ix < nrc_x; ix += 4) {
auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
auto d1 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr));
auto x = (const block_iq1_s_r4 *)(dptr + 4);
auto x = (const aux_iq1_s_r4 *)(dptr + 4);
for (int ib = 0; ib < nb/4; ++ib) {
for (int iy = 0; iy < nrc_y; ++iy) {
_mm256_storeu_ps(d8 + 8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib].d)));
}
for (int k = 0; k < 4; ++k) {
const uint64_t * s64 = (const uint64_t *)x[4*ib+k].qh;
auto sas = _mm_set1_epi64x(s64[0]);
auto idxh = _mm256_set1_epi64x(x[4*ib+k].qh);
auto sas = _mm256_castsi256_si128(idxh);
auto scales4 = _mm_and_si128(_mm_srli_epi16(sas, 12), _mm_set1_epi16(7));
scales4 = _mm_or_si128(_mm_slli_epi16(scales4, 1), _mm_set1_epi16(1));
auto signs = _mm_or_si128(_mm_cmpeq_epi16(_mm_and_si128(sas, ms), ms), _mm256_castsi256_si128(m1));
@@ -3293,22 +3298,18 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
auto delta = _mm256_set_m128(delta4, delta4);
scales4 = _mm_unpacklo_epi16(scales4, scales4); // 0,0, 1,1, 2,2, 3,3
auto scales = MM256_SET_M128I(scales4, scales4);
qx[0] = _mm256_set_epi64x(iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)]);
qx[1] = _mm256_set_epi64x(iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)]);
qx[2] = _mm256_set_epi64x(iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)]);
qx[3] = _mm256_set_epi64x(iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)]);
auto idxl = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)x[4*ib+k].qs));
idxh = _mm256_sllv_epi64(idxh, _mm256_set_epi64x(0, 2, 5, 8));
idxh = _mm256_srlv_epi64(idxh, _mm256_set_epi64x(1, 0, 0, 0));
helper.vec = _mm256_or_si256(idxl, _mm256_and_si256(_mm256_set1_epi16(0x0700), idxh));
qx[0] = _mm256_set_epi64x(iq1s_grid_us[helper.val[ 9]], iq1s_grid_us[helper.val[ 8]],
iq1s_grid_us[helper.val[ 1]], iq1s_grid_us[helper.val[ 0]]);
qx[1] = _mm256_set_epi64x(iq1s_grid_us[helper.val[13]], iq1s_grid_us[helper.val[12]],
iq1s_grid_us[helper.val[ 5]], iq1s_grid_us[helper.val[ 4]]);
qx[2] = _mm256_set_epi64x(iq1s_grid_us[helper.val[11]], iq1s_grid_us[helper.val[10]],
iq1s_grid_us[helper.val[ 3]], iq1s_grid_us[helper.val[ 2]]);
qx[3] = _mm256_set_epi64x(iq1s_grid_us[helper.val[15]], iq1s_grid_us[helper.val[14]],
iq1s_grid_us[helper.val[ 7]], iq1s_grid_us[helper.val[ 6]]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ib].qs + k);
#ifdef HAVE_FANCY_SIMD