From 224b3deb84fb2977318c56acdb7131c4e9f49eeb Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 13 May 2026 01:22:51 +0000 Subject: [PATCH] Clean up completed communicator receives Erase completed receive bookkeeping from the communicator once the deferred receive future finishes, while preserving ordered receive chaining for repeated rank/tag operations. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/core/communicator.cc | 100 ++++++++++++++++++++++++--------------- 1 file changed, 63 insertions(+), 37 deletions(-) diff --git a/src/core/communicator.cc b/src/core/communicator.cc index c95ca421..97fadbbd 100644 --- a/src/core/communicator.cc +++ b/src/core/communicator.cc @@ -3,10 +3,60 @@ #include "communicator.hpp" +#include + #include "api.h" namespace mscclpp { +namespace { + +template +class ScopeGuard { + public: + explicit ScopeGuard(Fn fn) : fn_(std::move(fn)) {} + ScopeGuard(const ScopeGuard&) = delete; + ScopeGuard& operator=(const ScopeGuard&) = delete; + ~ScopeGuard() { fn_(); } + + private: + Fn fn_; +}; + +template +ScopeGuard makeScopeGuard(Fn fn) { + return ScopeGuard(std::move(fn)); +} + +template +std::shared_future makeOrderedRecvFuture(Impl* impl, int remoteRank, int tag, Fn fn) { + auto thisRecvItem = std::make_shared>(); + auto future = std::async(std::launch::deferred, [impl, remoteRank, tag, thisRecvItem, + lastRecvItem = impl->getLastRecvItem(remoteRank, tag), + fn = std::move(fn)]() mutable { + [[maybe_unused]] auto cleanup = makeScopeGuard([impl, remoteRank, tag, thisRecvItem]() { + auto item = thisRecvItem->lock(); + auto it = impl->lastRecvItems_.find({remoteRank, tag}); + if (item && it != impl->lastRecvItems_.end() && it->second == item) { + impl->lastRecvItems_.erase(it); + } + }); + + if (lastRecvItem) { + // Recursive call to the previous receive items + lastRecvItem->wait(); + } + return fn(); + }); + auto sharedFuture = std::shared_future(std::move(future)); + auto recvItem = std::make_shared>(sharedFuture); + *thisRecvItem = recvItem; + impl->setLastRecvItem(remoteRank, tag, recvItem); + return sharedFuture; +} + +} // namespace + Communicator::Impl::Impl(std::shared_ptr bootstrap, std::shared_ptr context) : bootstrap_(bootstrap) { if (!context) { @@ -83,19 +133,11 @@ MSCCLPP_API_CPP std::shared_future Communicator::recvMemory(in locRecvMemList.push_back(std::move(locRecvMem)); return future; } - auto future = std::async(std::launch::deferred, - [this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag)]() { - if (lastRecvItem) { - // Recursive call to the previous receive items - lastRecvItem->wait(); - } - std::vector data; - bootstrap()->recv(data, remoteRank, tag); - return RegisteredMemory::deserialize(data); - }); - auto shared_future = std::shared_future(std::move(future)); - pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared>(shared_future)); - return shared_future; + return makeOrderedRecvFuture(pimpl_.get(), remoteRank, tag, [this, remoteRank, tag]() { + std::vector data; + bootstrap()->recv(data, remoteRank, tag); + return RegisteredMemory::deserialize(data); + }); } MSCCLPP_API_CPP std::shared_future Communicator::connect(const Endpoint& localEndpoint, int remoteRank, @@ -112,12 +154,8 @@ MSCCLPP_API_CPP std::shared_future Communicator::connect(const Endpo bootstrap()->send(localEndpoint.serialize(), remoteRank, tag); - auto future = std::async(std::launch::deferred, [this, remoteRank, tag, localEndpoint, - lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag)]() mutable { - if (lastRecvItem) { - // Recursive call to the previous receive items - lastRecvItem->wait(); - } + return makeOrderedRecvFuture(pimpl_.get(), remoteRank, tag, + [this, remoteRank, tag, localEndpoint]() mutable { std::vector data; bootstrap()->recv(data, remoteRank, tag); auto remoteEndpoint = Endpoint::deserialize(data); @@ -125,9 +163,6 @@ MSCCLPP_API_CPP std::shared_future Communicator::connect(const Endpo pimpl_->connectionInfos_[connection.impl_.get()] = {remoteRank, tag}; return connection; }); - auto shared_future = std::shared_future(std::move(future)); - pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared>(shared_future)); - return shared_future; } MSCCLPP_API_CPP std::shared_future Communicator::connect(const EndpointConfig& localConfig, int remoteRank, @@ -141,21 +176,12 @@ MSCCLPP_API_CPP std::shared_future Communicator::buildSemaphore(const SemaphoreStub localStub(connection); bootstrap()->send(localStub.serialize(), remoteRank, tag); - auto future = - std::async(std::launch::deferred, [this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag), - localStub = localStub]() mutable { - if (lastRecvItem) { - // Recursive call to the previous receive items - lastRecvItem->wait(); - } - std::vector data; - bootstrap()->recv(data, remoteRank, tag); - auto remoteStub = SemaphoreStub::deserialize(data); - return Semaphore(localStub, remoteStub); - }); - auto shared_future = std::shared_future(std::move(future)); - pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared>(shared_future)); - return shared_future; + return makeOrderedRecvFuture(pimpl_.get(), remoteRank, tag, [this, remoteRank, tag, localStub]() mutable { + std::vector data; + bootstrap()->recv(data, remoteRank, tag); + auto remoteStub = SemaphoreStub::deserialize(data); + return Semaphore(localStub, remoteStub); + }); } MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) {