q8_KV: AVX2 gemm/gemv

We get 254 t/s for L3-8B vs 194 t/s for q8_0 without rtr.
This commit is contained in:
Iwan Kawrakow
2025-02-17 12:50:44 +02:00
parent 0d7885f081
commit a4ffe2e69e

View File

@@ -6175,6 +6175,9 @@ template <int nrc_y>
static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
GGML_ASSERT(n%32 == 0);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
#endif
__m256i qx[4];
__m256i sx[4];
__m256i acc[nrc_y] = {};
@@ -6195,7 +6198,12 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI
}
for (int iy = 0; iy < nrc_y; ++iy) {
for (int j = 0; j < 4; ++j) {
#ifdef HAVE_FANCY_SIMD
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 4*i + j), qx[j]));
#else
auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 4*i + j), qx[j]));
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot));
#endif
}
}
}
@@ -6206,7 +6214,12 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI
}
for (int iy = 0; iy < nrc_y; ++iy) {
for (int j = 0; j < 2; ++j) {
#ifdef HAVE_FANCY_SIMD
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
#else
auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot));
#endif
}
}
}
@@ -6214,7 +6227,12 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI
qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
for (int iy = 0; iy < nrc_y; ++iy) {
#ifdef HAVE_FANCY_SIMD
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
#else
auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot));
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
@@ -6230,16 +6248,23 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
GGML_ASSERT(nrc_x%8 == 0);
GGML_ASSERT(n%32 == 0);
__m256i qx[4];
//__m256i sx[4];
#ifndef HAVE_FANCY_SIMD
__m256i sx[4];
auto m1 = _mm256_set1_epi16(1);
#endif
__m256i acc[nrc_y] = {};
float dy[nrc_y];
#ifdef HAVE_FANCY_SIMD
int32_t sy[nrc_y];
#endif
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
#ifdef HAVE_FANCY_SIMD
auto iptr = (const int32_t *)(dptr + 1);
sy[iy] = -127*iptr[0];
#endif
q8y[iy] = (const int8_t *)(dptr + 2);
}
const int8_t * q8x[4];
@@ -6256,35 +6281,43 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
//qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
//qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
//qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
//qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
#ifdef HAVE_FANCY_SIMD
qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127));
qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127));
qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127));
qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127));
#else
qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
#ifdef HAVE_FANCY_SIMD
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
#else
auto dot1 = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
auto dot2 = _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto dot3 = _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto dot4 = _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
auto dot12 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot1), _mm256_madd_epi16(m1, dot2));
auto dot34 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot3), _mm256_madd_epi16(m1, dot4));
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(dot12, dot34));
#endif
}
}
auto scales_x = _mm_loadu_ps(dx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
#ifdef HAVE_FANCY_SIMD
sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy]));
#endif
auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
//auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[2*iy+0]));
//auto minus = _mm_mul_ps(scales_x, _mm_set1_ps(dy[2*iy+1]));
//info.store(ix, iy, _mm_fmadd_ps(scale, _mm_cvtepi32_ps(sumi), minus));
acc[iy] = _mm256_setzero_si256();
}
}