Refactor iqk: GEMM kernels are refactored on AVX2/AVX512

This commit is contained in:
Iwan Kawrakow
2025-05-18 15:50:20 +03:00
parent 0d96f3bd37
commit c63a0af5b7
4 changed files with 846 additions and 1715 deletions

View File

@@ -1308,38 +1308,734 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
#endif
template <int nrc_y>
//IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
inline void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
__m256i * isum, int16_t min) {
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
if constexpr (nrc_y == 1) {
auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
auto sumi = _mm256_setzero_si256();
auto bsums = q8.load_bsums(0, ibl);
#ifdef HAVE_FANCY_SIMD
sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
#else
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
#endif
isum[0] = _mm256_mullo_epi32(sumi, _mm256_set1_epi32(min));
} else {
auto s1 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9
auto s2 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11
auto s3 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13
auto s4 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = q8.load_bsums(iy, ibl);
#ifdef HAVE_FANCY_SIMD
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00));
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55));
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa));
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff));
#else
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
#endif
}
}
}
template <int nrc_y>
inline void iq2345_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
__m256i extra, __m256i * isum, int8_t min, int8_t delta) {
auto mask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101);
auto vdelta = _mm256_set1_epi8(delta);
auto vmin = _mm256_set1_epi8(min);
auto min1 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(extra, mask), mask)));
auto min2 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(extra, 4), mask), mask)));
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
auto m1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
auto m2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
auto m3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
auto m4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
auto s1 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 0), _mm256_extracti128_si256(m1, 0)),
MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9
auto s2 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 1), _mm256_extracti128_si256(m1, 1)),
MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11
auto s3 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 0), _mm256_extracti128_si256(m2, 0)),
MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13
auto s4 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 1), _mm256_extracti128_si256(m2, 1)),
MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = q8.load_bsums(iy, ibl);
#ifdef HAVE_FANCY_SIMD
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00));
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55));
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa));
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff));
#else
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
#endif
}
}
template <int nrc_y>
static void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
auto ms = _mm256_set1_epi8(4);
auto m03 = _mm256_set1_epi8(0x03);
auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
static const uint8_t kvalues_iq2nl[32] = {1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54};
auto values = _mm256_loadu_si256((const __m256i*)kvalues_iq2nl);
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
#ifndef HAVE_FANCY_SIMD
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
#endif
int nbl = n / QK_K;
__m256 acc[nrc_y] = {};
__m256i qx[4];
uint64_t stored_scales[8];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq2_k_r4 * iq2 = (const block_iq2_k_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
auto d4 = _mm256_set_m128(dl, dl);
auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq2[ibl].extra);
auto slbits = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales);
auto i8scales1 = _mm256_add_epi8(_mm256_and_si256(slbits, m4), _mm256_set1_epi8(-8));
auto i8scales2 = _mm256_add_epi8(_mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4), _mm256_set1_epi8(-8));
_mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
_mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
__m256i isum[nrc_y] = {};
iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -32);
for (int ib = 0; ib < QK_K/32; ++ib) {
#ifdef HAVE_FANCY_SIMD
auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
#else
auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
#endif
auto lb = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs+ib);
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 2)); extra = _mm256_srli_epi16(extra, 1);
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
qx[0] = _mm256_and_si256(lb, m03);
qx[1] = _mm256_and_si256(_mm256_srli_epi16(lb, 2), m03);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(lb, 4), m03);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(lb, 6), m03);
qx[0] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[0], shift));
qx[1] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[1], shift));
qx[2] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[2], shift));
qx[3] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[3], shift));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, sum);
}
}
}
template <int nrc_y>
static void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
auto ms = _mm256_set1_epi8(8);
auto m03 = _mm256_set1_epi8(0x03);
auto m04 = _mm256_set1_epi8(0x04);
auto smask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101);
auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
auto values128 = _mm_loadu_si128((const __m128i *)iq3nl_values);
auto values = MM256_SET_M128I(values128, values128);
values = _mm256_add_epi8(values, _mm256_set1_epi8(64));
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
#ifndef HAVE_FANCY_SIMD
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
#endif
int nbl = n / QK_K;
__m256 acc[nrc_y] = {};
__m256i qx[4];
uint64_t stored_scales[8];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d));
auto d4 = _mm256_set_m128(dl, dl);
auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq3[ibl].extra);
auto slbits = _mm256_loadu_si256((const __m256i *)iq3[ibl].scales_l);
auto sl1 = _mm256_add_epi8(_mm256_slli_epi16(_mm256_and_si256(slbits, m4), 1), _mm256_set1_epi8(1));
auto sl2 = _mm256_add_epi8(_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4), 1), _mm256_set1_epi8(1));
auto sh = _mm256_set1_epi64x(((const uint64_t *)iq3[ibl].scales_h)[0]);
auto sh1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sh, smask), smask), _mm256_set1_epi8(1));
auto sh2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(sh, 4), smask), smask), _mm256_set1_epi8(1));
auto i8scales1 = _mm256_sign_epi8(sl1, sh1);
auto i8scales2 = _mm256_sign_epi8(sl2, sh2);
_mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
_mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
__m256i isum[nrc_y] = {};
iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -64);
for (int ib = 0; ib < QK_K/32; ++ib) {
#ifdef HAVE_FANCY_SIMD
auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
#else
auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
#endif
auto lb = _mm256_loadu_si256((const __m256i *)iq3[ibl].qs+ib);
auto hbits = _mm_loadu_si128((const __m128i *)iq3[ibl].qh+ib);
auto hb = MM256_SET_M128I(hbits, _mm_slli_epi16(hbits, 4));
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 3)); extra = _mm256_srli_epi16(extra, 1);
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
qx[0] = _mm256_or_si256(_mm256_and_si256(lb, m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 2)));
qx[1] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 2), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 3)));
qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 4), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 4)));
qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 6), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 5)));
qx[0] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[0], shift));
qx[1] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[1], shift));
qx[2] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[2], shift));
qx[3] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[3], shift));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
#else
auto sumi1 = _mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00));
auto sumi2 = _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55));
auto sumi3 = _mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa));
auto sumi4 = _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4)));
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, sum);
}
}
}
template <int nrc_y>
static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
auto m30 = _mm256_set1_epi8(0x30);
auto m32 = _mm256_set1_epi8(32);
auto ms = _mm256_set1_epi8(4);
auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
#ifdef HAVE_FANCY_SIMD
auto values = load_iq4nl_values_256();
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
#else
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
auto values = MM256_SET_M128I(values128, values128);
#endif
int nbl = n / QK_K;
__m256 acc[nrc_y] = {};
__m256i qx[4];
uint64_t stored_scales[8];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq4_k_r4 * iq4 = (const block_iq4_k_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[ibl].d));
auto d4 = _mm256_set_m128(dl, dl);
auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq4[ibl].extra);
auto slbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
auto sl1 = _mm256_and_si256(slbits, m4);
auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4);
auto shbits = _mm_loadu_si128((const __m128i*)iq4[ibl].scales_h);
auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
auto i8scales1 = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(m30, _mm256_slli_epi16(sh, 4))), m32);
auto i8scales2 = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(m30, sh)), m32);
_mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
_mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
__m256i isum[nrc_y] = {};
#ifdef HAVE_FANCY_SIMD
iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128);
#endif
for (int ib = 0; ib < QK_K/32; ++ib) {
#ifdef HAVE_FANCY_SIMD
auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
#else
auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
#endif
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 2)); extra = _mm256_srli_epi16(extra, 1);
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
qx[0] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)));
qx[1] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)));
qx[2] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)));
qx[3] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)));
#ifndef HAVE_FANCY_SIMD
auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
#else
auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4)));
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, sum);
}
}
}
static inline __m256i prepare_5bit_quants(const __m256i * values, __m256i ql, __m256i qh, __m256i mask) {
auto q5vl = _mm256_shuffle_epi8(values[0], ql);
auto q5vh = _mm256_shuffle_epi8(values[1], ql);
#ifdef HAVE_FANCY_SIMD
return _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(qh, mask), mask), q5vl, q5vh);
#else
return _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(qh, mask), mask));
#endif
}
template <int nrc_y>
static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
auto m30 = _mm256_set1_epi8(0x30);
auto m32 = _mm256_set1_epi8(32);
auto ms = _mm256_set1_epi8(2);
auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
__m256i values[2];
{
auto val1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0);
auto val2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1);
values[0] = MM256_SET_M128I(val1, val1);
values[1] = MM256_SET_M128I(val2, val2);
#ifdef HAVE_FANCY_SIMD
values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128));
values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128));
#endif
}
#ifdef HAVE_FANCY_SIMD
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
#else
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
#endif
int nbl = n / QK_K;
__m256 acc[nrc_y] = {};
__m256i qx[4];
uint64_t stored_scales[8];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5[ibl].d));
auto d4 = _mm256_set_m128(dl, dl);
auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq5[ibl].extra);
auto slbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l);
auto sl1 = _mm256_and_si256(slbits, m4);
auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4);
auto shbits = _mm_loadu_si128((const __m128i*)iq5[ibl].scales_h);
auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
auto i8scales1 = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(m30, _mm256_slli_epi16(sh, 4))), m32);
auto i8scales2 = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(m30, sh)), m32);
_mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
_mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
__m256i isum[nrc_y] = {};
#ifdef HAVE_FANCY_SIMD
if constexpr (nrc_y == 1) {
iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128);
} else {
iq2345_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, extra, isum, -128, 2);
}
#endif
for (int ib = 0; ib < QK_K/32; ++ib) {
#ifdef HAVE_FANCY_SIMD
auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
#else
auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
#endif
auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib);
auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits);
qx[0] = _mm256_and_si256(lbits1, m4);
qx[1] = _mm256_and_si256(lbits2, m4);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4);
qx[0] = prepare_5bit_quants(values, qx[0], hb, _mm256_set1_epi8(0x01));
qx[1] = prepare_5bit_quants(values, qx[1], hb, _mm256_set1_epi8(0x10));
qx[2] = prepare_5bit_quants(values, qx[2], hb, _mm256_set1_epi8(0x02));
qx[3] = prepare_5bit_quants(values, qx[3], hb, _mm256_set1_epi8(0x20));
#ifdef HAVE_FANCY_SIMD
if constexpr (nrc_y == 1) {
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
qx[0] = _mm256_add_epi8(qx[0], shift);
qx[1] = _mm256_add_epi8(qx[1], shift);
qx[2] = _mm256_add_epi8(qx[2], shift);
qx[3] = _mm256_add_epi8(qx[3], shift);
}
#else
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
qx[0] = _mm256_add_epi8(qx[0], shift);
qx[1] = _mm256_add_epi8(qx[1], shift);
qx[2] = _mm256_add_epi8(qx[2], shift);
qx[3] = _mm256_add_epi8(qx[3], shift);
auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
#else
auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4)));
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, sum);
}
}
}
template <int nrc_y>
static void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
#ifndef HAVE_FANCY_SIMD
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
auto values = MM256_SET_M128I(values128, values128);
#else
auto values = load_iq4nl_values_256();
#endif
int nbl = n / QK_K;
using helper_t = union { __m256i vec; uint32_t val[8]; };
#ifndef HAVE_FANCY_SIMD
helper_t h, h_shift;
#else
using helper512_t = union { __m512i vec; uint64_t val[8]; };
helper_t h;
helper512_t h_shift;
#endif
__m256 acc[nrc_y] = {};
__m256i isum[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
auto dptr = (const float *)((const char *)vx + (ix+0)*bx);
const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4);
auto d4 = _mm_loadu_ps(dptr);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto scales = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales);
h.vec = _mm256_sub_epi8(_mm256_and_si256(scales, _mm256_set1_epi8(-2)), _mm256_set1_epi8(127));
#ifndef HAVE_FANCY_SIMD
h_shift.vec = _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 2);
{
__m256 v1 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[0])))),
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[0])))));
__m256 v2 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[1])))),
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[1])))));
__m256 v3 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[2])))),
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[2])))));
__m256 v4 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[3])))),
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[3])))));
for (int iy = 0; iy < nrc_y; ++iy) {
auto m8 = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
acc[iy] = _mm256_fmadd_ps(v1, _mm256_shuffle_ps(m8, m8, 0x00), acc[iy]);
acc[iy] = _mm256_fmadd_ps(v2, _mm256_shuffle_ps(m8, m8, 0x55), acc[iy]);
acc[iy] = _mm256_fmadd_ps(v3, _mm256_shuffle_ps(m8, m8, 0xaa), acc[iy]);
acc[iy] = _mm256_fmadd_ps(v4, _mm256_shuffle_ps(m8, m8, 0xff), acc[iy]);
}
}
#else
auto shift = _mm256_add_epi8(_mm256_set1_epi8(-64), _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 1));
h_shift.vec = _mm512_mullo_epi16(_mm512_cvtepi8_epi16(shift), _mm512_cvtepi8_epi16(h.vec));
#endif
for (int ib = 0; ib < QK_K/32; ++ib) {
#ifdef HAVE_FANCY_SIMD
auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib]));
auto ishifts = _mm256_cvtepi16_epi32(_mm_set1_epi64x(h_shift.val[ib]));
auto scales_m = _mm256_cvtepi32_ps(ishifts);
for (int iy = 0; iy < nrc_y; ++iy) {
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
}
#endif
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
#ifndef HAVE_FANCY_SIMD
auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle);
auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
#else
auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, ibl)), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
isum[iy] = _mm256_setzero_si256();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, _mm_mul_ps(d4, sum));
}
}
}
template <int nrc_y>
static void mul_mat_iq5_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
__m256i values[2];
{
auto val1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0);
auto val2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1);
values[0] = MM256_SET_M128I(val1, val1);
values[1] = MM256_SET_M128I(val2, val2);
#ifdef HAVE_FANCY_SIMD
values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128));
values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128));
#endif
}
int nbl = n / QK_K;
using helper_t = union { __m256i vec; uint32_t val[8]; };
#ifndef HAVE_FANCY_SIMD
helper_t h, h_shift;
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
#else
using helper512_t = union { __m512i vec; uint64_t val[8]; };
helper_t h;
helper512_t h_shift;
#endif
__m256 acc[nrc_y] = {};
__m256i isum[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
auto dptr = (const float *)((const char *)vx + (ix+0)*bx);
const block_iq5_ks_r4 * iq5 = (const block_iq5_ks_r4 *)(dptr + 4);
auto d4 = _mm_loadu_ps(dptr);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto scales = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales);
h.vec = _mm256_sub_epi8(_mm256_and_si256(scales, _mm256_set1_epi8(-2)), _mm256_set1_epi8(127));
#ifndef HAVE_FANCY_SIMD
h_shift.vec = _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 1);
{
__m256 v1 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[0])))),
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[0])))));
__m256 v2 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[1])))),
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[1])))));
__m256 v3 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[2])))),
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[2])))));
__m256 v4 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[3])))),
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[3])))));
for (int iy = 0; iy < nrc_y; ++iy) {
auto m8 = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
acc[iy] = _mm256_fmadd_ps(v1, _mm256_shuffle_ps(m8, m8, 0x00), acc[iy]);
acc[iy] = _mm256_fmadd_ps(v2, _mm256_shuffle_ps(m8, m8, 0x55), acc[iy]);
acc[iy] = _mm256_fmadd_ps(v3, _mm256_shuffle_ps(m8, m8, 0xaa), acc[iy]);
acc[iy] = _mm256_fmadd_ps(v4, _mm256_shuffle_ps(m8, m8, 0xff), acc[iy]);
}
}
#else
auto shift = _mm256_add_epi8(_mm256_set1_epi8(-64), _mm256_and_si256(scales, _mm256_set1_epi8(1)));
h_shift.vec = _mm512_mullo_epi16(_mm512_cvtepi8_epi16(shift), _mm512_cvtepi8_epi16(h.vec));
#endif
for (int ib = 0; ib < QK_K/32; ++ib) {
#ifdef HAVE_FANCY_SIMD
auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib]));
auto ishifts = _mm256_cvtepi16_epi32(_mm_set1_epi64x(h_shift.val[ib]));
auto scales_m = _mm256_cvtepi32_ps(ishifts);
for (int iy = 0; iy < nrc_y; ++iy) {
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
}
#endif
auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib);
auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits);
qx[0] = _mm256_and_si256(lbits1, m4);
qx[1] = _mm256_and_si256(lbits2, m4);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4);
qx[0] = prepare_5bit_quants(values, qx[0], hb, _mm256_set1_epi8(0x01));
qx[1] = prepare_5bit_quants(values, qx[1], hb, _mm256_set1_epi8(0x10));
qx[2] = prepare_5bit_quants(values, qx[2], hb, _mm256_set1_epi8(0x02));
qx[3] = prepare_5bit_quants(values, qx[3], hb, _mm256_set1_epi8(0x20));
#ifndef HAVE_FANCY_SIMD
auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle);
auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
#else
auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, ibl)), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
isum[iy] = _mm256_setzero_si256();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, _mm_mul_ps(d4, sum));
}
}
}
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
#ifdef HAVE_FANCY_SIMD
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2KS> ||
std::is_same_v<Dequantizer, DequantizerIQ4KS> ||
std::is_same_v<Dequantizer, DequantizerIQ5KS>) {
funcs[0] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 1>;
funcs[1] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 2>;
funcs[2] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 3>;
funcs[3] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 4>;
funcs[4] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 5>;
funcs[5] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 6>;
funcs[6] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 7>;
funcs[7] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 8>;
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_iqX_k_q8_K_AVX512_new, Dequantizer, funcs)
} else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2K>) {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_AVX512, Dequantizer, funcs);
funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;
funcs[2] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 3>;
funcs[3] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 4>;
funcs[4] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 5>;
funcs[5] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 6>;
funcs[6] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 7>;
funcs[7] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 8>;
} else {
funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;
funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;
funcs[2] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 3>;
funcs[3] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 4>;
funcs[4] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 5>;
funcs[5] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 6>;
funcs[6] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 7>;
funcs[7] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 8>;
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_iqX_k_q8_K_AVX512, Dequantizer, funcs);
}
#else
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2K>||
@@ -1347,23 +2043,9 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX
std::is_same_v<Dequantizer, DequantizerIQ4K>||
std::is_same_v<Dequantizer, DequantizerIQ5K>||
std::is_same_v<Dequantizer, DequantizerIQ6K>) {
funcs[0] = mul_mat_qY_K_q8_K_T<Dequantizer, 1>;
funcs[1] = mul_mat_qY_K_q8_K_T<Dequantizer, 2>;
funcs[2] = mul_mat_qY_K_q8_K_T<Dequantizer, 3>;
funcs[3] = mul_mat_qY_K_q8_K_T<Dequantizer, 4>;
funcs[4] = mul_mat_qY_K_q8_K_T<Dequantizer, 5>;
funcs[5] = mul_mat_qY_K_q8_K_T<Dequantizer, 6>;
funcs[6] = mul_mat_qY_K_q8_K_T<Dequantizer, 7>;
funcs[7] = mul_mat_qY_K_q8_K_T<Dequantizer, 8>;
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_K_T, Dequantizer, funcs);
} else {
funcs[0] = mul_mat_qX_K_q8_K_T<Dequantizer, 1>;
funcs[1] = mul_mat_qX_K_q8_K_T<Dequantizer, 2>;
funcs[2] = mul_mat_qX_K_q8_K_T<Dequantizer, 3>;
funcs[3] = mul_mat_qX_K_q8_K_T<Dequantizer, 4>;
funcs[4] = mul_mat_qX_K_q8_K_T<Dequantizer, 5>;
funcs[5] = mul_mat_qX_K_q8_K_T<Dequantizer, 6>;
funcs[6] = mul_mat_qX_K_q8_K_T<Dequantizer, 7>;
funcs[7] = mul_mat_qX_K_q8_K_T<Dequantizer, 8>;
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, Dequantizer, funcs);
}
#endif
@@ -1371,9 +2053,11 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX
} // namespace
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels) {
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {
auto etypeA = ggml_type(typeA);
auto expected_type_B = etypeA == GGML_TYPE_IQ4_KS_R4 || etypeA == GGML_TYPE_IQ5_KS_R4 ? GGML_TYPE_Q8_K32 : GGML_TYPE_Q8_K;
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
return false;
}
@@ -1405,6 +2089,33 @@ bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_m
case GGML_TYPE_IQ6_K:
set_functions<DequantizerIQ6K>(kernels);
break;
case GGML_TYPE_IQ2_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_k_r4_q8_k, kernels);
break;
case GGML_TYPE_IQ3_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_k_r4_q8_k, kernels);
#ifdef HAVE_FANCY_SIMD
func16 = mul_mat_iq3_k_r4_q8_k<16>;
#endif
break;
case GGML_TYPE_IQ4_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_k_r4_q8_k, kernels);
func16 = mul_mat_iq4_k_r4_q8_k<16>;
break;
case GGML_TYPE_IQ4_KS_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_ks_r4_q8_k, kernels);
#ifndef HAVE_FANCY_SIMD
// For some reason Zen4 does not like this particular function
func16 = mul_mat_iq4_ks_r4_q8_k<16>;
#endif
break;
case GGML_TYPE_IQ5_KS_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq5_ks_r4_q8_k, kernels);
#ifndef HAVE_FANCY_SIMD
// For some reason Zen4 does not like this particular function
func16 = mul_mat_iq5_ks_r4_q8_k<16>;
#endif
break;
default:
return false;
}

View File

@@ -6,6 +6,6 @@
#include <array>
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels);
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
#endif

View File

@@ -1625,6 +1625,81 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
}
}
// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
template <int nrc_y>
static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%32 == 0);
GGML_ASSERT(nrc_x%8 == 0);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
#endif
int nb = n / 16;
__m256i acc[nrc_y] = {};
__m256i qx[4];
float dy[nrc_y];
#ifdef HAVE_FANCY_SIMD
float sy[nrc_y];
#endif
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
#ifdef HAVE_FANCY_SIMD
auto iptr = (const int32_t *)(dptr + 1);
sy[iy] = -127*iptr[0];
#endif
q8y[iy] = (const int8_t *)(dptr + 2);
}
for (int ix = 0; ix < nrc_x; ix += 8) {
auto dptr = (const float *)((const char *)vx + ix*bx);
auto dx = _mm256_loadu_ps(dptr);
auto q8x = (const int8_t *)(dptr + 8);
for (int ib = 0; ib < nb; ++ib) { // Blocks of 16 for 8 interleaved rows
qx[0] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+0);
qx[1] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+1);
qx[2] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+2);
qx[3] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+3);
#ifndef HAVE_FANCY_SIMD
auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
#else
qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127));
qx[1] = _mm256_add_epi8(qx[1], _mm256_set1_epi8(127));
qx[2] = _mm256_add_epi8(qx[2], _mm256_set1_epi8(127));
qx[3] = _mm256_add_epi8(qx[3], _mm256_set1_epi8(127));
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y128 = _mm_loadu_si128((const __m128i*)q8y[iy]+ib);
auto y = MM256_SET_M128I(y128, y128);
#ifdef HAVE_FANCY_SIMD
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
#else
auto sumi1 = _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
auto sumi2 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto sumi3 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto sumi4 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
auto sumi12 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
auto sumi34 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi3), _mm256_madd_epi16(m1, sumi4));
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(sumi12, sumi34));
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto scale = _mm256_mul_ps(dx, _mm256_set1_ps(dy[iy]));
#ifdef HAVE_FANCY_SIMD
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_set1_epi32(sy[iy]));
#endif
info.store(ix, iy, _mm256_mul_ps(scale, _mm256_cvtepi32_ps(acc[iy])));
acc[iy] = _mm256_setzero_si256();
}
}
}
} // namespace
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
@@ -1632,7 +1707,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
auto etypeA = ggml_type(typeA);
auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
: etypeA == GGML_TYPE_Q8_KV ? GGML_TYPE_Q8_KV
: etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV
: GGML_TYPE_Q8_K;
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
@@ -1690,6 +1765,9 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
func16 = mul_mat_q8_KV_q8_KV<16>;
#endif
break;
case GGML_TYPE_Q8_KV_R8:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_r8_q8_KV, kernels);
break;
default:
return false;
}

File diff suppressed because it is too large Load Diff