diff --git a/src/llama.cpp b/src/llama.cpp index 67a80eb2..e54dacf8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3173,8 +3173,17 @@ static bool llama_kv_cache_init( const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * k; + ggml_tensor * v; + if (cparams.mla_attn) { + k = ggml_new_tensor_1d(ctx, type_k, 1); + v = ggml_new_tensor_1d(ctx, type_v, 1); + } + else { + k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + } + ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k);