// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. #include "communicator.hpp" #include #include "api.h" namespace mscclpp { namespace { template std::shared_future makeOrderedRecvFuture(Impl* impl, int remoteRank, int tag, Func func) { // Weak placeholder to avoid a reference cycle; updated with the real recvItem after the future is created. auto thisRecvItem = std::make_shared>(); auto future = std::async(std::launch::deferred, [impl, remoteRank, tag, thisRecvItem, lastRecvItem = impl->getLastRecvItem(remoteRank, tag), func = std::move(func)]() mutable { auto cleanup = [impl, remoteRank, tag, thisRecvItem]() { impl->clearLastRecvItemIfMatches(remoteRank, tag, thisRecvItem->lock()); }; if (lastRecvItem) { // Recursive call to the previous receive items lastRecvItem->wait(); } auto result = func(); cleanup(); return result; }); auto sharedFuture = std::shared_future(std::move(future)); auto recvItem = std::make_shared>(sharedFuture); *thisRecvItem = recvItem; impl->setLastRecvItem(remoteRank, tag, recvItem); return sharedFuture; } } // namespace Communicator::Impl::Impl(std::shared_ptr bootstrap, std::shared_ptr context) : bootstrap_(bootstrap) { if (!context) { context_ = Context::create(); } else { context_ = context; } } void Communicator::Impl::setLastRecvItem(int remoteRank, int tag, std::shared_ptr item) { lastRecvItems_[{remoteRank, tag}] = item; } std::shared_ptr Communicator::Impl::getLastRecvItem(int remoteRank, int tag) { auto it = lastRecvItems_.find({remoteRank, tag}); if (it == lastRecvItems_.end()) { return nullptr; } if (it->second->isReady()) { lastRecvItems_.erase(it); return nullptr; } return it->second; } void Communicator::Impl::clearLastRecvItemIfMatches(int remoteRank, int tag, const std::shared_ptr& expectedItem) { auto it = lastRecvItems_.find({remoteRank, tag}); if (it != lastRecvItems_.end() && it->second == expectedItem) { lastRecvItems_.erase(it); } } MSCCLPP_API_CPP Communicator::~Communicator() = default; MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr bootstrap, std::shared_ptr context) : pimpl_(std::make_unique(bootstrap, context)) {} MSCCLPP_API_CPP std::shared_ptr Communicator::bootstrap() { return pimpl_->bootstrap_; } MSCCLPP_API_CPP std::shared_ptr Communicator::context() { return pimpl_->context_; } MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) { return context()->registerMemory(ptr, size, transports); } MSCCLPP_API_CPP void Communicator::sendMemory(RegisteredMemory memory, int remoteRank, int tag) { if (remoteRank == bootstrap()->getRank()) { // Sending memory to self auto& locRecvMemList = pimpl_->localRecvMemories_[tag]; for (auto& locRecvMem : locRecvMemList) { if (!locRecvMem.isReady()) { // Found a local memory that is not ready, set the memory and return locRecvMem.set(std::move(memory)); return; } } // No local memory found, create a new LocalRecvMemory and set the memory LocalRecvMemory locRecvMem; locRecvMem.set(std::move(memory)); locRecvMemList.push_back(std::move(locRecvMem)); return; } bootstrap()->send(memory.serialize(), remoteRank, tag); } MSCCLPP_API_CPP std::shared_future Communicator::recvMemory(int remoteRank, int tag) { if (remoteRank == bootstrap()->getRank()) { // Receiving memory from self auto& locRecvMemList = pimpl_->localRecvMemories_[tag]; for (auto it = locRecvMemList.begin(); it != locRecvMemList.end(); ++it) { if (it->isReady()) { // Found a ready memory, remove it from the list and return its future auto future = it->reference(); locRecvMemList.erase(it); return future; } } // No ready memory found, create a new LocalRecvMemory and return its future LocalRecvMemory locRecvMem; auto future = locRecvMem.reference(); locRecvMemList.push_back(std::move(locRecvMem)); return future; } return makeOrderedRecvFuture(pimpl_.get(), remoteRank, tag, [this, remoteRank, tag]() { std::vector data; bootstrap()->recv(data, remoteRank, tag); return RegisteredMemory::deserialize(data); }); } MSCCLPP_API_CPP std::shared_future Communicator::connect(const Endpoint& localEndpoint, int remoteRank, int tag) { if (remoteRank == bootstrap()->getRank()) { // Connection to self auto remoteEndpoint = context()->createEndpoint(localEndpoint.config()); auto connection = context()->connect(localEndpoint, remoteEndpoint); std::promise promise; promise.set_value(connection); pimpl_->connectionInfos_[connection.impl_.get()] = {remoteRank, tag}; return std::shared_future(promise.get_future()); } bootstrap()->send(localEndpoint.serialize(), remoteRank, tag); return makeOrderedRecvFuture(pimpl_.get(), remoteRank, tag, [this, remoteRank, tag, localEndpoint]() mutable { std::vector data; bootstrap()->recv(data, remoteRank, tag); auto remoteEndpoint = Endpoint::deserialize(data); auto connection = context()->connect(localEndpoint, remoteEndpoint); pimpl_->connectionInfos_[connection.impl_.get()] = {remoteRank, tag}; return connection; }); } MSCCLPP_API_CPP std::shared_future Communicator::connect(const EndpointConfig& localConfig, int remoteRank, int tag) { auto localEndpoint = context()->createEndpoint(localConfig); return connect(localEndpoint, remoteRank, tag); } MSCCLPP_API_CPP std::shared_future Communicator::buildSemaphore(const Connection& connection, int remoteRank, int tag) { SemaphoreStub localStub(connection); bootstrap()->send(localStub.serialize(), remoteRank, tag); return makeOrderedRecvFuture(pimpl_.get(), remoteRank, tag, [this, remoteRank, tag, localStub]() mutable { std::vector data; bootstrap()->recv(data, remoteRank, tag); auto remoteStub = SemaphoreStub::deserialize(data); return Semaphore(localStub, remoteStub); }); } MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) { return pimpl_->connectionInfos_.at(connection.impl_.get()).remoteRank; } MSCCLPP_API_CPP int Communicator::tagOf(const Connection& connection) { return pimpl_->connectionInfos_.at(connection.impl_.get()).tag; } } // namespace mscclpp