mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
MLA: allow Q8_0 K-cache for MLA (#206)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -3201,11 +3201,7 @@ static bool llama_kv_cache_init(
|
|||||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||||
LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
|
LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
|
||||||
#if MLA_USE_TRANSPOSED_CACHE
|
#if MLA_USE_TRANSPOSED_CACHE
|
||||||
// TODO: The k-cache is contiguous and not permuted, so strictly speaking, it should be possible to quantize it.
|
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
|
||||||
// Sadly, at this point something goes wrong with quantized k-cache, so for now we set the k-cache
|
|
||||||
// type to type_v, which is guaranteed to be f16 or bf16 without FA.
|
|
||||||
//ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
|
|
||||||
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
|
|
||||||
#else
|
#else
|
||||||
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
|
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
|
||||||
#endif
|
#endif
|
||||||
@@ -13495,7 +13491,7 @@ struct llm_build_context {
|
|||||||
|
|
||||||
#if MLA_USE_TRANSPOSED_CACHE
|
#if MLA_USE_TRANSPOSED_CACHE
|
||||||
ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank,
|
ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank,
|
||||||
ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head));
|
ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head));
|
||||||
cb(kv_cache_trans_view, "kv_cache_trans_view", il);
|
cb(kv_cache_trans_view, "kv_cache_trans_view", il);
|
||||||
|
|
||||||
// note: storing transposed c^KV in the transposed KV cache
|
// note: storing transposed c^KV in the transposed KV cache
|
||||||
@@ -13503,7 +13499,7 @@ struct llm_build_context {
|
|||||||
|
|
||||||
ggml_tensor * kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il],
|
ggml_tensor * kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il],
|
||||||
n_kv, kv_lora_rank,
|
n_kv, kv_lora_rank,
|
||||||
ggml_row_size(kv_self.kv_l[il]->type, kv_self.size),
|
ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size),
|
||||||
0);
|
0);
|
||||||
cb(kv_cache_trans, "kv_cache_trans", il);
|
cb(kv_cache_trans, "kv_cache_trans", il);
|
||||||
#endif
|
#endif
|
||||||
@@ -18047,21 +18043,26 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
size_t memory_size_kv = 0;
|
size_t memory_size_kv = 0;
|
||||||
size_t memory_size_kvt = 0;
|
size_t memory_size_kvt = 0;
|
||||||
|
|
||||||
|
ggml_type kv_type = GGML_TYPE_COUNT;
|
||||||
|
ggml_type kvt_type = GGML_TYPE_COUNT;
|
||||||
|
|
||||||
for (auto & kv : ctx->kv_self.kv_l) {
|
for (auto & kv : ctx->kv_self.kv_l) {
|
||||||
memory_size_kv += ggml_nbytes(kv);
|
memory_size_kv += ggml_nbytes(kv);
|
||||||
|
kv_type = kv->type;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if MLA_USE_TRANSPOSED_CACHE
|
#if MLA_USE_TRANSPOSED_CACHE
|
||||||
for (auto & kvt : ctx->kv_self.kvt_l) {
|
for (auto & kvt : ctx->kv_self.kvt_l) {
|
||||||
memory_size_kvt += ggml_nbytes(kvt);
|
memory_size_kvt += ggml_nbytes(kvt);
|
||||||
|
kvt_type = kvt->type;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (memory_size_kv + memory_size_kvt > 0) {
|
if (memory_size_kv + memory_size_kvt > 0) {
|
||||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__,
|
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__,
|
||||||
(float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f),
|
(float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f),
|
||||||
ggml_type_name(type_v), (float)memory_size_kv / (1024.0f * 1024.0f),
|
ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f),
|
||||||
ggml_type_name(type_v), (float)memory_size_kvt / (1024.0f * 1024.0f));
|
ggml_type_name(kvt_type), (float)memory_size_kvt / (1024.0f * 1024.0f));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user