From c7deb32142d693b2baf9be38a0e624ed3147eb3a Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 18 Jan 2026 12:53:34 +0000 Subject: [PATCH] For the output ops use the result of the split that ran on the main GPU --- ggml/src/ggml-backend.cpp | 4 ++-- src/llama-build-context.cpp | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index c05bf566..d99d7022 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -2244,7 +2244,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } } - if (split->graph.nodes[0]->op == GGML_OP_REDUCE) { + if (split->graph.nodes[0]->op == GGML_OP_REDUCE && i < sched->n_splits - 1) { last_reduce = split_backend_id; if (ith == split_backend_id) { auto node = split->graph.nodes[0]; @@ -2318,7 +2318,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } } - if (split->graph.nodes[0]->op == GGML_OP_REDUCE) { + if (split->graph.nodes[0]->op == GGML_OP_REDUCE && i < sched->n_splits - 1) { last_reduce = split_backend_id; barrier.arrive_and_wait(); if (ith == split_backend_id) { diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 09fc383b..c6aecc36 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1759,7 +1759,8 @@ static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml return cur; } -static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur, ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb) { +static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur, + ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb) { // lm_head if (output->extra) { auto split_output = (ggml_split_tensor_t *)output->extra; @@ -1790,6 +1791,10 @@ static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml } } } else { + if (cur->op == GGML_OP_REDUCE && cur->src[lctx.model.main_gpu]) { + // avoid copy to main GPU + cur->view_src = cur->src[lctx.model.main_gpu]; + } if (output_norm) { cur = llm_build_context::llm_build_norm(ctx, cur, lctx.model.hparams, output_norm, NULL, LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1);