From 91db234fb5cac7195f517c5cd12abf2bd2c32e9a Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 13 Feb 2025 08:40:24 +0200 Subject: [PATCH] Warn user when disabling MLA --- src/llama.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index f9a59c79..52291cab 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3168,6 +3168,8 @@ static bool llama_kv_cache_init( cache.kv_l.reserve(n_layer); cache.kvt_l.reserve(n_layer); + bool warn = true; + int n_mla = 0; for (int i = 0; i < (int) n_layer; i++) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); @@ -3175,6 +3177,17 @@ static bool llama_kv_cache_init( struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); ggml_tensor * k; ggml_tensor * v; + if (cparams.mla_attn) { + if (!model.layers[i].wk_b || !model.layers[i].wv_b) { + if (warn) { + LLAMA_LOG_WARN("=======================================================================================\n"); + LLAMA_LOG_WARN("%s: missing MLA tensors => disabling MLA\n", __func__); + LLAMA_LOG_WARN("%s: you need to reconvert your model in order to use MLA\n", __func__); + LLAMA_LOG_WARN("=======================================================================================\n"); + warn = false; + } + } + } if (cparams.mla_attn && model.layers[i].wk_b && model.layers[i].wv_b) { // DeepSeek MLA const uint32_t n_embd_head_qk_rope = hparams.n_rot; @@ -3186,6 +3199,7 @@ static bool llama_kv_cache_init( ggml_format_name(kvt, "cache_kvt_l%d", i); cache.kv_l.push_back(kv); cache.kvt_l.push_back(kvt); + n_mla++; } else { k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); @@ -3196,6 +3210,11 @@ static bool llama_kv_cache_init( cache.v_l.push_back(v); } } + if (cparams.mla_attn && n_mla < n_layer && n_mla > 0) { + LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer)); + LLAMA_LOG_ERROR("%s: bailing out\n", __func__); + GGML_ABORT("fatal error"); + } // allocate tensors and initialize the buffers to avoid NaNs in the padding for (auto it : ctx_map) {