Make sure we are on the corrent device before synchronizing

This commit is contained in:
Iwan Kawrakow
2025-05-22 08:37:15 +03:00
parent 8b339b1453
commit e87bfe39ba

View File

@@ -3085,6 +3085,12 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend, cons
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
auto stream = cuda_ctx->stream();
GGML_ASSERT(stream);
int cur_device;
(void)cudaGetDevice(&cur_device);
if (cur_device != cuda_ctx->device) {
GGML_CUDA_LOG_WARN("%s: curent device is %d, context device is %d\n", __func__, cur_device, cuda_ctx->device);
CUDA_CHECK(cudaSetDevice(cuda_ctx->device));
}
auto err = cudaStreamSynchronize(stream);
if (err != cudaSuccess) {
@@ -3094,6 +3100,10 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend, cons
}
//CUDA_CHECK(cudaStreamSynchronize(stream));
if (cur_device != cuda_ctx->device) {
GGML_CUDA_LOG_WARN("%s: reverting device to %d\n", __func__, cur_device);
CUDA_CHECK(cudaSetDevice(cur_device));
}
GGML_UNUSED(backend);
}