diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 96103c1f..ecbcb265 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -265,6 +265,21 @@ static ggml_cuda_device_info ggml_cuda_init() { } info.have_nccl = true; printf("=============================== NCCL initialized\n"); + } else if (info.device_count == 3) { + int devs[4] = {0,1, 0,2}; + for (int ip = 0; ip < 2; ++ip) { + if (auto status = ncclCommInitAll(info.nccl_coms+2*ip, 2, devs+2*ip); status != ncclSuccess) { + printf("=============================== NCCL initialization of pair %d failed with status %d\n", ip, int(status)); + GGML_ABORT("Fatal error"); + } + } + int gpus[3] = {0, 1, 2}; + if (auto status = ncclCommInitAll(info.nccl_coms+4, 3, gpus); status != ncclSuccess) { + printf("=============================== NCCL initialization of 4 GPUs failed with status %d\n", int(status)); + GGML_ABORT("Fatal error"); + } + info.have_nccl = true; + printf("=============================== NCCL initialized\n"); } else { int gpu_list[GGML_CUDA_MAX_DEVICES]; for(int i = 0; i < info.device_count; ++i) gpu_list[i] = i; diff --git a/ggml/src/ggml-cuda/reduce.cu b/ggml/src/ggml-cuda/reduce.cu index c4e8bf34..5a7f5725 100644 --- a/ggml/src/ggml-cuda/reduce.cu +++ b/ggml/src/ggml-cuda/reduce.cu @@ -71,6 +71,50 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ } ncclGroupEnd(); } + } + else if (nreduce == 3) { + auto data_type = type == GGML_TYPE_F32 ? ncclFloat : ncclHalf; + if (dst->ne[1] > 32) { + static const int devs[4] = {0,1, 0,2}; + for (int ip = 0; ip < 2; ++ip) { + ncclGroupStart(); + ggml_cuda_set_device(devs[2*ip+0]); + auto status1 = ncclAllReduce(dst->src[devs[2*ip+0]]->data, dst->src[devs[2*ip+0]]->data, + ggml_nelements(dst), data_type, ncclSum, info.nccl_coms[2*ip+0], info.all_ctx[devs[2*ip+0]]->stream()); + ggml_cuda_set_device(devs[2*ip+1]); + auto status2 = ncclAllReduce(dst->src[devs[2*ip+1]]->data, dst->src[devs[2*ip+1]]->data, + ggml_nelements(dst), data_type, ncclSum, info.nccl_coms[2*ip+1], info.all_ctx[devs[2*ip+1]]->stream()); + ncclGroupEnd(); + if (status1 != ncclSuccess || status2 != ncclSuccess) { + fprintf(stderr, "%s: ncclAllReduce failed with statuses %d, %d\n", __func__, (int)status1, (int)status2); + GGML_ABORT("Fatal error"); + } + } + ncclGroupStart(); + ggml_cuda_set_device(0); + auto status1 = ncclSend(dst->src[0]->data, ggml_nelements(dst), data_type, 1, info.nccl_coms[0], info.all_ctx[0]->stream()); + ggml_cuda_set_device(1); + auto status2 = ncclRecv(dst->src[1]->data, ggml_nelements(dst), data_type, 0, info.nccl_coms[1], info.all_ctx[1]->stream()); + ncclGroupEnd(); + if (status1 != ncclSuccess || status2 != ncclSuccess) { + fprintf(stderr, "%s: ncclSend/Recv failed with statuses %d, %d\n", __func__, (int)status1, (int)status2); + GGML_ABORT("Fatal error"); + } + } else { + ncclGroupStart(); + for (int i = 0; i < nreduce; ++i) { + ggml_cuda_set_device(i); + auto stream = info.all_ctx[i]->stream(); + GGML_ASSERT(stream); + auto status = ncclAllReduce(dst->src[i]->data, dst->src[i]->data, ggml_nelements(dst), data_type, ncclSum, + info.nccl_coms[4+i], stream); + if (status != ncclSuccess) { + fprintf(stderr, "%s: ncclAllReduce on device %d failed with status %d\n", __func__, i, (int)status); + GGML_ABORT("Fatal error"); + } + } + ncclGroupEnd(); + } } else { ncclGroupStart(); for (int i = 0; i < nreduce; ++i) {