Deepseek-Lite (#184)

* Quantization mixes tweaks

* Make iq4_nl_r4 work with row size that are not a multiple of 128

... on Zen4

* Make iq4_nl_r4 work with row size that are not a multiple of 128

... on AVX2

* Make iq4_nl_r4 work with row size that are not a multiple of 128

... on AVX2

* Make q6_0_w4 work with row size that are not a multiple of 128

... on Zen4

* Make q6_0_w4 work with row size that are not a multiple of 128

... on Zen4

* Make q5_0_r4 work with row size that are not a multiple of 128

... on Zen4 and AVX2

* Make q5,6_0_r4, iq4_nl_e4 work with row size that are not a multiple of 128

also on NEON.

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-01-30 18:36:24 +02:00
committed by GitHub
parent f7a4a0fd42
commit ba470ec1b4
2 changed files with 315 additions and 170 deletions

View File

@@ -2474,44 +2474,63 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data
auto m4 = _mm512_set1_epi8(0xf);
auto values = load_iq4nl_values_512();
int nb = n / QK4_NL;
GGML_ASSERT(nb%4 == 0);
__m512 acc[2*nrc_y] = {};
__m512i qx[4];
float d8[8*nrc_y];
auto prepare = [&qx, &m4, &values] (const block_iq4_nl_r4& iq4l, const block_iq4_nl_r4& iq4h) {
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l.d));
auto scales1 = _mm256_set_m128(scales128, scales128);
scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h.d));
auto scales2 = _mm256_set_m128(scales128, scales128);
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+0)),
_mm256_loadu_si256((const __m256i *)iq4h.qs+0), 1);
auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+1)),
_mm256_loadu_si256((const __m256i *)iq4h.qs+1), 1);
qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4));
qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4));
qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4));
qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4));
return scales;
};
auto dot = [&qx] (__m256i y8) {
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
return sumi;
};
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_iq4_nl_r4 * iq4l = (const block_iq4_nl_r4 *)((const char *)vx + (ix+0)*bx);
const block_iq4_nl_r4 * iq4h = (const block_iq4_nl_r4 *)((const char *)vx + (ix+4)*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
_mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[4*ib4+k].d));
auto scales1 = _mm256_set_m128(scales128, scales128);
scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[4*ib4+k].d));
auto scales2 = _mm256_set_m128(scales128, scales128);
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-64.f));
auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)),
_mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1);
auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)),
_mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1);
qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4));
qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4));
qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4));
qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4));
auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]));
auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
auto dy = _mm512_set1_ps(d8[8*iy+k]);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = prepare(iq4l[ib], iq4h[ib]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d));
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]);
auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-64.f), acc[2*iy+1], acc[2*iy+0]);
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
@@ -2530,37 +2549,57 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data
auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
auto values = MM256_SET_M128I(values128, values128);
int nb = n / QK4_NL;
GGML_ASSERT(nb%4 == 0);
__m256 acc[nrc_y] = {};
//__m256 acc[2*nrc_y] = {};
__m256i qs[4];
float d8[4*nrc_y];
auto prepare = [&qs, &values, &m4] (const block_iq4_nl_r4& iq4) {
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4.d));
auto scales = _mm256_set_m128(scales128, scales128);
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4.qs+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4.qs+1);
qs[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
qs[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
qs[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
qs[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
return scales;
};
auto dot = [&qs, &m1] (__m256i y) {
auto u1 = _mm256_sign_epi8(qs[0], qs[0]);
auto u2 = _mm256_sign_epi8(qs[1], qs[1]);
auto sumi1 = _mm256_add_epi32(
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qs[0]))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qs[1]))));
u1 = _mm256_sign_epi8(qs[2], qs[2]);
u2 = _mm256_sign_epi8(qs[3], qs[3]);
auto sumi2 = _mm256_add_epi32(
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qs[2]))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qs[3]))));
return _mm256_add_epi32(sumi1, sumi2);
};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
_mm_storeu_ps(d8+4*iy, _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d));
auto scales = _mm256_set_m128(scales128, scales128);
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1);
auto q1 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
auto q2 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
auto q3 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
auto q4 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
auto s1 = _mm256_sign_epi8(q1, q1);
auto s2 = _mm256_sign_epi8(q2, q2);
auto s3 = _mm256_sign_epi8(q3, q3);
auto s4 = _mm256_sign_epi8(q4, q4);
auto scales = prepare(iq4[4*ib4+k]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q1))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q2))));
auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q3))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q4))));
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[iy]);
auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = prepare(iq4[ib]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
info.store(ix, iy, sum);
@@ -2797,43 +2836,73 @@ static void mul_mat_q5_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D
Q8<nrc_y, block_q8_1_x4> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
auto m5 = _mm256_set1_epi8(0x10);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
#endif
auto mscale = _mm256_set_m128(_mm_set1_ps(-8.f), _mm_set1_ps(1.f));
int nb = n / QK5_0;
GGML_ASSERT(nb%4 == 0);
__m256 acc[nrc_y] = {};
__m256i qx[4];
float d8[8*nrc_y];
auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5) {
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5.d));
auto scales = _mm256_set_m128(scales128, scales128);
auto bits1 = _mm256_loadu_si256((const __m256i *)iq5.qs+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq5.qs+1);
auto hbits = _mm_loadu_si128((const __m128i *)iq5.qh);
auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 1), hbits);
qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 4), m5));
qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 2), m5));
qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hb, m5));
qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hb, 2), m5));;
return scales;
};
#ifdef HAVE_FANCY_SIMD
auto dot = [&qx] (__m256i y) {
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
return sumi;
};
#else
auto dot = [&qx, &m1] (__m256i y) {
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
return sumi;
};
#endif
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d));
_mm256_storeu_ps(d8 + 8*iy, scales);
_mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(mscale, scales));
}
for (int k = 0; k < 4; ++k) {
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5[4*ib4+k].d));
auto scales = _mm256_set_m128(scales128, scales128);
auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-8.f));
auto bits1 = _mm256_loadu_si256((const __m256i *)iq5[4*ib4+k].qs+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq5[4*ib4+k].qs+1);
auto hbits = _mm_loadu_si128((const __m128i *)iq5[4*ib4+k].qh);
auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 1), hbits);
auto q1 = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 4), m5));
auto q2 = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 2), m5));
auto q3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hb, m5));
auto q4 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hb, 2), m5));;
auto scales = prepare(iq5[4*ib4+k]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff)));
auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k]));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]);
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = prepare(iq5[ib]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-8.f*GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
info.store(ix, iy, sum);
@@ -2853,50 +2922,68 @@ static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
auto m4 = _mm512_set1_epi8(0xf);
auto m5 = _mm512_set1_epi8(0x10);
int nb = n / QK5_0;
GGML_ASSERT(nb%4 == 0);
__m512 acc[2*nrc_y] = {};
__m512i qx[4];
float d8[8*nrc_y];
auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5l, const block_q5_0_r4& iq5h) {
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5l.d));
auto scales1 = _mm256_set_m128(scales128, scales128);
scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5h.d));
auto scales2 = _mm256_set_m128(scales128, scales128);
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+0)),
_mm256_loadu_si256((const __m256i *)iq5h.qs+0), 1);
auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+1)),
_mm256_loadu_si256((const __m256i *)iq5h.qs+1), 1);
auto hbits1 = _mm_loadu_si128((const __m128i *)iq5l.qh);
auto hbits2 = _mm_loadu_si128((const __m128i *)iq5h.qh);
auto hb1 = MM256_SET_M128I(_mm_srli_epi16(hbits1, 1), hbits1);
auto hb2 = MM256_SET_M128I(_mm_srli_epi16(hbits2, 1), hbits2);
auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hb1), hb2, 1);
qx[0] = _mm512_or_si512(_mm512_and_si512(bits1, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 4), m5));
qx[1] = _mm512_or_si512(_mm512_and_si512(bits2, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5));
qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4), _mm512_and_si512(hb, m5));
qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4), _mm512_and_si512(_mm512_srli_epi16(hb, 2), m5));
return scales;
};
auto dot = [&qx] (__m256i y8) {
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
return sumi;
};
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q5_0_r4 * iq5l = (const block_q5_0_r4 *)((const char *)vx + (ix+0)*bx);
const block_q5_0_r4 * iq5h = (const block_q5_0_r4 *)((const char *)vx + (ix+4)*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
_mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5l[4*ib4+k].d));
auto scales1 = _mm256_set_m128(scales128, scales128);
scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5h[4*ib4+k].d));
auto scales2 = _mm256_set_m128(scales128, scales128);
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-8.f));
auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[4*ib4+k].qs+0)),
_mm256_loadu_si256((const __m256i *)iq5h[4*ib4+k].qs+0), 1);
auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[4*ib4+k].qs+1)),
_mm256_loadu_si256((const __m256i *)iq5h[4*ib4+k].qs+1), 1);
auto hbits1 = _mm_loadu_si128((const __m128i *)iq5l[4*ib4+k].qh);
auto hbits2 = _mm_loadu_si128((const __m128i *)iq5h[4*ib4+k].qh);
auto hb1 = MM256_SET_M128I(_mm_srli_epi16(hbits1, 1), hbits1);
auto hb2 = MM256_SET_M128I(_mm_srli_epi16(hbits2, 1), hbits2);
auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hb1), hb2, 1);
qx[0] = _mm512_or_si512(_mm512_and_si512(bits1, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 4), m5));
qx[1] = _mm512_or_si512(_mm512_and_si512(bits2, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5));
//qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5);
qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4), _mm512_and_si512(hb, m5));
qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4), _mm512_and_si512(_mm512_srli_epi16(hb, 2), m5));
auto scales = prepare(iq5l[4*ib4+k], iq5h[4*ib4+k]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]));
auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
auto dy = _mm512_set1_ps(d8[8*iy+k]);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = prepare(iq5l[ib], iq5h[ib]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d));
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]);
auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]);
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
@@ -2919,51 +3006,72 @@ static void mul_mat_q6_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D
Q8<nrc_y, block_q8_1_x4> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
auto m6 = _mm256_set1_epi8(0x30);
auto mscale = _mm256_set_m128(_mm_set1_ps(-16.f), _mm_set1_ps(1.f));
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
#endif
int nb = n / QK6_0;
GGML_ASSERT(nb%4 == 0);
__m256 acc[nrc_y] = {};
float d8[8*nrc_y];
__m256i qx[4];
auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6) {
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6.d));
auto scales = _mm256_set_m128(scales128, scales128);
auto bits1 = _mm256_loadu_si256((const __m256i *)iq6.qs+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq6.qs+1);
auto hbits = _mm256_loadu_si256((const __m256i *)iq6.qh);
qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), m6));
qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), m6));
qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hbits, m6));
qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m6));
return scales;
};
#ifdef HAVE_FANCY_SIMD
auto dot = [&qx] (__m256i y) {
auto sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
return sumi;
};
#else
auto dot = [&qx, &m1] (__m256i y) {
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
return sumi;
};
#endif
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q6_0_r4 * iq6 = (const block_q6_0_r4 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d));
_mm256_storeu_ps(d8 + 8*iy, scales);
_mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(scales, mscale));
}
for (int k = 0; k < 4; ++k) {
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6[4*ib4+k].d));
auto scales = _mm256_set_m128(scales128, scales128);
auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-16.f));
auto bits1 = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qs+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qs+1);
auto hbits = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qh);
auto q1 = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), m6));
auto q2 = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), m6));
auto q3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hbits, m6));
auto q4 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m6));
auto scales = prepare(iq6[4*ib4+k]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), q1, _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, q2, _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, q3, _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, q4, _mm256_shuffle_epi32(y, 0xff));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff)));
auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
#endif
auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k]));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]);
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = prepare(iq6[ib]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-16.f*GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
info.store(ix, iy, sum);
@@ -2983,47 +3091,67 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
auto m4 = _mm512_set1_epi8(0xf);
auto m6 = _mm512_set1_epi8(0x30);
int nb = n / QK6_0;
GGML_ASSERT(nb%4 == 0);
__m512 acc[2*nrc_y] = {};
__m512i qx[4];
float d8[8*nrc_y];
auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6l, const block_q6_0_r4& iq6h) {
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6l.d));
auto scales1 = _mm256_set_m128(scales128, scales128);
scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6h.d));
auto scales2 = _mm256_set_m128(scales128, scales128);
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+0)),
_mm256_loadu_si256((const __m256i *)iq6h.qs+0), 1);
auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+1)),
_mm256_loadu_si256((const __m256i *)iq6h.qs+1), 1);
auto hbits1 = _mm256_loadu_si256((const __m256i *)iq6l.qh);
auto hbits2 = _mm256_loadu_si256((const __m256i *)iq6h.qh);
auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hbits1), hbits2, 1);
qx[0] = _mm512_and_si512(bits1, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 4), m6);
qx[1] = _mm512_and_si512(bits2, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m6);;
qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(hb, m6);
qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4) | _mm512_and_si512(_mm512_srli_epi16(hb, 2), m6);
return scales;
};
auto dot = [&qx] (__m256i y8) {
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
return sumi;
};
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q6_0_r4 * iq6l = (const block_q6_0_r4 *)((const char *)vx + (ix+0)*bx);
const block_q6_0_r4 * iq6h = (const block_q6_0_r4 *)((const char *)vx + (ix+4)*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d));
_mm256_storeu_ps(d8 + 8*iy, scales);
}
for (int k = 0; k < 4; ++k) {
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6l[4*ib4+k].d));
auto scales1 = _mm256_set_m128(scales128, scales128);
scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6h[4*ib4+k].d));
auto scales2 = _mm256_set_m128(scales128, scales128);
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-16.f));
auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qs+0)),
_mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qs+0), 1);
auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qs+1)),
_mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qs+1), 1);
auto hbits1 = _mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qh);
auto hbits2 = _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qh);
auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hbits1), hbits2, 1);
qx[0] = _mm512_and_si512(bits1, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 4), m6);
qx[1] = _mm512_and_si512(bits2, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m6);;
qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(hb, m6);
qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4) | _mm512_and_si512(_mm512_srli_epi16(hb, 2), m6);
auto scales = prepare(iq6l[4*ib4+k], iq6h[4*ib4+k]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]));
auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
auto dy = _mm512_set1_ps(d8[8*iy+k]);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = prepare(iq6l[ib], iq6h[ib]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d));
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]);
auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-16.f), acc[2*iy+1], acc[2*iy+0]);
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
@@ -12087,7 +12215,6 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
Q8<nrc_y, block_q8_0_x4> q8(info);
Dequantizer deq(vx, bx);
int nb = n / QK4_NL;
GGML_ASSERT(nb%4 == 0);
int8x16_t qx[8];
float d8[4*nrc_y];
float32x4_t acc[nrc_y] = {};
@@ -12098,7 +12225,7 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
auto scales = deq.prepare(ib4, k, qx);
auto scales = deq.prepare(4*ib4+k, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
auto sumi = interleaved_dotq(qx, y);
@@ -12107,6 +12234,16 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = deq.prepare(ib, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_0 *)q8.y[iy];
auto y = vld1q_s8_x2(qy[ib].qs);
auto sumi = interleaved_dotq(qx, y);
auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d)));
acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, deq.result(acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
@@ -12164,9 +12301,9 @@ void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
struct IQ4_NL_R4_Dequantizer {
IQ4_NL_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx), values(vld1q_s8(iq4k_values)) {}
inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); }
inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
inline float32x4_t prepare(int ib, int8x16_t * qx) const {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ib].d));
auto bits = vld1q_u8_x4(iq4[ib].qs);
prepare_iq4_nl_quants(values, m4, bits, qx);
return scales;
}
@@ -12242,10 +12379,10 @@ struct Q4_0_R8_Dequantizer {
struct Q5_0_R4_Dequantizer {
Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); }
inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[4*ib4+k].d));
auto lbits = vld1q_u8_x4(iq5[4*ib4+k].qs);
auto hbits = vld1q_u8(iq5[4*ib4+k].qh);
inline float32x4_t prepare(int ib, int8x16_t * qx) const {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ib].d));
auto lbits = vld1q_u8_x4(iq5[ib].qs);
auto hbits = vld1q_u8(iq5[ib].qh);
qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3
qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19
qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7
@@ -12271,10 +12408,10 @@ struct Q5_0_R4_Dequantizer {
struct Q6_0_R4_Dequantizer {
Q6_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
inline void new_row(int ix) { iq6 = (const block_q6_0_r4 *)(cx + ix*bx); }
inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[4*ib4+k].d));
auto lbits = vld1q_u8_x4(iq6[4*ib4+k].qs);
auto hbits = vld1q_u8_x2(iq6[4*ib4+k].qh);
inline float32x4_t prepare(int ib, int8x16_t * qx) const {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ib].d));
auto lbits = vld1q_u8_x4(iq6[ib].qs);
auto hbits = vld1q_u8_x2(iq6[ib].qh);
qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3
qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19
qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7

View File

@@ -16075,7 +16075,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || is_iq2_m ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
++qs.i_attention_wv;
}
else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) {
else if (qs.model.hparams.n_expert >= 8 && name.find("attn_k") != std::string::npos) {
new_type = GGML_TYPE_Q4_K;
}
else if (qs.model.hparams.n_expert >= 8 && name.find("attn_q") != std::string::npos) {
new_type = GGML_TYPE_Q4_K;
}
else if (name.find("attn_qkv.weight") != std::string::npos) {
@@ -16088,7 +16091,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
++qs.i_ffn_down;
}
else if (name.find("attn_output.weight") != std::string::npos) {
if (qs.model.hparams.n_expert == 8) {
if (qs.model.hparams.n_expert >= 4) {
new_type = GGML_TYPE_Q5_K;
} else {
if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_K;
@@ -16188,9 +16191,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (new_type == GGML_TYPE_Q5_K) new_type = GGML_TYPE_Q6_K;
}
++qs.i_attention_wv;
} else if (name.find("attn_k.weight") != std::string::npos) {
} else if (name.find("attn_k") != std::string::npos) {
if (qs.params->attn_k_type < GGML_TYPE_COUNT) new_type = qs.params->attn_k_type;
else if (qs.model.hparams.n_expert == 8) {
else if (qs.model.hparams.n_expert >= 8) {
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
// TODO: explore better strategies
new_type = GGML_TYPE_Q8_0;
@@ -16201,8 +16204,13 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4) {
new_type = GGML_TYPE_IQ2_S;
}
} else if (name.find("attn_q.weight") != std::string::npos) {
} else if (name.find("attn_q") != std::string::npos) {
if (qs.params->attn_q_type < GGML_TYPE_COUNT) new_type = qs.params->attn_q_type;
else if (qs.model.hparams.n_expert >= 8) {
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
// TODO: explore better strategies
new_type = GGML_TYPE_Q8_0;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
new_type = GGML_TYPE_IQ3_XXS;
}