mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-26 08:04:09 +00:00
Reduce compute buffer size for mla=3
This commit is contained in:
@@ -6492,9 +6492,14 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
|
||||
auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024);
|
||||
int n_max_head = n_head;
|
||||
if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) {
|
||||
while (n_max_head%2 == 0 && kv_f32_size > cparams.attn_max_batch) {
|
||||
n_max_head /= 2; kv_f32_size /= 2;
|
||||
n_max_head = 1;
|
||||
for (int niter = 2; niter < n_head; ++niter) {
|
||||
if (n_head % niter == 0 && kv_f32_size/(n_head/niter) <= cparams.attn_max_batch) {
|
||||
n_max_head = n_head/niter;
|
||||
break;
|
||||
}
|
||||
}
|
||||
//printf("Using n_max_head = %d -> kv_f32_size = %zu\n", n_max_head, kv_f32_size/(n_head/n_max_head));
|
||||
}
|
||||
GGML_ASSERT(n_head % n_max_head == 0);
|
||||
|
||||
@@ -6575,6 +6580,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
|
||||
} else {
|
||||
cur = ggml_concat(ctx0, cur, ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens), 0);
|
||||
}
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user