Q8_KV: 8-bit quantization type targeting the KV cache (#208)

* Adding q8_KV - Basics + AVX2 gemm/gemv

* q8_KV: Better AVX2 gemm

* q8_KV: Better Zen4 gemm

We get 225.7 t/s for L3-8B. In comparison q8_0 without
run-tinme-repacking is at 169 t/s.

* q8_KV: AVX2 gemm/gemv

We get 254 t/s for L3-8B vs 194 t/s for q8_0 without rtr.

* q8_KV: be able to use it for K cache

This required quite a few fixes in ggml and llama.cpp:
* ggml: do not calculate row size as n/block_size*type_size. I had
  removed most of it when implementing the quants with per row scale,
  bit it was stull lurking in ggml_copy. Not sure if these were the last
  remnants of ggmil-style row sizes, or if there are still places left
* llama.cpp: get rid of the the 1d K cache assumption. Create and manage
  the K-cache as a 2D tensor so we can have per row meta data as needed
  by q8_KV.

Using q8_KV for K-cache results in non-negligible performance gains.
More details to follow, but for DeepSeek-Lite with MLA, we get
18% speedup for PP-8192 compared to q8_0 K-cache.

* q8_KV: be able to use it for K cache in FA

* q8_KV: repack it for K*Q in FA

* q8_KV: slightly faster gemv on Zen4

* q8_KV: slightly faster gemv on Zen4

* q8_KV: ARM_NEON

We get PP-512 = 167 t/s for L3-8B without interleaving!
We do the interleaving on the fly, so I wonder if this
could be done for other quants as well.

* q8_KV: use it in FA on NEON

* q8_KV_r8 - repacked q8_KV

On Zen4 it is slower than q8_k_r8 (292 vs 370 t/s)
This makes no sense whatsoever as the q8_KV_r8 GEMM is
basically the q8_k_r8 GEMM with the unnecessary block stuff
removed (so, one would think that it would be faster).

* q8_KV_r8: don't use nrc_y = 16 on Zen4

This is faster - 350 t/s. Why?
Much better than the 290 t/s we had before, but still slower
than the 370 t/s for q8_k_r8.

* q8_KV: nrc_y = 16 also doesn't pay off in FA

* Minor

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-02-19 11:47:07 +02:00
committed by GitHub
parent 047ba895bb
commit a0ebfdd661
11 changed files with 983 additions and 34 deletions

View File

@@ -2967,6 +2967,103 @@ void iqk_quantize_row_q8_K128(const float * x, void * vy, int64_t k) {
}
#endif
}
// TODO: merge this with the above template
void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
assert(k % 32 == 0);
auto dptr = (float *)vy;
auto q8 = (int8_t *)(dptr + 2);
#ifdef __AVX2__
const __m256 signBit = _mm256_set1_ps(-0.0f);
const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
__m256 maxAbs = _mm256_setzero_ps();
for (int ib = 0; ib < k/8; ++ib) {
const __m256 v = _mm256_loadu_ps(x + 8*ib);
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v));
}
const float maxScalar = hmax_f32_8(maxAbs);
if (!maxScalar) {
dptr[0] = dptr[1] = 0;
std::memset(q8, 0, k*sizeof(int8_t));
return;
}
dptr[0] = maxScalar / 127.f;
auto mul = _mm256_set1_ps(1/dptr[0]);
auto isum = _mm256_setzero_si256();
for (int i = 0; i < k/32; i++) {
__m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 0));
__m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 8));
__m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 16));
__m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 24));
v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST);
v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST);
v2 = _mm256_round_ps(v2, _MM_ROUND_NEAREST);
v3 = _mm256_round_ps(v3, _MM_ROUND_NEAREST);
__m256i i0 = _mm256_cvtps_epi32(v0);
__m256i i1 = _mm256_cvtps_epi32(v1);
__m256i i2 = _mm256_cvtps_epi32(v2);
__m256i i3 = _mm256_cvtps_epi32(v3);
isum = _mm256_add_epi32(isum, _mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
i0 = _mm256_packs_epi32( i0, i1 );
i2 = _mm256_packs_epi32( i2, i3 );
i0 = _mm256_packs_epi16( i0, i2 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
_mm256_storeu_si256((__m256i *)q8, i0);
q8 += 32;
}
auto iptr = (int32_t *)(dptr + 1);
iptr[0] = hsum_i32_8(isum);
#elif defined __ARM_NEON
int32x4_t ival[8];
auto vmax = vdupq_n_f32(0.f);
for (int j = 0; j < k; j += 4) {
vmax = vmaxq_f32(vmax, vabsq_f32(vld1q_f32(x + j)));
}
auto smax = vmaxvq_f32(vmax);
if (!smax) {
dptr[0] = dptr[1] = 0;
std::memset(q8, 0, k*sizeof(int8_t));
return;
}
dptr[0] = smax/127;
auto vid = vdupq_n_f32(1/dptr[0]);
auto isum = vdupq_n_s32(0);
for (int ib = 0; ib < k/32; ++ib) {
auto xb = x + 32*ib;
for (int k = 0; k < 8; ++k) {
auto val = vld1q_f32(xb + 4*k);
ival[k] = vcvtnq_s32_f32(vmulq_f32(val, vid));
isum = vaddq_s32(isum, ival[k]);
}
for (int k = 0; k < 4; ++k) {
auto i16 = vcombine_s16(vmovn_s32(ival[2*k+0]), vmovn_s32(ival[2*k+1]));
vst1_s8(q8, vmovn_s16(i16));
q8 += 8;
}
}
auto iptr = (int32_t *)(dptr + 1);
iptr[0] = vaddvq_s32(isum);
#else
float amax = 0;
for (int j = 0; j < k; ++j) {
float ax = std::abs(x[j]);
amax = std::max(amax, ax);
}
if (!amax) {
dptr[0] = dptr[1] = 0;
std::memset(q8, 0, k*sizeof(int8_t));
return;
}
dptr[0] = amax/127;
float id = 1/dptr[0];
int isum = 0;
for (int i = 0; i < k; i++) {
q8[i] = nearest_int(id*x[i]);
isum += q8[i];
}
auto iptr = (int32_t *)(dptr + 1);
iptr[0] = isum;
#endif
}
}
void quantize_row_q8_K128(const float * x, void * vy, int64_t k) {
@@ -3886,7 +3983,7 @@ static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8
#ifdef HAVE_FANCY_SIMD
static void modify_q8_0_r8(int64_t k, char * cy) {
auto y = (block_iq4_nl_r8 *)cy;
auto y = (block_q8_0_r8 *)cy;
int nb = k/(32*8);
for (int ib = 0; ib < nb; ++ib) {
for (int l = 0; l < 4; ++l) {
@@ -5412,6 +5509,150 @@ void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t b
GGML_UNUSED(by);
}
//
// ========================================= q8_KV_r8
//
void quantize_row_q8_KV_r8_ref(const float * x, void * y, int64_t k) {
quantize_q8_KV_r8(x, y, 8, k/8, nullptr);
}
void quantize_row_q8_KV_r8(const float * x, void * y, int64_t k) {
quantize_q8_KV_r8(x, y, 8, k/8, nullptr);
}
static void repack_q8_KV(int nrows, int n_per_row, const char * cx, char * cy, [[maybe_unused]] bool online) {
GGML_ASSERT(nrows%8 == 0);
GGML_ASSERT(n_per_row%16 == 0);
auto row_size_x = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row);
auto row_size_y = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row);
const int8_t * x8[8];
#ifdef __ARM_NEON
int8x16x2_t m0, m1, m2, m3;
#endif
for (int row = 0; row < nrows; row += 8) {
auto dy = (float *)cy;
auto qy = (int8_t *)(dy + 8);
for (int k = 0; k < 8; ++k) {
auto dx = (const float *)(cx + k*row_size_x);
dy[k] = dx[0];
x8[k] = (const int8_t *)(dx + 2);
}
for (int ib = 0; ib < n_per_row/16; ++ib) {
#ifdef __AVX2__
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib));
auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib));
auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib));
auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib));
auto t0 = _mm256_unpacklo_epi32(m0, m1);
auto t1 = _mm256_unpacklo_epi32(m2, m3);
auto t2 = _mm256_unpackhi_epi32(m0, m1);
auto t3 = _mm256_unpackhi_epi32(m2, m3);
m0 = _mm256_unpacklo_epi64(t0, t1);
m1 = _mm256_unpackhi_epi64(t0, t1);
m2 = _mm256_unpacklo_epi64(t2, t3);
m3 = _mm256_unpackhi_epi64(t2, t3);
#ifdef HAVE_FANCY_SIMD
if (online) {
m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
}
#endif
_mm256_storeu_si256((__m256i *)qy + 4*ib+0, m0);
_mm256_storeu_si256((__m256i *)qy + 4*ib+1, m1);
_mm256_storeu_si256((__m256i *)qy + 4*ib+2, m2);
_mm256_storeu_si256((__m256i *)qy + 4*ib+3, m3);
#elif defined __ARM_NEON
m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib);
m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib);
m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib);
m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib);
auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
vst1q_s8_x2(qy + 0 + 128*ib, m0);
vst1q_s8_x2(qy + 32 + 128*ib, m1);
vst1q_s8_x2(qy + 64 + 128*ib, m2);
vst1q_s8_x2(qy + 96 + 128*ib, m3);
#else
// TODO
for (int l = 0; l < 4; ++l) {
for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
}
}
#endif
}
cx += 8*row_size_x;
cy += online ? 8*row_size_x : 8*row_size_y;
//So, if we are run-time-repacking (online = true) we don't want to change the stride, so we just leave some unused space at the end of each row
}
}
#ifdef HAVE_FANCY_SIMD
static void modify_q8_KV_r8(int64_t k, char * cy) {
int8_t * q8 = (int8_t *)(cy + 8*sizeof(float));
for (int j = 0; j < k; ++j) q8[j] += 127;
}
#endif
size_t quantize_q8_KV_r8(const float * src, void * dst, int64_t nrows, int64_t n_per_row, [[maybe_unused]] const float * imatrix) {
GGML_ASSERT(nrows%8 == 0);
GGML_ASSERT(n_per_row%16 == 0);
char * qcur = (char *)dst;
auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row);
auto row_size_1 = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row);
std::vector<char> qtmp(8*row_size_0);
for (int row = 0; row < nrows; row += 8) {
quantize_q8_KV(src, (void *)qtmp.data(), 8, n_per_row, imatrix);
repack_q8_KV(8, n_per_row, qtmp.data(), qcur, false);
qcur += 8*row_size_1;
src += 8*n_per_row;
}
return nrows*row_size_1;
}
void dequantize_row_q8_KV_r8(const void * vx, float * y, int64_t k) {
auto n_per_row = k/8;
float * y8[8];
for (int k = 0; k < 8; ++k) y8[k] = y + n_per_row*k;
auto dptr = (const float *)vx;
auto q8 = (const int8_t *)(dptr + 8);
for (int ib = 0; ib < n_per_row/16; ++ib) {
for (int k = 0; k < 8; ++k) {
for (int l = 0; l < 4; ++l) {
for (int i = 0; i < 4; ++i) y8[k][16*ib + 4*l + i] = dptr[k] * q8[128*ib + 32*l + 4*k + i];
}
}
}
}
void vec_dot_q8_KV_r8_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV_R8, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) {
return;
}
#endif
GGML_ASSERT(n%QK4_NL == 0);
GGML_ASSERT(nrc == 1);
GGML_UNUSED(bs);
GGML_UNUSED(bx);
GGML_UNUSED(by);
}
//
// ========================================= bf16_r4
//
@@ -6450,6 +6691,47 @@ void vec_dot_iq1_m_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t
GGML_UNUSED(by);
}
void quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
iqk_quantize_row_q8_KV(x, vy, k);
}
void quantize_row_q8_KV_ref(const float * x, void * y, int64_t k) {
quantize_row_q8_KV(x, y, k);
}
size_t quantize_q8_KV(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
(void)imatrix;
auto row_size = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row);
auto q = (char *)dst;
for (int row = 0; row < nrows; ++row) {
quantize_row_q8_KV(src, q, n_per_row);
src += n_per_row;
q += row_size;
}
return row_size*nrows;
}
void dequantize_row_q8_KV(const void * x, float * y, int64_t k) {
auto dptr = (const float *)x;
float d = dptr[0];
auto q8 = (const int8_t *)(dptr + 2);
for (int j = 0; j < k; ++j) y[j] = d * q8[j];
}
void vec_dot_q8_KV_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) {
return;
}
#endif
GGML_ASSERT(n%QK4_NL == 0);
GGML_ASSERT(nrc == 1);
GGML_UNUSED(bs);
GGML_UNUSED(bx);
GGML_UNUSED(by);
}
//================================================
namespace {
@@ -6472,8 +6754,9 @@ bool iqk_modify_tensor(struct ggml_tensor * tensor) {
{ GGML_TYPE_Q4_0_R8, {modify_q4_0_r8, 8} },
#endif
#ifdef HAVE_FANCY_SIMD
{ GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} },
{ GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} },
{ GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} },
{ GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} },
{ GGML_TYPE_Q8_KV_R8, {modify_q8_KV_r8, 8} },
#endif
};
auto it = k_mod_map.find(tensor->type);
@@ -6532,6 +6815,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) {
{ GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} },
{ GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R8, 8, (Repack::repack_func)repack_q8_0} },
{ GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} },
{ GGML_TYPE_Q8_KV, { GGML_TYPE_Q8_KV_R8, 8, (Repack::repack_func)repack_q8_KV} },
#ifdef __AVX512BF16__
{ GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_bf16_t>}},
{ GGML_TYPE_F16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_half>} },