From 23e90dc3258b86a51a29529b2c7ad45eb35843dc Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 29 Jan 2025 09:55:10 +0200 Subject: [PATCH] Make q4_0_r4 work with tensor row sizes that are not a multiple of 128 ... on Zen4. Also fix q8_0 K-cache for head sizes that are not multiple of 128. --- ggml/src/iqk/iqk_mul_mat.cpp | 134 +++++++++++++++++------------------ 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index bee3ce45..308d0dca 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -2717,9 +2717,38 @@ static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn Q8 q8(info); auto m4 = _mm512_set1_epi8(0xf); int nb = n / QK4_NL; - GGML_ASSERT(nb%4 == 0); __m512 acc[2*nrc_y] = {}; __m512i qx[8]; + auto prepare = [&qx, &m4] (const block_iq4_nl_r8& iq4l, const block_iq4_nl_r8& iq4h) { + auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l.d)); + auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h.d)); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + for (int j = 0; j < 4; ++j) { + auto bits = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+j)), + _mm256_loadu_si256((const __m256i *)iq4h.qs+j), 1); + qx[j+0] = _mm512_and_si512(bits, m4); + qx[j+4] = _mm512_and_si512(_mm512_srli_epi16(bits, 4), m4); + } + return scales; + }; + auto dot = [&qx] (const int8_t * qy) { + auto y4l = _mm_loadu_si128((const __m128i*)qy+0); + auto y4h = _mm_loadu_si128((const __m128i*)qy+1); + auto y8l = MM256_SET_M128I(y4l, y4l); + auto y8h = MM256_SET_M128I(y4h, y4h); + auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1); + auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff))); + sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff))); + return sumi; + }; float d8[8*nrc_y]; for (int ix = 0; ix < nrc_x; ix += 16) { const block_iq4_nl_r8 * iq4l = (const block_iq4_nl_r8 *)((const char *)vx + (ix+0)*bx); @@ -2729,47 +2758,25 @@ static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); } for (int k = 0; k < 4; ++k) { - auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l[4*ib4+k].d)); - auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h[4*ib4+k].d)); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)), - _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1); - auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)), - _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1); - auto bits3 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+2)), - _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+2), 1); - auto bits4 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+3)), - _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+3), 1); - qx[0] = _mm512_and_si512(bits1, m4); - qx[1] = _mm512_and_si512(bits2, m4); - qx[2] = _mm512_and_si512(bits3, m4); - qx[3] = _mm512_and_si512(bits4, m4); - qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4); - qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4); - qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits3, 4), m4); - qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits4, 4), m4); + auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]); for (int iy = 0; iy < nrc_y; ++iy) { - auto y4l = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0); - auto y4h = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1); - auto y8l = MM256_SET_M128I(y4l, y4l); - auto y8h = MM256_SET_M128I(y4h, y4h); - auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1); - auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff))); - sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff))); + auto sumi = dot(q8.y[iy][ib4].qs+32*k); auto dy = _mm512_set1_ps(d8[8*iy+k]); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq4l[ib], iq4h[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(qy[ib].qs); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { auto sum = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]); acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); @@ -3107,7 +3114,7 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[ib].qs, qy[ib].qs, qx); auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]); - acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(qy[ib].s), acc[1]); + acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[1]); } } info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0])); @@ -3140,22 +3147,20 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn } } } - if (4*(nb/4) < nb) { - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[ib].d)); - auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[ib].d)); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - for (int j = 0; j < 8; ++j) { - qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[ib].qs+j)), - _mm256_loadu_si256((const __m256i *)q8h[ib].qs+j), 1); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_1 *)q8.y[iy]; - auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); - acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); - } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[ib].d)); + auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[ib].d)); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + for (int j = 0; j < 8; ++j) { + qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[ib].qs+j)), + _mm256_loadu_si256((const __m256i *)q8h[ib].qs+j), 1); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -13217,10 +13222,10 @@ struct HelperQ80R4 : public BaseHelper { m2 = _mm256_unpacklo_epi64(t2, t3); m3 = _mm256_unpackhi_epi64(t2, t3); #ifdef HAVE_FANCY_SIMD - m0 = _mm256_xor_si256(m0, _mm256_set1_epi8(-128)); - m1 = _mm256_xor_si256(m1, _mm256_set1_epi8(-128)); - m2 = _mm256_xor_si256(m2, _mm256_set1_epi8(-128)); - m3 = _mm256_xor_si256(m3, _mm256_set1_epi8(-128)); + m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); #endif _mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0); _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1); @@ -13239,10 +13244,10 @@ struct HelperQ80R4 : public BaseHelper { m2 = _mm256_unpacklo_epi64(t2, t3); m3 = _mm256_unpackhi_epi64(t2, t3); #ifdef HAVE_FANCY_SIMD - m0 = _mm256_xor_si256(m0, _mm256_set1_epi8(-128)); - m1 = _mm256_xor_si256(m1, _mm256_set1_epi8(-128)); - m2 = _mm256_xor_si256(m2, _mm256_set1_epi8(-128)); - m3 = _mm256_xor_si256(m3, _mm256_set1_epi8(-128)); + m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); #endif _mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0); _mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1); @@ -14079,16 +14084,11 @@ struct FlashQKfp32 { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0= 128) { #ifdef HAVE_FANCY_SIMD - MAKE_FUNCS(mul_mat_qX_1_q8_1_T>) {