diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index e10e6cfe..097b1bc5 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -6492,9 +6492,14 @@ ggml_cgraph * llm_build_context::build_deepseek2() { auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024); int n_max_head = n_head; if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) { - while (n_max_head%2 == 0 && kv_f32_size > cparams.attn_max_batch) { - n_max_head /= 2; kv_f32_size /= 2; + n_max_head = 1; + for (int niter = 2; niter < n_head; ++niter) { + if (n_head % niter == 0 && kv_f32_size/(n_head/niter) <= cparams.attn_max_batch) { + n_max_head = n_head/niter; + break; + } } + //printf("Using n_max_head = %d -> kv_f32_size = %zu\n", n_max_head, kv_f32_size/(n_head/n_max_head)); } GGML_ASSERT(n_head % n_max_head == 0); @@ -6575,6 +6580,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() { } else { cur = ggml_concat(ctx0, cur, ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens), 0); } + ggml_build_forward_expand(gf, cur); }