mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 23:24:13 +00:00
Make q4_0_r4 work with tensor row sizes that are not a multiple of 128
... on Zen4. Also fix q8_0 K-cache for head sizes that are not multiple of 128.
This commit is contained in:
@@ -2717,9 +2717,38 @@ static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
|
||||
Q8<nrc_y, block_q8_1_x4> q8(info);
|
||||
auto m4 = _mm512_set1_epi8(0xf);
|
||||
int nb = n / QK4_NL;
|
||||
GGML_ASSERT(nb%4 == 0);
|
||||
__m512 acc[2*nrc_y] = {};
|
||||
__m512i qx[8];
|
||||
auto prepare = [&qx, &m4] (const block_iq4_nl_r8& iq4l, const block_iq4_nl_r8& iq4h) {
|
||||
auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l.d));
|
||||
auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h.d));
|
||||
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
auto bits = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+j)),
|
||||
_mm256_loadu_si256((const __m256i *)iq4h.qs+j), 1);
|
||||
qx[j+0] = _mm512_and_si512(bits, m4);
|
||||
qx[j+4] = _mm512_and_si512(_mm512_srli_epi16(bits, 4), m4);
|
||||
}
|
||||
return scales;
|
||||
};
|
||||
auto dot = [&qx] (const int8_t * qy) {
|
||||
auto y4l = _mm_loadu_si128((const __m128i*)qy+0);
|
||||
auto y4h = _mm_loadu_si128((const __m128i*)qy+1);
|
||||
auto y8l = MM256_SET_M128I(y4l, y4l);
|
||||
auto y8h = MM256_SET_M128I(y4h, y4h);
|
||||
auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
|
||||
auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
|
||||
auto sumi = _mm512_setzero_si512();
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
|
||||
return sumi;
|
||||
};
|
||||
float d8[8*nrc_y];
|
||||
for (int ix = 0; ix < nrc_x; ix += 16) {
|
||||
const block_iq4_nl_r8 * iq4l = (const block_iq4_nl_r8 *)((const char *)vx + (ix+0)*bx);
|
||||
@@ -2729,47 +2758,25 @@ static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
|
||||
_mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)));
|
||||
}
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l[4*ib4+k].d));
|
||||
auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h[4*ib4+k].d));
|
||||
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
|
||||
auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)),
|
||||
_mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1);
|
||||
auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)),
|
||||
_mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1);
|
||||
auto bits3 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+2)),
|
||||
_mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+2), 1);
|
||||
auto bits4 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+3)),
|
||||
_mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+3), 1);
|
||||
qx[0] = _mm512_and_si512(bits1, m4);
|
||||
qx[1] = _mm512_and_si512(bits2, m4);
|
||||
qx[2] = _mm512_and_si512(bits3, m4);
|
||||
qx[3] = _mm512_and_si512(bits4, m4);
|
||||
qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4);
|
||||
qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4);
|
||||
qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits3, 4), m4);
|
||||
qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits4, 4), m4);
|
||||
auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y4l = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0);
|
||||
auto y4h = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1);
|
||||
auto y8l = MM256_SET_M128I(y4l, y4l);
|
||||
auto y8h = MM256_SET_M128I(y4h, y4h);
|
||||
auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
|
||||
auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
|
||||
auto sumi = _mm512_setzero_si512();
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
|
||||
auto sumi = dot(q8.y[iy][ib4].qs+32*k);
|
||||
auto dy = _mm512_set1_ps(d8[8*iy+k]);
|
||||
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
|
||||
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int ib = 4*(nb/4); ib < nb; ++ib) {
|
||||
auto scales = prepare(iq4l[ib], iq4h[ib]);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto qy = (const block_q8_1 *)q8.y[iy];
|
||||
auto sumi = dot(qy[ib].qs);
|
||||
auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d));
|
||||
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
|
||||
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]);
|
||||
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
|
||||
@@ -3107,7 +3114,7 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
|
||||
auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[ib].qs, qy[ib].qs, qx);
|
||||
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
|
||||
acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]);
|
||||
acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(qy[ib].s), acc[1]);
|
||||
acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[1]);
|
||||
}
|
||||
}
|
||||
info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0]));
|
||||
@@ -3140,22 +3147,20 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
|
||||
}
|
||||
}
|
||||
}
|
||||
if (4*(nb/4) < nb) {
|
||||
for (int ib = 4*(nb/4); ib < nb; ++ib) {
|
||||
auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[ib].d));
|
||||
auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[ib].d));
|
||||
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[ib].qs+j)),
|
||||
_mm256_loadu_si256((const __m256i *)q8h[ib].qs+j), 1);
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto qy = (const block_q8_1 *)q8.y[iy];
|
||||
auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs);
|
||||
auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d));
|
||||
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
|
||||
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]);
|
||||
}
|
||||
for (int ib = 4*(nb/4); ib < nb; ++ib) {
|
||||
auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[ib].d));
|
||||
auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[ib].d));
|
||||
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[ib].qs+j)),
|
||||
_mm256_loadu_si256((const __m256i *)q8h[ib].qs+j), 1);
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto qy = (const block_q8_1 *)q8.y[iy];
|
||||
auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs);
|
||||
auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d));
|
||||
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
|
||||
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
@@ -13217,10 +13222,10 @@ struct HelperQ80R4 : public BaseHelper<step> {
|
||||
m2 = _mm256_unpacklo_epi64(t2, t3);
|
||||
m3 = _mm256_unpackhi_epi64(t2, t3);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
m0 = _mm256_xor_si256(m0, _mm256_set1_epi8(-128));
|
||||
m1 = _mm256_xor_si256(m1, _mm256_set1_epi8(-128));
|
||||
m2 = _mm256_xor_si256(m2, _mm256_set1_epi8(-128));
|
||||
m3 = _mm256_xor_si256(m3, _mm256_set1_epi8(-128));
|
||||
m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
|
||||
m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
|
||||
m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
|
||||
m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
|
||||
#endif
|
||||
_mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0);
|
||||
_mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1);
|
||||
@@ -13239,10 +13244,10 @@ struct HelperQ80R4 : public BaseHelper<step> {
|
||||
m2 = _mm256_unpacklo_epi64(t2, t3);
|
||||
m3 = _mm256_unpackhi_epi64(t2, t3);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
m0 = _mm256_xor_si256(m0, _mm256_set1_epi8(-128));
|
||||
m1 = _mm256_xor_si256(m1, _mm256_set1_epi8(-128));
|
||||
m2 = _mm256_xor_si256(m2, _mm256_set1_epi8(-128));
|
||||
m3 = _mm256_xor_si256(m3, _mm256_set1_epi8(-128));
|
||||
m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
|
||||
m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
|
||||
m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
|
||||
m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
|
||||
#endif
|
||||
_mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0);
|
||||
_mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1);
|
||||
@@ -14079,16 +14084,11 @@ struct FlashQKfp32 {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
|
||||
#else
|
||||
if constexpr (D >= 128) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q8_0_1_Unpacker, nq);
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q8_0_1_Unpacker, nq);
|
||||
#else
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
|
||||
#endif
|
||||
} else {
|
||||
// This does not actually work until we fix K-cache to be quantized to Q8_0_x4 only if D%128 == 0
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) {
|
||||
|
||||
Reference in New Issue
Block a user