diff --git a/ggml/src/ggml-cuda/reduce.cu b/ggml/src/ggml-cuda/reduce.cu index 371d2a6d..c4e8bf34 100644 --- a/ggml/src/ggml-cuda/reduce.cu +++ b/ggml/src/ggml-cuda/reduce.cu @@ -44,8 +44,10 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ static const int devs[8] = {0,1, 2,3, 0,2, 1,3}; for (int ip = 0; ip < 4; ++ip) { ncclGroupStart(); + ggml_cuda_set_device(devs[2*ip+0]); auto status1 = ncclAllReduce(dst->src[devs[2*ip+0]]->data, dst->src[devs[2*ip+0]]->data, ggml_nelements(dst), data_type, ncclSum, info.nccl_coms[2*ip+0], info.all_ctx[devs[2*ip+0]]->stream()); + ggml_cuda_set_device(devs[2*ip+1]); auto status2 = ncclAllReduce(dst->src[devs[2*ip+1]]->data, dst->src[devs[2*ip+1]]->data, ggml_nelements(dst), data_type, ncclSum, info.nccl_coms[2*ip+1], info.all_ctx[devs[2*ip+1]]->stream()); ncclGroupEnd(); @@ -57,6 +59,7 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ } else { ncclGroupStart(); for (int i = 0; i < nreduce; ++i) { + ggml_cuda_set_device(i); auto stream = info.all_ctx[i]->stream(); GGML_ASSERT(stream); auto status = ncclAllReduce(dst->src[i]->data, dst->src[i]->data, ggml_nelements(dst), data_type, ncclSum, @@ -78,6 +81,7 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ auto status = ncclCommSplit(info.nccl_coms[i], dst->src[i] ? 1 : 0, i, &this_comm, NULL); GGML_ASSERT(status == ncclSuccess); } + ggml_cuda_set_device(i); auto stream = info.all_ctx[i]->stream(); GGML_ASSERT(stream); ncclResult_t status; @@ -99,6 +103,7 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ } ncclGroupEnd(); } + ggml_cuda_set_device(ctx.device); //auto tim2 = std::chrono::steady_clock::now(); //printf("%s: launched in %g us\n", __func__, 1e-3*std::chrono::duration_cast(tim2-tim1).count()); return;