mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 02:41:47 +00:00
q8_KV: Better Zen4 gemm
We get 225.7 t/s for L3-8B. In comparison q8_0 without run-tinme-repacking is at 169 t/s.
This commit is contained in:
@@ -6230,13 +6230,16 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
GGML_ASSERT(n%32 == 0);
|
||||
__m256i qx[4];
|
||||
__m256i sx[4];
|
||||
//__m256i sx[4];
|
||||
__m256i acc[nrc_y] = {};
|
||||
float dy[nrc_y];
|
||||
int32_t sy[nrc_y];
|
||||
const int8_t * q8y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dptr = (const float *)info.src1_row(iy);
|
||||
dy[iy] = dptr[0];
|
||||
auto iptr = (const int32_t *)(dptr + 1);
|
||||
sy[iy] = -127*iptr[0];
|
||||
q8y[iy] = (const int8_t *)(dptr + 2);
|
||||
}
|
||||
const int8_t * q8x[4];
|
||||
@@ -6253,23 +6256,35 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
|
||||
auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
|
||||
auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
|
||||
auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
|
||||
qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
|
||||
qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
|
||||
qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
|
||||
qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
|
||||
//qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
|
||||
//qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
|
||||
//qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
|
||||
//qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
|
||||
qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127));
|
||||
qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127));
|
||||
qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127));
|
||||
qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
|
||||
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
|
||||
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
|
||||
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
|
||||
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
}
|
||||
}
|
||||
auto scales_x = _mm_loadu_ps(dx);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
|
||||
sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy]));
|
||||
auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
|
||||
info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
|
||||
//auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[2*iy+0]));
|
||||
//auto minus = _mm_mul_ps(scales_x, _mm_set1_ps(dy[2*iy+1]));
|
||||
//info.store(ix, iy, _mm_fmadd_ps(scale, _mm_cvtepi32_ps(sumi), minus));
|
||||
acc[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3010,7 +3010,8 @@ void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
|
||||
_mm256_storeu_si256((__m256i *)q8, i0);
|
||||
q8 += 32;
|
||||
}
|
||||
dptr[1] = dptr[0] * hsum_i32_8(isum);
|
||||
auto iptr = (int32_t *)(dptr + 1);
|
||||
iptr[0] = hsum_i32_8(isum);
|
||||
#elif defined __ARM_NEON
|
||||
int32x4_t ival[8];
|
||||
auto vmax = vdupq_n_f32(0.f);
|
||||
@@ -3037,7 +3038,8 @@ void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
|
||||
q8 += 8;
|
||||
}
|
||||
}
|
||||
dptr[1] = dptr[0] * vaddvq_s32(isum);
|
||||
auto iptr = (int32_t *)(dptr + 1);
|
||||
iptr[0] = vaddvq_s32(isum);
|
||||
#else
|
||||
float amax = 0;
|
||||
for (int j = 0; j < k; ++j) {
|
||||
@@ -3056,7 +3058,8 @@ void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
|
||||
q8[i] = nearest_int(id*x[i]);
|
||||
isum += q8[i];
|
||||
}
|
||||
dptr[1] = dptr[0]*isum;
|
||||
auto iptr = (int32_t *)(dptr + 1);
|
||||
iptr[0] = isum;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user