mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
Refactor iqk: GEMM kernels are refactored on AVX2/AVX512
This commit is contained in:
@@ -1308,38 +1308,734 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
|
||||
|
||||
#endif
|
||||
|
||||
template <int nrc_y>
|
||||
//IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
|
||||
inline void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
|
||||
__m256i * isum, int16_t min) {
|
||||
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
|
||||
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
|
||||
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
|
||||
auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
|
||||
if constexpr (nrc_y == 1) {
|
||||
auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
|
||||
auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
|
||||
auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
|
||||
auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
auto bsums = q8.load_bsums(0, ibl);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
|
||||
#else
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
|
||||
#endif
|
||||
isum[0] = _mm256_mullo_epi32(sumi, _mm256_set1_epi32(min));
|
||||
|
||||
} else {
|
||||
auto s1 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9
|
||||
auto s2 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11
|
||||
auto s3 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13
|
||||
auto s4 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto bsums = q8.load_bsums(iy, ibl);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00));
|
||||
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55));
|
||||
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa));
|
||||
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff));
|
||||
#else
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
inline void iq2345_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
|
||||
__m256i extra, __m256i * isum, int8_t min, int8_t delta) {
|
||||
auto mask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101);
|
||||
auto vdelta = _mm256_set1_epi8(delta);
|
||||
auto vmin = _mm256_set1_epi8(min);
|
||||
auto min1 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(extra, mask), mask)));
|
||||
auto min2 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(extra, 4), mask), mask)));
|
||||
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
|
||||
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
|
||||
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
|
||||
auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
|
||||
auto m1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
|
||||
auto m2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
|
||||
auto m3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
|
||||
auto m4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
|
||||
auto s1 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 0), _mm256_extracti128_si256(m1, 0)),
|
||||
MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9
|
||||
auto s2 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 1), _mm256_extracti128_si256(m1, 1)),
|
||||
MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11
|
||||
auto s3 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 0), _mm256_extracti128_si256(m2, 0)),
|
||||
MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13
|
||||
auto s4 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 1), _mm256_extracti128_si256(m2, 1)),
|
||||
MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto bsums = q8.load_bsums(iy, ibl);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00));
|
||||
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55));
|
||||
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa));
|
||||
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff));
|
||||
#else
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = _mm256_set1_epi8(0xf);
|
||||
auto ms = _mm256_set1_epi8(4);
|
||||
auto m03 = _mm256_set1_epi8(0x03);
|
||||
auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
|
||||
static const uint8_t kvalues_iq2nl[32] = {1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54};
|
||||
auto values = _mm256_loadu_si256((const __m256i*)kvalues_iq2nl);
|
||||
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
|
||||
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
|
||||
#endif
|
||||
int nbl = n / QK_K;
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
uint64_t stored_scales[8];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_iq2_k_r4 * iq2 = (const block_iq2_k_r4 *)((const char *)vx + (ix+0)*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
|
||||
auto d4 = _mm256_set_m128(dl, dl);
|
||||
auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq2[ibl].extra);
|
||||
auto slbits = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales);
|
||||
auto i8scales1 = _mm256_add_epi8(_mm256_and_si256(slbits, m4), _mm256_set1_epi8(-8));
|
||||
auto i8scales2 = _mm256_add_epi8(_mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4), _mm256_set1_epi8(-8));
|
||||
_mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
|
||||
_mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
|
||||
__m256i isum[nrc_y] = {};
|
||||
iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -32);
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
|
||||
#else
|
||||
auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
|
||||
#endif
|
||||
auto lb = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs+ib);
|
||||
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 2)); extra = _mm256_srli_epi16(extra, 1);
|
||||
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
|
||||
qx[0] = _mm256_and_si256(lb, m03);
|
||||
qx[1] = _mm256_and_si256(_mm256_srli_epi16(lb, 2), m03);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(lb, 4), m03);
|
||||
qx[3] = _mm256_and_si256(_mm256_srli_epi16(lb, 6), m03);
|
||||
qx[0] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[0], shift));
|
||||
qx[1] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[1], shift));
|
||||
qx[2] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[2], shift));
|
||||
qx[3] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[3], shift));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
|
||||
#else
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
|
||||
acc[iy] = _mm256_setzero_ps();
|
||||
info.store(ix+0, iy, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = _mm256_set1_epi8(0xf);
|
||||
auto ms = _mm256_set1_epi8(8);
|
||||
auto m03 = _mm256_set1_epi8(0x03);
|
||||
auto m04 = _mm256_set1_epi8(0x04);
|
||||
auto smask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101);
|
||||
auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
|
||||
auto values128 = _mm_loadu_si128((const __m128i *)iq3nl_values);
|
||||
auto values = MM256_SET_M128I(values128, values128);
|
||||
values = _mm256_add_epi8(values, _mm256_set1_epi8(64));
|
||||
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
|
||||
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
|
||||
#endif
|
||||
int nbl = n / QK_K;
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
uint64_t stored_scales[8];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + (ix+0)*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d));
|
||||
auto d4 = _mm256_set_m128(dl, dl);
|
||||
auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq3[ibl].extra);
|
||||
auto slbits = _mm256_loadu_si256((const __m256i *)iq3[ibl].scales_l);
|
||||
auto sl1 = _mm256_add_epi8(_mm256_slli_epi16(_mm256_and_si256(slbits, m4), 1), _mm256_set1_epi8(1));
|
||||
auto sl2 = _mm256_add_epi8(_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4), 1), _mm256_set1_epi8(1));
|
||||
auto sh = _mm256_set1_epi64x(((const uint64_t *)iq3[ibl].scales_h)[0]);
|
||||
auto sh1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sh, smask), smask), _mm256_set1_epi8(1));
|
||||
auto sh2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(sh, 4), smask), smask), _mm256_set1_epi8(1));
|
||||
auto i8scales1 = _mm256_sign_epi8(sl1, sh1);
|
||||
auto i8scales2 = _mm256_sign_epi8(sl2, sh2);
|
||||
_mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
|
||||
_mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
|
||||
__m256i isum[nrc_y] = {};
|
||||
iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -64);
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
|
||||
#else
|
||||
auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
|
||||
#endif
|
||||
auto lb = _mm256_loadu_si256((const __m256i *)iq3[ibl].qs+ib);
|
||||
auto hbits = _mm_loadu_si128((const __m128i *)iq3[ibl].qh+ib);
|
||||
auto hb = MM256_SET_M128I(hbits, _mm_slli_epi16(hbits, 4));
|
||||
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 3)); extra = _mm256_srli_epi16(extra, 1);
|
||||
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
|
||||
qx[0] = _mm256_or_si256(_mm256_and_si256(lb, m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 2)));
|
||||
qx[1] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 2), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 3)));
|
||||
qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 4), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 4)));
|
||||
qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 6), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 5)));
|
||||
qx[0] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[0], shift));
|
||||
qx[1] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[1], shift));
|
||||
qx[2] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[2], shift));
|
||||
qx[3] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[3], shift));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
|
||||
#else
|
||||
auto sumi1 = _mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
auto sumi2 = _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
auto sumi3 = _mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
auto sumi4 = _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
|
||||
acc[iy] = _mm256_setzero_ps();
|
||||
info.store(ix+0, iy, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = _mm256_set1_epi8(0xf);
|
||||
auto m30 = _mm256_set1_epi8(0x30);
|
||||
auto m32 = _mm256_set1_epi8(32);
|
||||
auto ms = _mm256_set1_epi8(4);
|
||||
auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto values = load_iq4nl_values_256();
|
||||
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
|
||||
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
|
||||
#else
|
||||
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
|
||||
auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
|
||||
auto values = MM256_SET_M128I(values128, values128);
|
||||
#endif
|
||||
int nbl = n / QK_K;
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
uint64_t stored_scales[8];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_iq4_k_r4 * iq4 = (const block_iq4_k_r4 *)((const char *)vx + (ix+0)*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[ibl].d));
|
||||
auto d4 = _mm256_set_m128(dl, dl);
|
||||
auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq4[ibl].extra);
|
||||
auto slbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
|
||||
auto sl1 = _mm256_and_si256(slbits, m4);
|
||||
auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4);
|
||||
auto shbits = _mm_loadu_si128((const __m128i*)iq4[ibl].scales_h);
|
||||
auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
|
||||
auto i8scales1 = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(m30, _mm256_slli_epi16(sh, 4))), m32);
|
||||
auto i8scales2 = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(m30, sh)), m32);
|
||||
_mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
|
||||
_mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
|
||||
__m256i isum[nrc_y] = {};
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128);
|
||||
#endif
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
|
||||
#else
|
||||
auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
|
||||
#endif
|
||||
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
|
||||
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
|
||||
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 2)); extra = _mm256_srli_epi16(extra, 1);
|
||||
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
|
||||
qx[0] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)));
|
||||
qx[1] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)));
|
||||
qx[2] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)));
|
||||
qx[3] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)));
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
|
||||
auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
|
||||
auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
|
||||
auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
|
||||
#endif
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
|
||||
#else
|
||||
auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
|
||||
auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
|
||||
auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
|
||||
auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
|
||||
acc[iy] = _mm256_setzero_ps();
|
||||
info.store(ix+0, iy, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline __m256i prepare_5bit_quants(const __m256i * values, __m256i ql, __m256i qh, __m256i mask) {
|
||||
auto q5vl = _mm256_shuffle_epi8(values[0], ql);
|
||||
auto q5vh = _mm256_shuffle_epi8(values[1], ql);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
return _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(qh, mask), mask), q5vl, q5vh);
|
||||
#else
|
||||
return _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(qh, mask), mask));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = _mm256_set1_epi8(0xf);
|
||||
auto m30 = _mm256_set1_epi8(0x30);
|
||||
auto m32 = _mm256_set1_epi8(32);
|
||||
auto ms = _mm256_set1_epi8(2);
|
||||
auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
|
||||
__m256i values[2];
|
||||
{
|
||||
auto val1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0);
|
||||
auto val2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1);
|
||||
values[0] = MM256_SET_M128I(val1, val1);
|
||||
values[1] = MM256_SET_M128I(val2, val2);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128));
|
||||
values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128));
|
||||
#endif
|
||||
}
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
|
||||
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
|
||||
#else
|
||||
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
|
||||
#endif
|
||||
int nbl = n / QK_K;
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
uint64_t stored_scales[8];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + (ix+0)*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5[ibl].d));
|
||||
auto d4 = _mm256_set_m128(dl, dl);
|
||||
auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq5[ibl].extra);
|
||||
auto slbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l);
|
||||
auto sl1 = _mm256_and_si256(slbits, m4);
|
||||
auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4);
|
||||
auto shbits = _mm_loadu_si128((const __m128i*)iq5[ibl].scales_h);
|
||||
auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
|
||||
auto i8scales1 = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(m30, _mm256_slli_epi16(sh, 4))), m32);
|
||||
auto i8scales2 = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(m30, sh)), m32);
|
||||
_mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
|
||||
_mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
|
||||
__m256i isum[nrc_y] = {};
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
if constexpr (nrc_y == 1) {
|
||||
iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128);
|
||||
} else {
|
||||
iq2345_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, extra, isum, -128, 2);
|
||||
}
|
||||
#endif
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
|
||||
#else
|
||||
auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
|
||||
#endif
|
||||
auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
|
||||
auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
|
||||
auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib);
|
||||
auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits);
|
||||
qx[0] = _mm256_and_si256(lbits1, m4);
|
||||
qx[1] = _mm256_and_si256(lbits2, m4);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4);
|
||||
qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4);
|
||||
|
||||
qx[0] = prepare_5bit_quants(values, qx[0], hb, _mm256_set1_epi8(0x01));
|
||||
qx[1] = prepare_5bit_quants(values, qx[1], hb, _mm256_set1_epi8(0x10));
|
||||
qx[2] = prepare_5bit_quants(values, qx[2], hb, _mm256_set1_epi8(0x02));
|
||||
qx[3] = prepare_5bit_quants(values, qx[3], hb, _mm256_set1_epi8(0x20));
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
if constexpr (nrc_y == 1) {
|
||||
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
|
||||
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
|
||||
qx[0] = _mm256_add_epi8(qx[0], shift);
|
||||
qx[1] = _mm256_add_epi8(qx[1], shift);
|
||||
qx[2] = _mm256_add_epi8(qx[2], shift);
|
||||
qx[3] = _mm256_add_epi8(qx[3], shift);
|
||||
}
|
||||
#else
|
||||
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
|
||||
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
|
||||
qx[0] = _mm256_add_epi8(qx[0], shift);
|
||||
qx[1] = _mm256_add_epi8(qx[1], shift);
|
||||
qx[2] = _mm256_add_epi8(qx[2], shift);
|
||||
qx[3] = _mm256_add_epi8(qx[3], shift);
|
||||
auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
|
||||
auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
|
||||
auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
|
||||
auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
|
||||
#endif
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
|
||||
#else
|
||||
auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
|
||||
auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
|
||||
auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
|
||||
auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
|
||||
acc[iy] = _mm256_setzero_ps();
|
||||
info.store(ix+0, iy, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = _mm256_set1_epi8(0xf);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
|
||||
auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
|
||||
auto values = MM256_SET_M128I(values128, values128);
|
||||
#else
|
||||
auto values = load_iq4nl_values_256();
|
||||
#endif
|
||||
int nbl = n / QK_K;
|
||||
using helper_t = union { __m256i vec; uint32_t val[8]; };
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
helper_t h, h_shift;
|
||||
#else
|
||||
using helper512_t = union { __m512i vec; uint64_t val[8]; };
|
||||
helper_t h;
|
||||
helper512_t h_shift;
|
||||
#endif
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i isum[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
auto dptr = (const float *)((const char *)vx + (ix+0)*bx);
|
||||
const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4);
|
||||
auto d4 = _mm_loadu_ps(dptr);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto scales = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales);
|
||||
h.vec = _mm256_sub_epi8(_mm256_and_si256(scales, _mm256_set1_epi8(-2)), _mm256_set1_epi8(127));
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
h_shift.vec = _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 2);
|
||||
{
|
||||
__m256 v1 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[0])))),
|
||||
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[0])))));
|
||||
__m256 v2 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[1])))),
|
||||
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[1])))));
|
||||
__m256 v3 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[2])))),
|
||||
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[2])))));
|
||||
__m256 v4 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[3])))),
|
||||
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[3])))));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto m8 = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
|
||||
acc[iy] = _mm256_fmadd_ps(v1, _mm256_shuffle_ps(m8, m8, 0x00), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(v2, _mm256_shuffle_ps(m8, m8, 0x55), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(v3, _mm256_shuffle_ps(m8, m8, 0xaa), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(v4, _mm256_shuffle_ps(m8, m8, 0xff), acc[iy]);
|
||||
}
|
||||
}
|
||||
#else
|
||||
auto shift = _mm256_add_epi8(_mm256_set1_epi8(-64), _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 1));
|
||||
h_shift.vec = _mm512_mullo_epi16(_mm512_cvtepi8_epi16(shift), _mm512_cvtepi8_epi16(h.vec));
|
||||
#endif
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib]));
|
||||
auto ishifts = _mm256_cvtepi16_epi32(_mm_set1_epi64x(h_shift.val[ib]));
|
||||
auto scales_m = _mm256_cvtepi32_ps(ishifts);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
|
||||
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
|
||||
}
|
||||
#endif
|
||||
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
|
||||
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
|
||||
qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
|
||||
qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
|
||||
qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
|
||||
qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle);
|
||||
auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
|
||||
auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
|
||||
auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
|
||||
auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
|
||||
#endif
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
|
||||
#else
|
||||
auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
|
||||
auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
|
||||
auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
|
||||
auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, ibl)), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
isum[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
|
||||
acc[iy] = _mm256_setzero_ps();
|
||||
info.store(ix+0, iy, _mm_mul_ps(d4, sum));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq5_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = _mm256_set1_epi8(0xf);
|
||||
__m256i values[2];
|
||||
{
|
||||
auto val1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0);
|
||||
auto val2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1);
|
||||
values[0] = MM256_SET_M128I(val1, val1);
|
||||
values[1] = MM256_SET_M128I(val2, val2);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128));
|
||||
values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128));
|
||||
#endif
|
||||
}
|
||||
int nbl = n / QK_K;
|
||||
using helper_t = union { __m256i vec; uint32_t val[8]; };
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
helper_t h, h_shift;
|
||||
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
|
||||
#else
|
||||
using helper512_t = union { __m512i vec; uint64_t val[8]; };
|
||||
helper_t h;
|
||||
helper512_t h_shift;
|
||||
#endif
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i isum[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
auto dptr = (const float *)((const char *)vx + (ix+0)*bx);
|
||||
const block_iq5_ks_r4 * iq5 = (const block_iq5_ks_r4 *)(dptr + 4);
|
||||
auto d4 = _mm_loadu_ps(dptr);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto scales = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales);
|
||||
h.vec = _mm256_sub_epi8(_mm256_and_si256(scales, _mm256_set1_epi8(-2)), _mm256_set1_epi8(127));
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
h_shift.vec = _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 1);
|
||||
{
|
||||
__m256 v1 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[0])))),
|
||||
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[0])))));
|
||||
__m256 v2 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[1])))),
|
||||
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[1])))));
|
||||
__m256 v3 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[2])))),
|
||||
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[2])))));
|
||||
__m256 v4 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[3])))),
|
||||
_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[3])))));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto m8 = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
|
||||
acc[iy] = _mm256_fmadd_ps(v1, _mm256_shuffle_ps(m8, m8, 0x00), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(v2, _mm256_shuffle_ps(m8, m8, 0x55), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(v3, _mm256_shuffle_ps(m8, m8, 0xaa), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(v4, _mm256_shuffle_ps(m8, m8, 0xff), acc[iy]);
|
||||
}
|
||||
}
|
||||
#else
|
||||
auto shift = _mm256_add_epi8(_mm256_set1_epi8(-64), _mm256_and_si256(scales, _mm256_set1_epi8(1)));
|
||||
h_shift.vec = _mm512_mullo_epi16(_mm512_cvtepi8_epi16(shift), _mm512_cvtepi8_epi16(h.vec));
|
||||
#endif
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib]));
|
||||
auto ishifts = _mm256_cvtepi16_epi32(_mm_set1_epi64x(h_shift.val[ib]));
|
||||
auto scales_m = _mm256_cvtepi32_ps(ishifts);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
|
||||
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
|
||||
}
|
||||
#endif
|
||||
auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
|
||||
auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
|
||||
auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib);
|
||||
auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits);
|
||||
qx[0] = _mm256_and_si256(lbits1, m4);
|
||||
qx[1] = _mm256_and_si256(lbits2, m4);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4);
|
||||
qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4);
|
||||
|
||||
qx[0] = prepare_5bit_quants(values, qx[0], hb, _mm256_set1_epi8(0x01));
|
||||
qx[1] = prepare_5bit_quants(values, qx[1], hb, _mm256_set1_epi8(0x10));
|
||||
qx[2] = prepare_5bit_quants(values, qx[2], hb, _mm256_set1_epi8(0x02));
|
||||
qx[3] = prepare_5bit_quants(values, qx[3], hb, _mm256_set1_epi8(0x20));
|
||||
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle);
|
||||
auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
|
||||
auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
|
||||
auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
|
||||
auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
|
||||
#endif
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
|
||||
#else
|
||||
auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
|
||||
auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
|
||||
auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
|
||||
auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, ibl)), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
isum[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
|
||||
acc[iy] = _mm256_setzero_ps();
|
||||
info.store(ix+0, iy, _mm_mul_ps(d4, sum));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2KS> ||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ4KS> ||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ5KS>) {
|
||||
funcs[0] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 1>;
|
||||
funcs[1] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 2>;
|
||||
funcs[2] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 3>;
|
||||
funcs[3] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 4>;
|
||||
funcs[4] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 5>;
|
||||
funcs[5] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 6>;
|
||||
funcs[6] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 7>;
|
||||
funcs[7] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 8>;
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_iqX_k_q8_K_AVX512_new, Dequantizer, funcs)
|
||||
} else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2K>) {
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_AVX512, Dequantizer, funcs);
|
||||
funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
|
||||
funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;
|
||||
funcs[2] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 3>;
|
||||
funcs[3] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 4>;
|
||||
funcs[4] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 5>;
|
||||
funcs[5] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 6>;
|
||||
funcs[6] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 7>;
|
||||
funcs[7] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 8>;
|
||||
} else {
|
||||
funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;
|
||||
funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;
|
||||
funcs[2] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 3>;
|
||||
funcs[3] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 4>;
|
||||
funcs[4] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 5>;
|
||||
funcs[5] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 6>;
|
||||
funcs[6] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 7>;
|
||||
funcs[7] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 8>;
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_iqX_k_q8_K_AVX512, Dequantizer, funcs);
|
||||
}
|
||||
#else
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2K>||
|
||||
@@ -1347,23 +2043,9 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX
|
||||
std::is_same_v<Dequantizer, DequantizerIQ4K>||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ5K>||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ6K>) {
|
||||
funcs[0] = mul_mat_qY_K_q8_K_T<Dequantizer, 1>;
|
||||
funcs[1] = mul_mat_qY_K_q8_K_T<Dequantizer, 2>;
|
||||
funcs[2] = mul_mat_qY_K_q8_K_T<Dequantizer, 3>;
|
||||
funcs[3] = mul_mat_qY_K_q8_K_T<Dequantizer, 4>;
|
||||
funcs[4] = mul_mat_qY_K_q8_K_T<Dequantizer, 5>;
|
||||
funcs[5] = mul_mat_qY_K_q8_K_T<Dequantizer, 6>;
|
||||
funcs[6] = mul_mat_qY_K_q8_K_T<Dequantizer, 7>;
|
||||
funcs[7] = mul_mat_qY_K_q8_K_T<Dequantizer, 8>;
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_K_T, Dequantizer, funcs);
|
||||
} else {
|
||||
funcs[0] = mul_mat_qX_K_q8_K_T<Dequantizer, 1>;
|
||||
funcs[1] = mul_mat_qX_K_q8_K_T<Dequantizer, 2>;
|
||||
funcs[2] = mul_mat_qX_K_q8_K_T<Dequantizer, 3>;
|
||||
funcs[3] = mul_mat_qX_K_q8_K_T<Dequantizer, 4>;
|
||||
funcs[4] = mul_mat_qX_K_q8_K_T<Dequantizer, 5>;
|
||||
funcs[5] = mul_mat_qX_K_q8_K_T<Dequantizer, 6>;
|
||||
funcs[6] = mul_mat_qX_K_q8_K_T<Dequantizer, 7>;
|
||||
funcs[7] = mul_mat_qX_K_q8_K_T<Dequantizer, 8>;
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, Dequantizer, funcs);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1371,9 +2053,11 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX
|
||||
|
||||
} // namespace
|
||||
|
||||
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels) {
|
||||
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
|
||||
|
||||
if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {
|
||||
auto etypeA = ggml_type(typeA);
|
||||
auto expected_type_B = etypeA == GGML_TYPE_IQ4_KS_R4 || etypeA == GGML_TYPE_IQ5_KS_R4 ? GGML_TYPE_Q8_K32 : GGML_TYPE_Q8_K;
|
||||
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1405,6 +2089,33 @@ bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_m
|
||||
case GGML_TYPE_IQ6_K:
|
||||
set_functions<DequantizerIQ6K>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_K_R4:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_k_r4_q8_k, kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_K_R4:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_k_r4_q8_k, kernels);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
func16 = mul_mat_iq3_k_r4_q8_k<16>;
|
||||
#endif
|
||||
break;
|
||||
case GGML_TYPE_IQ4_K_R4:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_k_r4_q8_k, kernels);
|
||||
func16 = mul_mat_iq4_k_r4_q8_k<16>;
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KS_R4:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_ks_r4_q8_k, kernels);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
// For some reason Zen4 does not like this particular function
|
||||
func16 = mul_mat_iq4_ks_r4_q8_k<16>;
|
||||
#endif
|
||||
break;
|
||||
case GGML_TYPE_IQ5_KS_R4:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq5_ks_r4_q8_k, kernels);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
// For some reason Zen4 does not like this particular function
|
||||
func16 = mul_mat_iq5_ks_r4_q8_k<16>;
|
||||
#endif
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
#include <array>
|
||||
|
||||
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels);
|
||||
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1625,6 +1625,81 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
|
||||
}
|
||||
}
|
||||
|
||||
// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(n%32 == 0);
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
int nb = n / 16;
|
||||
__m256i acc[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
float dy[nrc_y];
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
float sy[nrc_y];
|
||||
#endif
|
||||
const int8_t * q8y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dptr = (const float *)info.src1_row(iy);
|
||||
dy[iy] = dptr[0];
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto iptr = (const int32_t *)(dptr + 1);
|
||||
sy[iy] = -127*iptr[0];
|
||||
#endif
|
||||
q8y[iy] = (const int8_t *)(dptr + 2);
|
||||
}
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
auto dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto dx = _mm256_loadu_ps(dptr);
|
||||
auto q8x = (const int8_t *)(dptr + 8);
|
||||
for (int ib = 0; ib < nb; ++ib) { // Blocks of 16 for 8 interleaved rows
|
||||
qx[0] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+0);
|
||||
qx[1] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+1);
|
||||
qx[2] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+2);
|
||||
qx[3] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+3);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
|
||||
auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
|
||||
auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
|
||||
auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
|
||||
#else
|
||||
qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127));
|
||||
qx[1] = _mm256_add_epi8(qx[1], _mm256_set1_epi8(127));
|
||||
qx[2] = _mm256_add_epi8(qx[2], _mm256_set1_epi8(127));
|
||||
qx[3] = _mm256_add_epi8(qx[3], _mm256_set1_epi8(127));
|
||||
#endif
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y128 = _mm_loadu_si128((const __m128i*)q8y[iy]+ib);
|
||||
auto y = MM256_SET_M128I(y128, y128);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
#else
|
||||
auto sumi1 = _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
|
||||
auto sumi2 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
|
||||
auto sumi3 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
|
||||
auto sumi4 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
|
||||
auto sumi12 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
|
||||
auto sumi34 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi3), _mm256_madd_epi16(m1, sumi4));
|
||||
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(sumi12, sumi34));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto scale = _mm256_mul_ps(dx, _mm256_set1_ps(dy[iy]));
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_set1_epi32(sy[iy]));
|
||||
#endif
|
||||
info.store(ix, iy, _mm256_mul_ps(scale, _mm256_cvtepi32_ps(acc[iy])));
|
||||
acc[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
|
||||
@@ -1632,7 +1707,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
|
||||
auto etypeA = ggml_type(typeA);
|
||||
auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
|
||||
: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
|
||||
: etypeA == GGML_TYPE_Q8_KV ? GGML_TYPE_Q8_KV
|
||||
: etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV
|
||||
: GGML_TYPE_Q8_K;
|
||||
|
||||
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
|
||||
@@ -1690,6 +1765,9 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
|
||||
func16 = mul_mat_q8_KV_q8_KV<16>;
|
||||
#endif
|
||||
break;
|
||||
case GGML_TYPE_Q8_KV_R8:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_r8_q8_KV, kernels);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user