From 6b9d1bf4b4d90639a08da2c054bd1ce17340f9a5 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 14 Nov 2025 06:58:19 +0200 Subject: [PATCH] Graph reuse (#947) * Add mainline compatible FA command line option * Graph reuse: add command line argument to turn it on * WIP * This seems to work * This is perhaps cleaner * Change the command line option to -gr --------- Co-authored-by: Iwan Kawrakow --- common/common.cpp | 10 ++- common/common.h | 1 + examples/llama-bench/llama-bench.cpp | 39 +++++++-- include/llama.h | 1 + src/llama-build-context.cpp | 20 +++-- src/llama-build-context.h | 2 +- src/llama-context.h | 15 ++++ src/llama-cparams.h | 1 + src/llama.cpp | 123 +++++++++++++++++++++------ 9 files changed, 174 insertions(+), 38 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 979ca74e..d12e645b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1135,6 +1135,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.flash_attn = false; return true; } + if (arg == "-fa" || arg == "--flash-attn") { CHECK_ARG std::string next_arg{argv[i]}; @@ -1180,6 +1181,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.rope_cache = true; return true; } + if (arg == "-gr" || arg == "--graph-reuse") { + params.graph_reuse = true; + return true; + } if (arg == "-ser" || arg == "--smart-expert-reduction") { CHECK_ARG auto values = string_split_pairs(argv[i], ','); @@ -2004,6 +2009,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-no-mmad, --no-fused-mul-multiadd", "disaable fused mul-multi_add (default: %s)", params.fused_mmad? "enabled" : "disabled" }); options.push_back({ "*", "-rcache, --rope-cache", "enable RoPE cache (default: %s)", params.rope_cache ? "enabled" : "disabled" }); + options.push_back({ "*", "-gr, --graph-reuse", "enable graph reuse (default: %s)", params.graph_reuse ? "enabled" : "disabled" }); options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts}); options.push_back({ "*", "-mqkv, --merge-qkv,", "merge Q,K,V (default: %d)", params.merge_qkv}); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" @@ -2979,6 +2985,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.fused_up_gate = params.fused_up_gate; cparams.fused_mmad = params.fused_mmad; cparams.rope_cache = params.rope_cache; + cparams.graph_reuse = params.graph_reuse; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; cparams.only_active_experts = params.only_active_exps; @@ -4123,7 +4130,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "grouped_expert_routing: %s # default: false\n", params.grouped_expert_routing ? "true" : "false"); fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false"); fprintf(stream, "fused_mmad: %s # default: true\n", params.fused_mmad ? "true" : "false"); - fprintf(stream, "rope_cache: %s # default: true\n", params.rope_cache ? "true" : "false"); + fprintf(stream, "rope_cache: %s # default: false\n", params.rope_cache ? "true" : "false"); + fprintf(stream, "graph_reuse: %s # default: false\n", params.graph_reuse ? "true" : "false"); fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); diff --git a/common/common.h b/common/common.h index ee959618..f27eca97 100644 --- a/common/common.h +++ b/common/common.h @@ -254,6 +254,7 @@ struct gpt_params { bool fused_mmad = true; // fused mul+multi_add op bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch) bool rope_cache = false; // if to use RoPE cache (for supported models) + bool graph_reuse = false; // if to reuse compute graphs int min_experts = -1; float thresh_experts = 0; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 77327603..ed859e31 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -251,6 +251,7 @@ struct cmd_params { std::vector mla_attn; std::vector attn_max_batch; std::vector ser; + std::vector reuse; std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; @@ -292,6 +293,7 @@ static const cmd_params cmd_params_defaults = { /* mla_attn */ {3}, /* attn_max_batch */ {0}, /* ser */ {{-1,0.0f}}, + /* reuse */ {false}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -339,6 +341,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); printf(" -amb, --attn-max-batch (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str()); printf(" -ser, --smart-expert-reduction (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str()); + printf(" -gr, --graph-reuse <0|1> (default: %s)\n", join(cmd_params_defaults.reuse, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -681,6 +684,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.attn_max_batch.insert(params.attn_max_batch.end(), p.begin(), p.end()); + } else if (arg == "-gr" || arg == "--graph-reuse") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.reuse.insert(params.reuse.end(), p.begin(), p.end()); } else if (arg == "-ser" || arg == "--smart-expert-reduction") { if (++i >= argc) { invalid_param = true; @@ -852,6 +862,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; } if (params.attn_max_batch.empty()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch; } + if (params.reuse.empty()) { params.reuse = cmd_params_defaults.reuse; } if (params.ser.empty()) { params.ser = cmd_params_defaults.ser; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } @@ -891,6 +902,7 @@ struct cmd_params_instance { bool flash_attn; int mla_attn; int attn_max_batch; + bool reuse; Ser ser; std::vector tensor_split; std::string cuda_params; @@ -950,6 +962,7 @@ struct cmd_params_instance { cparams.flash_attn = flash_attn; cparams.mla_attn = mla_attn; cparams.attn_max_batch = attn_max_batch; + cparams.graph_reuse = reuse; cparams.fused_moe_up_gate = fmoe; cparams.grouped_expert_routing = ger; cparams.rope_cache = rcache; @@ -984,6 +997,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & fa : params.flash_attn) for (const auto & mla : params.mla_attn) for (const auto & amb : params.attn_max_batch) + for (const auto & reuse : params.reuse) for (const auto & ser : params.ser) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { @@ -1008,6 +1022,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .reuse = */ reuse, /* .ser = */ ser, /* .tensor_split = */ ts, /* .cuda_params = */ params.cuda_params, @@ -1048,6 +1063,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .reuse = */ reuse, /* .ser = */ ser, /* .tensor_split = */ ts, /* .cuda_params = */ params.cuda_params, @@ -1088,6 +1104,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .reuse = */ reuse, /* .ser = */ ser, /* .tensor_split = */ ts, /* .cuda_params = */ params.cuda_params, @@ -1128,6 +1145,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .reuse = */ reuse, /* .ser = */ ser, /* .tensor_split = */ ts, /* .cuda_params = */ params.cuda_params, @@ -1179,6 +1197,7 @@ struct test { bool flash_attn; int mla_attn; int attn_max_batch; + bool reuse; Ser ser; std::vector tensor_split; std::string cuda_params; @@ -1219,6 +1238,7 @@ struct test { flash_attn = inst.flash_attn; mla_attn = inst.mla_attn; attn_max_batch = inst.attn_max_batch; + reuse = inst.reuse; ser = inst.ser; tensor_split = inst.tensor_split; cuda_params = inst.cuda_params; @@ -1321,7 +1341,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", + "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", "reuse", "tensor_split", "use_mmap", "embeddings", "repack", "mqkv", "fused_moe", "grouped_er", "fused_up_gate", "use_thp", "ooae", "rcache", "n_prompt", "n_gen", "test_time", @@ -1346,7 +1366,7 @@ struct test { field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" || field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate" || field == "ooae" || field == "mqkv" || - field == "rcache") { + field == "rcache" || field == "reuse") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -1387,7 +1407,7 @@ struct test { std::to_string(is_gen ? n_threads.first : n_threads.second), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), - std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), + std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), std::to_string(reuse), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe), std::to_string(ger), std::to_string(rcache), std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae), std::to_string(mqkv), @@ -1559,6 +1579,9 @@ struct markdown_printer : public printer { if (field == "attn_max_batch") { return 5; } + if (field == "reuse") { + return 2; + } if (field == "ser") { return 10; } @@ -1623,7 +1646,10 @@ struct markdown_printer : public printer { if (field == "attn_max_batch") { return "amb"; } - if (field == "attn_max_batch") { + if (field == "reuse") { + return "gr"; + } + if (field == "ser") { return "ser"; } if (field == "use_mmap") { @@ -1702,9 +1728,12 @@ struct markdown_printer : public printer { if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) { fields.emplace_back("mla_attn"); } - if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn) { + if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.attn_max_batch) { fields.emplace_back("attn_max_batch"); } + if (params.reuse.size() > 1 || params.reuse != cmd_params_defaults.reuse) { + fields.emplace_back("reuse"); + } if (params.ser.size() > 1 || params.ser != cmd_params_defaults.ser) { fields.emplace_back("ser"); } diff --git a/include/llama.h b/include/llama.h index fe7bca4d..7682951e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -429,6 +429,7 @@ extern "C" { bool fused_up_gate; // whether to use fused up/gate op [EXPERIMENTAL] bool fused_mmad; // whether to use fused mul+multi_add op [EXPERIMENTAL] bool rope_cache; // whether to use RoPE cache [EXPERIMENTAL] + bool graph_reuse; // whether to reuse graphs when possible [EXPERIMENTAL] int min_experts; float thresh_experts; bool only_active_experts; diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 96d39b24..f2b28a1a 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -469,6 +469,7 @@ ggml_tensor * llm_build_context::llm_build_inp_embd( } void llm_build_context::llm_build_kv_store( + struct llama_context & lctx, struct ggml_context * ctx, const llama_hparams & hparams, const llama_cparams & cparams, @@ -494,29 +495,36 @@ void llm_build_context::llm_build_kv_store( // (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); //cb(k_cache_view, "k_cache_view", il); + GGML_ASSERT(2*il+1 < (int)lctx.cache_copies.size()); auto k_row_size = ggml_row_size(kv.k_l[il]->type, n_embd_head_k); ggml_tensor * k_cache_view = ggml_view_2d(ctx, kv.k_l[il], n_embd_head_k, n_tokens*n_head_kv, k_row_size, k_row_size*n_head_kv*kv_head); + lctx.cache_copies[2*il+0].cpy = ggml_cpy(ctx, k_cur, k_cache_view); + lctx.cache_copies[2*il+0].step = k_row_size*n_head_kv; + // note: storing RoPE-ed version of K in the KV cache - ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); + ggml_build_forward_expand(graph, lctx.cache_copies[2*il+0].cpy); struct ggml_tensor * v_cache_view = nullptr; if (cparams.flash_attn) { v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, (kv_head)*ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)); + lctx.cache_copies[2*il+1].step = ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa); } else { // note: the V cache is transposed when not using flash attention v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, ( n_ctx)*ggml_element_size(kv.v_l[il]), (kv_head)*ggml_element_size(kv.v_l[il])); + lctx.cache_copies[2*il+1].step = ggml_element_size(kv.v_l[il]); v_cur = ggml_transpose(ctx, v_cur); } cb(v_cache_view, "v_cache_view", il); - ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + lctx.cache_copies[2*il+1].cpy = ggml_cpy(ctx, v_cur, v_cache_view); + ggml_build_forward_expand(graph, lctx.cache_copies[2*il+1].cpy); } ggml_tensor * llm_build_context::llm_build_lora_mm( @@ -1205,7 +1213,7 @@ ggml_tensor * llm_build_context::llm_build_kv( ggml_build_forward_expand(graph, k_cur); ggml_build_forward_expand(graph, v_cur); - llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il); + llm_build_kv_store(lctx, ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il); struct ggml_tensor * cur; @@ -6045,7 +6053,9 @@ ggml_cgraph * llm_build_context::build_deepseek2() { auto row_size = ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope); ggml_tensor * kv_cache_view = ggml_view_2d(ctx0, kv_self.k_l[il], kv_self.k_l[il]->ne[0], n_tokens, row_size, row_size*kv_head); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, kvr, kv_cache_view)); + lctx.cache_copies[2*il+0].cpy = ggml_cpy(ctx0, kvr, kv_cache_view); + lctx.cache_copies[2*il+0].step = row_size; + ggml_build_forward_expand(gf, lctx.cache_copies[2*il+0].cpy); ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.k_l[il], kv_lora_rank + n_embd_head_qk_rope, n_kv, ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); @@ -7082,7 +7092,7 @@ ggml_cgraph * llm_build_context::build_t5_decoder() { model.layers[il].wk, nullptr, model.layers[il].wv, nullptr, 0, il); - llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); + llm_build_kv_store(lctx, ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); struct ggml_tensor * k = ggml_view_3d(ctx0, kv_self.k_l[il], diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 8e9d7adb..a96a49d4 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -292,7 +292,7 @@ struct llm_build_context { llm_norm_type type, const llm_build_cb & cb, int il, float scale_eps = 1); - static void llm_build_kv_store(ggml_context * ctx, const llama_hparams & hparams, + static void llm_build_kv_store(llama_context & lctx, ggml_context * ctx, const llama_hparams & hparams, const llama_cparams & cparams, const llama_kv_cache & kv, ggml_cgraph * graph, diff --git a/src/llama-context.h b/src/llama-context.h index b771d459..bb21d880 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -9,6 +9,7 @@ struct llama_model; #include #include #include +#include struct llama_kv_cell { llama_pos pos = -1; @@ -205,4 +206,18 @@ struct llama_context { ggml_backend_t ggml_backend_by_name(const char * name); + struct Prev; + std::unique_ptr prev; + + void reset_scheduler(); + bool can_reuse_graph(const llama_batch & u_batch); + + struct CacheCopy { + ggml_tensor * cpy = nullptr; + size_t step = 0; + }; + std::vector cache_copies; + + bool update_cache_copies(); + }; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index fcc107a3..0d118369 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -38,6 +38,7 @@ struct llama_cparams { bool fused_up_gate; bool fused_mmad; bool rope_cache; + bool graph_reuse; int min_experts; float thresh_experts; diff --git a/src/llama.cpp b/src/llama.cpp index 8bb4e266..bcf9433c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -536,8 +536,57 @@ static size_t llama_get_device_memory(const llama_model & model, int device) { GGML_UNUSED(device); } +struct llama_context::Prev { + int all_seq_id; + int n_outputs; + int n_kv; + ggml_cgraph * graph; +}; + +void llama_context::reset_scheduler() { + ggml_backend_sched_reset(sched); + prev.reset(); +} + +bool llama_context::can_reuse_graph(const llama_batch & u_batch) { + if (!prev || !prev->graph) return false; + if (u_batch.n_tokens > 1) return false; + if (u_batch.embd) return false; + if (!cparams.graph_reuse) return false; + return u_batch.all_seq_id == prev->all_seq_id && + kv_self.head > 0 && + kv_self.n == prev->n_kv && + n_outputs == prev->n_outputs && + update_cache_copies(); +} + +bool llama_context::update_cache_copies() { + int n_layer = cache_copies.size()/2; + if ((int)kv_self.k_l.size() != n_layer) return false; + if (!(kv_self.v_l.empty() || (int)kv_self.v_l.size() == n_layer)) return false; + for (int il = 0; il < n_layer; ++il) { + auto& c = cache_copies[2*il+0]; + if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.k_l[il]) return false; + c.cpy->view_offs = kv_self.head*c.step; + c.cpy->src[1]->data = (char *)kv_self.k_l[il]->data + c.cpy->view_offs; + c.cpy->data = c.cpy->src[1]->data; + } + if (kv_self.v_l.empty()) return true; + for (int il = 0; il < n_layer; ++il) { + auto& c = cache_copies[2*il+1]; + if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.v_l[il]) return false; + c.cpy->view_offs = kv_self.head*c.step; + c.cpy->src[1]->data = (char *)kv_self.v_l[il]->data + c.cpy->view_offs; + c.cpy->data = c.cpy->src[1]->data; + } + return true; +} + llama_context::llama_context(const llama_model & model) - : model(model) , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) {} + : model(model) , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) { + const auto & hparams = model.hparams; + cache_copies.resize(2*hparams.n_layer); +} llama_context::~llama_context() { ggml_backend_sched_free(sched); @@ -2876,27 +2925,53 @@ static int llama_decode_internal( printf("prelude(...): %d us\n", int(tim2-tim1)); #endif + + //if (n_tokens_all == 1) { + // printf("================= %s\n", __func__); + // printf(" all_pos_0 = %d, all_pos_1 = %d, all_seq_id = %d\n", batch_all.all_pos_0, batch_all.all_pos_1, batch_all.all_seq_id); + // printf(" embd = %p, logits = %p, token = %p\n", (const void *)batch_all.embd, (const void *)batch_all.logits, (const void *)batch_all.token); + // printf(" n_outputs = %d, kv_self.n = %d\n", n_outputs, kv_self.n); + //} //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif - ggml_backend_sched_reset(lctx.sched); - ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); + ggml_cgraph * gf = nullptr; + if (!lctx.can_reuse_graph(u_batch)) { + lctx.reset_scheduler(); + ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); #if IK_PRINT_TIMING - tim2 = ggml_time_us(); - printf("sched_reset(...): %d us\n", int(tim2-tim1)); + tim2 = ggml_time_us(); + printf("sched_reset(...): %d us\n", int(tim2-tim1)); #endif #if IK_PRINT_TIMING - tim1 = ggml_time_us(); + tim1 = ggml_time_us(); #endif - ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, u_batch, false); + gf = llm_build_context::llama_build_graph(lctx, u_batch, false); #if IK_PRINT_TIMING - tim2 = ggml_time_us(); - printf("build_graph(...): %d us\n", int(tim2-tim1)); + tim2 = ggml_time_us(); + printf("build_graph(...): %d us\n", int(tim2-tim1)); #endif +#if IK_PRINT_TIMING + tim1 = ggml_time_us(); +#endif + ggml_backend_sched_alloc_graph(lctx.sched, gf); +#if IK_PRINT_TIMING + tim2 = ggml_time_us(); + printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1)); +#endif + if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) { + lctx.prev = std::make_unique(llama_context::Prev{ + (int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n, gf}); + } + } else { + //printf("Reusing graph\n"); + gf = lctx.prev->graph; + } + // the output is always the last tensor in the graph struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2]; @@ -2921,15 +2996,6 @@ static int llama_decode_internal( } // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); -#if IK_PRINT_TIMING - tim1 = ggml_time_us(); -#endif - ggml_backend_sched_alloc_graph(lctx.sched, gf); -#if IK_PRINT_TIMING - tim2 = ggml_time_us(); - printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1)); -#endif - #if IK_PRINT_TIMING == 1 tim1 = ggml_time_us(); #endif @@ -3060,9 +3126,11 @@ static int llama_decode_internal( // Reset state for the next token before backend sync, to allow the CPU activities in the reset to // overlap with device computation. #if IK_PRINT_TIMING - auto tim1 = ggml_time_us(); + auto tim1 = ggml_time_us(); #endif - ggml_backend_sched_reset(lctx.sched); + if (!lctx.prev) { + lctx.reset_scheduler(); + } #if IK_PRINT_TIMING auto tim2 = ggml_time_us(); printf("sched_reset(...): %d us\n", int(tim2-tim1)); @@ -3158,7 +3226,7 @@ static int llama_encode_internal( batch.seq_id = seq_id_arr.data(); } - ggml_backend_sched_reset(lctx.sched); + lctx.reset_scheduler(); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, batch, false); @@ -3248,7 +3316,7 @@ static int llama_encode_internal( // Reset state for the next token before backend sync, to allow the CPU activities in the reset to // overlap with device computation. - ggml_backend_sched_reset(lctx.sched); + lctx.reset_scheduler(); return 0; } @@ -3462,7 +3530,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { #else // ggml_graph defrag - ggml_backend_sched_reset(lctx.sched); + lctx.reset_scheduler(); ggml_cgraph * gf = llm_build_context::llama_build_graph_defrag(lctx, ids); @@ -3484,7 +3552,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) { } { - ggml_backend_sched_reset(lctx.sched); + lctx.reset_scheduler(); ggml_cgraph * gf = llm_build_context::llama_build_graph_k_shift(lctx); @@ -3510,7 +3578,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) { if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) { { - ggml_backend_sched_reset(lctx.sched); + lctx.reset_scheduler(); ggml_cgraph * gf = llm_build_context::llama_build_graph_s_copy(lctx); @@ -3553,7 +3621,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) { ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true); // initialize scheduler with the worst-case graph - ggml_backend_sched_reset(lctx.sched); + lctx.reset_scheduler(); if (!ggml_backend_sched_reserve(lctx.sched, gf)) { LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); } @@ -3840,6 +3908,7 @@ struct llama_context_params llama_context_default_params() { /*.fused_up_gate =*/ true, /*.fused_mmad =*/ true, /*.rope_cache =*/ false, + /*.graph_reuse =*/ false, /*.min_experts =*/ -1, /*.thtesh_experts =*/ 0.0f, /*.only_active_experts =*/ false, @@ -4144,6 +4213,7 @@ struct llama_context * llama_new_context_with_model( cparams.fused_up_gate = params.fused_up_gate; cparams.fused_mmad = params.fused_mmad; cparams.rope_cache = params.rope_cache; + cparams.graph_reuse = params.graph_reuse; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; cparams.cuda_params = params.cuda_params; @@ -4230,6 +4300,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate); LLAMA_LOG_INFO("%s: fused_mmad = %d\n", __func__, cparams.fused_mmad); LLAMA_LOG_INFO("%s: rope_cache = %d\n", __func__, cparams.rope_cache); + LLAMA_LOG_INFO("%s: graph_reuse = %d\n", __func__, cparams.graph_reuse); LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);