q8_KV: be able to use it for K cache

This required quite a few fixes in ggml and llama.cpp:
* ggml: do not calculate row size as n/block_size*type_size. I had
  removed most of it when implementing the quants with per row scale,
  bit it was stull lurking in ggml_copy. Not sure if these were the last
  remnants of ggmil-style row sizes, or if there are still places left
* llama.cpp: get rid of the the 1d K cache assumption. Create and manage
  the K-cache as a 2D tensor so we can have per row meta data as needed
  by q8_KV.

Using q8_KV for K-cache results in non-negligible performance gains.
More details to follow, but for DeepSeek-Lite with MLA, we get
18% speedup for PP-8192 compared to q8_0 K-cache.
This commit is contained in:
Iwan Kawrakow
2025-02-17 15:19:05 +02:00
parent a4ffe2e69e
commit 0280b8d52b
4 changed files with 38 additions and 14 deletions

View File

@@ -3180,6 +3180,10 @@ static bool llama_kv_cache_init(
for (int i = 0; i < (int) n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
const uint32_t n_head = hparams.n_head(i);
const uint32_t n_head_kv = hparams.n_head_kv(i);
const uint32_t n_embd_head_k= hparams.n_embd_head_k;
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
ggml_tensor * k;
@@ -3201,7 +3205,8 @@ static bool llama_kv_cache_init(
const uint32_t kv_lora_rank = hparams.n_lora_kv;
LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
#if MLA_USE_TRANSPOSED_CACHE
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size);
//ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
#else
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
#endif
@@ -3215,7 +3220,10 @@ static bool llama_kv_cache_init(
n_mla++;
}
else {
k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
//printf("Creating cache tensors:\n");
//printf("n_embd_k_gqa = %d, kv_size = %d, n_head = %d, n_head_kv = %d, n_embd_head_k = %d\n", (int)n_embd_k_gqa, (int)kv_size, (int)n_head, (int)n_head_kv, (int)n_embd_head_k);
//k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
@@ -8285,11 +8293,20 @@ static void llm_build_kv_store(
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
const int64_t n_head = hparams.n_head(il);
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_head_v = hparams.n_embd_head_v;
GGML_ASSERT(kv.size == n_ctx);
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
(ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
cb(k_cache_view, "k_cache_view", il);
//struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
// (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
//cb(k_cache_view, "k_cache_view", il);
auto k_row_size = ggml_row_size(kv.k_l[il]->type, n_embd_head_k);
ggml_tensor * k_cache_view = ggml_view_2d(ctx, kv.k_l[il], n_embd_head_k, n_tokens*n_head_kv,
k_row_size, k_row_size*n_head_kv*kv_head);
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
@@ -8708,7 +8725,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * k =
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv.k_l[il]->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa),
ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
0);
cb(k, "k", il);
@@ -13509,8 +13526,9 @@ struct llm_build_context {
ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0);
cb(kvr, "kvr", il);
ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*(kv_lora_rank + n_embd_head_qk_rope),
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope)*kv_head);
auto row_size = ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
ggml_tensor * kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_tokens,
row_size, row_size*kv_head);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kvr, kv_cache_view));
ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank + n_embd_head_qk_rope, n_kv,