Add cudaStreamSynchronize() at the end of fused up gate

That way we will know if this is what fails.
This commit is contained in:
Iwan Kawrakow
2025-05-22 13:22:52 +03:00
parent 2ca7e29d7d
commit 11e674af0e

View File

@@ -2796,6 +2796,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
}
}
CUDA_CHECK(cudaStreamSynchronize(stream));
return fuse_down;
}
@@ -3063,10 +3065,10 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_
CUDA_CHECK(cudaGetDevice(&cur_device));
if (backend_src != backend_dst) {
if (cuda_ctx_src->device != cur_device) {
GGML_CUDA_LOG_WARN("%s: attempt to copy on device %d while current device is %d\n", __func__, cuda_ctx_src->device, cur_device);
CUDA_CHECK(cudaSetDevice(cuda_ctx_src->device));
}
//if (cuda_ctx_src->device != cur_device) {
// GGML_CUDA_LOG_WARN("%s: attempt to copy on device %d while current device is %d\n", __func__, cuda_ctx_src->device, cur_device);
// CUDA_CHECK(cudaSetDevice(cuda_ctx_src->device));
//}
// copy on src stream
if (cuda_ctx_src->device == cuda_ctx_dst->device) {
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
@@ -3097,10 +3099,10 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_
CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0));
} else {
// src and dst are on the same backend
if (cuda_ctx_src->device != cur_device) {
GGML_CUDA_LOG_WARN("%s: attempt to copy on device %d while current device is %d\n", __func__, cuda_ctx_src->device, cur_device);
CUDA_CHECK(cudaSetDevice(cuda_ctx_src->device));
}
//if (cuda_ctx_src->device != cur_device) {
// GGML_CUDA_LOG_WARN("%s: attempt to copy on device %d while current device is %d\n", __func__, cuda_ctx_src->device, cur_device);
// CUDA_CHECK(cudaSetDevice(cuda_ctx_src->device));
//}
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
}
return true;
@@ -3112,12 +3114,12 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend, cons
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
auto stream = cuda_ctx->stream();
GGML_ASSERT(stream);
int cur_device;
(void)cudaGetDevice(&cur_device);
if (cur_device != cuda_ctx->device) {
GGML_CUDA_LOG_WARN("%s: curent device is %d, context device is %d\n", __func__, cur_device, cuda_ctx->device);
CUDA_CHECK(cudaSetDevice(cuda_ctx->device));
}
//int cur_device;
//(void)cudaGetDevice(&cur_device);
//if (cur_device != cuda_ctx->device) {
// GGML_CUDA_LOG_WARN("%s: curent device is %d, context device is %d\n", __func__, cur_device, cuda_ctx->device);
// CUDA_CHECK(cudaSetDevice(cuda_ctx->device));
//}
auto err = cudaStreamSynchronize(stream);
if (err != cudaSuccess) {
@@ -3127,10 +3129,10 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend, cons
}
//CUDA_CHECK(cudaStreamSynchronize(stream));
if (cur_device != cuda_ctx->device) {
GGML_CUDA_LOG_WARN("%s: reverting device to %d\n", __func__, cur_device);
CUDA_CHECK(cudaSetDevice(cur_device));
}
//if (cur_device != cuda_ctx->device) {
// GGML_CUDA_LOG_WARN("%s: reverting device to %d\n", __func__, cur_device);
// CUDA_CHECK(cudaSetDevice(cur_device));
//}
GGML_UNUSED(backend);
}
@@ -3812,18 +3814,18 @@ GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
/* .context = */ ctx
};
#ifndef GGML_CUDA_NO_PEER_COPY
if (num_devices > 1) {
CUDA_CHECK(cudaSetDevice(device));
for (int i = 0; i < num_devices; ++i) {
if (i == device) continue;
cudaError_t err = cudaDeviceEnablePeerAccess(i, 0);
if (err != cudaSuccess && err != cudaErrorPeerAccessAlreadyEnabled) {
GGML_CUDA_LOG_ERROR("Failed to enable peer access from %d to %d: %s", device, i, cudaGetErrorString(err));
}
}
}
#endif
//#ifndef GGML_CUDA_NO_PEER_COPY
// if (num_devices > 1) {
// CUDA_CHECK(cudaSetDevice(device));
// for (int i = 0; i < num_devices; ++i) {
// if (i == device) continue;
// cudaError_t err = cudaDeviceEnablePeerAccess(i, 0);
// if (err != cudaSuccess && err != cudaErrorPeerAccessAlreadyEnabled) {
// GGML_CUDA_LOG_ERROR("Failed to enable peer access from %d to %d: %s", device, i, cudaGetErrorString(err));
// }
// }
// }
//#endif
return cuda_backend;
}