diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index ecbcb265..b4d9ef2a 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -250,47 +250,36 @@ static ggml_cuda_device_info ggml_cuda_init() { #ifdef GGML_USE_NCCL info.have_nccl = false; if (info.device_count > 1) { - if (info.device_count == 4) { - int devs[8] = {0,1, 2,3, 0,2, 1,3}; - for (int ip = 0; ip < 4; ++ip) { - if (auto status = ncclCommInitAll(info.nccl_coms+2*ip, 2, devs+2*ip); status != ncclSuccess) { - printf("=============================== NCCL initialization of pair %d failed with status %d\n", ip, int(status)); - GGML_ABORT("Fatal error"); - } - } - int gpus[4] = {0, 1, 2, 3}; - if (auto status = ncclCommInitAll(info.nccl_coms+8, 4, gpus); status != ncclSuccess) { - printf("=============================== NCCL initialization of 4 GPUs failed with status %d\n", int(status)); - GGML_ABORT("Fatal error"); - } - info.have_nccl = true; - printf("=============================== NCCL initialized\n"); - } else if (info.device_count == 3) { - int devs[4] = {0,1, 0,2}; - for (int ip = 0; ip < 2; ++ip) { - if (auto status = ncclCommInitAll(info.nccl_coms+2*ip, 2, devs+2*ip); status != ncclSuccess) { - printf("=============================== NCCL initialization of pair %d failed with status %d\n", ip, int(status)); - GGML_ABORT("Fatal error"); - } - } - int gpus[3] = {0, 1, 2}; - if (auto status = ncclCommInitAll(info.nccl_coms+4, 3, gpus); status != ncclSuccess) { - printf("=============================== NCCL initialization of 4 GPUs failed with status %d\n", int(status)); - GGML_ABORT("Fatal error"); - } - info.have_nccl = true; - printf("=============================== NCCL initialized\n"); - } else { int gpu_list[GGML_CUDA_MAX_DEVICES]; for(int i = 0; i < info.device_count; ++i) gpu_list[i] = i; auto status = ncclCommInitAll(info.nccl_coms, info.device_count, gpu_list); if (status == ncclSuccess) { - printf("=============================== NCCL initialized\n"); + printf("=============================== NCCL main communicator initialized\n"); info.have_nccl = true; } else { printf("=============================== NCCL initialization failed with status %d\n", int(status)); GGML_ABORT("Fatal error"); } + auto com = info.nccl_coms + info.device_count; + if (info.device_count == 4) { + int devs[8] = {0,1, 2,3, 0,2, 1,3}; + auto com = info.nccl_coms + info.device_count; + for (int ip = 0; ip < 4; ++ip) { + if (auto status = ncclCommInitAll(com+2*ip, 2, devs+2*ip); status != ncclSuccess) { + printf("=============================== NCCL initialization of pair %d failed with status %d\n", ip, int(status)); + GGML_ABORT("Fatal error"); + } + } + printf("=============================== NCCL pair communicators for %d GPUs initialized\n", info.device_count); + } else if (info.device_count == 3) { + int devs[4] = {0,1, 0,2}; + for (int ip = 0; ip < 2; ++ip) { + if (auto status = ncclCommInitAll(com+2*ip, 2, devs+2*ip); status != ncclSuccess) { + printf("=============================== NCCL initialization of pair %d failed with status %d\n", ip, int(status)); + GGML_ABORT("Fatal error"); + } + } + printf("=============================== NCCL pair communicators for %d GPUs initialized\n", info.device_count); } } #endif @@ -3512,8 +3501,23 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_ needs_f16_f32_copy = true; } else { +#ifdef GGML_USE_NCCL__ + auto & info = ggml_cuda_info(); + auto nbytes = ggml_nbytes(src); + ncclGroupStart(); + ggml_cuda_set_device(cuda_ctx_src->device); + auto status1 = ncclSend(src->data, nbytes, ncclUint8, cuda_ctx_dst->device, info.nccl_coms[cuda_ctx_src->device], + info.all_ctx[cuda_ctx_src->device]->stream()); + ggml_cuda_set_device(cuda_ctx_dst->device); + auto status2 = ncclRecv(dst->data, nbytes, ncclUint8, cuda_ctx_src->device, info.nccl_coms[cuda_ctx_dst->device], + info.all_ctx[cuda_ctx_dst->device]->stream()); + ncclGroupEnd(); + GGML_ASSERT(status1 == ncclSuccess && status2 == ncclSuccess); + return true; +#else ggml_cuda_set_device(cuda_ctx_src->device); CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream())); +#endif } #endif } @@ -4435,6 +4439,13 @@ GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device, [[maybe_unused]] con #endif } +#ifdef GGML_USE_NCCL + if (!enable_p2p) { + printf("================== P2P disabled, but needed for NCCL\n"); + enable_p2p = true; + } +#endif + #if !defined(GGML_CUDA_NO_PEER_COPY) if (enable_p2p) { ggml_cuda_set_peer_access(device); diff --git a/ggml/src/ggml-cuda/reduce.cu b/ggml/src/ggml-cuda/reduce.cu index 5a7f5725..307b52fa 100644 --- a/ggml/src/ggml-cuda/reduce.cu +++ b/ggml/src/ggml-cuda/reduce.cu @@ -38,114 +38,84 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ GGML_ABORT("Not implemented"); } //auto tim1 = std::chrono::steady_clock::now(); - if (nreduce == 4) { - auto data_type = type == GGML_TYPE_F32 ? ncclFloat : ncclHalf; - if (dst->ne[1] > 32) { - static const int devs[8] = {0,1, 2,3, 0,2, 1,3}; - for (int ip = 0; ip < 4; ++ip) { - ncclGroupStart(); - ggml_cuda_set_device(devs[2*ip+0]); - auto status1 = ncclAllReduce(dst->src[devs[2*ip+0]]->data, dst->src[devs[2*ip+0]]->data, - ggml_nelements(dst), data_type, ncclSum, info.nccl_coms[2*ip+0], info.all_ctx[devs[2*ip+0]]->stream()); - ggml_cuda_set_device(devs[2*ip+1]); - auto status2 = ncclAllReduce(dst->src[devs[2*ip+1]]->data, dst->src[devs[2*ip+1]]->data, - ggml_nelements(dst), data_type, ncclSum, info.nccl_coms[2*ip+1], info.all_ctx[devs[2*ip+1]]->stream()); - ncclGroupEnd(); - if (status1 != ncclSuccess || status2 != ncclSuccess) { - fprintf(stderr, "%s: ncclAllReduce failed with statuses %d, %d\n", __func__, (int)status1, (int)status2); - GGML_ABORT("Fatal error"); - } - } - } else { + auto data_type = type == GGML_TYPE_F32 ? ncclFloat : ncclHalf; + if (nreduce == 4 && dst->ne[1] > 32) { + auto com = info.nccl_coms + info.device_count; + static const int devs[8] = {0,1, 2,3, 0,2, 1,3}; + for (int ip = 0; ip < 4; ++ip) { ncclGroupStart(); - for (int i = 0; i < nreduce; ++i) { - ggml_cuda_set_device(i); - auto stream = info.all_ctx[i]->stream(); - GGML_ASSERT(stream); - auto status = ncclAllReduce(dst->src[i]->data, dst->src[i]->data, ggml_nelements(dst), data_type, ncclSum, - info.nccl_coms[8+i], stream); - if (status != ncclSuccess) { - fprintf(stderr, "%s: ncclAllReduce on device %d failed with status %d\n", __func__, i, (int)status); - GGML_ABORT("Fatal error"); - } - } - ncclGroupEnd(); - } - } - else if (nreduce == 3) { - auto data_type = type == GGML_TYPE_F32 ? ncclFloat : ncclHalf; - if (dst->ne[1] > 32) { - static const int devs[4] = {0,1, 0,2}; - for (int ip = 0; ip < 2; ++ip) { - ncclGroupStart(); - ggml_cuda_set_device(devs[2*ip+0]); - auto status1 = ncclAllReduce(dst->src[devs[2*ip+0]]->data, dst->src[devs[2*ip+0]]->data, - ggml_nelements(dst), data_type, ncclSum, info.nccl_coms[2*ip+0], info.all_ctx[devs[2*ip+0]]->stream()); - ggml_cuda_set_device(devs[2*ip+1]); - auto status2 = ncclAllReduce(dst->src[devs[2*ip+1]]->data, dst->src[devs[2*ip+1]]->data, - ggml_nelements(dst), data_type, ncclSum, info.nccl_coms[2*ip+1], info.all_ctx[devs[2*ip+1]]->stream()); - ncclGroupEnd(); - if (status1 != ncclSuccess || status2 != ncclSuccess) { - fprintf(stderr, "%s: ncclAllReduce failed with statuses %d, %d\n", __func__, (int)status1, (int)status2); - GGML_ABORT("Fatal error"); - } - } - ncclGroupStart(); - ggml_cuda_set_device(0); - auto status1 = ncclSend(dst->src[0]->data, ggml_nelements(dst), data_type, 1, info.nccl_coms[0], info.all_ctx[0]->stream()); - ggml_cuda_set_device(1); - auto status2 = ncclRecv(dst->src[1]->data, ggml_nelements(dst), data_type, 0, info.nccl_coms[1], info.all_ctx[1]->stream()); + ggml_cuda_set_device(devs[2*ip+0]); + auto status1 = ncclAllReduce(dst->src[devs[2*ip+0]]->data, dst->src[devs[2*ip+0]]->data, + ggml_nelements(dst), data_type, ncclSum, com[2*ip+0], info.all_ctx[devs[2*ip+0]]->stream()); + ggml_cuda_set_device(devs[2*ip+1]); + auto status2 = ncclAllReduce(dst->src[devs[2*ip+1]]->data, dst->src[devs[2*ip+1]]->data, + ggml_nelements(dst), data_type, ncclSum, com[2*ip+1], info.all_ctx[devs[2*ip+1]]->stream()); ncclGroupEnd(); if (status1 != ncclSuccess || status2 != ncclSuccess) { - fprintf(stderr, "%s: ncclSend/Recv failed with statuses %d, %d\n", __func__, (int)status1, (int)status2); + fprintf(stderr, "%s: ncclAllReduce failed with statuses %d, %d\n", __func__, (int)status1, (int)status2); GGML_ABORT("Fatal error"); } - } else { + } + } + else if (nreduce == 3 && dst->ne[1] > 32) { + auto com = info.nccl_coms + info.device_count; + static const int devs[4] = {0,1, 0,2}; + for (int ip = 0; ip < 2; ++ip) { ncclGroupStart(); - for (int i = 0; i < nreduce; ++i) { - ggml_cuda_set_device(i); - auto stream = info.all_ctx[i]->stream(); - GGML_ASSERT(stream); - auto status = ncclAllReduce(dst->src[i]->data, dst->src[i]->data, ggml_nelements(dst), data_type, ncclSum, - info.nccl_coms[4+i], stream); - if (status != ncclSuccess) { - fprintf(stderr, "%s: ncclAllReduce on device %d failed with status %d\n", __func__, i, (int)status); - GGML_ABORT("Fatal error"); - } - } + ggml_cuda_set_device(devs[2*ip+0]); + auto status1 = ncclAllReduce(dst->src[devs[2*ip+0]]->data, dst->src[devs[2*ip+0]]->data, + ggml_nelements(dst), data_type, ncclSum, com[2*ip+0], info.all_ctx[devs[2*ip+0]]->stream()); + ggml_cuda_set_device(devs[2*ip+1]); + auto status2 = ncclAllReduce(dst->src[devs[2*ip+1]]->data, dst->src[devs[2*ip+1]]->data, + ggml_nelements(dst), data_type, ncclSum, com[2*ip+1], info.all_ctx[devs[2*ip+1]]->stream()); ncclGroupEnd(); + if (status1 != ncclSuccess || status2 != ncclSuccess) { + fprintf(stderr, "%s: ncclAllReduce failed with statuses %d, %d\n", __func__, (int)status1, (int)status2); + GGML_ABORT("Fatal error"); + } } - } else { - ncclGroupStart(); - for (int i = 0; i < nreduce; ++i) { - ncclComm_t this_comm; - if (nhave == nreduce) { - this_comm = info.nccl_coms[i]; - } else { - auto status = ncclCommSplit(info.nccl_coms[i], dst->src[i] ? 1 : 0, i, &this_comm, NULL); - GGML_ASSERT(status == ncclSuccess); - } - ggml_cuda_set_device(i); - auto stream = info.all_ctx[i]->stream(); - GGML_ASSERT(stream); - ncclResult_t status; - if (type == GGML_TYPE_F32) { - status = ncclAllReduce(dst->src[i] ? dst->src[i]->data : nullptr, - dst->src[i] ? dst->src[i]->data : nullptr, - ggml_nelements(dst), - ncclFloat, ncclSum, this_comm, stream); - } else { - status = ncclAllReduce(dst->src[i] ? dst->src[i]->data : nullptr, - dst->src[i] ? dst->src[i]->data : nullptr, - ggml_nelements(dst), - ncclHalf, ncclSum, this_comm, stream); - } - if (status != ncclSuccess) { - fprintf(stderr, "%s: ncclAllReduce failed with status %d\n", __func__, (int)status); + ncclGroupStart(); + ggml_cuda_set_device(0); + auto status1 = ncclSend(dst->src[0]->data, ggml_nelements(dst), data_type, 1, info.nccl_coms[0], info.all_ctx[0]->stream()); + ggml_cuda_set_device(1); + auto status2 = ncclRecv(dst->src[1]->data, ggml_nelements(dst), data_type, 0, info.nccl_coms[1], info.all_ctx[1]->stream()); + ncclGroupEnd(); + if (status1 != ncclSuccess || status2 != ncclSuccess) { + fprintf(stderr, "%s: ncclSend/Recv failed with statuses %d, %d\n", __func__, (int)status1, (int)status2); GGML_ABORT("Fatal error"); } } - ncclGroupEnd(); + else { + ncclGroupStart(); + for (int i = 0; i < nreduce; ++i) { + ncclComm_t this_comm; + if (nhave == nreduce) { + this_comm = info.nccl_coms[i]; + } else { + auto status = ncclCommSplit(info.nccl_coms[i], dst->src[i] ? 1 : 0, i, &this_comm, NULL); + GGML_ASSERT(status == ncclSuccess); + } + ggml_cuda_set_device(i); + auto stream = info.all_ctx[i]->stream(); + GGML_ASSERT(stream); + ncclResult_t status; + if (type == GGML_TYPE_F32) { + status = ncclAllReduce(dst->src[i] ? dst->src[i]->data : nullptr, + dst->src[i] ? dst->src[i]->data : nullptr, + ggml_nelements(dst), + ncclFloat, ncclSum, this_comm, stream); + } else { + status = ncclAllReduce(dst->src[i] ? dst->src[i]->data : nullptr, + dst->src[i] ? dst->src[i]->data : nullptr, + ggml_nelements(dst), + ncclHalf, ncclSum, this_comm, stream); + } + if (status != ncclSuccess) { + fprintf(stderr, "%s: ncclAllReduce failed with status %d\n", __func__, (int)status); + GGML_ABORT("Fatal error"); + } + } + ncclGroupEnd(); } ggml_cuda_set_device(ctx.device); //auto tim2 = std::chrono::steady_clock::now();