mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-21 13:29:45 +00:00
Clean up completed communicator receives
Erase completed receive bookkeeping from the communicator once the deferred receive future finishes, while preserving ordered receive chaining for repeated rank/tag operations. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -3,10 +3,60 @@
|
||||
|
||||
#include "communicator.hpp"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "api.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Fn>
|
||||
class ScopeGuard {
|
||||
public:
|
||||
explicit ScopeGuard(Fn fn) : fn_(std::move(fn)) {}
|
||||
ScopeGuard(const ScopeGuard&) = delete;
|
||||
ScopeGuard& operator=(const ScopeGuard&) = delete;
|
||||
~ScopeGuard() { fn_(); }
|
||||
|
||||
private:
|
||||
Fn fn_;
|
||||
};
|
||||
|
||||
template <typename Fn>
|
||||
ScopeGuard<Fn> makeScopeGuard(Fn fn) {
|
||||
return ScopeGuard<Fn>(std::move(fn));
|
||||
}
|
||||
|
||||
template <typename T, typename Impl, typename Fn>
|
||||
std::shared_future<T> makeOrderedRecvFuture(Impl* impl, int remoteRank, int tag, Fn fn) {
|
||||
auto thisRecvItem = std::make_shared<std::weak_ptr<BaseRecvItem>>();
|
||||
auto future = std::async(std::launch::deferred, [impl, remoteRank, tag, thisRecvItem,
|
||||
lastRecvItem = impl->getLastRecvItem(remoteRank, tag),
|
||||
fn = std::move(fn)]() mutable {
|
||||
[[maybe_unused]] auto cleanup = makeScopeGuard([impl, remoteRank, tag, thisRecvItem]() {
|
||||
auto item = thisRecvItem->lock();
|
||||
auto it = impl->lastRecvItems_.find({remoteRank, tag});
|
||||
if (item && it != impl->lastRecvItems_.end() && it->second == item) {
|
||||
impl->lastRecvItems_.erase(it);
|
||||
}
|
||||
});
|
||||
|
||||
if (lastRecvItem) {
|
||||
// Recursive call to the previous receive items
|
||||
lastRecvItem->wait();
|
||||
}
|
||||
return fn();
|
||||
});
|
||||
auto sharedFuture = std::shared_future<T>(std::move(future));
|
||||
auto recvItem = std::make_shared<RecvItem<T>>(sharedFuture);
|
||||
*thisRecvItem = recvItem;
|
||||
impl->setLastRecvItem(remoteRank, tag, recvItem);
|
||||
return sharedFuture;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Communicator::Impl::Impl(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<Context> context)
|
||||
: bootstrap_(bootstrap) {
|
||||
if (!context) {
|
||||
@@ -83,19 +133,11 @@ MSCCLPP_API_CPP std::shared_future<RegisteredMemory> Communicator::recvMemory(in
|
||||
locRecvMemList.push_back(std::move(locRecvMem));
|
||||
return future;
|
||||
}
|
||||
auto future = std::async(std::launch::deferred,
|
||||
[this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag)]() {
|
||||
if (lastRecvItem) {
|
||||
// Recursive call to the previous receive items
|
||||
lastRecvItem->wait();
|
||||
}
|
||||
std::vector<char> data;
|
||||
bootstrap()->recv(data, remoteRank, tag);
|
||||
return RegisteredMemory::deserialize(data);
|
||||
});
|
||||
auto shared_future = std::shared_future<RegisteredMemory>(std::move(future));
|
||||
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<RegisteredMemory>>(shared_future));
|
||||
return shared_future;
|
||||
return makeOrderedRecvFuture<RegisteredMemory>(pimpl_.get(), remoteRank, tag, [this, remoteRank, tag]() {
|
||||
std::vector<char> data;
|
||||
bootstrap()->recv(data, remoteRank, tag);
|
||||
return RegisteredMemory::deserialize(data);
|
||||
});
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_future<Connection> Communicator::connect(const Endpoint& localEndpoint, int remoteRank,
|
||||
@@ -112,12 +154,8 @@ MSCCLPP_API_CPP std::shared_future<Connection> Communicator::connect(const Endpo
|
||||
|
||||
bootstrap()->send(localEndpoint.serialize(), remoteRank, tag);
|
||||
|
||||
auto future = std::async(std::launch::deferred, [this, remoteRank, tag, localEndpoint,
|
||||
lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag)]() mutable {
|
||||
if (lastRecvItem) {
|
||||
// Recursive call to the previous receive items
|
||||
lastRecvItem->wait();
|
||||
}
|
||||
return makeOrderedRecvFuture<Connection>(pimpl_.get(), remoteRank, tag,
|
||||
[this, remoteRank, tag, localEndpoint]() mutable {
|
||||
std::vector<char> data;
|
||||
bootstrap()->recv(data, remoteRank, tag);
|
||||
auto remoteEndpoint = Endpoint::deserialize(data);
|
||||
@@ -125,9 +163,6 @@ MSCCLPP_API_CPP std::shared_future<Connection> Communicator::connect(const Endpo
|
||||
pimpl_->connectionInfos_[connection.impl_.get()] = {remoteRank, tag};
|
||||
return connection;
|
||||
});
|
||||
auto shared_future = std::shared_future<Connection>(std::move(future));
|
||||
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<Connection>>(shared_future));
|
||||
return shared_future;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_future<Connection> Communicator::connect(const EndpointConfig& localConfig, int remoteRank,
|
||||
@@ -141,21 +176,12 @@ MSCCLPP_API_CPP std::shared_future<Semaphore> Communicator::buildSemaphore(const
|
||||
SemaphoreStub localStub(connection);
|
||||
bootstrap()->send(localStub.serialize(), remoteRank, tag);
|
||||
|
||||
auto future =
|
||||
std::async(std::launch::deferred, [this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag),
|
||||
localStub = localStub]() mutable {
|
||||
if (lastRecvItem) {
|
||||
// Recursive call to the previous receive items
|
||||
lastRecvItem->wait();
|
||||
}
|
||||
std::vector<char> data;
|
||||
bootstrap()->recv(data, remoteRank, tag);
|
||||
auto remoteStub = SemaphoreStub::deserialize(data);
|
||||
return Semaphore(localStub, remoteStub);
|
||||
});
|
||||
auto shared_future = std::shared_future<Semaphore>(std::move(future));
|
||||
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<Semaphore>>(shared_future));
|
||||
return shared_future;
|
||||
return makeOrderedRecvFuture<Semaphore>(pimpl_.get(), remoteRank, tag, [this, remoteRank, tag, localStub]() mutable {
|
||||
std::vector<char> data;
|
||||
bootstrap()->recv(data, remoteRank, tag);
|
||||
auto remoteStub = SemaphoreStub::deserialize(data);
|
||||
return Semaphore(localStub, remoteStub);
|
||||
});
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) {
|
||||
|
||||
Reference in New Issue
Block a user