Reduce compute buffer size for mla=3

This commit is contained in:
Kawrakow
2026-01-31 10:43:05 +00:00
parent 686fd1ebec
commit b85a2a50d5

View File

@@ -6492,9 +6492,14 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024);
int n_max_head = n_head;
if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) {
while (n_max_head%2 == 0 && kv_f32_size > cparams.attn_max_batch) {
n_max_head /= 2; kv_f32_size /= 2;
n_max_head = 1;
for (int niter = 2; niter < n_head; ++niter) {
if (n_head % niter == 0 && kv_f32_size/(n_head/niter) <= cparams.attn_max_batch) {
n_max_head = n_head/niter;
break;
}
}
//printf("Using n_max_head = %d -> kv_f32_size = %zu\n", n_max_head, kv_f32_size/(n_head/n_max_head));
}
GGML_ASSERT(n_head % n_max_head == 0);
@@ -6575,6 +6580,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
} else {
cur = ggml_concat(ctx0, cur, ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens), 0);
}
ggml_build_forward_expand(gf, cur);
}