diff --git a/src/llama.cpp b/src/llama.cpp index 2bdbf2a0..f9a59c79 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2685,7 +2685,6 @@ struct llama_kv_cache { std::vector v_l; // DeepSeek MLA - std::vector kr_l; // per layer std::vector kv_l; std::vector kvt_l; @@ -3166,7 +3165,6 @@ static bool llama_kv_cache_init( cache.v_l.reserve(n_layer); // DeepSeek MLA - cache.kr_l.reserve(n_layer); cache.kv_l.reserve(n_layer); cache.kvt_l.reserve(n_layer); @@ -3179,18 +3177,13 @@ static bool llama_kv_cache_init( ggml_tensor * v; if (cparams.mla_attn && model.layers[i].wk_b && model.layers[i].wv_b) { // DeepSeek MLA - //k = ggml_new_tensor_1d(ctx, type_k, 1); - //v = ggml_new_tensor_1d(ctx, type_v, 1); const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); - ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size); - ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); + ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); - ggml_format_name(kr, "cache_kr_l%d", i); ggml_format_name(kv, "cache_kv_l%d", i); ggml_format_name(kvt, "cache_kvt_l%d", i); - cache.kr_l.push_back(kr); cache.kv_l.push_back(kv); cache.kvt_l.push_back(kvt); } @@ -13457,7 +13450,6 @@ struct llm_build_context { 0); cb(kv_compressed, "kv_compressed", il); - //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, model.layers[il].attn_kv_a_norm, NULL, LLM_NORM_RMS, cb, il); @@ -13465,13 +13457,6 @@ struct llm_build_context { if (lctx.cparams.mla_attn && model.layers[il].wk_b && model.layers[il].wv_b) { - struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, - ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head); - cb(kv_cache_view, "kv_cache_view", il); - - // note: storing c^KV in the KV cache - ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view)); - struct ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head)); cb(kv_cache_trans_view, "kv_cache_trans_view", il); @@ -13479,13 +13464,6 @@ struct llm_build_context { // note: storing transposed c^KV in the transposed KV cache ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view)); - struct ggml_tensor * kv_cache = - ggml_view_2d(ctx0, kv_self.kv_l[il], - kv_lora_rank, n_kv, - ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank), - 0); - cb(kv_cache, "kv_cache", il); - struct ggml_tensor * kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_kv, kv_lora_rank, @@ -13493,19 +13471,16 @@ struct llm_build_context { 0); cb(kv_cache_trans, "kv_cache_trans", il); - struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, - ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head); - cb(kr_cache_view, "kr_cache_view", il); + ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0); + cb(kvr, "kvr", il); - // note: storing RoPE-ed version of K^R in the KV cache - ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_rope, kr_cache_view)); - - struct ggml_tensor * kr_cache = - ggml_view_2d(ctx0, kv_self.kr_l[il], - n_embd_head_qk_rope, n_kv, - ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope), - 0); - cb(kr_cache, "kr_cache", il); + ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*(kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope)*kv_head); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, kvr, kv_cache_view)); + ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.kv_l[il], + kv_lora_rank + n_embd_head_qk_rope, n_kv, + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cache, "kv_cache", il); struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), @@ -13518,33 +13493,20 @@ struct llm_build_context { struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); cb(q_nope2, "q_nope2", il); + ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); + cb(q, "q", il); if (!pp_opt) { - q_nope2 = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3); - cb(q_nope2, "q_nope2_perm", il); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + cb(q, "q_perm", il); } - struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2); - cb(kq_nope, "kq_nope", il); - - if (!pp_opt) { - kq_nope = ggml_permute(ctx0, kq_nope, 0, 2, 1, 3); - cb(kq_nope, "kq_nope_perm", il); - } - - if (pp_opt) { - q_rope = ggml_permute(ctx0, q_rope, 0, 2, 1, 3); - cb(q_rope, "q_rope_perm", il); - } - struct ggml_tensor * kq_rope = ggml_mul_mat(ctx0, kr_cache, q_rope); - cb(kq_rope, "kq_rope", il); - - if (!pp_opt) { - kq_rope = ggml_permute(ctx0, kq_rope, 0, 2, 1, 3); - cb(kq_rope, "kq_rope_perm", il); - } - - struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_rope); + ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); cb(kq, "kq", il); + if (!pp_opt) { + kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); + cb(kq, "kq_perm", il); + } + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); @@ -13561,7 +13523,9 @@ struct llm_build_context { cb(kqv_compressed, "kqv_compressed_perm", il); } - struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); + struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, + ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), + ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank)*n_embd_head_v, 0); cb(wv_b, "wv_b", il); struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); @@ -18033,14 +17997,9 @@ struct llama_context * llama_new_context_with_model( } { - size_t memory_size_kr = 0; size_t memory_size_kv = 0; size_t memory_size_kvt = 0; - for (auto & kr : ctx->kv_self.kr_l) { - memory_size_kr += ggml_nbytes(kr); - } - for (auto & kv : ctx->kv_self.kv_l) { memory_size_kv += ggml_nbytes(kv); } @@ -18049,10 +18008,9 @@ struct llama_context * llama_new_context_with_model( memory_size_kvt += ggml_nbytes(kvt); } - if (memory_size_kr + memory_size_kv + memory_size_kvt > 0) { - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__, - (float)(memory_size_kr + memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), - ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f), + if (memory_size_kv + memory_size_kvt > 0) { + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__, + (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_kv / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_kvt / (1024.0f * 1024.0f)); }