mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-03 02:20:01 +00:00
Explicitely set device
This commit is contained in:
@@ -44,8 +44,10 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
static const int devs[8] = {0,1, 2,3, 0,2, 1,3};
|
||||
for (int ip = 0; ip < 4; ++ip) {
|
||||
ncclGroupStart();
|
||||
ggml_cuda_set_device(devs[2*ip+0]);
|
||||
auto status1 = ncclAllReduce(dst->src[devs[2*ip+0]]->data, dst->src[devs[2*ip+0]]->data,
|
||||
ggml_nelements(dst), data_type, ncclSum, info.nccl_coms[2*ip+0], info.all_ctx[devs[2*ip+0]]->stream());
|
||||
ggml_cuda_set_device(devs[2*ip+1]);
|
||||
auto status2 = ncclAllReduce(dst->src[devs[2*ip+1]]->data, dst->src[devs[2*ip+1]]->data,
|
||||
ggml_nelements(dst), data_type, ncclSum, info.nccl_coms[2*ip+1], info.all_ctx[devs[2*ip+1]]->stream());
|
||||
ncclGroupEnd();
|
||||
@@ -57,6 +59,7 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
} else {
|
||||
ncclGroupStart();
|
||||
for (int i = 0; i < nreduce; ++i) {
|
||||
ggml_cuda_set_device(i);
|
||||
auto stream = info.all_ctx[i]->stream();
|
||||
GGML_ASSERT(stream);
|
||||
auto status = ncclAllReduce(dst->src[i]->data, dst->src[i]->data, ggml_nelements(dst), data_type, ncclSum,
|
||||
@@ -78,6 +81,7 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
auto status = ncclCommSplit(info.nccl_coms[i], dst->src[i] ? 1 : 0, i, &this_comm, NULL);
|
||||
GGML_ASSERT(status == ncclSuccess);
|
||||
}
|
||||
ggml_cuda_set_device(i);
|
||||
auto stream = info.all_ctx[i]->stream();
|
||||
GGML_ASSERT(stream);
|
||||
ncclResult_t status;
|
||||
@@ -99,6 +103,7 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
}
|
||||
ncclGroupEnd();
|
||||
}
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
//auto tim2 = std::chrono::steady_clock::now();
|
||||
//printf("%s: launched in %g us\n", __func__, 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(tim2-tim1).count());
|
||||
return;
|
||||
|
||||
Reference in New Issue
Block a user