Refactor iqk: Factor out GEMM for q8_K_R8, q8_KV

This commit is contained in:
Iwan Kawrakow
2025-05-18 14:02:07 +03:00
parent 6cd3609a85
commit f501200d42
3 changed files with 943 additions and 1026 deletions

View File

@@ -717,60 +717,930 @@ static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
#endif
template <int nrc_y>
static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
auto m30 = _mm256_set1_epi8(0x30);
auto m32 = _mm256_set1_epi8(32);
#ifndef HAVE_FANCY_SIMD
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
auto values = MM256_SET_M128I(values128, values128);
#else
auto values = load_iq4nl_values_256();
#endif
int nbl = n / QK_K;
using helper_t = union { __m256i vec[2]; uint64_t val[8]; };
helper_t h;
__m256 acc[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_iq4_xs_r8 * iq4 = (const block_iq4_xs_r8 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d));
auto slbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
auto sl1 = _mm256_and_si256(slbits, m4);
auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4);
auto shbits = _mm_loadu_si128((const __m128i*)iq4[ibl].scales_h);
auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
h.vec[0] = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(_mm256_slli_epi16(sh, 4), m30)), m32);
h.vec[1] = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(sh, m30)), m32);
__m256i isum[nrc_y] = {};
for (int ib = 0; ib < QK_K/32; ++ib) {
#ifdef HAVE_FANCY_SIMD
auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi64x(h.val[ib]));
auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-128.f));
for (int iy = 0; iy < nrc_y; ++iy) {
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
}
#else
auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(h.val[ib])), s_shuffle);
#endif
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+1);
qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1));
qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4)));
qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2));
qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4)));
#ifndef HAVE_FANCY_SIMD
auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+0);
auto y = MM256_SET_M128I(y128, y128);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
#else
auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)),
_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
isum[iy] = _mm256_add_epi32(isum[iy], sumi);
#endif
}
bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+2);
bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+3);
qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1));
qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4)));
qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2));
qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4)));
#ifndef HAVE_FANCY_SIMD
s1 = _mm256_sign_epi8(qx[0], qx[0]);
s2 = _mm256_sign_epi8(qx[1], qx[1]);
s3 = _mm256_sign_epi8(qx[2], qx[2]);
s4 = _mm256_sign_epi8(qx[3], qx[3]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+1);
auto y = MM256_SET_M128I(y128, y128);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
#else
auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)),
_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
isum[iy] = _mm256_add_epi32(isum[iy], sumi);
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = _mm256_setzero_ps();
}
}
}
#ifdef HAVE_FANCY_SIMD
template <int nrc_y>
static void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
mul_mat_iq4_xs_r8_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
return;
if constexpr (nrc_y == 1){
mul_mat_iq4_xs_r8_q8_k_avx2<1>(n, vx, bx, info, nrc_x);
} else {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm512_set1_epi8(0xf);
auto values = load_iq4nl_values_512();
int nbl = n / QK_K;
using helper_t = union { __m512i vec; uint32_t val[16]; };
helper_t h;
__m512 acc[nrc_y] = {};
__m512i isum[nrc_y] = {};
__m512i qx[4];
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_iq4_xs_r8 * iq4l = (const block_iq4_xs_r8 *)((const char *)vx + (ix+0)*bx);
const block_iq4_xs_r8 * iq4h = (const block_iq4_xs_r8 *)((const char *)vx + (ix+4)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[ibl].d));
auto dh = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[ibl].d));
auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1);
auto d4x64 = _mm512_mul_ps(d4, _mm512_set1_ps(-64.f));
auto slbits_l = _mm_loadu_si128((const __m128i *)iq4l[ibl].scales_l);
auto shbits_l = _mm_loadu_si128((const __m128i *)iq4h[ibl].scales_l);
auto sl_l = MM256_SET_M128I(_mm_srli_epi16(slbits_l, 4), slbits_l);
auto sh_l = MM256_SET_M128I(_mm_srli_epi16(shbits_l, 4), shbits_l);
auto slb = _mm512_and_si512(_mm512_inserti32x8(_mm512_castsi256_si512(sl_l), sh_l, 1), m4);
auto aux64 = (const uint64_t *)iq4l[ibl].scales_h;
auto slbits_h = _mm_set_epi64x(aux64[0] >> 2, aux64[0]);
aux64 = (const uint64_t *)iq4h[ibl].scales_h;
auto shbits_h = _mm_set_epi64x(aux64[0] >> 2, aux64[0]);
auto sl_h = MM256_SET_M128I(slbits_h, _mm_slli_epi16(slbits_h, 4));
auto sh_h = MM256_SET_M128I(shbits_h, _mm_slli_epi16(shbits_h, 4));
auto shb = _mm512_and_si512(_mm512_inserti32x8(_mm512_castsi256_si512(sl_h), sh_h, 1), _mm512_set1_epi8(0x30));
h.vec = _mm512_sub_epi8(_mm512_or_si512(slb, shb), _mm512_set1_epi8(32));
for (int ib = 0; ib < QK_K/32; ++ib) {
auto iscales = _mm512_cvtepi8_epi32(_mm_blend_epi32(_mm_set1_epi32(h.val[ib+0]), _mm_set1_epi32(h.val[ib+8]), 0x0c));
auto scales = _mm512_cvtepi32_ps(iscales);
auto scales_m = _mm512_mul_ps(scales, d4x64);
auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+0)),
_mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+0), 1);
auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+1)),
_mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+1), 1);
qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4));
qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4));
qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4));
qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi));
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]);
isum[iy] = _mm512_setzero_si512();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1));
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3));
info.store(ix+0, iy, sum1);
info.store(ix+4, iy, sum2);
acc[iy] = _mm512_setzero_ps();
}
}
}
}
#else
template <int nrc_y>
static void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
mul_mat_iq4_xs_r8_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
}
#endif
template <int nrc_y>
static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mxf = _mm256_set1_epi8(0xf);
auto m03 = _mm256_set1_epi8(0x03);
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
#ifdef HAVE_FANCY_SIMD
__m256i isum[nrc_y] = {};
#else
auto m1 = _mm256_set1_epi16(1);
#endif
int nbl = n / QK_K;
__m256 acc[nrc_y] = {};
__m256i qx[4];
int8_t scales[64];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q2_k_r4 * iq2 = (const block_q2_k_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq2[ibl].d));
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dm), _mm256_castps256_ps128(dm));
auto m4 = _mm256_set_m128(_mm256_extractf128_ps(dm, 1), _mm256_extractf128_ps(dm, 1));
m4 = _mm256_mul_ps(m4, _mm256_set1_ps(-1.f));
auto all_scales1 = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales+0);
auto all_scales2 = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales+1);
auto scales1 = _mm256_and_si256(_mm256_srli_epi16(all_scales1, 4), mxf);
auto scales2 = _mm256_and_si256(_mm256_srli_epi16(all_scales2, 4), mxf);
{
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = q8.load_bsums(iy, ibl);
auto sumi = _mm256_setzero_si256();
#ifdef HAVE_FANCY_SIMD
sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
auto d8 = _mm256_set1_ps(q8.scale(iy, ibl));
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]);
#else
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
auto d8 = _mm256_set1_ps(q8.scale(iy, ibl));
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]);
if constexpr (nrc_y == 1) {
d4 = _mm256_mul_ps(d4, d8);
}
#endif
}
}
all_scales1 = _mm256_and_si256(all_scales1, mxf);
all_scales2 = _mm256_and_si256(all_scales2, mxf);
_mm256_storeu_si256((__m256i *)scales+0, all_scales1);
_mm256_storeu_si256((__m256i *)scales+1, all_scales2);
for (int ib = 0; ib < QK_K/32; ++ib) {
auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 8*ib)));
#ifndef HAVE_FANCY_SIMD
auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
#endif
auto lb = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs+ib);
qx[0] = _mm256_and_si256(lb, m03);
qx[1] = _mm256_and_si256(_mm256_srli_epi16(lb, 2), m03);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(lb, 4), m03);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(lb, 6), m03);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
// Quants are in 0...3, so we can add add up all of them as int16_t without overflowing
auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
}
#endif
}
}
#ifdef HAVE_FANCY_SIMD
for (int iy = 0; iy < nrc_y; ++iy) {
auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
isum[iy] = _mm256_setzero_si256();
}
#endif
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, sum);
}
}
}
template <int nrc_y>
static void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
auto m30 = _mm256_set1_epi8(0x30);
auto m32 = _mm256_set1_epi8(32);
auto m03 = _mm256_set1_epi8(0x03);
auto m04 = _mm256_set1_epi8(0x04);
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
#ifdef HAVE_FANCY_SIMD
__m256i isum[nrc_y];
#else
auto m1 = _mm256_set1_epi16(1);
#endif
int nbl = n / QK_K;
__m256 acc[nrc_y] = {};
__m256i qx[4];
int8_t scales[64];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q3_k_r4 * iq3 = (const block_q3_k_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d));
auto d4 = _mm256_set_m128(dl, dl);
#ifndef HAVE_FANCY_SIMD
if constexpr (nrc_y == 1) {
d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
}
#endif
auto slb = _mm256_loadu_si256((const __m256i *)iq3[ibl].scales_l);
auto shbits = _mm_loadu_si128((const __m128i *)iq3[ibl].scales_h);
auto shb = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
auto scales1 = _mm256_sub_epi8(_mm256_or_si256(_mm256_and_si256(slb, m4), _mm256_and_si256(_mm256_slli_epi16(shb, 4), m30)), m32);
auto scales2 = _mm256_sub_epi8(_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(slb, 4), m4), _mm256_and_si256(shb, m30)), m32);
_mm256_storeu_si256((__m256i *)scales+0, scales1);
_mm256_storeu_si256((__m256i *)scales+1, scales2);
{
#ifndef HAVE_FANCY_SIMD
auto min = _mm256_mul_ps(d4, _mm256_set1_ps(-4.f));
#endif
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
#ifdef HAVE_FANCY_SIMD
s1 = _mm256_mullo_epi16(s1, _mm256_set1_epi16(-4));
s2 = _mm256_mullo_epi16(s2, _mm256_set1_epi16(-4));
s3 = _mm256_mullo_epi16(s3, _mm256_set1_epi16(-4));
s4 = _mm256_mullo_epi16(s4, _mm256_set1_epi16(-4));
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = q8.load_bsums(iy, ibl);
auto sumi = _mm256_setzero_si256();
#ifdef HAVE_FANCY_SIMD
sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
isum[iy] = sumi;
#else
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(min, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
}
#endif
}
}
for (int ib = 0; ib < QK_K/32; ++ib) {
auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 8*ib)));
#ifndef HAVE_FANCY_SIMD
auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
#endif
auto lb = _mm256_loadu_si256((const __m256i *)iq3[ibl].qs+ib);
auto hbits = _mm_loadu_si128((const __m128i *)iq3[ibl].qh+ib);
auto hb = MM256_SET_M128I(hbits, _mm_slli_epi16(hbits, 4));
qx[0] = _mm256_or_si256(_mm256_and_si256(lb, m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 2)));
qx[1] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 2), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 3)));
qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 4), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 4)));
qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 6), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 5)));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
// Quants are in 0...8, so we can add add up all of them as int16_t without overflowing
auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
}
#endif
}
}
#ifdef HAVE_FANCY_SIMD
for (int iy = 0; iy < nrc_y; ++iy) {
auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
}
#endif
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, sum);
}
}
}
template <int nrc_y>
inline void process_min_r4_b32(int ibl, __m256 m4, __m256i mins, const Q8<nrc_y, block_q8_K>& q8, __m256 * acc) {
auto mins_l = _mm256_castsi256_si128(mins);
auto mins_h = _mm256_extracti128_si256(mins, 1);
auto aux1 = _mm_unpacklo_epi32(mins_l, mins_h);
auto aux2 = _mm_unpackhi_epi32(mins_l, mins_h);
auto ic1 = _mm256_cvtepi8_epi32(aux1);
auto ic2 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux1, 0xee));
auto ic3 = _mm256_cvtepi8_epi32(aux2);
auto ic4 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux2, 0xee));
if constexpr (nrc_y == 1) {
auto bs = _mm256_loadu_ps((const float *)q8.y[0][ibl].bsums);
auto sumf = _mm256_mul_ps(_mm256_cvtepi32_ps(ic1), _mm256_shuffle_ps(bs, bs, 0x00));
sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic2), _mm256_shuffle_ps(bs, bs, 0x55), sumf);
sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic3), _mm256_shuffle_ps(bs, bs, 0xaa), sumf);
sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic4), _mm256_shuffle_ps(bs, bs, 0xff), sumf);
acc[0] = _mm256_fmadd_ps(m4, sumf, acc[0]);
} else {
auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic1));
auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic2));
auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic3));
auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic4));
for (int iy = 0; iy < nrc_y; ++iy) {
auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
}
}
}
template <int nrc_y>
static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = _mm256_set1_epi8(0xf);
auto m3 = _mm256_set1_epi8(0x30);
int nbl = n / QK_K;
union { __m256i vec; uint32_t val[8]; } hd;
__m256 acc[nrc_y] = {};
__m256i isum[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q4_k_r4 * iq4 = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d));
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
auto lbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
auto hbits128 = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_h);
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m3));
auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m3));
process_min_r4_b32(ibl, m4, mins, q8, acc);
for (int ib = 0; ib < QK_K/32; ++ib) {
#ifdef HAVE_FANCY_SIMD
auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
#else
auto aux = _mm_set1_epi32(hd.val[ib]);
aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
auto scales_d = MM256_SET_M128I(aux, aux);
#endif
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
qx[0] = _mm256_and_si256(bits1, mf);
qx[1] = _mm256_and_si256(bits2, mf);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits1, 4), mf);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits2, 4), mf);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales_d, _mm256_add_epi16(sumi1, sumi2)));
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
isum[iy] = _mm256_setzero_si256();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, sum);
}
}
}
template <int nrc_y>
static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = _mm256_set1_epi8(0xf);
auto m10 = _mm256_set1_epi8(0x10);
auto m30 = _mm256_set1_epi8(0x30);
int nbl = n / QK_K;
union { __m256i vec; uint32_t val[8]; } hd;
__m256 acc[nrc_y] = {};
__m256i isum[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5[ibl].d));
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
auto lbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l);
auto hbits128 = _mm_loadu_si128((const __m128i *)iq5[ibl].scales_h);
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m30));
auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m30));
process_min_r4_b32(ibl, m4, mins, q8, acc);
for (int ib = 0; ib < QK_K/32; ++ib) {
#ifdef HAVE_FANCY_SIMD
auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
#else
auto aux = _mm_set1_epi32(hd.val[ib]);
aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
auto scales_d = MM256_SET_M128I(aux, aux);
#endif
auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
auto hbits128 = _mm_loadu_si128((const __m128i*)iq5[ibl].qh + ib);
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
qx[0] = _mm256_or_si256(_mm256_and_si256(lbits1, mf), _mm256_and_si256(m10, hbits));
qx[1] = _mm256_or_si256(_mm256_and_si256(lbits2, mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 2)));
qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 1)));
qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 3)));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
// To avoid overflow, we can only add up to 4 q5 x q8 products.
auto sumi = _mm256_add_epi32(_mm256_madd_epi16(scales_d, sumi1), _mm256_madd_epi16(scales_d, sumi2));
isum[iy] = _mm256_add_epi32(isum[iy], sumi);
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
isum[iy] = _mm256_setzero_si256();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, sum);
}
}
}
template <int nrc_y>
static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
auto m3 = _mm256_set1_epi8(0x30);
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
#ifdef HAVE_FANCY_SIMD
__m256i isum[nrc_y];
#else
auto m1 = _mm256_set1_epi16(1);
#endif
int nbl = n / QK_K;
__m256 acc[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q6_k_r4 * iq6 = (const block_q6_k_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6[ibl].d));
auto d4 = _mm256_set_m128(dl, dl);
#ifndef HAVE_FANCY_SIMD
if constexpr (nrc_y == 1) {
d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
}
#endif
{
#ifndef HAVE_FANCY_SIMD
auto min = _mm256_mul_ps(d4, _mm256_set1_ps(-32.f));
#endif
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+0)), shuff); // blocks 0, 1, 2, 3 for each row
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+1)), shuff); // blocks 4, 5, 6, 7 for each row
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+2)), shuff); // blocks 8, 9, 10, 11 for each row
auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+3)), shuff); // blocks 12, 13, 14, 15 for each row
auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
#ifdef HAVE_FANCY_SIMD
s1 = _mm256_mullo_epi16(s1, _mm256_set1_epi16(-32));
s2 = _mm256_mullo_epi16(s2, _mm256_set1_epi16(-32));
s3 = _mm256_mullo_epi16(s3, _mm256_set1_epi16(-32));
s4 = _mm256_mullo_epi16(s4, _mm256_set1_epi16(-32));
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = q8.load_bsums(iy, ibl);
auto sumi = _mm256_setzero_si256();
#ifdef HAVE_FANCY_SIMD
sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
isum[iy] = sumi;
#else
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(min, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
}
#endif
}
}
const uint32_t * scales = (const uint32_t *)iq6[ibl].scales;
for (int ib = 0; ib < QK_K/32; ++ib) {
auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 2*ib)));
#ifndef HAVE_FANCY_SIMD
auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
#endif
auto lbits1 = _mm256_loadu_si256((const __m256i *)iq6[ibl].ql+2*ib+0);
auto lbits2 = _mm256_loadu_si256((const __m256i *)iq6[ibl].ql+2*ib+1);
auto hbits = _mm256_loadu_si256((const __m256i *)iq6[ibl].qh+ib);
qx[0] = _mm256_or_si256(_mm256_and_si256(lbits1, m4), _mm256_and_si256(m3, _mm256_slli_epi16(hbits, 4)));
qx[1] = _mm256_or_si256(_mm256_and_si256(lbits2, m4), _mm256_and_si256(m3, hbits));
qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4), _mm256_and_si256(m3, _mm256_slli_epi16(hbits, 2)));
qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4), _mm256_and_si256(m3, _mm256_srli_epi16(hbits, 2)));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
// Quants are in 0...63, so we can add at most 4 as int16_t to be sure of no int16_t overflow
auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
}
#endif
}
}
#ifdef HAVE_FANCY_SIMD
for (int iy = 0; iy < nrc_y; ++iy) {
auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
}
#endif
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, sum);
}
}
}
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
#ifdef HAVE_FANCY_SIMD
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4XS>) {
funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;
funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;
funcs[2] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 3>;
funcs[3] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 4>;
funcs[4] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 5>;
funcs[5] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 6>;
funcs[6] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 7>;
funcs[7] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 8>;
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_iqX_k_q8_K_AVX512, Dequantizer, funcs)
} else {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_AVX512, Dequantizer, funcs)
funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;
funcs[2] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 3>;
funcs[3] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 4>;
funcs[4] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 5>;
funcs[5] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 6>;
funcs[6] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 7>;
funcs[7] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 8>;
}
#else
if constexpr (std::is_same_v<Dequantizer, DequantizerQ2K> ||
std::is_same_v<Dequantizer, DequantizerQ3K> ||
std::is_same_v<Dequantizer, DequantizerQ6K>) {
funcs[0] = mul_mat_qY_K_q8_K_T<Dequantizer, 1>;
funcs[1] = mul_mat_qY_K_q8_K_T<Dequantizer, 2>;
funcs[2] = mul_mat_qY_K_q8_K_T<Dequantizer, 3>;
funcs[3] = mul_mat_qY_K_q8_K_T<Dequantizer, 4>;
funcs[4] = mul_mat_qY_K_q8_K_T<Dequantizer, 5>;
funcs[5] = mul_mat_qY_K_q8_K_T<Dequantizer, 6>;
funcs[6] = mul_mat_qY_K_q8_K_T<Dequantizer, 7>;
funcs[7] = mul_mat_qY_K_q8_K_T<Dequantizer, 8>;
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_K_T, Dequantizer, funcs)
} else {
funcs[0] = mul_mat_qX_K_q8_K_T<Dequantizer, 1>;
funcs[1] = mul_mat_qX_K_q8_K_T<Dequantizer, 2>;
funcs[2] = mul_mat_qX_K_q8_K_T<Dequantizer, 3>;
funcs[3] = mul_mat_qX_K_q8_K_T<Dequantizer, 4>;
funcs[4] = mul_mat_qX_K_q8_K_T<Dequantizer, 5>;
funcs[5] = mul_mat_qX_K_q8_K_T<Dequantizer, 6>;
funcs[6] = mul_mat_qX_K_q8_K_T<Dequantizer, 7>;
funcs[7] = mul_mat_qX_K_q8_K_T<Dequantizer, 8>;
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, Dequantizer, funcs)
}
#endif
}
// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
template <int nrc_y>
static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_K> q8(info);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
#endif
int nbl = n / QK_K;
__m256 acc[nrc_y] = {};
__m256i isum[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ibl].d));
for (int ib = 0; ib < QK_K/16; ++ib) {
qx[0] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+0);
qx[1] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+1);
qx[2] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+2);
qx[3] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+3);
#ifndef HAVE_FANCY_SIMD
auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
#else
qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127));
qx[1] = _mm256_add_epi8(qx[1], _mm256_set1_epi8(127));
qx[2] = _mm256_add_epi8(qx[2], _mm256_set1_epi8(127));
qx[3] = _mm256_add_epi8(qx[3], _mm256_set1_epi8(127));
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+ib);
auto y = MM256_SET_M128I(y128, y128);
#ifdef HAVE_FANCY_SIMD
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
#else
auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])));
auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])));
auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])));
auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi1, sumi2));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi3, sumi4));
#endif
}
}
#ifdef HAVE_FANCY_SIMD
auto m4 = _mm256_mul_ps(d4, _mm256_set1_ps(-128.f));
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
#ifdef HAVE_FANCY_SIMD
auto bsums = (const float *)q8.y[iy][ibl].bsums;
acc[iy] = _mm256_fmadd_ps(m4, _mm256_set1_ps(bsums[0]), acc[iy]);
#endif
isum[iy] = _mm256_setzero_si256();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = _mm256_setzero_ps();
}
}
}
template <int nrc_y>
static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
GGML_ASSERT(n%32 == 0);
__m256i qx[4];
#ifndef HAVE_FANCY_SIMD
__m256i sx[4];
auto m1 = _mm256_set1_epi16(1);
#endif
__m256i acc[nrc_y] = {};
float dy[nrc_y];
#ifdef HAVE_FANCY_SIMD
int32_t sy[nrc_y];
#endif
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
#ifdef HAVE_FANCY_SIMD
auto iptr = (const int32_t *)(dptr + 1);
sy[iy] = -127*iptr[0];
#endif
q8y[iy] = (const int8_t *)(dptr + 2);
}
const int8_t * q8x[4];
float dx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
for (int kx = 0; kx < 4; ++kx) {
auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
dx[kx] = dptr[0];
q8x[kx] = (const int8_t *)(dptr + 2);
}
for (int i = 0; i < n/32; ++i) {
for (int kx = 0; kx < 4; ++kx) qx[kx] = _mm256_loadu_si256((const __m256i *)q8x[kx] + i);
auto t0 = _mm256_unpacklo_epi32(qx[0], qx[1]);
auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
#ifdef HAVE_FANCY_SIMD
qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127));
qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127));
qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127));
qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127));
#else
qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
#ifdef HAVE_FANCY_SIMD
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
#else
auto dot1 = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
auto dot2 = _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto dot3 = _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto dot4 = _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
auto dot12 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot1), _mm256_madd_epi16(m1, dot2));
auto dot34 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot3), _mm256_madd_epi16(m1, dot4));
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(dot12, dot34));
#endif
}
}
auto scales_x = _mm_loadu_ps(dx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
#ifdef HAVE_FANCY_SIMD
sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy]));
#endif
auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
acc[iy] = _mm256_setzero_si256();
}
}
}
} // namespace
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels) {
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {
auto etypeA = ggml_type(typeA);
auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
: etypeA == GGML_TYPE_Q8_KV ? GGML_TYPE_Q8_KV
: GGML_TYPE_Q8_K;
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
return false;
}
func16 = nullptr;
switch (typeA) {
case GGML_TYPE_Q2_K:
set_functions<DequantizerQ2K>(kernels);
@@ -790,6 +1660,36 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
case GGML_TYPE_IQ4_XS:
set_functions<DequantizerIQ4XS>(kernels);
break;
case GGML_TYPE_Q2_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q2_k_r4_q8_k, kernels)
break;
case GGML_TYPE_Q3_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q3_k_r4_q8_k, kernels)
break;
case GGML_TYPE_Q4_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q4_k_r4_q8_k, kernels)
break;
case GGML_TYPE_Q5_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q5_k_r4_q8_k, kernels)
break;
case GGML_TYPE_Q6_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q6_k_r4_q8_k, kernels)
break;
case GGML_TYPE_IQ4_XS_R8:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_xs_r8_q8_k_avx2, kernels)
break;
case GGML_TYPE_Q8_K_R8:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_k_r8_q8_k, kernels)
#ifdef HAVE_FANCY_SIMD
func16 = mul_mat_q8_k_r8_q8_k<16>;
#endif
break;
case GGML_TYPE_Q8_KV:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_q8_KV, kernels)
#ifdef HAVE_FANCY_SIMD
func16 = mul_mat_q8_KV_q8_KV<16>;
#endif
break;
default:
return false;
}

View File

@@ -6,6 +6,6 @@
#include <array>
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels);
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
#endif

File diff suppressed because it is too large Load Diff