diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index c1034d8d..ba26db5b 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -843,6 +843,7 @@ struct ggml_backend_cuda_context { int device; std::string name; cudaEvent_t copy_event = nullptr; + cudaEvent_t compute_event = nullptr; bool p2p_enabled = false; cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; diff --git a/ggml/src/ggml-cuda/reduce.cu b/ggml/src/ggml-cuda/reduce.cu index 81a48d88..bc82576b 100644 --- a/ggml/src/ggml-cuda/reduce.cu +++ b/ggml/src/ggml-cuda/reduce.cu @@ -196,7 +196,7 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ // i = 0, peer = 1, ichunk = 1 -> copy part 1 from device 1, device 0 now has parts 0, 1, 2, 3 // etc. // - if (false && dst->ne[1] >= 32) { + if (dst->ne[1] >= 32) { auto nelem = ggml_nelements(dst); auto elem_size = ggml_element_size(dst); auto nelem_per_device = (nelem + nhave - 1)/nhave; @@ -204,17 +204,21 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ for (int ii = 0; ii < nhave; ++ii) { int i = idx[ii]; auto this_ctx = info.all_ctx[i]; - if (!this_ctx->copy_event) { + if (!this_ctx->copy_event || !this_ctx->compute_event || required_size > this_ctx->copy_size) { ggml_cuda_set_device(this_ctx->device); - CUDA_CHECK(cudaEventCreateWithFlags(&this_ctx->copy_event, cudaEventDisableTiming)); - } - if (required_size > this_ctx->copy_size) { - ggml_cuda_set_device(this_ctx->device); - if (this_ctx->copy_buffer) { - CUDA_CHECK(cudaFree(this_ctx->copy_buffer)); + if (!this_ctx->copy_event) { + CUDA_CHECK(cudaEventCreateWithFlags(&this_ctx->copy_event, cudaEventDisableTiming)); + } + if (!this_ctx->compute_event) { + CUDA_CHECK(cudaEventCreateWithFlags(&this_ctx->compute_event, cudaEventDisableTiming)); + } + if (required_size > this_ctx->copy_size) { + if (this_ctx->copy_buffer) { + CUDA_CHECK(cudaFree(this_ctx->copy_buffer)); + } + CUDA_CHECK(ggml_cuda_device_malloc(&this_ctx->copy_buffer, required_size, this_ctx->device)); + this_ctx->copy_size = required_size; } - CUDA_CHECK(ggml_cuda_device_malloc(&this_ctx->copy_buffer, required_size, this_ctx->device)); - this_ctx->copy_size = required_size; } } for (int stage = 0; stage < nhave-1; ++stage) { @@ -224,10 +228,20 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ int peer = idx[(ii+1)%nhave]; auto this_nelem = std::min(nelem_per_device, nelem - ichunk*nelem_per_device); ggml_cuda_set_device(info.all_ctx[peer]->device); + if (stage > 0) { + CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[peer]->stream(), info.all_ctx[i]->compute_event, 0)); + } CUDA_CHECK(cudaMemcpyPeerAsync(info.all_ctx[i]->copy_buffer, info.all_ctx[i]->device, (const char *)dst->src[peer]->data + ichunk*nelem_per_device*elem_size, info.all_ctx[peer]->device, this_nelem*elem_size, info.all_ctx[peer]->stream())); CUDA_CHECK(cudaEventRecord(info.all_ctx[peer]->copy_event, info.all_ctx[peer]->stream())); + ichunk = (ichunk + 1)%nhave; + } + ichunk = stage; + for (int ii = 0; ii < nhave; ++ii) { + int i = idx[ii]; + int peer = idx[(ii+1)%nhave]; + auto this_nelem = std::min(nelem_per_device, nelem - ichunk*nelem_per_device); ggml_cuda_set_device(info.all_ctx[i]->device); CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[peer]->copy_event, 0)); int num_blocks = (this_nelem + CUDA_REDUCE_BLOCK_SIZE - 1)/CUDA_REDUCE_BLOCK_SIZE; @@ -238,6 +252,7 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ k_add<<stream()>>>(this_nelem, (const float *)info.all_ctx[i]->copy_buffer, (float *)dst->src[i]->data + ichunk*nelem_per_device); } + CUDA_CHECK(cudaEventRecord(info.all_ctx[i]->compute_event, info.all_ctx[i]->stream())); ichunk = (ichunk + 1)%nhave; } }