Make graph reuse work with split mode graph

This commit is contained in:
Kawrakow
2025-11-29 09:17:07 +00:00
parent abc5bd6e74
commit bf2a1dad98
2 changed files with 55 additions and 24 deletions

View File

@@ -9197,36 +9197,37 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
GGML_ASSERT(kv_self.size == cparams.n_ctx);
GGML_ASSERT(2*il+1 < (int)lctx.cache_copies.size());
auto idx = 2*wq->n_device*il + 2*id;
GGML_ASSERT(idx+1 < (int)lctx.cache_copies.size());
auto k_row_size = ggml_row_size(split_kl->type, n_embd_head_k);
ggml_tensor * k_cache_view = ggml_view_2d(ctx0, split_kl, n_embd_head_k, n_tokens*n_head_kv,
k_row_size, k_row_size*n_head_kv*kv_head);
lctx.cache_copies[2*il+0].cpy = ggml_cpy(ctx0, Kcur, k_cache_view);
lctx.cache_copies[2*il+0].step = k_row_size*n_head_kv;
lctx.cache_copies[idx+0].cpy = ggml_cpy(ctx0, Kcur, k_cache_view);
lctx.cache_copies[idx+0].step = k_row_size*n_head_kv;
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(gf, lctx.cache_copies[2*il+0].cpy);
ggml_build_forward_expand(gf, lctx.cache_copies[idx+0].cpy);
struct ggml_tensor * v_cache_view = nullptr;
if (cparams.flash_attn) {
v_cache_view = ggml_view_1d(ctx0, split_vl, n_tokens*split_wv->ne[1],
kv_head*ggml_row_size(split_vl->type, split_wv->ne[1]));
lctx.cache_copies[2*il+1].step = ggml_row_size(split_vl->type, split_wv->ne[1]);
lctx.cache_copies[idx+1].step = ggml_row_size(split_vl->type, split_wv->ne[1]);
} else {
// note: the V cache is transposed when not using flash attention
v_cache_view = ggml_view_2d(ctx0, split_vl, n_tokens, split_wv->ne[1],
( n_ctx)*ggml_element_size(split_vl),
(kv_head)*ggml_element_size(split_vl));
lctx.cache_copies[2*il+1].step = ggml_element_size(split_vl);
lctx.cache_copies[idx+1].step = ggml_element_size(split_vl);
Vcur = ggml_transpose(ctx0, Vcur);
}
cb(v_cache_view, "v_cache_view", il_cb);
lctx.cache_copies[2*il+1].cpy = ggml_cpy(ctx0, Vcur, v_cache_view);
ggml_build_forward_expand(gf, lctx.cache_copies[2*il+1].cpy);
lctx.cache_copies[idx+1].cpy = ggml_cpy(ctx0, Vcur, v_cache_view);
ggml_build_forward_expand(gf, lctx.cache_copies[idx+1].cpy);
auto q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
cb(q, "q", il_cb);