// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. #include "communicator.hpp" #include "api.h" #include "debug.h" namespace mscclpp { Communicator::Impl::Impl(std::shared_ptr bootstrap, std::shared_ptr context) : bootstrap_(bootstrap) { if (!context) { context_ = std::make_shared(); } else { context_ = context; } } MSCCLPP_API_CPP Communicator::~Communicator() = default; MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr bootstrap, std::shared_ptr context) : pimpl_(std::make_unique(bootstrap, context)) {} MSCCLPP_API_CPP std::shared_ptr Communicator::bootstrap() { return pimpl_->bootstrap_; } MSCCLPP_API_CPP std::shared_ptr Communicator::context() { return pimpl_->context_; } MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) { return context()->registerMemory(ptr, size, transports); } struct MemorySender : public Setuppable { MemorySender(RegisteredMemory memory, int remoteRank, int tag) : memory_(memory), remoteRank_(remoteRank), tag_(tag) {} void beginSetup(std::shared_ptr bootstrap) override { bootstrap->send(memory_.serialize(), remoteRank_, tag_); } RegisteredMemory memory_; int remoteRank_; int tag_; }; MSCCLPP_API_CPP void Communicator::sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag) { onSetup(std::make_shared(memory, remoteRank, tag)); } struct MemoryReceiver : public Setuppable { MemoryReceiver(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {} void endSetup(std::shared_ptr bootstrap) override { std::vector data; bootstrap->recv(data, remoteRank_, tag_); memoryPromise_.set_value(RegisteredMemory::deserialize(data)); } std::promise memoryPromise_; int remoteRank_; int tag_; }; MSCCLPP_API_CPP NonblockingFuture Communicator::recvMemoryOnSetup(int remoteRank, int tag) { auto memoryReceiver = std::make_shared(remoteRank, tag); onSetup(memoryReceiver); return NonblockingFuture(memoryReceiver->memoryPromise_.get_future()); } struct Communicator::Impl::Connector : public Setuppable { Connector(Communicator& comm, Communicator::Impl& commImpl_, int remoteRank, int tag, EndpointConfig localConfig) : comm_(comm), commImpl_(commImpl_), remoteRank_(remoteRank), tag_(tag), localEndpoint_(comm.context()->createEndpoint(localConfig)) {} void beginSetup(std::shared_ptr bootstrap) override { bootstrap->send(localEndpoint_.serialize(), remoteRank_, tag_); } void endSetup(std::shared_ptr bootstrap) override { std::vector data; bootstrap->recv(data, remoteRank_, tag_); auto remoteEndpoint = Endpoint::deserialize(data); auto connection = comm_.context()->connect(localEndpoint_, remoteEndpoint); commImpl_.connectionInfos_[connection.get()] = {remoteRank_, tag_}; connectionPromise_.set_value(connection); INFO(MSCCLPP_INIT, "Connection %d -> %d created (%s)", comm_.bootstrap()->getRank(), remoteRank_, connection->getTransportName().c_str()); } std::promise> connectionPromise_; Communicator& comm_; Communicator::Impl& commImpl_; int remoteRank_; int tag_; Endpoint localEndpoint_; }; MSCCLPP_API_CPP NonblockingFuture> Communicator::connectOnSetup( int remoteRank, int tag, EndpointConfig localConfig) { auto connector = std::make_shared(*this, *pimpl_, remoteRank, tag, localConfig); onSetup(connector); return NonblockingFuture>(connector->connectionPromise_.get_future()); } MSCCLPP_API_CPP std::shared_ptr Communicator::connctNvlsCollective(std::vector allRanks, EndpointConfig config) { auto bootstrap = this->bootstrap(); int rank = bootstrap->getRank(); bool isRoot = false; bool amongAllRanks = false; int rootRank = allRanks[0]; for (auto nvlsRank : allRanks) { if (nvlsRank == rank) amongAllRanks = true; rootRank = std::min(rootRank, nvlsRank); } if (amongAllRanks == false) { throw Error("rank is not among allRanks", ErrorCode::InvalidUsage); } if (rootRank == rank) isRoot = true; std::shared_ptr conn; if (isRoot) { conn = std::make_shared(config.nvlsBufferSize, allRanks.size()); auto serialized = conn->serialize(); for (auto nvlsRank : allRanks) { if (nvlsRank != rank) bootstrap->send(serialized, nvlsRank, 0); } } else { std::vector data; bootstrap->recv(data, rootRank, 0); conn = std::make_shared(data); } // Now let's synchronize all ranks bootstrap->groupBarrier(allRanks); // now it is safe to add my device conn->addDevice(); // sync here to make sure all ranks have added their devices bootstrap->groupBarrier(allRanks); return conn; } MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) { return pimpl_->connectionInfos_.at(&connection).remoteRank; } MSCCLPP_API_CPP int Communicator::tagOf(const Connection& connection) { return pimpl_->connectionInfos_.at(&connection).tag; } MSCCLPP_API_CPP void Communicator::onSetup(std::shared_ptr setuppable) { pimpl_->toSetup_.push_back(setuppable); } MSCCLPP_API_CPP void Communicator::setup() { for (auto& setuppable : pimpl_->toSetup_) { setuppable->beginSetup(pimpl_->bootstrap_); } for (auto& setuppable : pimpl_->toSetup_) { setuppable->endSetup(pimpl_->bootstrap_); } pimpl_->toSetup_.clear(); } } // namespace mscclpp