Graph reuse (#947)

* Add mainline compatible FA command line option

* Graph reuse: add command line argument to turn it on

* WIP

* This seems to work

* This is perhaps cleaner

* Change the command line option to -gr

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-11-14 06:58:19 +02:00
committed by GitHub
parent 22c20fcd6d
commit 6b9d1bf4b4
9 changed files with 174 additions and 38 deletions

View File

@@ -536,8 +536,57 @@ static size_t llama_get_device_memory(const llama_model & model, int device) {
GGML_UNUSED(device);
}
struct llama_context::Prev {
int all_seq_id;
int n_outputs;
int n_kv;
ggml_cgraph * graph;
};
void llama_context::reset_scheduler() {
ggml_backend_sched_reset(sched);
prev.reset();
}
bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
if (!prev || !prev->graph) return false;
if (u_batch.n_tokens > 1) return false;
if (u_batch.embd) return false;
if (!cparams.graph_reuse) return false;
return u_batch.all_seq_id == prev->all_seq_id &&
kv_self.head > 0 &&
kv_self.n == prev->n_kv &&
n_outputs == prev->n_outputs &&
update_cache_copies();
}
bool llama_context::update_cache_copies() {
int n_layer = cache_copies.size()/2;
if ((int)kv_self.k_l.size() != n_layer) return false;
if (!(kv_self.v_l.empty() || (int)kv_self.v_l.size() == n_layer)) return false;
for (int il = 0; il < n_layer; ++il) {
auto& c = cache_copies[2*il+0];
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.k_l[il]) return false;
c.cpy->view_offs = kv_self.head*c.step;
c.cpy->src[1]->data = (char *)kv_self.k_l[il]->data + c.cpy->view_offs;
c.cpy->data = c.cpy->src[1]->data;
}
if (kv_self.v_l.empty()) return true;
for (int il = 0; il < n_layer; ++il) {
auto& c = cache_copies[2*il+1];
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.v_l[il]) return false;
c.cpy->view_offs = kv_self.head*c.step;
c.cpy->src[1]->data = (char *)kv_self.v_l[il]->data + c.cpy->view_offs;
c.cpy->data = c.cpy->src[1]->data;
}
return true;
}
llama_context::llama_context(const llama_model & model)
: model(model) , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) {}
: model(model) , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) {
const auto & hparams = model.hparams;
cache_copies.resize(2*hparams.n_layer);
}
llama_context::~llama_context() {
ggml_backend_sched_free(sched);
@@ -2876,27 +2925,53 @@ static int llama_decode_internal(
printf("prelude(...): %d us\n", int(tim2-tim1));
#endif
//if (n_tokens_all == 1) {
// printf("================= %s\n", __func__);
// printf(" all_pos_0 = %d, all_pos_1 = %d, all_seq_id = %d\n", batch_all.all_pos_0, batch_all.all_pos_1, batch_all.all_seq_id);
// printf(" embd = %p, logits = %p, token = %p\n", (const void *)batch_all.embd, (const void *)batch_all.logits, (const void *)batch_all.token);
// printf(" n_outputs = %d, kv_self.n = %d\n", n_outputs, kv_self.n);
//}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
ggml_cgraph * gf = nullptr;
if (!lctx.can_reuse_graph(u_batch)) {
lctx.reset_scheduler();
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
#if IK_PRINT_TIMING
tim2 = ggml_time_us();
printf("sched_reset(...): %d us\n", int(tim2-tim1));
tim2 = ggml_time_us();
printf("sched_reset(...): %d us\n", int(tim2-tim1));
#endif
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
tim1 = ggml_time_us();
#endif
ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, u_batch, false);
gf = llm_build_context::llama_build_graph(lctx, u_batch, false);
#if IK_PRINT_TIMING
tim2 = ggml_time_us();
printf("build_graph(...): %d us\n", int(tim2-tim1));
tim2 = ggml_time_us();
printf("build_graph(...): %d us\n", int(tim2-tim1));
#endif
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
ggml_backend_sched_alloc_graph(lctx.sched, gf);
#if IK_PRINT_TIMING
tim2 = ggml_time_us();
printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1));
#endif
if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
lctx.prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n, gf});
}
} else {
//printf("Reusing graph\n");
gf = lctx.prev->graph;
}
// the output is always the last tensor in the graph
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
@@ -2921,15 +2996,6 @@ static int llama_decode_internal(
}
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
ggml_backend_sched_alloc_graph(lctx.sched, gf);
#if IK_PRINT_TIMING
tim2 = ggml_time_us();
printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1));
#endif
#if IK_PRINT_TIMING == 1
tim1 = ggml_time_us();
#endif
@@ -3060,9 +3126,11 @@ static int llama_decode_internal(
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
#if IK_PRINT_TIMING
auto tim1 = ggml_time_us();
auto tim1 = ggml_time_us();
#endif
ggml_backend_sched_reset(lctx.sched);
if (!lctx.prev) {
lctx.reset_scheduler();
}
#if IK_PRINT_TIMING
auto tim2 = ggml_time_us();
printf("sched_reset(...): %d us\n", int(tim2-tim1));
@@ -3158,7 +3226,7 @@ static int llama_encode_internal(
batch.seq_id = seq_id_arr.data();
}
ggml_backend_sched_reset(lctx.sched);
lctx.reset_scheduler();
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, batch, false);
@@ -3248,7 +3316,7 @@ static int llama_encode_internal(
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(lctx.sched);
lctx.reset_scheduler();
return 0;
}
@@ -3462,7 +3530,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
#else
// ggml_graph defrag
ggml_backend_sched_reset(lctx.sched);
lctx.reset_scheduler();
ggml_cgraph * gf = llm_build_context::llama_build_graph_defrag(lctx, ids);
@@ -3484,7 +3552,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
}
{
ggml_backend_sched_reset(lctx.sched);
lctx.reset_scheduler();
ggml_cgraph * gf = llm_build_context::llama_build_graph_k_shift(lctx);
@@ -3510,7 +3578,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
{
ggml_backend_sched_reset(lctx.sched);
lctx.reset_scheduler();
ggml_cgraph * gf = llm_build_context::llama_build_graph_s_copy(lctx);
@@ -3553,7 +3621,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(lctx.sched);
lctx.reset_scheduler();
if (!ggml_backend_sched_reserve(lctx.sched, gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
@@ -3840,6 +3908,7 @@ struct llama_context_params llama_context_default_params() {
/*.fused_up_gate =*/ true,
/*.fused_mmad =*/ true,
/*.rope_cache =*/ false,
/*.graph_reuse =*/ false,
/*.min_experts =*/ -1,
/*.thtesh_experts =*/ 0.0f,
/*.only_active_experts =*/ false,
@@ -4144,6 +4213,7 @@ struct llama_context * llama_new_context_with_model(
cparams.fused_up_gate = params.fused_up_gate;
cparams.fused_mmad = params.fused_mmad;
cparams.rope_cache = params.rope_cache;
cparams.graph_reuse = params.graph_reuse;
cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts;
cparams.cuda_params = params.cuda_params;
@@ -4230,6 +4300,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate);
LLAMA_LOG_INFO("%s: fused_mmad = %d\n", __func__, cparams.fused_mmad);
LLAMA_LOG_INFO("%s: rope_cache = %d\n", __func__, cparams.rope_cache);
LLAMA_LOG_INFO("%s: graph_reuse = %d\n", __func__, cparams.graph_reuse);
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);