From a4ffe2e69e86c97b0d854ce2bafcac71483e3f71 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 17 Feb 2025 12:50:44 +0200 Subject: [PATCH] q8_KV: AVX2 gemm/gemv We get 254 t/s for L3-8B vs 194 t/s for q8_0 without rtr. --- ggml/src/iqk/iqk_mul_mat.cpp | 57 ++++++++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 12 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index aec76b6a..48355e32 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6175,6 +6175,9 @@ template static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); GGML_ASSERT(n%32 == 0); +#ifndef HAVE_FANCY_SIMD + auto m1 = _mm256_set1_epi16(1); +#endif __m256i qx[4]; __m256i sx[4]; __m256i acc[nrc_y] = {}; @@ -6195,7 +6198,12 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI } for (int iy = 0; iy < nrc_y; ++iy) { for (int j = 0; j < 4; ++j) { +#ifdef HAVE_FANCY_SIMD acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 4*i + j), qx[j])); +#else + auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 4*i + j), qx[j])); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot)); +#endif } } } @@ -6206,7 +6214,12 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI } for (int iy = 0; iy < nrc_y; ++iy) { for (int j = 0; j < 2; ++j) { +#ifdef HAVE_FANCY_SIMD acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j])); +#else + auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j])); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot)); +#endif } } } @@ -6214,7 +6227,12 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i); sx[0] = _mm256_sign_epi8(qx[0], qx[0]); for (int iy = 0; iy < nrc_y; ++iy) { +#ifdef HAVE_FANCY_SIMD acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0])); +#else + auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0])); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot)); +#endif } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -6230,16 +6248,23 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf GGML_ASSERT(nrc_x%8 == 0); GGML_ASSERT(n%32 == 0); __m256i qx[4]; - //__m256i sx[4]; +#ifndef HAVE_FANCY_SIMD + __m256i sx[4]; + auto m1 = _mm256_set1_epi16(1); +#endif __m256i acc[nrc_y] = {}; float dy[nrc_y]; +#ifdef HAVE_FANCY_SIMD int32_t sy[nrc_y]; +#endif const int8_t * q8y[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) { auto dptr = (const float *)info.src1_row(iy); dy[iy] = dptr[0]; +#ifdef HAVE_FANCY_SIMD auto iptr = (const int32_t *)(dptr + 1); sy[iy] = -127*iptr[0]; +#endif q8y[iy] = (const int8_t *)(dptr + 2); } const int8_t * q8x[4]; @@ -6256,35 +6281,43 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]); auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]); auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]); - //qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]); - //qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]); - //qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]); - //qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]); +#ifdef HAVE_FANCY_SIMD qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127)); qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127)); qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127)); qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127)); +#else + qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]); + qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]); + qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]); + qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]); +#endif for (int iy = 0; iy < nrc_y; ++iy) { auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i); - //acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); - //acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); - //acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); - //acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); +#ifdef HAVE_FANCY_SIMD acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00)); acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55)); acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa)); acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff)); +#else + auto dot1 = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto dot2 = _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto dot3 = _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto dot4 = _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + auto dot12 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot1), _mm256_madd_epi16(m1, dot2)); + auto dot34 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot3), _mm256_madd_epi16(m1, dot4)); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(dot12, dot34)); +#endif } } auto scales_x = _mm_loadu_ps(dx); for (int iy = 0; iy < nrc_y; ++iy) { auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1)); +#ifdef HAVE_FANCY_SIMD sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy])); +#endif auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy])); info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi))); - //auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[2*iy+0])); - //auto minus = _mm_mul_ps(scales_x, _mm_set1_ps(dy[2*iy+1])); - //info.store(ix, iy, _mm_fmadd_ps(scale, _mm_cvtepi32_ps(sumi), minus)); acc[iy] = _mm256_setzero_si256(); } }