mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-13 07:20:15 +00:00
Reduce compute buffer size for mla=3
This commit is contained in:
@@ -6492,9 +6492,14 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
|
|||||||
auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024);
|
auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024);
|
||||||
int n_max_head = n_head;
|
int n_max_head = n_head;
|
||||||
if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) {
|
if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) {
|
||||||
while (n_max_head%2 == 0 && kv_f32_size > cparams.attn_max_batch) {
|
n_max_head = 1;
|
||||||
n_max_head /= 2; kv_f32_size /= 2;
|
for (int niter = 2; niter < n_head; ++niter) {
|
||||||
|
if (n_head % niter == 0 && kv_f32_size/(n_head/niter) <= cparams.attn_max_batch) {
|
||||||
|
n_max_head = n_head/niter;
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
//printf("Using n_max_head = %d -> kv_f32_size = %zu\n", n_max_head, kv_f32_size/(n_head/n_max_head));
|
||||||
}
|
}
|
||||||
GGML_ASSERT(n_head % n_max_head == 0);
|
GGML_ASSERT(n_head % n_max_head == 0);
|
||||||
|
|
||||||
@@ -6575,6 +6580,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
|
|||||||
} else {
|
} else {
|
||||||
cur = ggml_concat(ctx0, cur, ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens), 0);
|
cur = ggml_concat(ctx0, cur, ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens), 0);
|
||||||
}
|
}
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user