mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-08 15:30:15 +00:00
For the output ops use the result of the split that ran on the main GPU
This commit is contained in:
@@ -2244,7 +2244,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
||||
}
|
||||
}
|
||||
|
||||
if (split->graph.nodes[0]->op == GGML_OP_REDUCE) {
|
||||
if (split->graph.nodes[0]->op == GGML_OP_REDUCE && i < sched->n_splits - 1) {
|
||||
last_reduce = split_backend_id;
|
||||
if (ith == split_backend_id) {
|
||||
auto node = split->graph.nodes[0];
|
||||
@@ -2318,7 +2318,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
||||
}
|
||||
}
|
||||
|
||||
if (split->graph.nodes[0]->op == GGML_OP_REDUCE) {
|
||||
if (split->graph.nodes[0]->op == GGML_OP_REDUCE && i < sched->n_splits - 1) {
|
||||
last_reduce = split_backend_id;
|
||||
barrier.arrive_and_wait();
|
||||
if (ith == split_backend_id) {
|
||||
|
||||
@@ -1759,7 +1759,8 @@ static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml
|
||||
return cur;
|
||||
}
|
||||
|
||||
static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur, ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb) {
|
||||
static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur,
|
||||
ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb) {
|
||||
// lm_head
|
||||
if (output->extra) {
|
||||
auto split_output = (ggml_split_tensor_t *)output->extra;
|
||||
@@ -1790,6 +1791,10 @@ static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (cur->op == GGML_OP_REDUCE && cur->src[lctx.model.main_gpu]) {
|
||||
// avoid copy to main GPU
|
||||
cur->view_src = cur->src[lctx.model.main_gpu];
|
||||
}
|
||||
if (output_norm) {
|
||||
cur = llm_build_context::llm_build_norm(ctx, cur, lctx.model.hparams, output_norm, NULL, LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
Reference in New Issue
Block a user