mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
Faster Q4_K_R4 and Q5_K_R4 on AVX2/Zen4 (#182)
* Slightly faster AVX2 implementation for q4_k_r4
* Even better AVX2 implementation for q4_k_r4
We now arrive at PP-512 = 328 t/s for LLaMA-3.1-8B on a
Ryzen-5975WX CPU, up from 291 t/s when I last measured
on 3c5f8722.
With FA and Q8_0 K-cache we get to 339.5 t/s.
* Fix llama-bench labels that I broke with #181
* Faster AVX2 implementation for q5_k_q4
We arrive at 302 t/s for LLaMA-3.1-8B on a Ryzen-5975WX CPU,
up from 273 t/s.
* Use AVX2 implementation of q4_k_r4 and q5_k_r4 also on Zen4
After the changes I made to AVX2, it ends up being slightly faster
compared to what I had for Zen4.
* Minor tweak
* Cleanup
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -756,7 +756,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
continue;
|
||||
}
|
||||
cmd_params_instance instance = {
|
||||
/* .test_kind = */ TEST_KIND_PP,
|
||||
/* .test_kind = */ TEST_KIND_TG,
|
||||
/* .model = */ m,
|
||||
/* .n_prompt = */ 0,
|
||||
/* .n_gen = */ n_gen,
|
||||
@@ -784,7 +784,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
continue;
|
||||
}
|
||||
cmd_params_instance instance = {
|
||||
/* .test_kind = */ TEST_KIND_PP,
|
||||
/* .test_kind = */ TEST_KIND_PG,
|
||||
/* .model = */ m,
|
||||
/* .n_prompt = */ n_pg.first,
|
||||
/* .n_gen = */ n_pg.second,
|
||||
|
||||
@@ -4430,17 +4430,47 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
inline void process_min_r4_b32(int ibl, __m256 m4, __m256i mins, const Q8<nrc_y, block_q8_K>& q8, __m256 * acc) {
|
||||
auto mins_l = _mm256_castsi256_si128(mins);
|
||||
auto mins_h = _mm256_extracti128_si256(mins, 1);
|
||||
auto aux1 = _mm_unpacklo_epi32(mins_l, mins_h);
|
||||
auto aux2 = _mm_unpackhi_epi32(mins_l, mins_h);
|
||||
auto ic1 = _mm256_cvtepi8_epi32(aux1);
|
||||
auto ic2 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux1, 0xee));
|
||||
auto ic3 = _mm256_cvtepi8_epi32(aux2);
|
||||
auto ic4 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux2, 0xee));
|
||||
if constexpr (nrc_y == 1) {
|
||||
auto bs = _mm256_loadu_ps((const float *)q8.y[0][ibl].bsums);
|
||||
auto sumf = _mm256_mul_ps(_mm256_cvtepi32_ps(ic1), _mm256_shuffle_ps(bs, bs, 0x00));
|
||||
sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic2), _mm256_shuffle_ps(bs, bs, 0x55), sumf);
|
||||
sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic3), _mm256_shuffle_ps(bs, bs, 0xaa), sumf);
|
||||
sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic4), _mm256_shuffle_ps(bs, bs, 0xff), sumf);
|
||||
acc[0] = _mm256_fmadd_ps(m4, sumf, acc[0]);
|
||||
} else {
|
||||
auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic1));
|
||||
auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic2));
|
||||
auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic3));
|
||||
auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic4));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
|
||||
acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto mf = _mm256_set1_epi8(0xf);
|
||||
auto m3 = _mm256_set1_epi8(0x30);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
int nbl = n / QK_K;
|
||||
union { __m256i vec; uint32_t val[8]; } hd;
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i isum[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_q4_k_r4 * iq4 = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx);
|
||||
@@ -4448,31 +4478,20 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
|
||||
auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d));
|
||||
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
|
||||
auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
|
||||
if constexpr (nrc_y == 1) {
|
||||
d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
|
||||
}
|
||||
auto lbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
|
||||
auto hbits128 = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_h);
|
||||
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
|
||||
hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m3));
|
||||
auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m3));
|
||||
auto shuffle = _mm256_set1_epi64x(0x0000000400000000);
|
||||
auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
|
||||
shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
|
||||
auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
|
||||
shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
|
||||
auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
|
||||
shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
|
||||
auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
|
||||
acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
|
||||
}
|
||||
process_min_r4_b32(ibl, m4, mins, q8, acc);
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
auto scales_d = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]))));
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
|
||||
#else
|
||||
auto aux = _mm_set1_epi32(hd.val[ib]);
|
||||
aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
|
||||
auto scales_d = MM256_SET_M128I(aux, aux);
|
||||
#endif
|
||||
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
|
||||
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
|
||||
qx[0] = _mm256_and_si256(bits1, mf);
|
||||
@@ -4487,21 +4506,20 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
|
||||
#else
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales_d, _mm256_add_epi16(sumi1, sumi2)));
|
||||
#endif
|
||||
if constexpr (nrc_y == 1) {
|
||||
acc[iy] = _mm256_fmadd_ps(scales_d, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
} else {
|
||||
float d8 = q8.scale(iy, ibl);
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales_d, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
isum[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
|
||||
@@ -4511,113 +4529,17 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
//mul_mat_q4_k_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
|
||||
if constexpr (nrc_y == 1){
|
||||
mul_mat_q4_k_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x);
|
||||
} else {
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto mf = _mm512_set1_epi8(0xf);
|
||||
int nbl = n / QK_K;
|
||||
using helper_t = union { __m512i vec; uint32_t val[16]; };
|
||||
helper_t hd, hm;
|
||||
__m512 acc[nrc_y] = {};
|
||||
__m512i isum[nrc_y] = {};
|
||||
__m512i qx[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
const block_q4_k_r4 * iq4l = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx);
|
||||
const block_q4_k_r4 * iq4h = (const block_q4_k_r4 *)((const char *)vx + (ix+4)*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto d1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l[ibl].d));
|
||||
auto d2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h[ibl].d));
|
||||
auto dl = _mm256_castps256_ps128(d1);
|
||||
auto ml = _mm256_extractf128_ps(d1, 1);
|
||||
auto dh = _mm256_castps256_ps128(d2);
|
||||
auto mh = _mm256_extractf128_ps(d2, 1);
|
||||
auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1);
|
||||
auto m4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(ml, ml)), _mm256_set_m128(mh, mh), 1);
|
||||
m4 = _mm512_mul_ps(m4, _mm512_set1_ps(-0.5f));
|
||||
auto slbits_l = _mm256_loadu_si256((const __m256i *)iq4l[ibl].scales_l);
|
||||
auto shbits_l = _mm256_loadu_si256((const __m256i *)iq4h[ibl].scales_l);
|
||||
auto slb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_l), shbits_l, 1);
|
||||
auto sld = _mm512_and_si512(slb, mf);
|
||||
auto slm = _mm512_and_si512(_mm512_srli_epi16(slb, 4), mf);
|
||||
auto slbits_h = _mm_loadu_si128((const __m128i *)iq4l[ibl].scales_h);
|
||||
auto shbits_h = _mm_loadu_si128((const __m128i *)iq4h[ibl].scales_h);
|
||||
auto slbits_h2 = MM256_SET_M128I(_mm_srli_epi16(slbits_h, 4), slbits_h);
|
||||
auto shbits_h2 = MM256_SET_M128I(_mm_srli_epi16(shbits_h, 4), shbits_h);
|
||||
auto shb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_h2), shbits_h2, 1);
|
||||
auto shd = _mm512_and_si512(_mm512_slli_epi16(shb, 4), _mm512_set1_epi8(0x30));
|
||||
auto shm = _mm512_and_si512(_mm512_slli_epi16(shb, 2), _mm512_set1_epi8(0x30));
|
||||
hd.vec = _mm512_or_si512(sld, shd);
|
||||
hm.vec = _mm512_or_si512(slm, shm);
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
auto scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+0]));
|
||||
auto scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+8]));
|
||||
auto iscales = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1);
|
||||
scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+0]));
|
||||
scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+8]));
|
||||
auto iscales_m = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1);
|
||||
auto scales_m = _mm512_mul_ps(m4, _mm512_cvtepi32_ps(iscales_m));
|
||||
auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+0)),
|
||||
_mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+0), 1);
|
||||
auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+1)),
|
||||
_mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+1), 1);
|
||||
qx[0] = _mm512_and_si512(bits1, mf);
|
||||
qx[1] = _mm512_and_si512(bits2, mf);
|
||||
qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), mf);
|
||||
qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), mf);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
|
||||
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
|
||||
auto sumi = _mm512_setzero_si512();
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi));
|
||||
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
|
||||
acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
isum[iy] = _mm512_setzero_si512();
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1));
|
||||
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3));
|
||||
info.store(ix+0, iy, sum1);
|
||||
info.store(ix+4, iy, sum2);
|
||||
acc[iy] = _mm512_setzero_ps();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
mul_mat_q4_k_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto mf = _mm256_set1_epi8(0xf);
|
||||
auto m10 = _mm256_set1_epi8(0x10);
|
||||
auto m30 = _mm256_set1_epi8(0x30);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
int nbl = n / QK_K;
|
||||
union { __m256i vec; uint32_t val[8]; } hd;
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i isum[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx);
|
||||
@@ -4625,31 +4547,20 @@ static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
|
||||
auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5[ibl].d));
|
||||
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
|
||||
auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
|
||||
if constexpr (nrc_y == 1) {
|
||||
d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
|
||||
}
|
||||
auto lbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l);
|
||||
auto hbits128 = _mm_loadu_si128((const __m128i *)iq5[ibl].scales_h);
|
||||
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
|
||||
hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m30));
|
||||
auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m30));
|
||||
auto shuffle = _mm256_set1_epi64x(0x0000000400000000);
|
||||
auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
|
||||
shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
|
||||
auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
|
||||
shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
|
||||
auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
|
||||
shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
|
||||
auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
|
||||
acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
|
||||
}
|
||||
process_min_r4_b32(ibl, m4, mins, q8, acc);
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
auto scales_d = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]))));
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
|
||||
#else
|
||||
auto aux = _mm_set1_epi32(hd.val[ib]);
|
||||
aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
|
||||
auto scales_d = MM256_SET_M128I(aux, aux);
|
||||
#endif
|
||||
auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
|
||||
auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
|
||||
auto hbits128 = _mm_loadu_si128((const __m128i*)iq5[ibl].qh + ib);
|
||||
@@ -4666,21 +4577,22 @@ static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
|
||||
#else
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
|
||||
// To avoid overflow, we can only add up to 4 q5 x q8 products.
|
||||
auto sumi = _mm256_add_epi32(_mm256_madd_epi16(scales_d, sumi1), _mm256_madd_epi16(scales_d, sumi2));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], sumi);
|
||||
#endif
|
||||
if constexpr (nrc_y == 1) {
|
||||
acc[iy] = _mm256_fmadd_ps(scales_d, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
} else {
|
||||
float d8 = q8.scale(iy, ibl);
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales_d, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
isum[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
|
||||
@@ -4690,105 +4602,6 @@ static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
if constexpr (nrc_y == 1){
|
||||
mul_mat_q5_k_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x);
|
||||
} else {
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto mf = _mm512_set1_epi8(0xf);
|
||||
auto m10 = _mm512_set1_epi8(0x10);
|
||||
int nbl = n / QK_K;
|
||||
using helper_t = union { __m512i vec; uint32_t val[16]; };
|
||||
helper_t hd, hm;
|
||||
__m512 acc[nrc_y] = {};
|
||||
__m512i isum[nrc_y] = {};
|
||||
__m512i qx[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
const block_q5_k_r4 * iq5l = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx);
|
||||
const block_q5_k_r4 * iq5h = (const block_q5_k_r4 *)((const char *)vx + (ix+4)*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto d1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5l[ibl].d));
|
||||
auto d2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5h[ibl].d));
|
||||
auto dl = _mm256_castps256_ps128(d1);
|
||||
auto ml = _mm256_extractf128_ps(d1, 1);
|
||||
auto dh = _mm256_castps256_ps128(d2);
|
||||
auto mh = _mm256_extractf128_ps(d2, 1);
|
||||
auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1);
|
||||
auto m4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(ml, ml)), _mm256_set_m128(mh, mh), 1);
|
||||
m4 = _mm512_mul_ps(m4, _mm512_set1_ps(-0.5f));
|
||||
auto slbits_l = _mm256_loadu_si256((const __m256i *)iq5l[ibl].scales_l);
|
||||
auto shbits_l = _mm256_loadu_si256((const __m256i *)iq5h[ibl].scales_l);
|
||||
auto slb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_l), shbits_l, 1);
|
||||
auto sld = _mm512_and_si512(slb, mf);
|
||||
auto slm = _mm512_and_si512(_mm512_srli_epi16(slb, 4), mf);
|
||||
auto slbits_h = _mm_loadu_si128((const __m128i *)iq5l[ibl].scales_h);
|
||||
auto shbits_h = _mm_loadu_si128((const __m128i *)iq5h[ibl].scales_h);
|
||||
auto slbits_h2 = MM256_SET_M128I(_mm_srli_epi16(slbits_h, 4), slbits_h);
|
||||
auto shbits_h2 = MM256_SET_M128I(_mm_srli_epi16(shbits_h, 4), shbits_h);
|
||||
auto shb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_h2), shbits_h2, 1);
|
||||
auto shd = _mm512_and_si512(_mm512_slli_epi16(shb, 4), _mm512_set1_epi8(0x30));
|
||||
auto shm = _mm512_and_si512(_mm512_slli_epi16(shb, 2), _mm512_set1_epi8(0x30));
|
||||
hd.vec = _mm512_or_si512(sld, shd);
|
||||
hm.vec = _mm512_or_si512(slm, shm);
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
auto scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+0]));
|
||||
auto scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+8]));
|
||||
auto iscales = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1);
|
||||
scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+0]));
|
||||
scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+8]));
|
||||
auto iscales_m = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1);
|
||||
auto scales_m = _mm512_mul_ps(m4, _mm512_cvtepi32_ps(iscales_m));
|
||||
auto lbits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[ibl].qs+2*ib+0)),
|
||||
_mm256_loadu_si256((const __m256i *)iq5h[ibl].qs+2*ib+0), 1);
|
||||
auto lbits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[ibl].qs+2*ib+1)),
|
||||
_mm256_loadu_si256((const __m256i *)iq5h[ibl].qs+2*ib+1), 1);
|
||||
auto hbits1 = _mm_loadu_si128((const __m128i*)iq5l[ibl].qh+ib);
|
||||
auto hbits2 = _mm_loadu_si128((const __m128i*)iq5h[ibl].qh+ib);
|
||||
auto hbl = MM256_SET_M128I(hbits1, _mm_slli_epi16(hbits1, 4));
|
||||
auto hbh = MM256_SET_M128I(hbits2, _mm_slli_epi16(hbits2, 4));
|
||||
auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbl), hbh, 1);
|
||||
qx[0] = _mm512_or_si512(_mm512_and_si512(lbits1, mf), _mm512_and_si512(m10, hbits));
|
||||
qx[1] = _mm512_or_si512(_mm512_and_si512(lbits2, mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 2)));
|
||||
qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(lbits1, 4), mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 1)));
|
||||
qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(lbits2, 4), mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 3)));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
|
||||
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
|
||||
auto sumi = _mm512_setzero_si512();
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi));
|
||||
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
|
||||
acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
isum[iy] = _mm512_setzero_si512();
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1));
|
||||
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3));
|
||||
info.store(ix+0, iy, sum1);
|
||||
info.store(ix+4, iy, sum2);
|
||||
acc[iy] = _mm512_setzero_ps();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
mul_mat_q5_k_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
|
||||
Reference in New Issue
Block a user