From ac409b4c7fc1473327497e744b9cc8dbc8af54f9 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 12 Nov 2025 14:52:13 +0200 Subject: [PATCH] Graph reuse: add command line argument to turn it on --- common/common.cpp | 9 +++- common/common.h | 1 + include/llama.h | 1 + src/llama-context.h | 7 +++ src/llama-cparams.h | 1 + src/llama.cpp | 115 ++++++++++++++++++++++++++++++++++---------- 6 files changed, 108 insertions(+), 26 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index eddd5117..92fc1c9c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1180,6 +1180,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.rope_cache = true; return true; } + if (arg == "-reuse" || arg == "--graph-reuse") { + params.graph_reuse = true; + return true; + } if (arg == "-ser" || arg == "--smart-expert-reduction") { CHECK_ARG auto values = string_split_pairs(argv[i], ','); @@ -2004,6 +2008,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-no-mmad, --no-fused-mul-multiadd", "disaable fused mul-multi_add (default: %s)", params.fused_mmad? "enabled" : "disabled" }); options.push_back({ "*", "-rcache, --rope-cache", "enable RoPE cache (default: %s)", params.rope_cache ? "enabled" : "disabled" }); + options.push_back({ "*", "-reuse, --graph-reuse", "enable graph reuse (default: %s)", params.graph_reuse ? "enabled" : "disabled" }); options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts}); options.push_back({ "*", "-mqkv, --merge-qkv,", "merge Q,K,V (default: %d)", params.merge_qkv}); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" @@ -2979,6 +2984,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.fused_up_gate = params.fused_up_gate; cparams.fused_mmad = params.fused_mmad; cparams.rope_cache = params.rope_cache; + cparams.graph_reuse = params.graph_reuse; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; cparams.only_active_experts = params.only_active_exps; @@ -4123,7 +4129,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "grouped_expert_routing: %s # default: false\n", params.grouped_expert_routing ? "true" : "false"); fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false"); fprintf(stream, "fused_mmad: %s # default: true\n", params.fused_mmad ? "true" : "false"); - fprintf(stream, "rope_cache: %s # default: true\n", params.rope_cache ? "true" : "false"); + fprintf(stream, "rope_cache: %s # default: false\n", params.rope_cache ? "true" : "false"); + fprintf(stream, "graph_reuse: %s # default: false\n", params.graph_reuse ? "true" : "false"); fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); diff --git a/common/common.h b/common/common.h index ee959618..f27eca97 100644 --- a/common/common.h +++ b/common/common.h @@ -254,6 +254,7 @@ struct gpt_params { bool fused_mmad = true; // fused mul+multi_add op bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch) bool rope_cache = false; // if to use RoPE cache (for supported models) + bool graph_reuse = false; // if to reuse compute graphs int min_experts = -1; float thresh_experts = 0; diff --git a/include/llama.h b/include/llama.h index fe7bca4d..7682951e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -429,6 +429,7 @@ extern "C" { bool fused_up_gate; // whether to use fused up/gate op [EXPERIMENTAL] bool fused_mmad; // whether to use fused mul+multi_add op [EXPERIMENTAL] bool rope_cache; // whether to use RoPE cache [EXPERIMENTAL] + bool graph_reuse; // whether to reuse graphs when possible [EXPERIMENTAL] int min_experts; float thresh_experts; bool only_active_experts; diff --git a/src/llama-context.h b/src/llama-context.h index b771d459..54764a3b 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -9,6 +9,7 @@ struct llama_model; #include #include #include +#include struct llama_kv_cell { llama_pos pos = -1; @@ -205,4 +206,10 @@ struct llama_context { ggml_backend_t ggml_backend_by_name(const char * name); + struct Prev; + std::unique_ptr prev; + + void reset_scheduler(); + bool can_reuse_graph(const llama_batch & u_batch) const; + }; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index fcc107a3..0d118369 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -38,6 +38,7 @@ struct llama_cparams { bool fused_up_gate; bool fused_mmad; bool rope_cache; + bool graph_reuse; int min_experts; float thresh_experts; diff --git a/src/llama.cpp b/src/llama.cpp index 8bb4e266..d5fab312 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -535,6 +535,42 @@ static size_t llama_get_device_memory(const llama_model & model, int device) { GGML_UNUSED(model); GGML_UNUSED(device); } + //llama_batch u_batch = { + // /* .n_tokens = */ (int32_t) n_tokens, + // /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr, + // /* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr, + // /* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr, + // /* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr, + // /* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr, + // /* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr, + // /* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1, + // /* .all_pos_1 = */ batch_all.all_pos_1, + // /* .all_seq_id = */ batch_all.all_seq_id, + //}; + +struct llama_context::Prev { + int all_pos_0, all_pos_1, all_seq_id; + int n_outputs; + int n_kv; + ggml_cgraph * graph; +}; + +void llama_context::reset_scheduler() { + ggml_backend_sched_reset(sched); + prev.reset(); +} + +bool llama_context::can_reuse_graph(const llama_batch & u_batch) const { + if (!prev || !prev->graph) return false; + if (u_batch.n_tokens > 1) return false; + if (u_batch.embd) return false; + if (!cparams.graph_reuse) return false; + return u_batch.all_pos_0 == prev->all_pos_0 && + u_batch.all_pos_1 == prev->all_pos_1 && + u_batch.all_seq_id == prev->all_seq_id && + kv_self.n == prev->n_kv && + n_outputs == prev->n_outputs; +} llama_context::llama_context(const llama_model & model) : model(model) , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) {} @@ -2876,27 +2912,60 @@ static int llama_decode_internal( printf("prelude(...): %d us\n", int(tim2-tim1)); #endif + + //if (n_tokens_all == 1) { + // printf("================= %s\n", __func__); + // printf(" all_pos_0 = %d, all_pos_1 = %d, all_seq_id = %d\n", batch_all.all_pos_0, batch_all.all_pos_1, batch_all.all_seq_id); + // printf(" embd = %p, logits = %p, token = %p\n", (const void *)batch_all.embd, (const void *)batch_all.logits, (const void *)batch_all.token); + // printf(" n_outputs = %d, kv_self.n = %d\n", n_outputs, kv_self.n); + //} //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif - ggml_backend_sched_reset(lctx.sched); - ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); + ggml_cgraph * gf = nullptr; + if (!lctx.can_reuse_graph(u_batch)) { + lctx.reset_scheduler(); + ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); #if IK_PRINT_TIMING - tim2 = ggml_time_us(); - printf("sched_reset(...): %d us\n", int(tim2-tim1)); + tim2 = ggml_time_us(); + printf("sched_reset(...): %d us\n", int(tim2-tim1)); #endif #if IK_PRINT_TIMING - tim1 = ggml_time_us(); + tim1 = ggml_time_us(); #endif - ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, u_batch, false); + gf = llm_build_context::llama_build_graph(lctx, u_batch, false); #if IK_PRINT_TIMING - tim2 = ggml_time_us(); - printf("build_graph(...): %d us\n", int(tim2-tim1)); + tim2 = ggml_time_us(); + printf("build_graph(...): %d us\n", int(tim2-tim1)); #endif +#if IK_PRINT_TIMING + tim1 = ggml_time_us(); +#endif + ggml_backend_sched_alloc_graph(lctx.sched, gf); +#if IK_PRINT_TIMING + tim2 = ggml_time_us(); + printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1)); +#endif +//struct llama_context::Prev { +// int all_pos_0, all_pos_1, all_seq_id; +// int n_outputs; +// int n_kv; +// ggml_cgraph * graph; +//}; + if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) { + lctx.prev = std::make_unique(llama_context::Prev{ + (int)u_batch.all_pos_0, (int)u_batch.all_pos_1, (int)u_batch.all_seq_id, + (int)lctx.n_outputs, (int)lctx.kv_self.n, gf}); + } + } else { + //printf("Reusing graph\n"); + gf = lctx.prev->graph; + } + // the output is always the last tensor in the graph struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2]; @@ -2921,15 +2990,6 @@ static int llama_decode_internal( } // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); -#if IK_PRINT_TIMING - tim1 = ggml_time_us(); -#endif - ggml_backend_sched_alloc_graph(lctx.sched, gf); -#if IK_PRINT_TIMING - tim2 = ggml_time_us(); - printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1)); -#endif - #if IK_PRINT_TIMING == 1 tim1 = ggml_time_us(); #endif @@ -3060,9 +3120,11 @@ static int llama_decode_internal( // Reset state for the next token before backend sync, to allow the CPU activities in the reset to // overlap with device computation. #if IK_PRINT_TIMING - auto tim1 = ggml_time_us(); + auto tim1 = ggml_time_us(); #endif - ggml_backend_sched_reset(lctx.sched); + if (!lctx.prev) { + lctx.reset_scheduler(); + } #if IK_PRINT_TIMING auto tim2 = ggml_time_us(); printf("sched_reset(...): %d us\n", int(tim2-tim1)); @@ -3158,7 +3220,7 @@ static int llama_encode_internal( batch.seq_id = seq_id_arr.data(); } - ggml_backend_sched_reset(lctx.sched); + lctx.reset_scheduler(); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, batch, false); @@ -3248,7 +3310,7 @@ static int llama_encode_internal( // Reset state for the next token before backend sync, to allow the CPU activities in the reset to // overlap with device computation. - ggml_backend_sched_reset(lctx.sched); + lctx.reset_scheduler(); return 0; } @@ -3462,7 +3524,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { #else // ggml_graph defrag - ggml_backend_sched_reset(lctx.sched); + lctx.reset_scheduler(); ggml_cgraph * gf = llm_build_context::llama_build_graph_defrag(lctx, ids); @@ -3484,7 +3546,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) { } { - ggml_backend_sched_reset(lctx.sched); + lctx.reset_scheduler(); ggml_cgraph * gf = llm_build_context::llama_build_graph_k_shift(lctx); @@ -3510,7 +3572,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) { if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) { { - ggml_backend_sched_reset(lctx.sched); + lctx.reset_scheduler(); ggml_cgraph * gf = llm_build_context::llama_build_graph_s_copy(lctx); @@ -3553,7 +3615,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) { ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true); // initialize scheduler with the worst-case graph - ggml_backend_sched_reset(lctx.sched); + lctx.reset_scheduler(); if (!ggml_backend_sched_reserve(lctx.sched, gf)) { LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); } @@ -3840,6 +3902,7 @@ struct llama_context_params llama_context_default_params() { /*.fused_up_gate =*/ true, /*.fused_mmad =*/ true, /*.rope_cache =*/ false, + /*.graph_reuse =*/ false, /*.min_experts =*/ -1, /*.thtesh_experts =*/ 0.0f, /*.only_active_experts =*/ false, @@ -4144,6 +4207,7 @@ struct llama_context * llama_new_context_with_model( cparams.fused_up_gate = params.fused_up_gate; cparams.fused_mmad = params.fused_mmad; cparams.rope_cache = params.rope_cache; + cparams.graph_reuse = params.graph_reuse; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; cparams.cuda_params = params.cuda_params; @@ -4230,6 +4294,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate); LLAMA_LOG_INFO("%s: fused_mmad = %d\n", __func__, cparams.fused_mmad); LLAMA_LOG_INFO("%s: rope_cache = %d\n", __func__, cparams.rope_cache); + LLAMA_LOG_INFO("%s: graph_reuse = %d\n", __func__, cparams.graph_reuse); LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);