From bf2a1dad985d1bd2523a671b93f6509e0ef99346 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 29 Nov 2025 09:17:07 +0000 Subject: [PATCH] Make graph reuse work with split mode graph --- src/llama-build-context.cpp | 17 +++++----- src/llama.cpp | 62 +++++++++++++++++++++++++++---------- 2 files changed, 55 insertions(+), 24 deletions(-) diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 862e849e..9866124d 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -9197,36 +9197,37 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens GGML_ASSERT(kv_self.size == cparams.n_ctx); - GGML_ASSERT(2*il+1 < (int)lctx.cache_copies.size()); + auto idx = 2*wq->n_device*il + 2*id; + GGML_ASSERT(idx+1 < (int)lctx.cache_copies.size()); auto k_row_size = ggml_row_size(split_kl->type, n_embd_head_k); ggml_tensor * k_cache_view = ggml_view_2d(ctx0, split_kl, n_embd_head_k, n_tokens*n_head_kv, k_row_size, k_row_size*n_head_kv*kv_head); - lctx.cache_copies[2*il+0].cpy = ggml_cpy(ctx0, Kcur, k_cache_view); - lctx.cache_copies[2*il+0].step = k_row_size*n_head_kv; + lctx.cache_copies[idx+0].cpy = ggml_cpy(ctx0, Kcur, k_cache_view); + lctx.cache_copies[idx+0].step = k_row_size*n_head_kv; // note: storing RoPE-ed version of K in the KV cache - ggml_build_forward_expand(gf, lctx.cache_copies[2*il+0].cpy); + ggml_build_forward_expand(gf, lctx.cache_copies[idx+0].cpy); struct ggml_tensor * v_cache_view = nullptr; if (cparams.flash_attn) { v_cache_view = ggml_view_1d(ctx0, split_vl, n_tokens*split_wv->ne[1], kv_head*ggml_row_size(split_vl->type, split_wv->ne[1])); - lctx.cache_copies[2*il+1].step = ggml_row_size(split_vl->type, split_wv->ne[1]); + lctx.cache_copies[idx+1].step = ggml_row_size(split_vl->type, split_wv->ne[1]); } else { // note: the V cache is transposed when not using flash attention v_cache_view = ggml_view_2d(ctx0, split_vl, n_tokens, split_wv->ne[1], ( n_ctx)*ggml_element_size(split_vl), (kv_head)*ggml_element_size(split_vl)); - lctx.cache_copies[2*il+1].step = ggml_element_size(split_vl); + lctx.cache_copies[idx+1].step = ggml_element_size(split_vl); Vcur = ggml_transpose(ctx0, Vcur); } cb(v_cache_view, "v_cache_view", il_cb); - lctx.cache_copies[2*il+1].cpy = ggml_cpy(ctx0, Vcur, v_cache_view); - ggml_build_forward_expand(gf, lctx.cache_copies[2*il+1].cpy); + lctx.cache_copies[idx+1].cpy = ggml_cpy(ctx0, Vcur, v_cache_view); + ggml_build_forward_expand(gf, lctx.cache_copies[idx+1].cpy); auto q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); cb(q, "q", il_cb); diff --git a/src/llama.cpp b/src/llama.cpp index df97976f..c0434018 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -556,23 +556,49 @@ bool llama_context::can_reuse_graph(const llama_batch & u_batch) { } bool llama_context::update_cache_copies() { - int n_layer = cache_copies.size()/2; + int n_layer = model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2; if ((int)kv_self.k_l.size() != n_layer) return false; if (!(kv_self.v_l.empty() || (int)kv_self.v_l.size() == n_layer)) return false; - for (int il = 0; il < n_layer; ++il) { - auto& c = cache_copies[2*il+0]; - if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.k_l[il]) return false; - c.cpy->view_offs = kv_self.head*c.step; - c.cpy->src[1]->data = (char *)kv_self.k_l[il]->data + c.cpy->view_offs; - c.cpy->data = c.cpy->src[1]->data; - } - if (kv_self.v_l.empty()) return true; - for (int il = 0; il < n_layer; ++il) { - auto& c = cache_copies[2*il+1]; - if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.v_l[il]) return false; - c.cpy->view_offs = kv_self.head*c.step; - c.cpy->src[1]->data = (char *)kv_self.v_l[il]->data + c.cpy->view_offs; - c.cpy->data = c.cpy->src[1]->data; + if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH && model.splits.size() > 1) { + for (int il = 0; il < n_layer; ++il) { + auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra; + auto vl = !kv_self.v_l.empty() && kv_self.v_l[il] ? (ggml_split_tensor_t *)kv_self.v_l[il]->extra : nullptr; + GGML_ASSERT(kl && (!kv_self.v_l[il] || vl)); + if (vl) { + GGML_ASSERT(kl->n_device == vl->n_device); + } + for (int id = 0; id < kl->n_device; ++id) { + auto& c = cache_copies[2*model.splits.size()*il + 2*id + 0]; + if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kl->splits[id]) return false; + c.cpy->view_offs = kv_self.head*c.step; + c.cpy->src[1]->data = (char *)kl->splits[id]->data + c.cpy->view_offs; + c.cpy->data = c.cpy->src[1]->data; + } + if (!vl) continue; + for (int id = 0; id < vl->n_device; ++id) { + auto& c = cache_copies[2*model.splits.size()*il + 2*id + 1]; + if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != vl->splits[id]) return false; + c.cpy->view_offs = kv_self.head*c.step; + c.cpy->src[1]->data = (char *)vl->splits[id]->data + c.cpy->view_offs; + c.cpy->data = c.cpy->src[1]->data; + } + } + } else { + for (int il = 0; il < n_layer; ++il) { + auto& c = cache_copies[2*il+0]; + if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.k_l[il]) return false; + c.cpy->view_offs = kv_self.head*c.step; + c.cpy->src[1]->data = (char *)kv_self.k_l[il]->data + c.cpy->view_offs; + c.cpy->data = c.cpy->src[1]->data; + } + if (kv_self.v_l.empty()) return true; + for (int il = 0; il < n_layer; ++il) { + auto& c = cache_copies[2*il+1]; + if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.v_l[il]) return false; + c.cpy->view_offs = kv_self.head*c.step; + c.cpy->src[1]->data = (char *)kv_self.v_l[il]->data + c.cpy->view_offs; + c.cpy->data = c.cpy->src[1]->data; + } } return true; } @@ -580,7 +606,11 @@ bool llama_context::update_cache_copies() { llama_context::llama_context(const llama_model & model) : model(model) , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) { const auto & hparams = model.hparams; - cache_copies.resize(2*hparams.n_layer); + if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH && model.splits.size() > 1) { + cache_copies.resize(2*model.splits.size()*hparams.n_layer); + } else { + cache_copies.resize(2*hparams.n_layer); + } } llama_context::~llama_context() {