mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-20 22:49:31 +00:00
Refactor iqk: Factor out GEMM for q8_K_R8, q8_KV
This commit is contained in:
@@ -717,60 +717,930 @@ static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
|
||||
|
||||
#endif
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = _mm256_set1_epi8(0xf);
|
||||
auto m30 = _mm256_set1_epi8(0x30);
|
||||
auto m32 = _mm256_set1_epi8(32);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
|
||||
auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
|
||||
auto values = MM256_SET_M128I(values128, values128);
|
||||
#else
|
||||
auto values = load_iq4nl_values_256();
|
||||
#endif
|
||||
int nbl = n / QK_K;
|
||||
using helper_t = union { __m256i vec[2]; uint64_t val[8]; };
|
||||
helper_t h;
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
const block_iq4_xs_r8 * iq4 = (const block_iq4_xs_r8 *)((const char *)vx + (ix+0)*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d));
|
||||
auto slbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
|
||||
auto sl1 = _mm256_and_si256(slbits, m4);
|
||||
auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4);
|
||||
auto shbits = _mm_loadu_si128((const __m128i*)iq4[ibl].scales_h);
|
||||
auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
|
||||
h.vec[0] = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(_mm256_slli_epi16(sh, 4), m30)), m32);
|
||||
h.vec[1] = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(sh, m30)), m32);
|
||||
__m256i isum[nrc_y] = {};
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi64x(h.val[ib]));
|
||||
auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
|
||||
auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-128.f));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
|
||||
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
|
||||
}
|
||||
#else
|
||||
auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(h.val[ib])), s_shuffle);
|
||||
#endif
|
||||
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+0);
|
||||
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+1);
|
||||
qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1));
|
||||
qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4)));
|
||||
qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2));
|
||||
qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4)));
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
|
||||
auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
|
||||
auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
|
||||
auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
|
||||
#endif
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+0);
|
||||
auto y = MM256_SET_M128I(y128, y128);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
|
||||
#else
|
||||
auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
|
||||
auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
|
||||
auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
|
||||
auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
|
||||
auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)),
|
||||
_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], sumi);
|
||||
#endif
|
||||
}
|
||||
bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+2);
|
||||
bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+3);
|
||||
qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1));
|
||||
qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4)));
|
||||
qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2));
|
||||
qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4)));
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
s1 = _mm256_sign_epi8(qx[0], qx[0]);
|
||||
s2 = _mm256_sign_epi8(qx[1], qx[1]);
|
||||
s3 = _mm256_sign_epi8(qx[2], qx[2]);
|
||||
s4 = _mm256_sign_epi8(qx[3], qx[3]);
|
||||
#endif
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+1);
|
||||
auto y = MM256_SET_M128I(y128, y128);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
|
||||
#else
|
||||
auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
|
||||
auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
|
||||
auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
|
||||
auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
|
||||
auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)),
|
||||
_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], sumi);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, acc[iy]);
|
||||
acc[iy] = _mm256_setzero_ps();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
mul_mat_iq4_xs_r8_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
|
||||
return;
|
||||
if constexpr (nrc_y == 1){
|
||||
mul_mat_iq4_xs_r8_q8_k_avx2<1>(n, vx, bx, info, nrc_x);
|
||||
} else {
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = _mm512_set1_epi8(0xf);
|
||||
auto values = load_iq4nl_values_512();
|
||||
int nbl = n / QK_K;
|
||||
using helper_t = union { __m512i vec; uint32_t val[16]; };
|
||||
helper_t h;
|
||||
__m512 acc[nrc_y] = {};
|
||||
__m512i isum[nrc_y] = {};
|
||||
__m512i qx[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
const block_iq4_xs_r8 * iq4l = (const block_iq4_xs_r8 *)((const char *)vx + (ix+0)*bx);
|
||||
const block_iq4_xs_r8 * iq4h = (const block_iq4_xs_r8 *)((const char *)vx + (ix+4)*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[ibl].d));
|
||||
auto dh = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[ibl].d));
|
||||
auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1);
|
||||
auto d4x64 = _mm512_mul_ps(d4, _mm512_set1_ps(-64.f));
|
||||
auto slbits_l = _mm_loadu_si128((const __m128i *)iq4l[ibl].scales_l);
|
||||
auto shbits_l = _mm_loadu_si128((const __m128i *)iq4h[ibl].scales_l);
|
||||
auto sl_l = MM256_SET_M128I(_mm_srli_epi16(slbits_l, 4), slbits_l);
|
||||
auto sh_l = MM256_SET_M128I(_mm_srli_epi16(shbits_l, 4), shbits_l);
|
||||
auto slb = _mm512_and_si512(_mm512_inserti32x8(_mm512_castsi256_si512(sl_l), sh_l, 1), m4);
|
||||
auto aux64 = (const uint64_t *)iq4l[ibl].scales_h;
|
||||
auto slbits_h = _mm_set_epi64x(aux64[0] >> 2, aux64[0]);
|
||||
aux64 = (const uint64_t *)iq4h[ibl].scales_h;
|
||||
auto shbits_h = _mm_set_epi64x(aux64[0] >> 2, aux64[0]);
|
||||
auto sl_h = MM256_SET_M128I(slbits_h, _mm_slli_epi16(slbits_h, 4));
|
||||
auto sh_h = MM256_SET_M128I(shbits_h, _mm_slli_epi16(shbits_h, 4));
|
||||
auto shb = _mm512_and_si512(_mm512_inserti32x8(_mm512_castsi256_si512(sl_h), sh_h, 1), _mm512_set1_epi8(0x30));
|
||||
h.vec = _mm512_sub_epi8(_mm512_or_si512(slb, shb), _mm512_set1_epi8(32));
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
auto iscales = _mm512_cvtepi8_epi32(_mm_blend_epi32(_mm_set1_epi32(h.val[ib+0]), _mm_set1_epi32(h.val[ib+8]), 0x0c));
|
||||
auto scales = _mm512_cvtepi32_ps(iscales);
|
||||
auto scales_m = _mm512_mul_ps(scales, d4x64);
|
||||
auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+0)),
|
||||
_mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+0), 1);
|
||||
auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+1)),
|
||||
_mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+1), 1);
|
||||
qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4));
|
||||
qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4));
|
||||
qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4));
|
||||
qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
|
||||
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
|
||||
auto sumi = _mm512_setzero_si512();
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi));
|
||||
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
|
||||
acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
isum[iy] = _mm512_setzero_si512();
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1));
|
||||
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3));
|
||||
info.store(ix+0, iy, sum1);
|
||||
info.store(ix+4, iy, sum2);
|
||||
acc[iy] = _mm512_setzero_ps();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
mul_mat_iq4_xs_r8_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto mxf = _mm256_set1_epi8(0xf);
|
||||
auto m03 = _mm256_set1_epi8(0x03);
|
||||
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
|
||||
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
__m256i isum[nrc_y] = {};
|
||||
#else
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
int nbl = n / QK_K;
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
int8_t scales[64];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_q2_k_r4 * iq2 = (const block_q2_k_r4 *)((const char *)vx + (ix+0)*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto dm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq2[ibl].d));
|
||||
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dm), _mm256_castps256_ps128(dm));
|
||||
auto m4 = _mm256_set_m128(_mm256_extractf128_ps(dm, 1), _mm256_extractf128_ps(dm, 1));
|
||||
m4 = _mm256_mul_ps(m4, _mm256_set1_ps(-1.f));
|
||||
auto all_scales1 = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales+0);
|
||||
auto all_scales2 = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales+1);
|
||||
auto scales1 = _mm256_and_si256(_mm256_srli_epi16(all_scales1, 4), mxf);
|
||||
auto scales2 = _mm256_and_si256(_mm256_srli_epi16(all_scales2, 4), mxf);
|
||||
{
|
||||
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
|
||||
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
|
||||
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
|
||||
auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
|
||||
auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
|
||||
auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
|
||||
auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
|
||||
auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto bsums = q8.load_bsums(iy, ibl);
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
|
||||
auto d8 = _mm256_set1_ps(q8.scale(iy, ibl));
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
#else
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
|
||||
auto d8 = _mm256_set1_ps(q8.scale(iy, ibl));
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
if constexpr (nrc_y == 1) {
|
||||
d4 = _mm256_mul_ps(d4, d8);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
all_scales1 = _mm256_and_si256(all_scales1, mxf);
|
||||
all_scales2 = _mm256_and_si256(all_scales2, mxf);
|
||||
_mm256_storeu_si256((__m256i *)scales+0, all_scales1);
|
||||
_mm256_storeu_si256((__m256i *)scales+1, all_scales2);
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 8*ib)));
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
|
||||
#endif
|
||||
auto lb = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs+ib);
|
||||
qx[0] = _mm256_and_si256(lb, m03);
|
||||
qx[1] = _mm256_and_si256(_mm256_srli_epi16(lb, 2), m03);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(lb, 4), m03);
|
||||
qx[3] = _mm256_and_si256(_mm256_srli_epi16(lb, 6), m03);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
|
||||
#else
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
// Quants are in 0...3, so we can add add up all of them as int16_t without overflowing
|
||||
auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
|
||||
if constexpr (nrc_y == 1) {
|
||||
acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
} else {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
|
||||
acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
isum[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
|
||||
acc[iy] = _mm256_setzero_ps();
|
||||
info.store(ix+0, iy, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = _mm256_set1_epi8(0xf);
|
||||
auto m30 = _mm256_set1_epi8(0x30);
|
||||
auto m32 = _mm256_set1_epi8(32);
|
||||
auto m03 = _mm256_set1_epi8(0x03);
|
||||
auto m04 = _mm256_set1_epi8(0x04);
|
||||
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
|
||||
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
__m256i isum[nrc_y];
|
||||
#else
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
int nbl = n / QK_K;
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
int8_t scales[64];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_q3_k_r4 * iq3 = (const block_q3_k_r4 *)((const char *)vx + (ix+0)*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d));
|
||||
auto d4 = _mm256_set_m128(dl, dl);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
if constexpr (nrc_y == 1) {
|
||||
d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
|
||||
}
|
||||
#endif
|
||||
auto slb = _mm256_loadu_si256((const __m256i *)iq3[ibl].scales_l);
|
||||
auto shbits = _mm_loadu_si128((const __m128i *)iq3[ibl].scales_h);
|
||||
auto shb = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
|
||||
auto scales1 = _mm256_sub_epi8(_mm256_or_si256(_mm256_and_si256(slb, m4), _mm256_and_si256(_mm256_slli_epi16(shb, 4), m30)), m32);
|
||||
auto scales2 = _mm256_sub_epi8(_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(slb, 4), m4), _mm256_and_si256(shb, m30)), m32);
|
||||
_mm256_storeu_si256((__m256i *)scales+0, scales1);
|
||||
_mm256_storeu_si256((__m256i *)scales+1, scales2);
|
||||
{
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto min = _mm256_mul_ps(d4, _mm256_set1_ps(-4.f));
|
||||
#endif
|
||||
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
|
||||
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
|
||||
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
|
||||
auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
|
||||
auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
|
||||
auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
|
||||
auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
|
||||
auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
s1 = _mm256_mullo_epi16(s1, _mm256_set1_epi16(-4));
|
||||
s2 = _mm256_mullo_epi16(s2, _mm256_set1_epi16(-4));
|
||||
s3 = _mm256_mullo_epi16(s3, _mm256_set1_epi16(-4));
|
||||
s4 = _mm256_mullo_epi16(s4, _mm256_set1_epi16(-4));
|
||||
#endif
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto bsums = q8.load_bsums(iy, ibl);
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
|
||||
isum[iy] = sumi;
|
||||
#else
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
|
||||
if constexpr (nrc_y == 1) {
|
||||
acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
} else {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(min, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 8*ib)));
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
|
||||
#endif
|
||||
auto lb = _mm256_loadu_si256((const __m256i *)iq3[ibl].qs+ib);
|
||||
auto hbits = _mm_loadu_si128((const __m128i *)iq3[ibl].qh+ib);
|
||||
auto hb = MM256_SET_M128I(hbits, _mm_slli_epi16(hbits, 4));
|
||||
qx[0] = _mm256_or_si256(_mm256_and_si256(lb, m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 2)));
|
||||
qx[1] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 2), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 3)));
|
||||
qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 4), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 4)));
|
||||
qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 6), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 5)));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
|
||||
#else
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
// Quants are in 0...8, so we can add add up all of them as int16_t without overflowing
|
||||
auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
|
||||
if constexpr (nrc_y == 1) {
|
||||
acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
} else {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
}
|
||||
#endif
|
||||
|
||||
}
|
||||
}
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
|
||||
acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
|
||||
acc[iy] = _mm256_setzero_ps();
|
||||
info.store(ix+0, iy, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
inline void process_min_r4_b32(int ibl, __m256 m4, __m256i mins, const Q8<nrc_y, block_q8_K>& q8, __m256 * acc) {
|
||||
auto mins_l = _mm256_castsi256_si128(mins);
|
||||
auto mins_h = _mm256_extracti128_si256(mins, 1);
|
||||
auto aux1 = _mm_unpacklo_epi32(mins_l, mins_h);
|
||||
auto aux2 = _mm_unpackhi_epi32(mins_l, mins_h);
|
||||
auto ic1 = _mm256_cvtepi8_epi32(aux1);
|
||||
auto ic2 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux1, 0xee));
|
||||
auto ic3 = _mm256_cvtepi8_epi32(aux2);
|
||||
auto ic4 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux2, 0xee));
|
||||
if constexpr (nrc_y == 1) {
|
||||
auto bs = _mm256_loadu_ps((const float *)q8.y[0][ibl].bsums);
|
||||
auto sumf = _mm256_mul_ps(_mm256_cvtepi32_ps(ic1), _mm256_shuffle_ps(bs, bs, 0x00));
|
||||
sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic2), _mm256_shuffle_ps(bs, bs, 0x55), sumf);
|
||||
sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic3), _mm256_shuffle_ps(bs, bs, 0xaa), sumf);
|
||||
sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic4), _mm256_shuffle_ps(bs, bs, 0xff), sumf);
|
||||
acc[0] = _mm256_fmadd_ps(m4, sumf, acc[0]);
|
||||
} else {
|
||||
auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic1));
|
||||
auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic2));
|
||||
auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic3));
|
||||
auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic4));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
|
||||
acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto mf = _mm256_set1_epi8(0xf);
|
||||
auto m3 = _mm256_set1_epi8(0x30);
|
||||
int nbl = n / QK_K;
|
||||
union { __m256i vec; uint32_t val[8]; } hd;
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i isum[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_q4_k_r4 * iq4 = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d));
|
||||
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
|
||||
auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
|
||||
auto lbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
|
||||
auto hbits128 = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_h);
|
||||
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
|
||||
hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m3));
|
||||
auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m3));
|
||||
process_min_r4_b32(ibl, m4, mins, q8, acc);
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
|
||||
#else
|
||||
auto aux = _mm_set1_epi32(hd.val[ib]);
|
||||
aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
|
||||
auto scales_d = MM256_SET_M128I(aux, aux);
|
||||
#endif
|
||||
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
|
||||
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
|
||||
qx[0] = _mm256_and_si256(bits1, mf);
|
||||
qx[1] = _mm256_and_si256(bits2, mf);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits1, 4), mf);
|
||||
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits2, 4), mf);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
|
||||
#else
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales_d, _mm256_add_epi16(sumi1, sumi2)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
isum[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
|
||||
acc[iy] = _mm256_setzero_ps();
|
||||
info.store(ix+0, iy, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto mf = _mm256_set1_epi8(0xf);
|
||||
auto m10 = _mm256_set1_epi8(0x10);
|
||||
auto m30 = _mm256_set1_epi8(0x30);
|
||||
int nbl = n / QK_K;
|
||||
union { __m256i vec; uint32_t val[8]; } hd;
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i isum[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5[ibl].d));
|
||||
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
|
||||
auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
|
||||
auto lbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l);
|
||||
auto hbits128 = _mm_loadu_si128((const __m128i *)iq5[ibl].scales_h);
|
||||
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
|
||||
hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m30));
|
||||
auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m30));
|
||||
process_min_r4_b32(ibl, m4, mins, q8, acc);
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
|
||||
#else
|
||||
auto aux = _mm_set1_epi32(hd.val[ib]);
|
||||
aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
|
||||
auto scales_d = MM256_SET_M128I(aux, aux);
|
||||
#endif
|
||||
auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
|
||||
auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
|
||||
auto hbits128 = _mm_loadu_si128((const __m128i*)iq5[ibl].qh + ib);
|
||||
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
|
||||
qx[0] = _mm256_or_si256(_mm256_and_si256(lbits1, mf), _mm256_and_si256(m10, hbits));
|
||||
qx[1] = _mm256_or_si256(_mm256_and_si256(lbits2, mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 2)));
|
||||
qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 1)));
|
||||
qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 3)));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
|
||||
#else
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
// To avoid overflow, we can only add up to 4 q5 x q8 products.
|
||||
auto sumi = _mm256_add_epi32(_mm256_madd_epi16(scales_d, sumi1), _mm256_madd_epi16(scales_d, sumi2));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], sumi);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
isum[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
|
||||
acc[iy] = _mm256_setzero_ps();
|
||||
info.store(ix+0, iy, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = _mm256_set1_epi8(0xf);
|
||||
auto m3 = _mm256_set1_epi8(0x30);
|
||||
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
|
||||
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
__m256i isum[nrc_y];
|
||||
#else
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
int nbl = n / QK_K;
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_q6_k_r4 * iq6 = (const block_q6_k_r4 *)((const char *)vx + (ix+0)*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6[ibl].d));
|
||||
auto d4 = _mm256_set_m128(dl, dl);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
if constexpr (nrc_y == 1) {
|
||||
d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
|
||||
}
|
||||
#endif
|
||||
{
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto min = _mm256_mul_ps(d4, _mm256_set1_ps(-32.f));
|
||||
#endif
|
||||
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+0)), shuff); // blocks 0, 1, 2, 3 for each row
|
||||
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+1)), shuff); // blocks 4, 5, 6, 7 for each row
|
||||
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+2)), shuff); // blocks 8, 9, 10, 11 for each row
|
||||
auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+3)), shuff); // blocks 12, 13, 14, 15 for each row
|
||||
auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
|
||||
auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
|
||||
auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
|
||||
auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
s1 = _mm256_mullo_epi16(s1, _mm256_set1_epi16(-32));
|
||||
s2 = _mm256_mullo_epi16(s2, _mm256_set1_epi16(-32));
|
||||
s3 = _mm256_mullo_epi16(s3, _mm256_set1_epi16(-32));
|
||||
s4 = _mm256_mullo_epi16(s4, _mm256_set1_epi16(-32));
|
||||
#endif
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto bsums = q8.load_bsums(iy, ibl);
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
|
||||
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
|
||||
isum[iy] = sumi;
|
||||
#else
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
|
||||
if constexpr (nrc_y == 1) {
|
||||
acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
} else {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(min, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
const uint32_t * scales = (const uint32_t *)iq6[ibl].scales;
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 2*ib)));
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
|
||||
#endif
|
||||
auto lbits1 = _mm256_loadu_si256((const __m256i *)iq6[ibl].ql+2*ib+0);
|
||||
auto lbits2 = _mm256_loadu_si256((const __m256i *)iq6[ibl].ql+2*ib+1);
|
||||
auto hbits = _mm256_loadu_si256((const __m256i *)iq6[ibl].qh+ib);
|
||||
qx[0] = _mm256_or_si256(_mm256_and_si256(lbits1, m4), _mm256_and_si256(m3, _mm256_slli_epi16(hbits, 4)));
|
||||
qx[1] = _mm256_or_si256(_mm256_and_si256(lbits2, m4), _mm256_and_si256(m3, hbits));
|
||||
qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4), _mm256_and_si256(m3, _mm256_slli_epi16(hbits, 2)));
|
||||
qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4), _mm256_and_si256(m3, _mm256_srli_epi16(hbits, 2)));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
|
||||
#else
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
// Quants are in 0...63, so we can add at most 4 as int16_t to be sure of no int16_t overflow
|
||||
auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
|
||||
if constexpr (nrc_y == 1) {
|
||||
acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
} else {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
|
||||
acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
|
||||
acc[iy] = _mm256_setzero_ps();
|
||||
info.store(ix+0, iy, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4XS>) {
|
||||
funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;
|
||||
funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;
|
||||
funcs[2] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 3>;
|
||||
funcs[3] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 4>;
|
||||
funcs[4] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 5>;
|
||||
funcs[5] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 6>;
|
||||
funcs[6] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 7>;
|
||||
funcs[7] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 8>;
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_iqX_k_q8_K_AVX512, Dequantizer, funcs)
|
||||
} else {
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_AVX512, Dequantizer, funcs)
|
||||
funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
|
||||
funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;
|
||||
funcs[2] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 3>;
|
||||
funcs[3] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 4>;
|
||||
funcs[4] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 5>;
|
||||
funcs[5] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 6>;
|
||||
funcs[6] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 7>;
|
||||
funcs[7] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 8>;
|
||||
}
|
||||
#else
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerQ2K> ||
|
||||
std::is_same_v<Dequantizer, DequantizerQ3K> ||
|
||||
std::is_same_v<Dequantizer, DequantizerQ6K>) {
|
||||
funcs[0] = mul_mat_qY_K_q8_K_T<Dequantizer, 1>;
|
||||
funcs[1] = mul_mat_qY_K_q8_K_T<Dequantizer, 2>;
|
||||
funcs[2] = mul_mat_qY_K_q8_K_T<Dequantizer, 3>;
|
||||
funcs[3] = mul_mat_qY_K_q8_K_T<Dequantizer, 4>;
|
||||
funcs[4] = mul_mat_qY_K_q8_K_T<Dequantizer, 5>;
|
||||
funcs[5] = mul_mat_qY_K_q8_K_T<Dequantizer, 6>;
|
||||
funcs[6] = mul_mat_qY_K_q8_K_T<Dequantizer, 7>;
|
||||
funcs[7] = mul_mat_qY_K_q8_K_T<Dequantizer, 8>;
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_K_T, Dequantizer, funcs)
|
||||
} else {
|
||||
funcs[0] = mul_mat_qX_K_q8_K_T<Dequantizer, 1>;
|
||||
funcs[1] = mul_mat_qX_K_q8_K_T<Dequantizer, 2>;
|
||||
funcs[2] = mul_mat_qX_K_q8_K_T<Dequantizer, 3>;
|
||||
funcs[3] = mul_mat_qX_K_q8_K_T<Dequantizer, 4>;
|
||||
funcs[4] = mul_mat_qX_K_q8_K_T<Dequantizer, 5>;
|
||||
funcs[5] = mul_mat_qX_K_q8_K_T<Dequantizer, 6>;
|
||||
funcs[6] = mul_mat_qX_K_q8_K_T<Dequantizer, 7>;
|
||||
funcs[7] = mul_mat_qX_K_q8_K_T<Dequantizer, 8>;
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, Dequantizer, funcs)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
int nbl = n / QK_K;
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i isum[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + (ix+0)*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
|
||||
auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ibl].d));
|
||||
for (int ib = 0; ib < QK_K/16; ++ib) {
|
||||
qx[0] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+0);
|
||||
qx[1] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+1);
|
||||
qx[2] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+2);
|
||||
qx[3] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+3);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
|
||||
auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
|
||||
auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
|
||||
auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
|
||||
#else
|
||||
qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127));
|
||||
qx[1] = _mm256_add_epi8(qx[1], _mm256_set1_epi8(127));
|
||||
qx[2] = _mm256_add_epi8(qx[2], _mm256_set1_epi8(127));
|
||||
qx[3] = _mm256_add_epi8(qx[3], _mm256_set1_epi8(127));
|
||||
#endif
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+ib);
|
||||
auto y = MM256_SET_M128I(y128, y128);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
#else
|
||||
auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])));
|
||||
auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])));
|
||||
auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])));
|
||||
auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi1, sumi2));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi3, sumi4));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto m4 = _mm256_mul_ps(d4, _mm256_set1_ps(-128.f));
|
||||
#endif
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
|
||||
acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto bsums = (const float *)q8.y[iy][ibl].bsums;
|
||||
acc[iy] = _mm256_fmadd_ps(m4, _mm256_set1_ps(bsums[0]), acc[iy]);
|
||||
#endif
|
||||
isum[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, acc[iy]);
|
||||
acc[iy] = _mm256_setzero_ps();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
GGML_ASSERT(n%32 == 0);
|
||||
__m256i qx[4];
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
__m256i sx[4];
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
__m256i acc[nrc_y] = {};
|
||||
float dy[nrc_y];
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
int32_t sy[nrc_y];
|
||||
#endif
|
||||
const int8_t * q8y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dptr = (const float *)info.src1_row(iy);
|
||||
dy[iy] = dptr[0];
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto iptr = (const int32_t *)(dptr + 1);
|
||||
sy[iy] = -127*iptr[0];
|
||||
#endif
|
||||
q8y[iy] = (const int8_t *)(dptr + 2);
|
||||
}
|
||||
const int8_t * q8x[4];
|
||||
float dx[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
for (int kx = 0; kx < 4; ++kx) {
|
||||
auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
|
||||
dx[kx] = dptr[0];
|
||||
q8x[kx] = (const int8_t *)(dptr + 2);
|
||||
}
|
||||
for (int i = 0; i < n/32; ++i) {
|
||||
for (int kx = 0; kx < 4; ++kx) qx[kx] = _mm256_loadu_si256((const __m256i *)q8x[kx] + i);
|
||||
auto t0 = _mm256_unpacklo_epi32(qx[0], qx[1]);
|
||||
auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
|
||||
auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
|
||||
auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127));
|
||||
qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127));
|
||||
qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127));
|
||||
qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127));
|
||||
#else
|
||||
qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
|
||||
qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
|
||||
qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
|
||||
qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
|
||||
#endif
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
#else
|
||||
auto dot1 = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
|
||||
auto dot2 = _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
|
||||
auto dot3 = _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
|
||||
auto dot4 = _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
|
||||
auto dot12 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot1), _mm256_madd_epi16(m1, dot2));
|
||||
auto dot34 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot3), _mm256_madd_epi16(m1, dot4));
|
||||
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(dot12, dot34));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
auto scales_x = _mm_loadu_ps(dx);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy]));
|
||||
#endif
|
||||
auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
|
||||
info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
|
||||
acc[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels) {
|
||||
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
|
||||
|
||||
if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {
|
||||
auto etypeA = ggml_type(typeA);
|
||||
auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
|
||||
: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
|
||||
: etypeA == GGML_TYPE_Q8_KV ? GGML_TYPE_Q8_KV
|
||||
: GGML_TYPE_Q8_K;
|
||||
|
||||
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
|
||||
return false;
|
||||
}
|
||||
|
||||
func16 = nullptr;
|
||||
|
||||
switch (typeA) {
|
||||
case GGML_TYPE_Q2_K:
|
||||
set_functions<DequantizerQ2K>(kernels);
|
||||
@@ -790,6 +1660,36 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
set_functions<DequantizerIQ4XS>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K_R4:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q2_k_r4_q8_k, kernels)
|
||||
break;
|
||||
case GGML_TYPE_Q3_K_R4:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q3_k_r4_q8_k, kernels)
|
||||
break;
|
||||
case GGML_TYPE_Q4_K_R4:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q4_k_r4_q8_k, kernels)
|
||||
break;
|
||||
case GGML_TYPE_Q5_K_R4:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q5_k_r4_q8_k, kernels)
|
||||
break;
|
||||
case GGML_TYPE_Q6_K_R4:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q6_k_r4_q8_k, kernels)
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS_R8:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_xs_r8_q8_k_avx2, kernels)
|
||||
break;
|
||||
case GGML_TYPE_Q8_K_R8:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_k_r8_q8_k, kernels)
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
func16 = mul_mat_q8_k_r8_q8_k<16>;
|
||||
#endif
|
||||
break;
|
||||
case GGML_TYPE_Q8_KV:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_q8_KV, kernels)
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
func16 = mul_mat_q8_KV_q8_KV<16>;
|
||||
#endif
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
#include <array>
|
||||
|
||||
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels);
|
||||
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
|
||||
|
||||
#endif
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user