Graph reuse: add command line argument to turn it on

This commit is contained in:
Iwan Kawrakow
2025-11-12 14:52:13 +02:00
parent 14e06e26a5
commit ac409b4c7f
6 changed files with 108 additions and 26 deletions

View File

@@ -1180,6 +1180,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.rope_cache = true;
return true;
}
if (arg == "-reuse" || arg == "--graph-reuse") {
params.graph_reuse = true;
return true;
}
if (arg == "-ser" || arg == "--smart-expert-reduction") {
CHECK_ARG
auto values = string_split_pairs<int,float>(argv[i], ',');
@@ -2004,6 +2008,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" });
options.push_back({ "*", "-no-mmad, --no-fused-mul-multiadd", "disaable fused mul-multi_add (default: %s)", params.fused_mmad? "enabled" : "disabled" });
options.push_back({ "*", "-rcache, --rope-cache", "enable RoPE cache (default: %s)", params.rope_cache ? "enabled" : "disabled" });
options.push_back({ "*", "-reuse, --graph-reuse", "enable graph reuse (default: %s)", params.graph_reuse ? "enabled" : "disabled" });
options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
options.push_back({ "*", "-mqkv, --merge-qkv,", "merge Q,K,V (default: %d)", params.merge_qkv});
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
@@ -2979,6 +2984,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.fused_up_gate = params.fused_up_gate;
cparams.fused_mmad = params.fused_mmad;
cparams.rope_cache = params.rope_cache;
cparams.graph_reuse = params.graph_reuse;
cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts;
cparams.only_active_experts = params.only_active_exps;
@@ -4123,7 +4129,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "grouped_expert_routing: %s # default: false\n", params.grouped_expert_routing ? "true" : "false");
fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false");
fprintf(stream, "fused_mmad: %s # default: true\n", params.fused_mmad ? "true" : "false");
fprintf(stream, "rope_cache: %s # default: true\n", params.rope_cache ? "true" : "false");
fprintf(stream, "rope_cache: %s # default: false\n", params.rope_cache ? "true" : "false");
fprintf(stream, "graph_reuse: %s # default: false\n", params.graph_reuse ? "true" : "false");
fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts);
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);

View File

@@ -254,6 +254,7 @@ struct gpt_params {
bool fused_mmad = true; // fused mul+multi_add op
bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch)
bool rope_cache = false; // if to use RoPE cache (for supported models)
bool graph_reuse = false; // if to reuse compute graphs
int min_experts = -1;
float thresh_experts = 0;

View File

@@ -429,6 +429,7 @@ extern "C" {
bool fused_up_gate; // whether to use fused up/gate op [EXPERIMENTAL]
bool fused_mmad; // whether to use fused mul+multi_add op [EXPERIMENTAL]
bool rope_cache; // whether to use RoPE cache [EXPERIMENTAL]
bool graph_reuse; // whether to reuse graphs when possible [EXPERIMENTAL]
int min_experts;
float thresh_experts;
bool only_active_experts;

View File

@@ -9,6 +9,7 @@ struct llama_model;
#include <vector>
#include <map>
#include <set>
#include <memory>
struct llama_kv_cell {
llama_pos pos = -1;
@@ -205,4 +206,10 @@ struct llama_context {
ggml_backend_t ggml_backend_by_name(const char * name);
struct Prev;
std::unique_ptr<Prev> prev;
void reset_scheduler();
bool can_reuse_graph(const llama_batch & u_batch) const;
};

View File

@@ -38,6 +38,7 @@ struct llama_cparams {
bool fused_up_gate;
bool fused_mmad;
bool rope_cache;
bool graph_reuse;
int min_experts;
float thresh_experts;

View File

@@ -535,6 +535,42 @@ static size_t llama_get_device_memory(const llama_model & model, int device) {
GGML_UNUSED(model);
GGML_UNUSED(device);
}
//llama_batch u_batch = {
// /* .n_tokens = */ (int32_t) n_tokens,
// /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
// /* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr,
// /* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr,
// /* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr,
// /* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr,
// /* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr,
// /* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1,
// /* .all_pos_1 = */ batch_all.all_pos_1,
// /* .all_seq_id = */ batch_all.all_seq_id,
//};
struct llama_context::Prev {
int all_pos_0, all_pos_1, all_seq_id;
int n_outputs;
int n_kv;
ggml_cgraph * graph;
};
void llama_context::reset_scheduler() {
ggml_backend_sched_reset(sched);
prev.reset();
}
bool llama_context::can_reuse_graph(const llama_batch & u_batch) const {
if (!prev || !prev->graph) return false;
if (u_batch.n_tokens > 1) return false;
if (u_batch.embd) return false;
if (!cparams.graph_reuse) return false;
return u_batch.all_pos_0 == prev->all_pos_0 &&
u_batch.all_pos_1 == prev->all_pos_1 &&
u_batch.all_seq_id == prev->all_seq_id &&
kv_self.n == prev->n_kv &&
n_outputs == prev->n_outputs;
}
llama_context::llama_context(const llama_model & model)
: model(model) , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) {}
@@ -2876,27 +2912,60 @@ static int llama_decode_internal(
printf("prelude(...): %d us\n", int(tim2-tim1));
#endif
//if (n_tokens_all == 1) {
// printf("================= %s\n", __func__);
// printf(" all_pos_0 = %d, all_pos_1 = %d, all_seq_id = %d\n", batch_all.all_pos_0, batch_all.all_pos_1, batch_all.all_seq_id);
// printf(" embd = %p, logits = %p, token = %p\n", (const void *)batch_all.embd, (const void *)batch_all.logits, (const void *)batch_all.token);
// printf(" n_outputs = %d, kv_self.n = %d\n", n_outputs, kv_self.n);
//}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
ggml_cgraph * gf = nullptr;
if (!lctx.can_reuse_graph(u_batch)) {
lctx.reset_scheduler();
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
#if IK_PRINT_TIMING
tim2 = ggml_time_us();
printf("sched_reset(...): %d us\n", int(tim2-tim1));
tim2 = ggml_time_us();
printf("sched_reset(...): %d us\n", int(tim2-tim1));
#endif
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
tim1 = ggml_time_us();
#endif
ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, u_batch, false);
gf = llm_build_context::llama_build_graph(lctx, u_batch, false);
#if IK_PRINT_TIMING
tim2 = ggml_time_us();
printf("build_graph(...): %d us\n", int(tim2-tim1));
tim2 = ggml_time_us();
printf("build_graph(...): %d us\n", int(tim2-tim1));
#endif
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
ggml_backend_sched_alloc_graph(lctx.sched, gf);
#if IK_PRINT_TIMING
tim2 = ggml_time_us();
printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1));
#endif
//struct llama_context::Prev {
// int all_pos_0, all_pos_1, all_seq_id;
// int n_outputs;
// int n_kv;
// ggml_cgraph * graph;
//};
if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
lctx.prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
(int)u_batch.all_pos_0, (int)u_batch.all_pos_1, (int)u_batch.all_seq_id,
(int)lctx.n_outputs, (int)lctx.kv_self.n, gf});
}
} else {
//printf("Reusing graph\n");
gf = lctx.prev->graph;
}
// the output is always the last tensor in the graph
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
@@ -2921,15 +2990,6 @@ static int llama_decode_internal(
}
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
ggml_backend_sched_alloc_graph(lctx.sched, gf);
#if IK_PRINT_TIMING
tim2 = ggml_time_us();
printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1));
#endif
#if IK_PRINT_TIMING == 1
tim1 = ggml_time_us();
#endif
@@ -3060,9 +3120,11 @@ static int llama_decode_internal(
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
#if IK_PRINT_TIMING
auto tim1 = ggml_time_us();
auto tim1 = ggml_time_us();
#endif
ggml_backend_sched_reset(lctx.sched);
if (!lctx.prev) {
lctx.reset_scheduler();
}
#if IK_PRINT_TIMING
auto tim2 = ggml_time_us();
printf("sched_reset(...): %d us\n", int(tim2-tim1));
@@ -3158,7 +3220,7 @@ static int llama_encode_internal(
batch.seq_id = seq_id_arr.data();
}
ggml_backend_sched_reset(lctx.sched);
lctx.reset_scheduler();
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, batch, false);
@@ -3248,7 +3310,7 @@ static int llama_encode_internal(
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(lctx.sched);
lctx.reset_scheduler();
return 0;
}
@@ -3462,7 +3524,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
#else
// ggml_graph defrag
ggml_backend_sched_reset(lctx.sched);
lctx.reset_scheduler();
ggml_cgraph * gf = llm_build_context::llama_build_graph_defrag(lctx, ids);
@@ -3484,7 +3546,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
}
{
ggml_backend_sched_reset(lctx.sched);
lctx.reset_scheduler();
ggml_cgraph * gf = llm_build_context::llama_build_graph_k_shift(lctx);
@@ -3510,7 +3572,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
{
ggml_backend_sched_reset(lctx.sched);
lctx.reset_scheduler();
ggml_cgraph * gf = llm_build_context::llama_build_graph_s_copy(lctx);
@@ -3553,7 +3615,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(lctx.sched);
lctx.reset_scheduler();
if (!ggml_backend_sched_reserve(lctx.sched, gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
@@ -3840,6 +3902,7 @@ struct llama_context_params llama_context_default_params() {
/*.fused_up_gate =*/ true,
/*.fused_mmad =*/ true,
/*.rope_cache =*/ false,
/*.graph_reuse =*/ false,
/*.min_experts =*/ -1,
/*.thtesh_experts =*/ 0.0f,
/*.only_active_experts =*/ false,
@@ -4144,6 +4207,7 @@ struct llama_context * llama_new_context_with_model(
cparams.fused_up_gate = params.fused_up_gate;
cparams.fused_mmad = params.fused_mmad;
cparams.rope_cache = params.rope_cache;
cparams.graph_reuse = params.graph_reuse;
cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts;
cparams.cuda_params = params.cuda_params;
@@ -4230,6 +4294,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate);
LLAMA_LOG_INFO("%s: fused_mmad = %d\n", __func__, cparams.fused_mmad);
LLAMA_LOG_INFO("%s: rope_cache = %d\n", __func__, cparams.rope_cache);
LLAMA_LOG_INFO("%s: graph_reuse = %d\n", __func__, cparams.graph_reuse);
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);