// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. #include #include #include #include #include #include #include #include #include "api.h" #include "debug.h" #include "socket.h" #include "unix_socket.hpp" #include "utils_internal.hpp" namespace mscclpp { static void setFilesLimit() { rlimit filesLimit; if (getrlimit(RLIMIT_NOFILE, &filesLimit) != 0) throw SysError("getrlimit failed", errno); filesLimit.rlim_cur = filesLimit.rlim_max; if (setrlimit(RLIMIT_NOFILE, &filesLimit) != 0) throw SysError("setrlimit failed", errno); } /* Socket Interface Selection type */ enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 }; struct ExtInfo { int rank; int nRanks; SocketAddress extAddressListenRoot; SocketAddress extAddressListen; }; MSCCLPP_API_CPP void Bootstrap::groupBarrier(const std::vector& ranks) { int dummy = 0; for (auto rank : ranks) { if (rank != this->getRank()) { this->send(static_cast(&dummy), sizeof(dummy), rank, 0); } } for (auto rank : ranks) { if (rank != this->getRank()) { this->recv(static_cast(&dummy), sizeof(dummy), rank, 0); } } } MSCCLPP_API_CPP int Bootstrap::getNranksPerIpcDomain() const { return getNranksPerNode(); } MSCCLPP_API_CPP void Bootstrap::send(const std::vector& data, int peer, int tag) { size_t size = data.size(); send((void*)&size, sizeof(size_t), peer, tag); send((void*)data.data(), data.size(), peer, tag); } MSCCLPP_API_CPP void Bootstrap::recv(std::vector& data, int peer, int tag) { size_t size; recv((void*)&size, sizeof(size_t), peer, tag); data.resize(size); recv((void*)data.data(), data.size(), peer, tag); } struct UniqueIdInternal { uint64_t magic; union SocketAddress addr; }; static_assert(sizeof(UniqueIdInternal) <= sizeof(UniqueId), "UniqueIdInternal is too large to fit into UniqueId"); class TcpBootstrap::Impl { public: static UniqueId createUniqueId(); static UniqueId getUniqueId(const UniqueIdInternal& uniqueId); Impl(int rank, int nRanks); ~Impl(); void initialize(const UniqueId& uniqueId, int64_t timeoutSec); void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec); void establishConnections(int64_t timeoutSec); UniqueId getUniqueId() const; int getRank(); int getNranks(); int getNranksPerNode(); int getNranksPerIpcDomain(); void allGather(void* allData, int size); void broadcast(void* data, int size, int root); void send(void* data, int size, int peer, int tag); void recv(void* data, int size, int peer, int tag); void barrier(); void close(); private: UniqueIdInternal uniqueId_; int rank_; int nRanks_; int nRanksPerNode_; int nRanksPerIpcDomain_; bool netInitialized; std::unique_ptr listenSockRoot_; std::unique_ptr listenSock_; std::unique_ptr ringRecvSocket_; std::unique_ptr ringSendSocket_; std::vector peerCommAddresses_; std::vector barrierArr_; std::unique_ptr abortFlagStorage_; volatile uint32_t* abortFlag_; std::thread rootThread_; SocketAddress netIfAddr_; std::unordered_map, std::shared_ptr, PairHash> peerSendSockets_; std::unordered_map, std::shared_ptr, PairHash> peerRecvSockets_; void netSend(Socket* sock, const void* data, int size); void netRecv(Socket* sock, void* data, int size); std::shared_ptr getPeerSendSocket(int peer, int tag); std::shared_ptr getPeerRecvSocket(int peer, int tag); UnixSocketServer& unixSocketServer_; static void assignPortToUniqueId(UniqueIdInternal& uniqueId); static void netInit(std::string ipPortPair, std::string interface, SocketAddress& netIfAddr); void bootstrapCreateRoot(); void bootstrapRoot(); void getRemoteAddresses(Socket* listenSock, std::vector& rankAddresses, std::vector& rankAddressesRoot, int& rank); void sendHandleToPeer(int peer, const std::vector& rankAddresses, const std::vector& rankAddressesRoot); }; UniqueId TcpBootstrap::Impl::createUniqueId() { UniqueIdInternal uniqueId; SocketAddress netIfAddr; netInit("", "", netIfAddr); getRandomData(&uniqueId.magic, sizeof(uniqueId_.magic)); std::memcpy(&uniqueId.addr, &netIfAddr, sizeof(SocketAddress)); assignPortToUniqueId(uniqueId); return getUniqueId(uniqueId); } UniqueId TcpBootstrap::Impl::getUniqueId(const UniqueIdInternal& uniqueId) { UniqueId ret; std::memcpy(&ret, &uniqueId, sizeof(uniqueId)); return ret; } TcpBootstrap::Impl::Impl(int rank, int nRanks) : rank_(rank), nRanks_(nRanks), nRanksPerNode_(0), nRanksPerIpcDomain_(0), netInitialized(false), peerCommAddresses_(nRanks, SocketAddress()), barrierArr_(nRanks, 0), abortFlagStorage_(new uint32_t(0)), abortFlag_(abortFlagStorage_.get()), unixSocketServer_(UnixSocketServer::instance()) {} UniqueId TcpBootstrap::Impl::getUniqueId() const { return getUniqueId(uniqueId_); } int TcpBootstrap::Impl::getRank() { return rank_; } int TcpBootstrap::Impl::getNranks() { return nRanks_; } void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId, int64_t timeoutSec) { if (!netInitialized) { netInit("", "", netIfAddr_); netInitialized = true; } std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_)); if (rank_ == 0) { bootstrapCreateRoot(); } char line[MAX_IF_NAME_SIZE + 1]; SocketToString(&uniqueId_.addr, line); INFO(MSCCLPP_INIT, "rank %d nranks %d - connecting to %s", rank_, nRanks_, line); establishConnections(timeoutSec); unixSocketServer_.start(); INFO(MSCCLPP_INIT, "rank %d - unix socket server started", rank_); } void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t timeoutSec) { // first check if it is a trio int nColons = 0; for (auto c : ifIpPortTrio) { if (c == ':') { nColons++; } } std::string ipPortPair = ifIpPortTrio; std::string interface = ""; if (nColons == 2) { // we know the interface = ifIpPortTrio.substr(0, ipPortPair.find_first_of(':')); ipPortPair = ifIpPortTrio.substr(ipPortPair.find_first_of(':') + 1); } if (!netInitialized) { netInit(ipPortPair, interface, netIfAddr_); netInitialized = true; } uniqueId_.magic = 0xdeadbeef; std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress)); SocketGetAddrFromString(&uniqueId_.addr, ipPortPair.c_str()); if (rank_ == 0) { bootstrapCreateRoot(); } establishConnections(timeoutSec); unixSocketServer_.start(); INFO(MSCCLPP_INIT, "rank %d - unix socket server started", rank_); } TcpBootstrap::Impl::~Impl() { if (abortFlag_) { *abortFlag_ = 1; } if (rootThread_.joinable()) { rootThread_.join(); } } void TcpBootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector& rankAddresses, std::vector& rankAddressesRoot, int& rank) { ExtInfo info; SocketAddress zero; std::memset(&zero, 0, sizeof(SocketAddress)); { Socket sock(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_); sock.accept(listenSock); netRecv(&sock, &info, sizeof(info)); } if (this->nRanks_ != info.nRanks) { throw Error("Bootstrap Root : mismatch in rank count from procs " + std::to_string(this->nRanks_) + " : " + std::to_string(info.nRanks), ErrorCode::InternalError); } if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(SocketAddress)) != 0) { throw Error("Bootstrap Root : rank " + std::to_string(info.rank) + " of " + std::to_string(this->nRanks_) + " has already checked in", ErrorCode::InternalError); } // Save the connection handle for that rank rankAddressesRoot[info.rank] = info.extAddressListenRoot; rankAddresses[info.rank] = info.extAddressListen; rank = info.rank; } void TcpBootstrap::Impl::sendHandleToPeer(int peer, const std::vector& rankAddresses, const std::vector& rankAddressesRoot) { int next = (peer + 1) % nRanks_; Socket sock(&rankAddressesRoot[peer], uniqueId_.magic, SocketTypeBootstrap, abortFlag_); sock.connect(); netSend(&sock, &rankAddresses[next], sizeof(SocketAddress)); } void TcpBootstrap::Impl::assignPortToUniqueId(UniqueIdInternal& uniqueId) { std::unique_ptr socket = std::make_unique(&uniqueId.addr, uniqueId.magic, SocketTypeBootstrap); socket->bind(); uniqueId.addr = socket->getAddr(); } void TcpBootstrap::Impl::bootstrapCreateRoot() { listenSockRoot_ = std::make_unique(&uniqueId_.addr, uniqueId_.magic, SocketTypeBootstrap, abortFlag_, 0); listenSockRoot_->bindAndListen(); uniqueId_.addr = listenSockRoot_->getAddr(); rootThread_ = std::thread([this]() { try { bootstrapRoot(); } catch (const std::exception& e) { if (abortFlag_ && *abortFlag_) return; throw e; } }); } void TcpBootstrap::Impl::bootstrapRoot() { int numCollected = 0; std::vector rankAddresses(nRanks_, SocketAddress()); // for initial rank <-> root information exchange std::vector rankAddressesRoot(nRanks_, SocketAddress()); std::memset(rankAddresses.data(), 0, sizeof(SocketAddress) * nRanks_); std::memset(rankAddressesRoot.data(), 0, sizeof(SocketAddress) * nRanks_); setFilesLimit(); TRACE(MSCCLPP_INIT, "BEGIN"); /* Receive addresses from all ranks */ do { int rank; getRemoteAddresses(listenSockRoot_.get(), rankAddresses, rankAddressesRoot, rank); ++numCollected; TRACE(MSCCLPP_INIT, "Received connect from rank %d total %d/%d", rank, numCollected, nRanks_); } while (numCollected < nRanks_ && (!abortFlag_ || *abortFlag_ == 0)); if (abortFlag_ && *abortFlag_) { TRACE(MSCCLPP_INIT, "ABORTED"); return; } TRACE(MSCCLPP_INIT, "COLLECTED ALL %d HANDLES", nRanks_); // Send the connect handle for the next rank in the AllGather ring for (int peer = 0; peer < nRanks_; ++peer) { sendHandleToPeer(peer, rankAddresses, rankAddressesRoot); } TRACE(MSCCLPP_INIT, "DONE"); } void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface, SocketAddress& netIfAddr) { char netIfName[MAX_IF_NAME_SIZE + 1]; if (!ipPortPair.empty()) { if (interface != "") { // we know the int ret = FindInterfaces(netIfName, &netIfAddr, MAX_IF_NAME_SIZE, 1, interface.c_str()); if (ret <= 0) throw Error("NET/Socket : No interface named " + interface + " found.", ErrorCode::InternalError); } else { // we do not know the try to match it next SocketAddress remoteAddr; SocketGetAddrFromString(&remoteAddr, ipPortPair.c_str()); if (FindInterfaceMatchSubnet(netIfName, &netIfAddr, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) { throw Error("NET/Socket : No usable listening interface found", ErrorCode::InternalError); } } } else { int ret = FindInterfaces(netIfName, &netIfAddr, MAX_IF_NAME_SIZE, 1); if (ret <= 0) { throw Error("TcpBootstrap : no socket interface found", ErrorCode::InternalError); } } char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2]; std::sprintf(line, " %s:", netIfName); SocketToString(&netIfAddr, line + strlen(line)); INFO(MSCCLPP_INIT, "TcpBootstrap : Using%s", line); } #define TIMEOUT(__exp) \ do { \ try { \ __exp; \ } catch (const Error& e) { \ if (e.getErrorCode() == ErrorCode::Timeout) { \ throw Error("TcpBootstrap connection timeout", ErrorCode::Timeout); \ } \ throw; \ } \ } while (0); void TcpBootstrap::Impl::establishConnections(int64_t timeoutSec) { const int64_t connectionTimeoutUs = timeoutSec * 1000000; Timer timer; SocketAddress nextAddr; ExtInfo info; TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank_, nRanks_); auto getLeftTime = [&]() { if (connectionTimeoutUs < 0) { // no timeout: always return a large number return int64_t(1e9); } int64_t timeout = connectionTimeoutUs - timer.elapsed(); if (timeout <= 0) throw Error("TcpBootstrap connection timeout", ErrorCode::Timeout); return timeout; }; info.rank = rank_; info.nRanks = nRanks_; uint64_t magic = uniqueId_.magic; // Create socket for other ranks to contact me listenSock_ = std::make_unique(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_); listenSock_->bindAndListen(); info.extAddressListen = listenSock_->getAddr(); { // Create socket for root to contact me Socket lsock(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_); lsock.bindAndListen(); info.extAddressListenRoot = lsock.getAddr(); // stagger connection times to avoid an overload of the root auto randomSleep = [](int rank) { timespec tv; tv.tv_sec = rank / 1000; tv.tv_nsec = 1000000 * (rank % 1000); TRACE(MSCCLPP_INIT, "rank %d delaying connection to root by %ld msec", rank, rank); (void)nanosleep(&tv, NULL); }; if (nRanks_ > 128) { randomSleep(rank_); } // send info on my listening socket to root { Socket sock(&uniqueId_.addr, magic, SocketTypeBootstrap, abortFlag_); TIMEOUT(sock.connect(getLeftTime())); netSend(&sock, &info, sizeof(info)); } // get info on my "next" rank in the bootstrap ring from root { Socket sock(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_); TIMEOUT(sock.accept(&lsock, getLeftTime())); netRecv(&sock, &nextAddr, sizeof(SocketAddress)); } } ringSendSocket_ = std::make_unique(&nextAddr, magic, SocketTypeBootstrap, abortFlag_); TIMEOUT(ringSendSocket_->connect(getLeftTime())); // Accept the connect request from the previous rank in the AllGather ring ringRecvSocket_ = std::make_unique(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_); TIMEOUT(ringRecvSocket_->accept(listenSock_.get(), getLeftTime())); // AllGather all listen handlers peerCommAddresses_[rank_] = listenSock_->getAddr(); allGather(peerCommAddresses_.data(), sizeof(SocketAddress)); TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank_, nRanks_); } int TcpBootstrap::Impl::getNranksPerNode() { if (nRanksPerNode_ > 0) return nRanksPerNode_; int nRanksPerNode = 0; bool useIpv4 = peerCommAddresses_[rank_].sa.sa_family == AF_INET; for (int i = 0; i < nRanks_; i++) { if (useIpv4) { if (peerCommAddresses_[i].sin.sin_addr.s_addr == peerCommAddresses_[rank_].sin.sin_addr.s_addr) { nRanksPerNode++; } } else { if (std::memcmp(&(peerCommAddresses_[i].sin6.sin6_addr), &(peerCommAddresses_[rank_].sin6.sin6_addr), sizeof(in6_addr)) == 0) { nRanksPerNode++; } } } nRanksPerNode_ = nRanksPerNode; return nRanksPerNode_; } int TcpBootstrap::Impl::getNranksPerIpcDomain() { if (nRanksPerIpcDomain_ > 0) return nRanksPerIpcDomain_; std::vector ipcDomainHashes(nRanks_); ipcDomainHashes[rank_] = getIpcDomainHash(); allGather(ipcDomainHashes.data(), sizeof(uint64_t)); int nRanksPerIpcDomain = 0; for (int i = 0; i < nRanks_; ++i) { if (ipcDomainHashes[i] == ipcDomainHashes[rank_]) { ++nRanksPerIpcDomain; } } INFO(MSCCLPP_INIT, "rank %d IPC domain fabric hash 0x%016llx nRanksPerIpcDomain %d", rank_, static_cast(ipcDomainHashes[rank_]), nRanksPerIpcDomain); nRanksPerIpcDomain_ = nRanksPerIpcDomain; return nRanksPerIpcDomain_; } void TcpBootstrap::Impl::allGather(void* allData, int size) { char* data = static_cast(allData); int rank = rank_; int nRanks = nRanks_; TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d", rank, nRanks, size); /* Simple ring based AllGather * At each step i receive data from (rank-i-1) from left * and send previous step's data from (rank-i) to right */ for (int i = 0; i < nRanks - 1; i++) { size_t rSlice = (rank - i - 1 + nRanks) % nRanks; size_t sSlice = (rank - i + nRanks) % nRanks; // Send slice to the right netSend(ringSendSocket_.get(), data + sSlice * size, size); // Recv slice from the left netRecv(ringRecvSocket_.get(), data + rSlice * size, size); } TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nRanks, size); } void TcpBootstrap::Impl::broadcast(void* data, int size, int root) { int rank = rank_; int nRanks = nRanks_; if (nRanks == 1) return; TRACE(MSCCLPP_INIT, "rank %d nranks %d root %d size %d", rank, nRanks, root, size); /*Ring-based broadcast, propagate data in nRanks-1 steps*/ // For ring-based broadcast, propagate data in nRanks-1 steps for (int step = 0; step < nRanks - 1; step++) { if (rank == root) { // Root sends data to next rank in first step if (step == 0) { netSend(ringSendSocket_.get(), data, size); } } else { // Calculate when this rank should receive data int receiveStep = (rank - root - 1 + nRanks) % nRanks; if (step == receiveStep) { // Receive from previous rank netRecv(ringRecvSocket_.get(), data, size); // Forward to next rank (if not last step) if (step < nRanks - 2) { netSend(ringSendSocket_.get(), data, size); } } } } TRACE(MSCCLPP_INIT, "rank %d nranks %d root %d size %d - DONE", rank, nRanks, root, size); } std::shared_ptr TcpBootstrap::Impl::getPeerSendSocket(int peer, int tag) { auto it = peerSendSockets_.find(std::make_pair(peer, tag)); if (it != peerSendSockets_.end()) { return it->second; } auto sock = std::make_shared(&peerCommAddresses_[peer], uniqueId_.magic, SocketTypeBootstrap, abortFlag_); sock->connect(); netSend(sock.get(), &rank_, sizeof(int)); netSend(sock.get(), &tag, sizeof(int)); peerSendSockets_[std::make_pair(peer, tag)] = sock; return sock; } std::shared_ptr TcpBootstrap::Impl::getPeerRecvSocket(int peer, int tag) { auto it = peerRecvSockets_.find(std::make_pair(peer, tag)); if (it != peerRecvSockets_.end()) { return it->second; } for (;;) { auto sock = std::make_shared(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_); sock->accept(listenSock_.get()); int recvPeer, recvTag; netRecv(sock.get(), &recvPeer, sizeof(int)); netRecv(sock.get(), &recvTag, sizeof(int)); peerRecvSockets_[std::make_pair(recvPeer, recvTag)] = sock; if (recvPeer == peer && recvTag == tag) { return sock; } // TODO(chhwang): set an exit condition or timeout } } void TcpBootstrap::Impl::netSend(Socket* sock, const void* data, int size) { sock->send(&size, sizeof(int)); sock->send(const_cast(data), size); } void TcpBootstrap::Impl::netRecv(Socket* sock, void* data, int size) { int recvSize; sock->recv(&recvSize, sizeof(int)); if (recvSize > size) { std::stringstream ss; ss << "Message truncated : received " << recvSize << " bytes instead of " << size; throw Error(ss.str(), ErrorCode::InvalidUsage); } sock->recv(data, std::min(recvSize, size)); } void TcpBootstrap::Impl::send(void* data, int size, int peer, int tag) { auto sock = getPeerSendSocket(peer, tag); netSend(sock.get(), data, size); } void TcpBootstrap::Impl::recv(void* data, int size, int peer, int tag) { auto sock = getPeerRecvSocket(peer, tag); netRecv(sock.get(), data, size); } void TcpBootstrap::Impl::barrier() { allGather(barrierArr_.data(), sizeof(int)); } void TcpBootstrap::Impl::close() { listenSockRoot_.reset(nullptr); listenSock_.reset(nullptr); ringRecvSocket_.reset(nullptr); ringSendSocket_.reset(nullptr); peerSendSockets_.clear(); peerRecvSockets_.clear(); unixSocketServer_.stop(); UnixSocketClient::instance().reset(); } MSCCLPP_API_CPP UniqueId TcpBootstrap::createUniqueId() { return Impl::createUniqueId(); } MSCCLPP_API_CPP TcpBootstrap::TcpBootstrap(int rank, int nRanks) { pimpl_ = std::make_unique(rank, nRanks); } MSCCLPP_API_CPP UniqueId TcpBootstrap::getUniqueId() const { return pimpl_->getUniqueId(); } MSCCLPP_API_CPP int TcpBootstrap::getRank() const { return pimpl_->getRank(); } MSCCLPP_API_CPP int TcpBootstrap::getNranks() const { return pimpl_->getNranks(); } MSCCLPP_API_CPP int TcpBootstrap::getNranksPerNode() const { return pimpl_->getNranksPerNode(); } MSCCLPP_API_CPP int TcpBootstrap::getNranksPerIpcDomain() const { return pimpl_->getNranksPerIpcDomain(); } MSCCLPP_API_CPP void TcpBootstrap::send(void* data, int size, int peer, int tag) { pimpl_->send(data, size, peer, tag); } MSCCLPP_API_CPP void TcpBootstrap::recv(void* data, int size, int peer, int tag) { pimpl_->recv(data, size, peer, tag); } MSCCLPP_API_CPP void TcpBootstrap::allGather(void* allData, int size) { pimpl_->allGather(allData, size); } MSCCLPP_API_CPP void TcpBootstrap::broadcast(void* data, int size, int root) { pimpl_->broadcast(data, size, root); } MSCCLPP_API_CPP void TcpBootstrap::initialize(UniqueId uniqueId, int64_t timeoutSec) { pimpl_->initialize(uniqueId, timeoutSec); } MSCCLPP_API_CPP void TcpBootstrap::initialize(const std::string& ipPortPair, int64_t timeoutSec) { pimpl_->initialize(ipPortPair, timeoutSec); } MSCCLPP_API_CPP void TcpBootstrap::barrier() { pimpl_->barrier(); } MSCCLPP_API_CPP TcpBootstrap::~TcpBootstrap() { pimpl_->close(); } } // namespace mscclpp