Faster Q4_K_R4 and Q5_K_R4 on AVX2/Zen4 (#182)

* Slightly faster AVX2 implementation for q4_k_r4

* Even better AVX2 implementation for q4_k_r4

We now arrive at PP-512 = 328 t/s for LLaMA-3.1-8B on a
Ryzen-5975WX CPU, up from 291 t/s when I last measured
on 3c5f8722.
With FA and Q8_0 K-cache we get to 339.5 t/s.

* Fix llama-bench labels that I broke with #181

* Faster AVX2 implementation for q5_k_q4

We arrive at 302 t/s for LLaMA-3.1-8B on a Ryzen-5975WX CPU,
up from 273 t/s.

* Use AVX2 implementation of q4_k_r4 and q5_k_r4 also on Zen4

After the changes I made to AVX2, it ends up being slightly faster
compared to what I had for Zen4.

* Minor tweak

* Cleanup

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-01-30 09:28:53 +02:00
committed by GitHub
parent 5bbe93c0c4
commit f7a4a0fd42
2 changed files with 68 additions and 255 deletions

View File

@@ -756,7 +756,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
continue;
}
cmd_params_instance instance = {
/* .test_kind = */ TEST_KIND_PP,
/* .test_kind = */ TEST_KIND_TG,
/* .model = */ m,
/* .n_prompt = */ 0,
/* .n_gen = */ n_gen,
@@ -784,7 +784,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
continue;
}
cmd_params_instance instance = {
/* .test_kind = */ TEST_KIND_PP,
/* .test_kind = */ TEST_KIND_PG,
/* .model = */ m,
/* .n_prompt = */ n_pg.first,
/* .n_gen = */ n_pg.second,

View File

@@ -4430,17 +4430,47 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
}
template <int nrc_y>
static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
inline void process_min_r4_b32(int ibl, __m256 m4, __m256i mins, const Q8<nrc_y, block_q8_K>& q8, __m256 * acc) {
auto mins_l = _mm256_castsi256_si128(mins);
auto mins_h = _mm256_extracti128_si256(mins, 1);
auto aux1 = _mm_unpacklo_epi32(mins_l, mins_h);
auto aux2 = _mm_unpackhi_epi32(mins_l, mins_h);
auto ic1 = _mm256_cvtepi8_epi32(aux1);
auto ic2 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux1, 0xee));
auto ic3 = _mm256_cvtepi8_epi32(aux2);
auto ic4 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux2, 0xee));
if constexpr (nrc_y == 1) {
auto bs = _mm256_loadu_ps((const float *)q8.y[0][ibl].bsums);
auto sumf = _mm256_mul_ps(_mm256_cvtepi32_ps(ic1), _mm256_shuffle_ps(bs, bs, 0x00));
sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic2), _mm256_shuffle_ps(bs, bs, 0x55), sumf);
sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic3), _mm256_shuffle_ps(bs, bs, 0xaa), sumf);
sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic4), _mm256_shuffle_ps(bs, bs, 0xff), sumf);
acc[0] = _mm256_fmadd_ps(m4, sumf, acc[0]);
} else {
auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic1));
auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic2));
auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic3));
auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic4));
for (int iy = 0; iy < nrc_y; ++iy) {
auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
}
}
}
template <int nrc_y>
static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = _mm256_set1_epi8(0xf);
auto m3 = _mm256_set1_epi8(0x30);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
#endif
int nbl = n / QK_K;
union { __m256i vec; uint32_t val[8]; } hd;
__m256 acc[nrc_y] = {};
__m256i isum[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q4_k_r4 * iq4 = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx);
@@ -4448,31 +4478,20 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d));
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
if constexpr (nrc_y == 1) {
d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
}
auto lbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
auto hbits128 = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_h);
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m3));
auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m3));
auto shuffle = _mm256_set1_epi64x(0x0000000400000000);
auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
for (int iy = 0; iy < nrc_y; ++iy) {
auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
}
process_min_r4_b32(ibl, m4, mins, q8, acc);
for (int ib = 0; ib < QK_K/32; ++ib) {
auto scales_d = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]))));
#ifdef HAVE_FANCY_SIMD
auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
#else
auto aux = _mm_set1_epi32(hd.val[ib]);
aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
auto scales_d = MM256_SET_M128I(aux, aux);
#endif
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
qx[0] = _mm256_and_si256(bits1, mf);
@@ -4487,21 +4506,20 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales_d, _mm256_add_epi16(sumi1, sumi2)));
#endif
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(scales_d, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
float d8 = q8.scale(iy, ibl);
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales_d, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
isum[iy] = _mm256_setzero_si256();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
@@ -4511,113 +4529,17 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
}
}
#ifdef HAVE_FANCY_SIMD
template <int nrc_y>
static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
//mul_mat_q4_k_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
if constexpr (nrc_y == 1){
mul_mat_q4_k_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x);
} else {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = _mm512_set1_epi8(0xf);
int nbl = n / QK_K;
using helper_t = union { __m512i vec; uint32_t val[16]; };
helper_t hd, hm;
__m512 acc[nrc_y] = {};
__m512i isum[nrc_y] = {};
__m512i qx[4];
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q4_k_r4 * iq4l = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx);
const block_q4_k_r4 * iq4h = (const block_q4_k_r4 *)((const char *)vx + (ix+4)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l[ibl].d));
auto d2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h[ibl].d));
auto dl = _mm256_castps256_ps128(d1);
auto ml = _mm256_extractf128_ps(d1, 1);
auto dh = _mm256_castps256_ps128(d2);
auto mh = _mm256_extractf128_ps(d2, 1);
auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1);
auto m4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(ml, ml)), _mm256_set_m128(mh, mh), 1);
m4 = _mm512_mul_ps(m4, _mm512_set1_ps(-0.5f));
auto slbits_l = _mm256_loadu_si256((const __m256i *)iq4l[ibl].scales_l);
auto shbits_l = _mm256_loadu_si256((const __m256i *)iq4h[ibl].scales_l);
auto slb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_l), shbits_l, 1);
auto sld = _mm512_and_si512(slb, mf);
auto slm = _mm512_and_si512(_mm512_srli_epi16(slb, 4), mf);
auto slbits_h = _mm_loadu_si128((const __m128i *)iq4l[ibl].scales_h);
auto shbits_h = _mm_loadu_si128((const __m128i *)iq4h[ibl].scales_h);
auto slbits_h2 = MM256_SET_M128I(_mm_srli_epi16(slbits_h, 4), slbits_h);
auto shbits_h2 = MM256_SET_M128I(_mm_srli_epi16(shbits_h, 4), shbits_h);
auto shb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_h2), shbits_h2, 1);
auto shd = _mm512_and_si512(_mm512_slli_epi16(shb, 4), _mm512_set1_epi8(0x30));
auto shm = _mm512_and_si512(_mm512_slli_epi16(shb, 2), _mm512_set1_epi8(0x30));
hd.vec = _mm512_or_si512(sld, shd);
hm.vec = _mm512_or_si512(slm, shm);
for (int ib = 0; ib < QK_K/32; ++ib) {
auto scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+0]));
auto scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+8]));
auto iscales = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1);
scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+0]));
scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+8]));
auto iscales_m = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1);
auto scales_m = _mm512_mul_ps(m4, _mm512_cvtepi32_ps(iscales_m));
auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+0)),
_mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+0), 1);
auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+1)),
_mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+1), 1);
qx[0] = _mm512_and_si512(bits1, mf);
qx[1] = _mm512_and_si512(bits2, mf);
qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), mf);
qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), mf);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi));
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]);
isum[iy] = _mm512_setzero_si512();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1));
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3));
info.store(ix+0, iy, sum1);
info.store(ix+4, iy, sum2);
acc[iy] = _mm512_setzero_ps();
}
}
}
}
#else
template <int nrc_y>
static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
mul_mat_q4_k_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
}
#endif
template <int nrc_y>
static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = _mm256_set1_epi8(0xf);
auto m10 = _mm256_set1_epi8(0x10);
auto m30 = _mm256_set1_epi8(0x30);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
#endif
int nbl = n / QK_K;
union { __m256i vec; uint32_t val[8]; } hd;
__m256 acc[nrc_y] = {};
__m256i isum[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx);
@@ -4625,31 +4547,20 @@ static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5[ibl].d));
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
if constexpr (nrc_y == 1) {
d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
}
auto lbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l);
auto hbits128 = _mm_loadu_si128((const __m128i *)iq5[ibl].scales_h);
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m30));
auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m30));
auto shuffle = _mm256_set1_epi64x(0x0000000400000000);
auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
for (int iy = 0; iy < nrc_y; ++iy) {
auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
}
process_min_r4_b32(ibl, m4, mins, q8, acc);
for (int ib = 0; ib < QK_K/32; ++ib) {
auto scales_d = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]))));
#ifdef HAVE_FANCY_SIMD
auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
#else
auto aux = _mm_set1_epi32(hd.val[ib]);
aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
auto scales_d = MM256_SET_M128I(aux, aux);
#endif
auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
auto hbits128 = _mm_loadu_si128((const __m128i*)iq5[ibl].qh + ib);
@@ -4666,21 +4577,22 @@ static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
// To avoid overflow, we can only add up to 4 q5 x q8 products.
auto sumi = _mm256_add_epi32(_mm256_madd_epi16(scales_d, sumi1), _mm256_madd_epi16(scales_d, sumi2));
isum[iy] = _mm256_add_epi32(isum[iy], sumi);
#endif
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(scales_d, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
float d8 = q8.scale(iy, ibl);
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales_d, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
isum[iy] = _mm256_setzero_si256();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
@@ -4690,105 +4602,6 @@ static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
}
}
#ifdef HAVE_FANCY_SIMD
template <int nrc_y>
static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
if constexpr (nrc_y == 1){
mul_mat_q5_k_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x);
} else {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = _mm512_set1_epi8(0xf);
auto m10 = _mm512_set1_epi8(0x10);
int nbl = n / QK_K;
using helper_t = union { __m512i vec; uint32_t val[16]; };
helper_t hd, hm;
__m512 acc[nrc_y] = {};
__m512i isum[nrc_y] = {};
__m512i qx[4];
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q5_k_r4 * iq5l = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx);
const block_q5_k_r4 * iq5h = (const block_q5_k_r4 *)((const char *)vx + (ix+4)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5l[ibl].d));
auto d2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5h[ibl].d));
auto dl = _mm256_castps256_ps128(d1);
auto ml = _mm256_extractf128_ps(d1, 1);
auto dh = _mm256_castps256_ps128(d2);
auto mh = _mm256_extractf128_ps(d2, 1);
auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1);
auto m4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(ml, ml)), _mm256_set_m128(mh, mh), 1);
m4 = _mm512_mul_ps(m4, _mm512_set1_ps(-0.5f));
auto slbits_l = _mm256_loadu_si256((const __m256i *)iq5l[ibl].scales_l);
auto shbits_l = _mm256_loadu_si256((const __m256i *)iq5h[ibl].scales_l);
auto slb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_l), shbits_l, 1);
auto sld = _mm512_and_si512(slb, mf);
auto slm = _mm512_and_si512(_mm512_srli_epi16(slb, 4), mf);
auto slbits_h = _mm_loadu_si128((const __m128i *)iq5l[ibl].scales_h);
auto shbits_h = _mm_loadu_si128((const __m128i *)iq5h[ibl].scales_h);
auto slbits_h2 = MM256_SET_M128I(_mm_srli_epi16(slbits_h, 4), slbits_h);
auto shbits_h2 = MM256_SET_M128I(_mm_srli_epi16(shbits_h, 4), shbits_h);
auto shb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_h2), shbits_h2, 1);
auto shd = _mm512_and_si512(_mm512_slli_epi16(shb, 4), _mm512_set1_epi8(0x30));
auto shm = _mm512_and_si512(_mm512_slli_epi16(shb, 2), _mm512_set1_epi8(0x30));
hd.vec = _mm512_or_si512(sld, shd);
hm.vec = _mm512_or_si512(slm, shm);
for (int ib = 0; ib < QK_K/32; ++ib) {
auto scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+0]));
auto scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+8]));
auto iscales = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1);
scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+0]));
scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+8]));
auto iscales_m = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1);
auto scales_m = _mm512_mul_ps(m4, _mm512_cvtepi32_ps(iscales_m));
auto lbits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[ibl].qs+2*ib+0)),
_mm256_loadu_si256((const __m256i *)iq5h[ibl].qs+2*ib+0), 1);
auto lbits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[ibl].qs+2*ib+1)),
_mm256_loadu_si256((const __m256i *)iq5h[ibl].qs+2*ib+1), 1);
auto hbits1 = _mm_loadu_si128((const __m128i*)iq5l[ibl].qh+ib);
auto hbits2 = _mm_loadu_si128((const __m128i*)iq5h[ibl].qh+ib);
auto hbl = MM256_SET_M128I(hbits1, _mm_slli_epi16(hbits1, 4));
auto hbh = MM256_SET_M128I(hbits2, _mm_slli_epi16(hbits2, 4));
auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbl), hbh, 1);
qx[0] = _mm512_or_si512(_mm512_and_si512(lbits1, mf), _mm512_and_si512(m10, hbits));
qx[1] = _mm512_or_si512(_mm512_and_si512(lbits2, mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 2)));
qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(lbits1, 4), mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 1)));
qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(lbits2, 4), mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 3)));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi));
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]);
isum[iy] = _mm512_setzero_si512();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1));
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3));
info.store(ix+0, iy, sum1);
info.store(ix+4, iy, sum2);
acc[iy] = _mm512_setzero_ps();
}
}
}
}
#else
template <int nrc_y>
static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
mul_mat_q5_k_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
}
#endif
template <int nrc_y>
static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);