Files
mscclpp/src/bootstrap/bootstrap.cc
2023-04-21 09:41:52 +00:00

1051 lines
35 KiB
C++

#include "bootstrap.h"
#include "utils.h"
#include <cstring>
#include <mutex>
#include <queue>
#include <thread>
#include <sys/resource.h>
#include <sys/types.h>
namespace {
uint64_t hashUniqueId(const mscclppBootstrapHandle& id)
{
const char* bytes = (const char*)&id;
uint64_t h = 0xdeadbeef;
for (int i = 0; i < (int)sizeof(mscclppBootstrapHandle); i++) {
h ^= h >> 32;
h *= 0x8db3db47fa2994ad;
h += bytes[i];
}
return h;
}
mscclppResult_t setFilesLimit()
{
struct rlimit filesLimit;
SYSCHECK(getrlimit(RLIMIT_NOFILE, &filesLimit), "getrlimit");
filesLimit.rlim_cur = filesLimit.rlim_max;
SYSCHECK(setrlimit(RLIMIT_NOFILE, &filesLimit), "setrlimit");
return mscclppSuccess;
}
} // namespace
/* Socket Interface Selection type */
enum bootstrapInterface_t
{
findSubnetIf = -1,
dontCareIf = -2
};
struct MscclppBootstrap::UniqueId
{
uint64_t magic;
union mscclppSocketAddress addr;
};
struct unexpectedConn
{
int peer;
int tag;
struct mscclppSocket sock;
};
struct extInfo
{
int rank;
int nRanks;
union mscclppSocketAddress extAddressListenRoot;
union mscclppSocketAddress extAddressListen;
};
class MscclppBootstrap::Impl
{
public:
Impl(std::string ipPortPair, int rank, int nRanks, const mscclppBootstrapHandle handle);
~Impl();
mscclppResult_t initialize();
mscclppResult_t allGather(void* allData, int size);
mscclppResult_t send(void* data, int size, int peer, int tag);
mscclppResult_t recv(void* data, int size, int peer, int tag);
mscclppResult_t barrier();
mscclppResult_t close();
MscclppBootstrap::UniqueId uniqueId_;
private:
int rank_;
int nRanks_;
mscclppSocket listenSock_;
mscclppSocket ringRecvSocket_;
mscclppSocket ringSendSocket_;
std::vector<mscclppSocketAddress> peerCommAddresses_;
std::vector<mscclppSocketAddress> peerProxyAddresses_;
std::queue<unexpectedConn> unexpectedConnections_;
volatile uint32_t* abortFlag_;
std::thread rootThread_;
char netIfName_[MAX_IF_NAME_SIZE + 1];
union mscclppSocketAddress netIfAddr_;
static mscclppResult_t netSend(mscclppSocket* sock, const void* data, int size);
static mscclppResult_t netRecv(mscclppSocket* sock, void* data, int size);
mscclppResult_t bootstrapRoot();
mscclppResult_t getRemoteAddresses(mscclppSocket* listenSock, std::vector<mscclppSocketAddress>& rankAddresses,
std::vector<mscclppSocketAddress>& rankAddressesRoot, int& rank);
mscclppResult_t sendHandleToPeer(int peer, const std::vector<mscclppSocketAddress>& rankAddresses,
const std::vector<mscclppSocketAddress>& rankAddressesRoot);
mscclppResult_t netInit(std::string ipPortPair);
};
MscclppBootstrap::Impl::Impl(std::string ipPortPair, int rank, int nRanks, const mscclppBootstrapHandle handle)
: rank_(rank), nRanks_(nRanks), peerCommAddresses_(nRanks, mscclppSocketAddress()),
peerProxyAddresses_(nRanks, mscclppSocketAddress()), abortFlag_(nullptr)
{
int ret = netInit(ipPortPair);
if (ret != mscclppSuccess) {
throw std::runtime_error("Failed to initialize network");
}
mscclppBootstrapHandle zeroHandle = {0};
if (memcmp(&handle, &zeroHandle, sizeof(mscclppBootstrapHandle)) != 0) {
uniqueId_.magic = handle.magic;
uniqueId_.addr = handle.addr;
return;
}
if (!ipPortPair.empty()) {
uniqueId_.magic = 0xdeadbeef;
} else {
mscclppResult_t ret = getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic));
if (ret != mscclppSuccess) {
throw std::runtime_error("getting random data failed");
}
}
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(union mscclppSocketAddress));
if (rank_ == 0) {
rootThread_ = std::thread(&MscclppBootstrap::Impl::bootstrapRoot, this);
}
}
MscclppBootstrap::Impl::~Impl()
{
if (rootThread_.joinable()) {
rootThread_.join();
}
}
mscclppResult_t MscclppBootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock,
std::vector<mscclppSocketAddress>& rankAddresses,
std::vector<mscclppSocketAddress>& rankAddressesRoot,
int& rank)
{
mscclppSocket sock;
extInfo info;
mscclppResult_t res = mscclppSuccess;
mscclppSocketAddress zero;
std::memset(&zero, 0, sizeof(mscclppSocketAddress));
res = mscclppSocketInit(&sock);
if (res != mscclppSuccess) {
WARN("Bootstrap Root : mscclppSocketInit failed");
return res;
}
res = mscclppSocketAccept(&sock, listenSock);
if (res != mscclppSuccess) {
WARN("Bootstrap Root : mscclppSocketAccept failed");
return res;
}
res = netRecv(&sock, &info, sizeof(info));
if (res != mscclppSuccess) {
WARN("Bootstrap Root : netRecv failed");
return res;
}
res = mscclppSocketClose(&sock);
if (res != mscclppSuccess) {
WARN("Bootstrap Root : mscclppSocketClose failed");
return res;
}
if (this->nRanks_ != info.nRanks) {
WARN("Bootstrap Root : mismatch in rank count from procs %d : %d", this->nRanks_, info.nRanks);
return res;
}
if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(mscclppSocketAddress)) != 0) {
WARN("Bootstrap Root : rank %d of %d ranks has already checked in", info.rank, this->nRanks_);
return res;
}
// Save the connection handle for that rank
rankAddressesRoot[info.rank] = info.extAddressListenRoot;
rankAddresses[info.rank] = info.extAddressListen;
rank = info.rank;
return res;
}
mscclppResult_t MscclppBootstrap::Impl::sendHandleToPeer(int peer,
const std::vector<mscclppSocketAddress>& rankAddresses,
const std::vector<mscclppSocketAddress>& rankAddressesRoot)
{
mscclppSocket sock;
mscclppResult_t res;
int next = (peer + 1) % this->nRanks_;
res = mscclppSocketInit(&sock, &rankAddressesRoot[peer], this->uniqueId_.magic, mscclppSocketTypeBootstrap);
if (res != mscclppSuccess) {
WARN("Bootstrap Root : mscclppSocketInit failed");
return res;
}
res = mscclppSocketConnect(&sock);
if (res != mscclppSuccess) {
WARN("Bootstrap Root : mscclppSocketConnect failed");
return res;
}
res = netSend(&sock, &rankAddresses[next], sizeof(mscclppSocketAddress));
if (res != mscclppSuccess) {
WARN("Bootstrap Root : netSend failed");
return res;
}
res = mscclppSocketClose(&sock);
if (res != mscclppSuccess) {
WARN("Bootstrap Root : mscclppSocketClose failed");
return res;
}
return mscclppSuccess;
}
mscclppResult_t MscclppBootstrap::Impl::bootstrapRoot()
{
mscclppResult_t res = mscclppSuccess;
int numCollected = 0;
std::vector<mscclppSocketAddress> rankAddresses(this->nRanks_, mscclppSocketAddress());
// for initial rank <-> root information exchange
std::vector<mscclppSocketAddress> rankAddressesRoot(this->nRanks_, mscclppSocketAddress());
std::memset(rankAddresses.data(), 0, sizeof(mscclppSocketAddress) * this->nRanks_);
std::memset(rankAddressesRoot.data(), 0, sizeof(mscclppSocketAddress) * this->nRanks_);
setFilesLimit();
mscclppSocket listenSock;
MSCCLPPCHECK(
mscclppSocketInit(&listenSock, &uniqueId_.addr, uniqueId_.magic, mscclppSocketTypeBootstrap, nullptr, 0));
MSCCLPPCHECK(mscclppSocketListen(&listenSock));
TRACE(MSCCLPP_INIT, "BEGIN");
/* Receive addresses from all ranks */
do {
int rank;
res = getRemoteAddresses(&listenSock, rankAddresses, rankAddressesRoot, rank);
if (res != mscclppSuccess) {
WARN("Bootstrap Root : getRemoteAddresses failed");
break;
}
++numCollected;
TRACE(MSCCLPP_INIT, "Received connect from rank %d total %d/%d", rank, numCollected, this->nRanks_);
} while (numCollected < this->nRanks_);
TRACE(MSCCLPP_INIT, "COLLECTED ALL %d HANDLES", this->nRanks_);
// Send the connect handle for the next rank in the AllGather ring
for (int peer = 0; peer < this->nRanks_; ++peer) {
res = sendHandleToPeer(peer, rankAddresses, rankAddressesRoot);
if (res != mscclppSuccess) {
WARN("Bootstrap Root : sendHandleToPeer failed");
break;
}
}
if (res == mscclppSuccess) {
TRACE(MSCCLPP_INIT, "SENT OUT ALL %d HANDLES", this->nRanks_);
}
TRACE(MSCCLPP_INIT, "DONE");
return res;
}
mscclppResult_t MscclppBootstrap::Impl::netInit(std::string ipPortPair)
{
if (!ipPortPair.empty()) {
union mscclppSocketAddress remoteAddr;
if (mscclppSocketGetAddrFromString(&remoteAddr, ipPortPair.c_str()) != mscclppSuccess) {
WARN("Invalid MSCCLPP_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
return mscclppInvalidArgument;
}
if (mscclppFindInterfaceMatchSubnet(netIfName_, &netIfAddr_, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
WARN("NET/Socket : No usable listening interface found");
return mscclppSystemError;
}
} else {
int ret =
mscclppFindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1);
if (ret <= 0) {
WARN("Bootstrap : no socket interface found");
return mscclppInternalError;
}
}
char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2];
std::sprintf(line, " %s:", netIfName_);
mscclppSocketToString(&netIfAddr_, line + strlen(line));
INFO(MSCCLPP_INIT, "Bootstrap : Using%s", line);
return mscclppSuccess;
}
mscclppResult_t MscclppBootstrap::Impl::initialize()
{
mscclppSocket* proxySocket;
mscclppSocketAddress nextAddr;
mscclppSocket sock, listenSockRoot;
extInfo info;
TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank, nranks);
info.rank = this->rank_;
info.nRanks = this->nRanks_;
uint64_t magic = this->uniqueId_.magic;
// Create socket for other ranks to contact me
MSCCLPPCHECK(
mscclppSocketInit(&this->listenSock_, &netIfAddr_, magic, mscclppSocketTypeBootstrap, this->abortFlag_));
MSCCLPPCHECK(mscclppSocketListen(&this->listenSock_));
MSCCLPPCHECK(mscclppSocketGetAddr(&this->listenSock_, &info.extAddressListen));
// Create socket for root to contact me
MSCCLPPCHECK(
mscclppSocketInit(&listenSockRoot, &netIfAddr_, magic, mscclppSocketTypeBootstrap, this->abortFlag_));
MSCCLPPCHECK(mscclppSocketListen(&listenSockRoot));
MSCCLPPCHECK(mscclppSocketGetAddr(&listenSockRoot, &info.extAddressListenRoot));
// stagger connection times to avoid an overload of the root
auto randomSleep = [](int rank) {
struct timespec tv;
tv.tv_sec = rank / 1000;
tv.tv_nsec = 1000000 * (rank % 1000);
TRACE(MSCCLPP_INIT, "rank %d delaying connection to root by %ld msec", rank, rank);
(void)nanosleep(&tv, NULL);
};
if (this->nRanks_ > 128) {
randomSleep(this->rank_);
}
// send info on my listening socket to root
MSCCLPPCHECK(mscclppSocketInit(&sock, &this->uniqueId_.addr, magic, mscclppSocketTypeBootstrap, this->abortFlag_));
MSCCLPPCHECK(mscclppSocketConnect(&sock));
MSCCLPPCHECK(netSend(&sock, &info, sizeof(info)));
MSCCLPPCHECK(mscclppSocketClose(&sock));
// get info on my "next" rank in the bootstrap ring from root
MSCCLPPCHECK(mscclppSocketInit(&sock));
MSCCLPPCHECK(mscclppSocketAccept(&sock, &listenSockRoot));
MSCCLPPCHECK(netRecv(&sock, &nextAddr, sizeof(union mscclppSocketAddress)));
MSCCLPPCHECK(mscclppSocketClose(&sock));
MSCCLPPCHECK(mscclppSocketClose(&listenSockRoot));
MSCCLPPCHECK(
mscclppSocketInit(&this->ringSendSocket_, &nextAddr, magic, mscclppSocketTypeBootstrap, this->abortFlag_));
MSCCLPPCHECK(mscclppSocketConnect(&this->ringSendSocket_));
// Accept the connect request from the previous rank in the AllGather ring
MSCCLPPCHECK(mscclppSocketInit(&this->ringRecvSocket_));
MSCCLPPCHECK(mscclppSocketAccept(&this->ringRecvSocket_, &this->listenSock_));
// AllGather all listen handlers
MSCCLPPCHECK(mscclppSocketGetAddr(&this->listenSock_, &this->peerCommAddresses_[rank_]));
MSCCLPPCHECK(allGather(this->peerCommAddresses_.data(), sizeof(union mscclppSocketAddress)));
// proxy is aborted through a message; don't set abortFlag
MSCCLPPCHECK(mscclppCalloc(&proxySocket, 1));
MSCCLPPCHECK(mscclppSocketInit(proxySocket, &netIfAddr_, magic, mscclppSocketTypeProxy, this->abortFlag_));
MSCCLPPCHECK(mscclppSocketListen(proxySocket));
MSCCLPPCHECK(mscclppSocketGetAddr(proxySocket, &this->peerProxyAddresses_[rank_]));
MSCCLPPCHECK(allGather(this->peerProxyAddresses_.data(), sizeof(union mscclppSocketAddress)));
TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank, nranks);
return mscclppSuccess;
}
mscclppResult_t MscclppBootstrap::Impl::allGather(void* allData, int size)
{
char* data = static_cast<char*>(allData);
int rank = this->rank_;
int nRanks = this->nRanks_;
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d", rank, nRanks, size);
/* Simple ring based AllGather
* At each step i receive data from (rank-i-1) from left
* and send previous step's data from (rank-i) to right
*/
for (int i = 0; i < nRanks - 1; i++) {
size_t rSlice = (rank - i - 1 + nRanks) % nRanks;
size_t sSlice = (rank - i + nRanks) % nRanks;
// Send slice to the right
MSCCLPPCHECK(netSend(&this->ringSendSocket_, data + sSlice * size, size));
// Recv slice from the left
MSCCLPPCHECK(netRecv(&this->ringRecvSocket_, data + rSlice * size, size));
}
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size);
return mscclppSuccess;
}
mscclppResult_t MscclppBootstrap::Impl::netSend(mscclppSocket* sock, const void* data, int size)
{
MSCCLPPCHECK(mscclppSocketSend(sock, &size, sizeof(int)));
MSCCLPPCHECK(mscclppSocketSend(sock, const_cast<void*>(data), size));
return mscclppSuccess;
}
mscclppResult_t MscclppBootstrap::Impl::netRecv(mscclppSocket* sock, void* data, int size)
{
int recvSize;
MSCCLPPCHECK(mscclppSocketRecv(sock, &recvSize, sizeof(int)));
if (recvSize > size) {
WARN("Message truncated : received %d bytes instead of %d", recvSize, size);
return mscclppInternalError;
}
MSCCLPPCHECK(mscclppSocketRecv(sock, data, std::min(recvSize, size)));
return mscclppSuccess;
}
mscclppResult_t MscclppBootstrap::Impl::send(void* data, int size, int peer, int tag)
{
mscclppSocket sock;
MSCCLPPCHECK(mscclppSocketInit(&sock, &this->peerCommAddresses_[peer], this->uniqueId_.magic,
mscclppSocketTypeBootstrap, this->abortFlag_));
MSCCLPPCHECK(mscclppSocketConnect(&sock));
MSCCLPPCHECK(netSend(&sock, &this->rank_, sizeof(int)));
MSCCLPPCHECK(netSend(&sock, &tag, sizeof(int)));
MSCCLPPCHECK(netSend(&sock, data, size));
MSCCLPPCHECK(mscclppSocketClose(&sock));
return mscclppSuccess;
}
mscclppResult_t MscclppBootstrap::Impl::recv(void* data, int size, int peer, int tag)
{
return mscclppSuccess;
}
mscclppResult_t MscclppBootstrap::Impl::barrier()
{
return mscclppSuccess;
}
mscclppResult_t MscclppBootstrap::Impl::close()
{
return mscclppSuccess;
}
MscclppBootstrap::MscclppBootstrap(std::string ipPortPair, int rank, int nRanks)
{
pimpl_ = std::make_unique<Impl>(ipPortPair, rank, nRanks, mscclppBootstrapHandle{0});
}
MscclppBootstrap::MscclppBootstrap(mscclppBootstrapHandle handle, int rank, int nRanks)
{
pimpl_ = std::make_unique<Impl>("", rank, nRanks, handle);
}
MscclppBootstrap::UniqueId MscclppBootstrap::GetUniqueId()
{
return pimpl_->uniqueId_;
}
void MscclppBootstrap::Send(void* data, int size, int peer, int tag)
{
mscclppResult_t res = pimpl_->send(data, size, peer, tag);
if (res != mscclppSuccess) {
throw std::runtime_error("MscclppBootstrap::Send failed");
}
}
void MscclppBootstrap::Recv(void* data, int size, int peer, int tag)
{
mscclppResult_t res = pimpl_->recv(data, size, peer, tag);
if (res != mscclppSuccess) {
throw std::runtime_error("MscclppBootstrap::Recv failed");
}
}
void MscclppBootstrap::AllGather(void* allData, int size)
{
mscclppResult_t res = pimpl_->allGather(allData, size);
if (res != mscclppSuccess) {
throw std::runtime_error("MscclppBootstrap::AllGather failed");
}
}
void MscclppBootstrap::Initialize()
{
mscclppResult_t res = pimpl_->initialize();
if (res != mscclppSuccess) {
throw std::runtime_error("MscclppBootstrap::Initialize failed");
}
}
void MscclppBootstrap::Barrier()
{
mscclppResult_t res = pimpl_->barrier();
if (res != mscclppSuccess) {
throw std::runtime_error("MscclppBootstrap::Barrier failed");
}
}
void MscclppBootstrap::Close()
{
mscclppResult_t res = pimpl_->close();
if (res != mscclppSuccess) {
throw std::runtime_error("MscclppBootstrap::Close failed");
}
}
// ------------------- Old bootstrap functions -------------------
struct bootstrapRootArgs
{
struct mscclppSocket* listenSock;
uint64_t magic;
};
/* Init functions */
static char bootstrapNetIfName[MAX_IF_NAME_SIZE + 1];
static union mscclppSocketAddress bootstrapNetIfAddr;
static int bootstrapNetInitDone = 0;
pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER;
mscclppResult_t bootstrapNetInit(const char* ip_port_pair)
{
if (bootstrapNetInitDone == 0) {
pthread_mutex_lock(&bootstrapNetLock);
if (bootstrapNetInitDone == 0) {
const char* env;
if (ip_port_pair) {
env = ip_port_pair;
} else {
env = getenv("MSCCLPP_COMM_ID");
}
if (env) {
union mscclppSocketAddress remoteAddr;
if (mscclppSocketGetAddrFromString(&remoteAddr, env) != mscclppSuccess) {
WARN("Invalid MSCCLPP_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
return mscclppInvalidArgument;
}
if (mscclppFindInterfaceMatchSubnet(bootstrapNetIfName, &bootstrapNetIfAddr, &remoteAddr, MAX_IF_NAME_SIZE,
1) <= 0) {
WARN("NET/Socket : No usable listening interface found");
return mscclppSystemError;
}
} else {
int nIfs = mscclppFindInterfaces(bootstrapNetIfName, &bootstrapNetIfAddr, MAX_IF_NAME_SIZE, 1);
if (nIfs <= 0) {
WARN("Bootstrap : no socket interface found");
return mscclppInternalError;
}
}
char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2];
sprintf(line, " %s:", bootstrapNetIfName);
mscclppSocketToString(&bootstrapNetIfAddr, line + strlen(line));
INFO(MSCCLPP_INIT, "Bootstrap : Using%s", line);
bootstrapNetInitDone = 1;
}
pthread_mutex_unlock(&bootstrapNetLock);
}
return mscclppSuccess;
}
// Additional sync functions
static mscclppResult_t bootstrapNetSend(struct mscclppSocket* sock, void* data, int size)
{
MSCCLPPCHECK(mscclppSocketSend(sock, &size, sizeof(int)));
MSCCLPPCHECK(mscclppSocketSend(sock, data, size));
return mscclppSuccess;
}
static mscclppResult_t bootstrapNetRecv(struct mscclppSocket* sock, void* data, int size)
{
int recvSize;
MSCCLPPCHECK(mscclppSocketRecv(sock, &recvSize, sizeof(int)));
if (recvSize > size) {
WARN("Message truncated : received %d bytes instead of %d", recvSize, size);
return mscclppInternalError;
}
MSCCLPPCHECK(mscclppSocketRecv(sock, data, std::min(recvSize, size)));
return mscclppSuccess;
}
// struct extInfo
// {
// int rank;
// int nranks;
// union mscclppSocketAddress extAddressListenRoot;
// union mscclppSocketAddress extAddressListen;
// };
#include <sys/resource.h>
// static mscclppResult_t setFilesLimit()
// {
// struct rlimit filesLimit;
// SYSCHECK(getrlimit(RLIMIT_NOFILE, &filesLimit), "getrlimit");
// filesLimit.rlim_cur = filesLimit.rlim_max;
// SYSCHECK(setrlimit(RLIMIT_NOFILE, &filesLimit), "setrlimit");
// return mscclppSuccess;
// }
static void* bootstrapRoot(void* rargs)
{
struct bootstrapRootArgs* args = (struct bootstrapRootArgs*)rargs;
struct mscclppSocket* listenSock = args->listenSock;
uint64_t magic = args->magic;
mscclppResult_t res = mscclppSuccess;
int nranks = 0, c = 0;
struct extInfo info;
union mscclppSocketAddress* rankAddresses = NULL;
union mscclppSocketAddress* rankAddressesRoot = NULL; // for initial rank <-> root information exchange
union mscclppSocketAddress* zero = NULL;
MSCCLPPCHECKGOTO(mscclppCalloc(&zero, 1), res, out);
setFilesLimit();
TRACE(MSCCLPP_INIT, "BEGIN");
/* Receive addresses from all ranks */
do {
struct mscclppSocket sock;
MSCCLPPCHECKGOTO(mscclppSocketInit(&sock), res, out);
MSCCLPPCHECKGOTO(mscclppSocketAccept(&sock, listenSock), res, out);
MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, &info, sizeof(info)), res, out);
MSCCLPPCHECKGOTO(mscclppSocketClose(&sock), res, out);
if (c == 0) {
nranks = info.nRanks;
MSCCLPPCHECKGOTO(mscclppCalloc(&rankAddresses, nranks), res, out);
MSCCLPPCHECKGOTO(mscclppCalloc(&rankAddressesRoot, nranks), res, out);
}
if (nranks != info.nRanks) {
WARN("Bootstrap Root : mismatch in rank count from procs %d : %d", nranks, info.nRanks);
goto out;
}
if (memcmp(zero, &rankAddressesRoot[info.rank], sizeof(union mscclppSocketAddress)) != 0) {
WARN("Bootstrap Root : rank %d of %d ranks has already checked in", info.rank, nranks);
goto out;
}
// Save the connection handle for that rank
memcpy(rankAddressesRoot + info.rank, &info.extAddressListenRoot, sizeof(union mscclppSocketAddress));
memcpy(rankAddresses + info.rank, &info.extAddressListen, sizeof(union mscclppSocketAddress));
++c;
TRACE(MSCCLPP_INIT, "Received connect from rank %d total %d/%d", info.rank, c, nranks);
} while (c < nranks);
TRACE(MSCCLPP_INIT, "COLLECTED ALL %d HANDLES", nranks);
// Send the connect handle for the next rank in the AllGather ring
for (int r = 0; r < nranks; ++r) {
int next = (r + 1) % nranks;
struct mscclppSocket sock;
MSCCLPPCHECKGOTO(mscclppSocketInit(&sock, rankAddressesRoot + r, magic, mscclppSocketTypeBootstrap), res, out);
MSCCLPPCHECKGOTO(mscclppSocketConnect(&sock), res, out);
MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, rankAddresses + next, sizeof(union mscclppSocketAddress)), res, out);
MSCCLPPCHECKGOTO(mscclppSocketClose(&sock), res, out);
}
TRACE(MSCCLPP_INIT, "SENT OUT ALL %d HANDLES", nranks);
out:
if (listenSock != NULL) {
mscclppSocketClose(listenSock);
free(listenSock);
}
if (rankAddresses)
free(rankAddresses);
if (rankAddressesRoot)
free(rankAddressesRoot);
if (zero)
free(zero);
free(rargs);
TRACE(MSCCLPP_INIT, "DONE");
return NULL;
}
mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle)
{
struct mscclppSocket* listenSock;
struct bootstrapRootArgs* args;
pthread_t thread;
MSCCLPPCHECK(mscclppCalloc(&listenSock, 1));
MSCCLPPCHECK(mscclppSocketInit(listenSock, &handle->addr, handle->magic, mscclppSocketTypeBootstrap, NULL, 0));
MSCCLPPCHECK(mscclppSocketListen(listenSock));
MSCCLPPCHECK(mscclppSocketGetAddr(listenSock, &handle->addr));
MSCCLPPCHECK(mscclppCalloc(&args, 1));
args->listenSock = listenSock;
args->magic = handle->magic;
NEQCHECK(pthread_create(&thread, NULL, bootstrapRoot, (void*)args), 0);
mscclppSetThreadName(thread, "MSCCLPP BootstrapR");
NEQCHECK(pthread_detach(thread), 0); // will not be pthread_join()'d
return mscclppSuccess;
}
// #include <netinet/in.h>
// #include <arpa/inet.h>
mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot, const char* ip_port_pair)
{
memset(handle, 0, sizeof(mscclppBootstrapHandle));
const char* env = NULL;
if (ip_port_pair) {
env = ip_port_pair;
} else {
env = getenv("MSCCLPP_COMM_ID");
}
if (env) {
handle->magic = 0xdeadbeef;
INFO(MSCCLPP_ENV, "MSCCLPP_COMM_ID set by environment to %s", env);
if (mscclppSocketGetAddrFromString(&handle->addr, env) != mscclppSuccess) {
WARN("Invalid MSCCLPP_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
return mscclppInvalidArgument;
}
if (isRoot)
MSCCLPPCHECK(bootstrapCreateRoot(handle));
} else {
MSCCLPPCHECK(getRandomData(&handle->magic, sizeof(handle->magic)));
memcpy(&handle->addr, &bootstrapNetIfAddr, sizeof(union mscclppSocketAddress));
MSCCLPPCHECK(bootstrapCreateRoot(handle));
}
// printf("addr = %s port = %d\n", inet_ntoa(handle->addr.sin.sin_addr), (int)ntohs(handle->addr.sin.sin_port));
// printf("addr = %s\n", inet_ntoa((*(struct sockaddr_in*)&handle->addr.sa).sin_addr));
return mscclppSuccess;
}
struct unexConn
{
int peer;
int tag;
struct mscclppSocket sock;
struct unexConn* next;
};
struct bootstrapState
{
struct mscclppSocket listenSock;
struct mscclppSocket ringRecvSocket;
struct mscclppSocket ringSendSocket;
union mscclppSocketAddress* peerCommAddresses;
union mscclppSocketAddress* peerProxyAddresses;
struct unexConn* unexpectedConnections;
int cudaDev;
int rank;
int nranks;
uint64_t magic;
volatile uint32_t* abortFlag;
};
mscclppResult_t bootstrapInit(struct mscclppBootstrapHandle* handle, struct mscclppComm* comm)
{
int rank = comm->rank;
int nranks = comm->nRanks;
struct bootstrapState* state;
struct mscclppSocket* proxySocket;
mscclppSocketAddress nextAddr;
struct mscclppSocket sock, listenSockRoot;
struct extInfo info;
MSCCLPPCHECK(mscclppCalloc(&state, 1));
state->rank = rank;
state->nranks = nranks;
state->abortFlag = comm->abortFlag;
comm->bootstrap = state;
comm->magic = state->magic = handle->magic;
TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank, nranks);
info.rank = rank;
info.nRanks = nranks;
// Create socket for other ranks to contact me
MSCCLPPCHECK(mscclppSocketInit(&state->listenSock, &bootstrapNetIfAddr, comm->magic, mscclppSocketTypeBootstrap,
comm->abortFlag));
MSCCLPPCHECK(mscclppSocketListen(&state->listenSock));
MSCCLPPCHECK(mscclppSocketGetAddr(&state->listenSock, &info.extAddressListen));
// Create socket for root to contact me
MSCCLPPCHECK(
mscclppSocketInit(&listenSockRoot, &bootstrapNetIfAddr, comm->magic, mscclppSocketTypeBootstrap, comm->abortFlag));
MSCCLPPCHECK(mscclppSocketListen(&listenSockRoot));
MSCCLPPCHECK(mscclppSocketGetAddr(&listenSockRoot, &info.extAddressListenRoot));
// stagger connection times to avoid an overload of the root
if (nranks > 128) {
long msec = rank;
struct timespec tv;
tv.tv_sec = msec / 1000;
tv.tv_nsec = 1000000 * (msec % 1000);
TRACE(MSCCLPP_INIT, "rank %d delaying connection to root by %ld msec", rank, msec);
(void)nanosleep(&tv, NULL);
}
// send info on my listening socket to root
MSCCLPPCHECK(mscclppSocketInit(&sock, &handle->addr, comm->magic, mscclppSocketTypeBootstrap, comm->abortFlag));
MSCCLPPCHECK(mscclppSocketConnect(&sock));
MSCCLPPCHECK(bootstrapNetSend(&sock, &info, sizeof(info)));
MSCCLPPCHECK(mscclppSocketClose(&sock));
// get info on my "next" rank in the bootstrap ring from root
MSCCLPPCHECK(mscclppSocketInit(&sock));
MSCCLPPCHECK(mscclppSocketAccept(&sock, &listenSockRoot));
MSCCLPPCHECK(bootstrapNetRecv(&sock, &nextAddr, sizeof(union mscclppSocketAddress)));
MSCCLPPCHECK(mscclppSocketClose(&sock));
MSCCLPPCHECK(mscclppSocketClose(&listenSockRoot));
MSCCLPPCHECK(
mscclppSocketInit(&state->ringSendSocket, &nextAddr, comm->magic, mscclppSocketTypeBootstrap, comm->abortFlag));
MSCCLPPCHECK(mscclppSocketConnect(&state->ringSendSocket));
// Accept the connect request from the previous rank in the AllGather ring
MSCCLPPCHECK(mscclppSocketInit(&state->ringRecvSocket));
MSCCLPPCHECK(mscclppSocketAccept(&state->ringRecvSocket, &state->listenSock));
// AllGather all listen handlers
MSCCLPPCHECK(mscclppCalloc(&state->peerCommAddresses, nranks));
MSCCLPPCHECK(mscclppSocketGetAddr(&state->listenSock, state->peerCommAddresses + rank));
MSCCLPPCHECK(bootstrapAllGather(state, state->peerCommAddresses, sizeof(union mscclppSocketAddress)));
// Create the service proxy
MSCCLPPCHECK(mscclppCalloc(&state->peerProxyAddresses, nranks));
// proxy is aborted through a message; don't set abortFlag
MSCCLPPCHECK(mscclppCalloc(&proxySocket, 1));
MSCCLPPCHECK(
mscclppSocketInit(proxySocket, &bootstrapNetIfAddr, comm->magic, mscclppSocketTypeProxy, comm->abortFlag));
MSCCLPPCHECK(mscclppSocketListen(proxySocket));
MSCCLPPCHECK(mscclppSocketGetAddr(proxySocket, state->peerProxyAddresses + rank));
MSCCLPPCHECK(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(union mscclppSocketAddress)));
// MSCCLPPCHECK(mscclppProxyInit(comm, proxySocket, state->peerProxyAddresses));
TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank, nranks);
return mscclppSuccess;
}
mscclppResult_t bootstrapAllGather(void* commState, void* allData, int size)
{
struct bootstrapState* state = (struct bootstrapState*)commState;
char* data = (char*)allData;
int rank = state->rank;
int nranks = state->nranks;
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d", rank, nranks, size);
/* Simple ring based AllGather
* At each step i receive data from (rank-i-1) from left
* and send previous step's data from (rank-i) to right
*/
for (int i = 0; i < nranks - 1; i++) {
size_t rslice = (rank - i - 1 + nranks) % nranks;
size_t sslice = (rank - i + nranks) % nranks;
// Send slice to the right
MSCCLPPCHECK(bootstrapNetSend(&state->ringSendSocket, data + sslice * size, size));
// Recv slice from the left
MSCCLPPCHECK(bootstrapNetRecv(&state->ringRecvSocket, data + rslice * size, size));
}
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size);
return mscclppSuccess;
}
mscclppResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size)
{
mscclppResult_t ret = mscclppSuccess;
struct bootstrapState* state = (struct bootstrapState*)commState;
struct mscclppSocket sock;
MSCCLPPCHECKGOTO(mscclppSocketInit(&sock, state->peerCommAddresses + peer, state->magic, mscclppSocketTypeBootstrap,
state->abortFlag),
ret, fail);
MSCCLPPCHECKGOTO(mscclppSocketConnect(&sock), ret, fail);
MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, &state->rank, sizeof(int)), ret, fail);
MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, &tag, sizeof(int)), ret, fail);
MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, data, size), ret, fail);
exit:
MSCCLPPCHECK(mscclppSocketClose(&sock));
return ret;
fail:
goto exit;
}
mscclppResult_t bootstrapBarrier(void* commState, int* ranks, int rank, int nranks, int tag)
{
if (nranks == 1)
return mscclppSuccess;
TRACE(MSCCLPP_INIT, "rank %d nranks %d tag %x - ENTER", rank, nranks, tag);
/* Simple intra process barrier
*
* Based on the dissemination algorithm by Debra Hensgen, Raphael Finkel, and Udi Manbet,
* "Two Algorithms for Barrier Synchronization," International Journal of Parallel Programming, 17(1):1-17, 1988"
*/
int data[1];
for (int mask = 1; mask < nranks; mask <<= 1) {
int src = (rank - mask + nranks) % nranks;
int dst = (rank + mask) % nranks;
MSCCLPPCHECK(bootstrapSend(commState, ranks[dst], tag, data, sizeof(data)));
MSCCLPPCHECK(bootstrapRecv(commState, ranks[src], tag, data, sizeof(data)));
}
TRACE(MSCCLPP_INIT, "rank %d nranks %d tag %x - DONE", rank, nranks, tag);
return mscclppSuccess;
}
mscclppResult_t bootstrapIntraNodeAllGather(void* commState, int* ranks, int rank, int nranks, void* allData, int size)
{
if (nranks == 1)
return mscclppSuccess;
char* data = (char*)allData;
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - ENTER", rank, nranks, size);
for (int i = 1; i < nranks; i++) {
int src = (rank - i + nranks) % nranks;
int dst = (rank + i) % nranks;
MSCCLPPCHECK(bootstrapSend(commState, ranks[dst], /*tag=*/i, data + rank * size, size));
MSCCLPPCHECK(bootstrapRecv(commState, ranks[src], /*tag=*/i, data + src * size, size));
}
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size);
return mscclppSuccess;
}
mscclppResult_t unexpectedEnqueue(struct bootstrapState* state, int peer, int tag, struct mscclppSocket* sock)
{
// New unex
struct unexConn* unex;
MSCCLPPCHECK(mscclppCalloc(&unex, 1));
unex->peer = peer;
unex->tag = tag;
memcpy(&unex->sock, sock, sizeof(struct mscclppSocket));
// Enqueue
struct unexConn* list = state->unexpectedConnections;
if (list == NULL) {
state->unexpectedConnections = unex;
return mscclppSuccess;
}
while (list->next)
list = list->next;
list->next = unex;
return mscclppSuccess;
}
mscclppResult_t unexpectedDequeue(struct bootstrapState* state, int peer, int tag, struct mscclppSocket* sock,
int* found)
{
struct unexConn* elem = state->unexpectedConnections;
struct unexConn* prev = NULL;
*found = 0;
while (elem) {
if (elem->peer == peer && elem->tag == tag) {
if (prev == NULL) {
state->unexpectedConnections = elem->next;
} else {
prev->next = elem->next;
}
memcpy(sock, &elem->sock, sizeof(struct mscclppSocket));
free(elem);
*found = 1;
return mscclppSuccess;
}
prev = elem;
elem = elem->next;
}
return mscclppSuccess;
}
static void unexpectedFree(struct bootstrapState* state)
{
struct unexConn* elem = state->unexpectedConnections;
struct unexConn* prev = NULL;
while (elem) {
prev = elem;
elem = elem->next;
free(prev);
}
return;
}
// We can't know who we'll receive from, so we need to receive everything at once
mscclppResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size)
{
mscclppResult_t ret = mscclppSuccess;
struct bootstrapState* state = (struct bootstrapState*)commState;
struct mscclppSocket sock;
int newPeer, newTag;
// Search unexpected connections first
int found;
MSCCLPPCHECK(unexpectedDequeue(state, peer, tag, &sock, &found));
if (found) {
MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, ((char*)data), size), ret, fail);
goto exit;
}
// Then look for new connections
while (1) {
MSCCLPPCHECKGOTO(mscclppSocketInit(&sock), ret, fail);
MSCCLPPCHECKGOTO(mscclppSocketAccept(&sock, &state->listenSock), ret, fail);
MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, &newPeer, sizeof(int)), ret, fail);
MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, &newTag, sizeof(int)), ret, fail);
if (newPeer == peer && newTag == tag) {
MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, ((char*)data), size), ret, fail);
goto exit;
}
// Unexpected connection. Save for later.
MSCCLPPCHECKGOTO(unexpectedEnqueue(state, newPeer, newTag, &sock), ret, fail);
}
exit:
MSCCLPPCHECK(mscclppSocketClose(&sock));
return ret;
fail:
goto exit;
}
mscclppResult_t bootstrapClose(void* commState)
{
struct bootstrapState* state = (struct bootstrapState*)commState;
if (state->unexpectedConnections != NULL) {
unexpectedFree(state);
if (*state->abortFlag == 0) {
WARN("Unexpected connections are not empty");
return mscclppInternalError;
}
}
MSCCLPPCHECK(mscclppSocketClose(&state->listenSock));
MSCCLPPCHECK(mscclppSocketClose(&state->ringSendSocket));
MSCCLPPCHECK(mscclppSocketClose(&state->ringRecvSocket));
free(state->peerCommAddresses);
free(state);
return mscclppSuccess;
}
mscclppResult_t bootstrapAbort(void* commState)
{
struct bootstrapState* state = (struct bootstrapState*)commState;
if (commState == NULL)
return mscclppSuccess;
MSCCLPPCHECK(mscclppSocketClose(&state->listenSock));
MSCCLPPCHECK(mscclppSocketClose(&state->ringSendSocket));
MSCCLPPCHECK(mscclppSocketClose(&state->ringRecvSocket));
free(state->peerCommAddresses);
free(state->peerProxyAddresses);
free(state);
return mscclppSuccess;
}