diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 707ba1a8..532143d5 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -121,7 +121,7 @@ void llm_build_context::free() { } ggml_cgraph * llm_build_context::build_k_shift() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(n_tokens), false); GGML_ASSERT(kv_self.size == n_ctx); @@ -189,7 +189,7 @@ ggml_cgraph * llm_build_context::build_k_shift() { } ggml_cgraph * llm_build_context::build_s_copy() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(n_tokens), false); const uint32_t qnext_state_slots = llama_kv_qnext_state_slots(kv_self); const bool has_qnext_state = qnext_state_slots > 0; @@ -225,7 +225,7 @@ ggml_cgraph * llm_build_context::build_s_copy() { } ggml_cgraph * llm_build_context::build_defrag(const std::vector & ids) { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(n_tokens), false); for (uint32_t i = 0; i < ids.size(); ++i) { const uint32_t id = ids[i]; @@ -1937,7 +1937,7 @@ static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml } ggml_cgraph * llm_build_context::build_llama() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(n_tokens), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -2153,7 +2153,7 @@ ggml_cgraph * llm_build_context::build_llama() { } ggml_cgraph * llm_build_context::build_mistral3() { - auto gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + auto gf = ggml_new_graph_custom(ctx0, model.max_nodes(n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -2244,7 +2244,7 @@ ggml_cgraph * llm_build_context::build_mistral3() { } ggml_cgraph * llm_build_context::build_deci() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(n_tokens), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -2381,7 +2381,7 @@ ggml_cgraph * llm_build_context::build_deci() { } ggml_cgraph * llm_build_context::build_baichuan() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -2480,10 +2480,10 @@ ggml_cgraph * llm_build_context::build_baichuan() { return gf; } -static inline size_t llama_model_max_nodes(const llama_model & model) { return model.max_nodes(); } +static inline size_t llama_model_max_nodes(const llama_model & model, int n_tokens) { return model.max_nodes(n_tokens); } ggml_cgraph * llm_build_context::build_xverse() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -2573,7 +2573,7 @@ ggml_cgraph * llm_build_context::build_xverse() { } ggml_cgraph * llm_build_context::build_falcon() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -2684,7 +2684,7 @@ ggml_cgraph * llm_build_context::build_falcon() { } ggml_cgraph * llm_build_context::build_grok() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -2824,7 +2824,7 @@ ggml_cgraph * llm_build_context::build_grok() { } ggml_cgraph * llm_build_context::build_dbrx() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -2946,7 +2946,7 @@ ggml_cgraph * llm_build_context::build_dbrx() { } ggml_cgraph * llm_build_context::build_starcoder() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -3038,7 +3038,7 @@ ggml_cgraph * llm_build_context::build_starcoder() { } ggml_cgraph * llm_build_context::build_refact() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -3118,7 +3118,7 @@ ggml_cgraph * llm_build_context::build_refact() { } ggml_cgraph * llm_build_context::build_bert() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -3306,7 +3306,7 @@ ggml_cgraph * llm_build_context::build_bert() { } ggml_cgraph * llm_build_context::build_bloom() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -3392,7 +3392,7 @@ ggml_cgraph * llm_build_context::build_bloom() { } ggml_cgraph * llm_build_context::build_mpt() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -3632,7 +3632,7 @@ ggml_cgraph * llm_build_context::build_stablelm() { } ggml_cgraph * llm_build_context::build_seedoss() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -3682,7 +3682,7 @@ ggml_cgraph * llm_build_context::build_seedoss() { } ggml_cgraph * llm_build_context::build_step35() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); ggml_tensor * cur; auto inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); auto inp_pos = build_inp_pos(); @@ -3753,7 +3753,7 @@ ggml_cgraph * llm_build_context::build_step35() { } ggml_cgraph * llm_build_context::build_qwen() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -3856,7 +3856,7 @@ ggml_cgraph * llm_build_context::build_qwen() { } ggml_cgraph * llm_build_context::build_qwen2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -3947,7 +3947,7 @@ ggml_cgraph * llm_build_context::build_qwen2() { } ggml_cgraph * llm_build_context::build_qwen2vl() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; @@ -4048,7 +4048,7 @@ ggml_cgraph * llm_build_context::build_qwen2vl() { } ggml_cgraph * llm_build_context::build_qwen2moe() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -4178,7 +4178,7 @@ ggml_cgraph * llm_build_context::build_qwen2moe() { } ggml_cgraph * llm_build_context::build_qwen3() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -4279,7 +4279,7 @@ ggml_cgraph * llm_build_context::build_qwen3() { } ggml_cgraph * llm_build_context::build_qwen3moe() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -4344,7 +4344,7 @@ ggml_cgraph * llm_build_context::build_qwen3moe() { ggml_cgraph * llm_build_context::build_qwen3next() { static constexpr int QWEN3NEXT_CHUNK_SIZE = 64; - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); delta_net delta(lctx, batch); @@ -4503,7 +4503,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() { ggml_cgraph * llm_build_context::build_qwen35moe() { static constexpr int QWEN3NEXT_CHUNK_SIZE = 64; - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); delta_net delta(lctx, batch); @@ -4638,7 +4638,7 @@ ggml_cgraph * llm_build_context::build_qwen35moe() { } ggml_cgraph * llm_build_context::build_qwen3vl() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_full = hparams.n_embd; // main embd + deepstack embds const size_t n_deepstack_layers = hparams.n_deepstack_layers; @@ -4707,7 +4707,7 @@ ggml_cgraph * llm_build_context::build_qwen3vl() { } ggml_cgraph * llm_build_context::build_qwen3vlmoe() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -4853,7 +4853,7 @@ ggml_cgraph * llm_build_context::build_qwen3vlmoe() { } ggml_cgraph * llm_build_context::build_phi2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -4968,7 +4968,7 @@ ggml_cgraph * llm_build_context::build_phi2() { } ggml_cgraph * llm_build_context::build_phi3() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -5183,7 +5183,7 @@ ggml_cgraph * llm_build_context::build_plamo() { } ggml_cgraph * llm_build_context::build_gpt2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -5276,7 +5276,7 @@ ggml_cgraph * llm_build_context::build_gpt2() { } ggml_cgraph * llm_build_context::build_codeshell() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -5375,7 +5375,7 @@ ggml_cgraph * llm_build_context::build_codeshell() { } ggml_cgraph * llm_build_context::build_orion() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5465,7 +5465,7 @@ ggml_cgraph * llm_build_context::build_orion() { } ggml_cgraph * llm_build_context::build_internlm2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5558,7 +5558,7 @@ ggml_cgraph * llm_build_context::build_internlm2() { // https://github.com/ggerganov/llama.cpp/issues/5276#issuecomment-1925774738 // based on the original build_llama() function ggml_cgraph * llm_build_context::build_minicpm() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5675,7 +5675,7 @@ ggml_cgraph * llm_build_context::build_minicpm() { } ggml_cgraph * llm_build_context::build_gemma() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head_k = hparams.n_embd_head_k; @@ -5767,7 +5767,7 @@ ggml_cgraph * llm_build_context::build_gemma() { } ggml_cgraph * llm_build_context::build_gemma2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head_k = hparams.n_embd_head_k; @@ -5884,7 +5884,7 @@ ggml_cgraph * llm_build_context::build_gemma2() { } ggml_cgraph * llm_build_context::build_gemma3() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -6004,7 +6004,7 @@ ggml_cgraph * llm_build_context::build_gemma3() { } ggml_cgraph * llm_build_context::build_starcoder2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -6094,7 +6094,7 @@ ggml_cgraph * llm_build_context::build_starcoder2() { } ggml_cgraph * llm_build_context::build_mamba() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t d_model = n_embd; const int64_t d_conv = hparams.ssm_d_conv; @@ -6240,7 +6240,7 @@ ggml_cgraph * llm_build_context::build_mamba() { ggml_cgraph * llm_build_context::build_command_r() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -6366,7 +6366,7 @@ ggml_cgraph * llm_build_context::build_command_r() { // * removed bias // * removed MoE ggml_cgraph * llm_build_context::build_olmo() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -6484,7 +6484,7 @@ ggml_cgraph * llm_build_context::build_olmo() { } ggml_cgraph * llm_build_context::build_openelm() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -6596,7 +6596,7 @@ ggml_cgraph * llm_build_context::build_openelm() { } ggml_cgraph * llm_build_context::build_gptneox() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -6726,7 +6726,7 @@ ggml_cgraph * llm_build_context::build_gptneox() { } ggml_cgraph * llm_build_context::build_arctic() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -6848,7 +6848,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() { #else constexpr bool use_f32_attn_precision = false; #endif - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -7389,7 +7389,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() { ggml_cgraph * llm_build_context::build_glm4_moe() { // create a new graph - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7670,7 +7670,7 @@ struct ggml_tensor * llm_build_context::build_mtp_tail( } ggml_cgraph * llm_build_context::build_bitnet() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7815,7 +7815,7 @@ ggml_cgraph * llm_build_context::build_bitnet() { } ggml_cgraph * llm_build_context::build_bitnet_158() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -7933,7 +7933,7 @@ ggml_cgraph * llm_build_context::build_bitnet_158() { } ggml_cgraph * llm_build_context::build_cohere2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -8010,7 +8010,7 @@ ggml_cgraph * llm_build_context::build_cohere2() { } ggml_cgraph * llm_build_context::build_t5_encoder() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -8128,7 +8128,7 @@ ggml_cgraph * llm_build_context::build_t5_encoder() { } ggml_cgraph * llm_build_context::build_t5_decoder() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -8317,7 +8317,7 @@ ggml_cgraph * llm_build_context::build_t5_decoder() { } ggml_cgraph * llm_build_context::build_jais() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -8397,7 +8397,7 @@ ggml_cgraph * llm_build_context::build_jais() { } ggml_cgraph * llm_build_context::build_chatglm() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -8499,7 +8499,7 @@ ggml_cgraph * llm_build_context::build_chatglm() { } ggml_cgraph * llm_build_context::build_glm4() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -8629,7 +8629,7 @@ ggml_cgraph * llm_build_context::build_glm4() { } ggml_cgraph * llm_build_context::build_dots1() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; @@ -8770,7 +8770,7 @@ ggml_cgraph * llm_build_context::build_dots1() { } ggml_cgraph * llm_build_context::build_ernie4_5() { - struct ggml_cgraph* gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph* gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -8886,7 +8886,7 @@ ggml_cgraph * llm_build_context::build_ernie4_5() { } ggml_cgraph * llm_build_context::build_ernie4_5_moe() { - struct ggml_cgraph* gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph* gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -8955,7 +8955,7 @@ ggml_cgraph * llm_build_context::build_ernie4_5_moe() { } ggml_cgraph * llm_build_context::build_hunyuan_moe() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; @@ -9012,7 +9012,7 @@ ggml_cgraph * llm_build_context::build_hunyuan_moe() { } ggml_cgraph * llm_build_context::build_mimo2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); //const int64_t n_embd_head = hparams.n_embd_head_v; //GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -9084,7 +9084,7 @@ ggml_cgraph * llm_build_context::build_mimo2() { } ggml_cgraph * llm_build_context::build_openai_moe() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -9146,7 +9146,7 @@ ggml_cgraph * llm_build_context::build_openai_moe() { } ggml_cgraph * llm_build_context::build_bailingmoe2() { - ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; //const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -9275,7 +9275,7 @@ ggml_cgraph * llm_build_context::build_bailingmoe2() { } ggml_cgraph* llm_build_context::build_minimaxm2() { - ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); // GGML_ASSERT(n_embd_head == hparams.n_rot); this is wrong in case of minimax, head_dim = 128, n_rot = 64 @@ -9527,7 +9527,7 @@ ggml_cgraph* llm_build_context::build_minimaxm2() { } ggml_cgraph* llm_build_context::build_smollm3() { - ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); // GGML_ASSERT(n_embd_head == hparams.n_rot); this is wrong in case of minimax, head_dim = 128, n_rot = 64 diff --git a/src/llama-model.h b/src/llama-model.h index 02358fb5..0d27d937 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -422,8 +422,15 @@ struct llama_model { ~llama_model(); - // Not actually needed, but left in place for now - size_t max_nodes() const { return 65536 * 2; } + size_t max_nodes(int n_tokens) const { + auto n_tensors = tensors_by_name.size(); + if (split_mode == LLAMA_SPLIT_MODE_GRAPH && !devices.empty()) n_tensors *= devices.size(); + if (arch == LLM_ARCH_QWEN3NEXT || arch == LLM_ARCH_QWEN35MOE) { + return std::max(n_tokens * 40, 32u * n_tensors); + } + return std::max(1024, 8*n_tensors); + //return 65536 * 2; + } bool has_tensor_overrides() const { return tensor_overrides; diff --git a/src/llama.cpp b/src/llama.cpp index a21b3518..477a505f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3792,7 +3792,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { // - x2 for keys and values //const uint32_t max_moves = model.max_nodes()/(6*n_layer); // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 - const uint32_t max_moves = (lctx.model.max_nodes() - 2*n_layer)/(6*n_layer); + const uint32_t max_moves = (lctx.model.max_nodes(1) - 2*n_layer)/(6*n_layer); // determine which KV cells to move where // @@ -5112,7 +5112,8 @@ struct llama_context * llama_init_from_model( } } - const size_t max_nodes = model->max_nodes(); + int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch); + const size_t max_nodes = model->max_nodes(n_tokens); // buffer used to store the computation graph and the tensor meta data ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); @@ -5137,7 +5138,6 @@ struct llama_context * llama_init_from_model( llama_repack_up_gate_exps(*ctx); // build worst-case graph - int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch); int n_past = cparams.n_ctx - n_tokens; llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph ggml_cgraph * gf = llm_build_context::llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);