From df257a07e626ef6103c38b9faf89a421abba2b3a Mon Sep 17 00:00:00 2001 From: saood06 Date: Fri, 30 May 2025 02:28:27 -0500 Subject: [PATCH] Replace MLA-specific KV cache with the standard KV cache V2 (#473) * Fix save and restore when there is no V cache * Fix double print --- src/llama.cpp | 76 +++++++++++++++++++++------------------------------ 1 file changed, 31 insertions(+), 45 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 099d94a9..b8555677 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20699,41 +20699,21 @@ struct llama_context * llama_new_context_with_model( } if (memory_size_k + memory_size_v > 0) { - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), - ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); - } - } - - { - size_t memory_size_kv = 0; - size_t memory_size_kvt = 0; - - ggml_type kv_type = GGML_TYPE_COUNT; - ggml_type kvt_type = GGML_TYPE_COUNT; - - for (auto & kv : ctx->kv_self.k_l) { - memory_size_kv += ggml_nbytes(kv); - kv_type = kv->type; - } - - for (auto & kvt : ctx->kv_self.v_l) { - memory_size_kvt += ggml_nbytes(kvt); - kvt_type = kvt->type; - } - - if (memory_size_kv + memory_size_kvt > 0) { - if (cparams.mla_attn == 1 && !cparams.flash_attn) { + if (cparams.mla_attn != 0 && !cparams.flash_attn) { LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__, - (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), - ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f), - ggml_type_name(kvt_type), (float)memory_size_kvt / (1024.0f * 1024.0f)); - } else { + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } else if (cparams.mla_attn != 0 && cparams.flash_attn) { LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T: not used\n", __func__, - (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), - ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f)); - } + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f)); + } else { + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } } } @@ -21450,13 +21430,13 @@ struct llama_data_write { const struct llama_kv_cache & kv_self = ctx->kv_self; const struct llama_hparams & hparams = ctx->model.hparams; - // Misuse v_trans: 0 -> not transposed V cache - // 1 -> transposed V cache - // 2 -> no V cache (as it maybe the case with MLA) - const uint32_t v_trans = kv_self.v_l.empty() ? 2 : kv_self.v_trans ? 1 : 0; + // v_state: 0 -> not transposed V cache + // 1 -> transposed V cache + // 2 -> no V cache (as it may be the case with MLA) + const uint32_t v_state = kv_self.v_l.empty() ? 2 : kv_self.v_trans ? 1 : 0; const uint32_t n_layer = hparams.n_layer; - write(&v_trans, sizeof(v_trans)); + write(&v_state, sizeof(v_state)); write(&n_layer, sizeof(n_layer)); std::vector tmp_buf; @@ -21482,7 +21462,7 @@ struct llama_data_write { } } - if (kv_self.v_trans == 0) { + if (v_state == 0) { for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); @@ -21502,7 +21482,7 @@ struct llama_data_write { } } } - else if (v_trans == 1) { + else if (v_state == 1) { // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = kv_self.size; for (uint32_t il = 0; il < n_layer; ++il) { @@ -21748,9 +21728,13 @@ struct llama_data_read { bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) { const struct llama_hparams & hparams = ctx->model.hparams; struct llama_kv_cache & kv_self = ctx->kv_self; - uint32_t v_trans; + + // v_state: 0 -> not transposed V cache + // 1 -> transposed V cache + // 2 -> no V cache (as it may be the case with MLA) + uint32_t v_state; uint32_t n_layer; - read_to(&v_trans, sizeof(v_trans)); + read_to(&v_state, sizeof(v_state)); read_to(&n_layer, sizeof(n_layer)); if (n_layer != hparams.n_layer) { @@ -21761,7 +21745,9 @@ struct llama_data_read { LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size); return false; } - if (kv_self.v_trans != (bool) v_trans) { + + // Currently the only way there is no V cache (and thus v_state is 2) requires flash_attn, and flash_attn sets kv_self.v_trans to false + if (kv_self.v_trans != (v_state == 1)) { LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); return false; } @@ -21794,7 +21780,7 @@ struct llama_data_read { } } - if (kv_self.v_trans == 0) { + if (v_state == 0) { for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); @@ -21822,7 +21808,7 @@ struct llama_data_read { } } } - else if (v_trans == 1) { + else if (v_state == 1) { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();