diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 8d97de48..24c03e73 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -16,6 +16,8 @@ #include #include +#include + #define IK_PRINT_TIMING 0 #define MAX(a, b) ((a) > (b) ? (a) : (b)) @@ -2223,6 +2225,18 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s // auto ec = ggml_backend_graph_compute_async(sched->backends[my_backend_id], &graph); // if (ec != GGML_STATUS_SUCCESS) return ec; //} + if (node->op == GGML_OP_REDUCE) { + ncclGroupStart(); + for (int ib = 0; ib < sched->n_backends; ++ib) { + if (ib != split_backend_id && !ggml_backend_is_cpu(sched->backends[ib])) { + printf("%s: triggering reduce for %s on backend %d\n", __func__, node->name, ib); + auto graph = split->graph; + graph.n_nodes = 1; + auto ec = ggml_backend_graph_compute_async(sched->backends[ib], &graph); + if (ec != GGML_STATUS_SUCCESS) return ec; + } + } + } //if (split_backend_id != my_backend_id) continue; @@ -2286,15 +2300,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } } if (node->op == GGML_OP_REDUCE) { - for (int ib = 0; ib < sched->n_backends; ++ib) { - if (ib != split_backend_id && !ggml_backend_is_cpu(sched->backends[ib])) { - printf("%s: triggering reduce for %s on backend %d\n", __func__, node->name, ib); - auto graph = split->graph; - graph.n_nodes = 1; - auto ec = ggml_backend_graph_compute_async(sched->backends[ib], &graph); - if (ec != GGML_STATUS_SUCCESS) return ec; - } - } + ncclGroupEnd(); } // record the event of this copy diff --git a/ggml/src/ggml-cuda/reduce.cu b/ggml/src/ggml-cuda/reduce.cu index 6e40d7d5..41f4cd3b 100644 --- a/ggml/src/ggml-cuda/reduce.cu +++ b/ggml/src/ggml-cuda/reduce.cu @@ -40,7 +40,7 @@ void ggml_cuda_op_reduce(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { this_comm = info.nccl_coms[device]; } else { int color = extra->splits[device] ? 1 : 0; - auto status = ncclCommSplit(info.nccl_coms[0], color, ctx.device, &this_comm, nullptr); + auto status = ncclCommSplit(info.nccl_coms[device], color, ctx.device, &this_comm, nullptr); GGML_ASSERT(status == ncclSuccess); } GGML_ASSERT(this_comm);