Option to use MLA without a transposed cache (#235)

The `-mla` command line option turns into an int from a bool.
mla = 0: use standard attention
mla = 1: use MLA with transposed cache
mla > 1: use MLA without transposed cache

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-02-27 16:40:49 +02:00
committed by GitHub
parent 51029edfdf
commit b762db7c92
6 changed files with 64 additions and 91 deletions

View File

@@ -111,14 +111,6 @@
#define LLAMA_MAX_LAYERS 512
#define LLAMA_MAX_EXPERTS 256 // DeepSeekV2
//
// === MLA cache
// If tou are desperate to reduce KV cache size, set MLA_USE_TRANSPOSED_CACHE to 0.
// TG perfornce will be slower (similar to no-MLA), but KV cache size will be cut to ~half.
// PP performance will be about the same as with MLA_USE_TRANSPOSED_CACHE = 1.
//
#define MLA_USE_TRANSPOSED_CACHE 1
//
// helpers
//
@@ -2518,7 +2510,7 @@ struct llama_cparams {
bool causal_attn;
bool offload_kqv;
bool flash_attn;
bool mla_attn;
int mla_attn;
bool fused_moe_up_gate;
enum llama_pooling_type pooling_type;
@@ -2695,9 +2687,7 @@ struct llama_kv_cache {
// DeepSeek MLA
std::vector<struct ggml_tensor *> kv_l;
#if MLA_USE_TRANSPOSED_CACHE
std::vector<struct ggml_tensor *> kvt_l;
#endif
std::vector<struct ggml_context *> ctxs;
std::vector<ggml_backend_buffer_t> bufs;
@@ -3175,9 +3165,9 @@ static bool llama_kv_cache_init(
// DeepSeek MLA
cache.kv_l.reserve(n_layer);
#if MLA_USE_TRANSPOSED_CACHE
cache.kvt_l.reserve(n_layer);
#endif
if (cparams.mla_attn == 1) {
cache.kvt_l.reserve(n_layer);
}
bool warn = true;
int n_mla = 0;
@@ -3208,25 +3198,18 @@ static bool llama_kv_cache_init(
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
#if MLA_USE_TRANSPOSED_CACHE
ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size);
//ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
#else
ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_v, kv_lora_rank + n_embd_head_qk_rope, kv_size);
#endif
auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v;
ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size);
ggml_format_name(kv, "cache_kv_l%d", i);
cache.kv_l.push_back(kv);
#if MLA_USE_TRANSPOSED_CACHE
ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size);
ggml_format_name(kvt, "cache_kvt_l%d", i);
cache.kvt_l.push_back(kvt);
#endif
if (cparams.mla_attn == 1) {
ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size);
ggml_format_name(kvt, "cache_kvt_l%d", i);
cache.kvt_l.push_back(kvt);
}
n_mla++;
}
else {
//printf("Creating cache tensors:\n");
//printf("n_embd_k_gqa = %d, kv_size = %d, n_head = %d, n_head_kv = %d, n_embd_head_k = %d\n", (int)n_embd_k_gqa, (int)kv_size, (int)n_head, (int)n_head_kv, (int)n_embd_head_k);
//k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_format_name(k, "cache_k_l%d", i);
@@ -8940,7 +8923,7 @@ struct llm_build_context {
const int32_t n_ctx_orig;
const bool flash_attn;
const bool mla_attn;
const int mla_attn;
const bool fused_moe_up_gate;
const enum llama_pooling_type pooling_type;
@@ -13546,20 +13529,22 @@ struct llm_build_context {
if (lctx.cparams.mla_attn && model.layers[il].wk_b && model.layers[il].wv_b) {
#if MLA_USE_TRANSPOSED_CACHE
ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank,
ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head));
cb(kv_cache_trans_view, "kv_cache_trans_view", il);
ggml_tensor * kv_cache_trans;
// note: storing transposed c^KV in the transposed KV cache
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view));
if (lctx.cparams.mla_attn == 1) {
ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank,
ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head));
cb(kv_cache_trans_view, "kv_cache_trans_view", il);
ggml_tensor * kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il],
n_kv, kv_lora_rank,
ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size),
0);
cb(kv_cache_trans, "kv_cache_trans", il);
#endif
// note: storing transposed c^KV in the transposed KV cache
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view));
kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il],
n_kv, kv_lora_rank,
ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size),
0);
cb(kv_cache_trans, "kv_cache_trans", il);
}
ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0);
cb(kvr, "kvr", il);
@@ -13607,15 +13592,15 @@ struct llm_build_context {
cb(kq, "kq_soft_max_ext_perm", il);
}
#if !MLA_USE_TRANSPOSED_CACHE
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank, n_kv,
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
cb(kv_cache, "kv_cache_lora", il);
if (lctx.cparams.mla_attn > 1) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank, n_kv,
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
cb(kv_cache, "kv_cache_lora", il);
ggml_tensor * kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora));
cb(kv_cache_trans, "kv_cache_trans", il);
#endif
kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora));
cb(kv_cache_trans, "kv_cache_trans", il);
}
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
cb(kqv_compressed, "kqv_compressed", il);
@@ -17658,7 +17643,7 @@ struct llama_context_params llama_context_default_params() {
/*.embeddings =*/ false,
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
/*.mla_attn =*/ false,
/*.mla_attn =*/ 0,
/*.fused_moe_up_gate =*/ false,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
@@ -18140,18 +18125,23 @@ struct llama_context * llama_new_context_with_model(
kv_type = kv->type;
}
#if MLA_USE_TRANSPOSED_CACHE
for (auto & kvt : ctx->kv_self.kvt_l) {
memory_size_kvt += ggml_nbytes(kvt);
kvt_type = kvt->type;
}
#endif
if (memory_size_kv + memory_size_kvt > 0) {
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__,
(float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f),
ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f),
ggml_type_name(kvt_type), (float)memory_size_kvt / (1024.0f * 1024.0f));
if (cparams.mla_attn == 1) {
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__,
(float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f),
ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f),
ggml_type_name(kvt_type), (float)memory_size_kvt / (1024.0f * 1024.0f));
} else {
GGML_ASSERT(memory_size_kvt == 0);
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T: not used\n", __func__,
(float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f),
ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f));
}
}
}