diff --git a/src/bootstrap/bootstrap.cc b/src/bootstrap/bootstrap.cc index 84d6d25b..0dac4e2f 100644 --- a/src/bootstrap/bootstrap.cc +++ b/src/bootstrap/bootstrap.cc @@ -64,9 +64,6 @@ struct extInfo class MscclppBootstrap::Impl { public: - static char bootstrapNetIfName[MAX_IF_NAME_SIZE + 1]; - static union mscclppSocketAddress bootstrapNetIfAddr; - Impl(std::string ipPortPair, int rank, int nRanks, const mscclppBootstrapHandle handle); ~Impl(); mscclppResult_t initialize(); @@ -74,6 +71,7 @@ public: mscclppResult_t send(void* data, int size, int peer, int tag); mscclppResult_t recv(void* data, int size, int peer, int tag); mscclppResult_t barrier(); + mscclppResult_t close(); MscclppBootstrap::UniqueId uniqueId_; @@ -88,6 +86,8 @@ private: std::queue unexpectedConnections_; volatile uint32_t* abortFlag_; std::thread rootThread_; + char netIfName_[MAX_IF_NAME_SIZE + 1]; + union mscclppSocketAddress netIfAddr_; static mscclppResult_t netSend(mscclppSocket* sock, const void* data, int size); static mscclppResult_t netRecv(mscclppSocket* sock, void* data, int size); @@ -124,9 +124,9 @@ MscclppBootstrap::Impl::Impl(std::string ipPortPair, int rank, int nRanks, const throw std::runtime_error("getting random data failed"); } } - std::memcpy(&uniqueId_.addr, &bootstrapNetIfAddr, sizeof(union mscclppSocketAddress)); + std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(union mscclppSocketAddress)); if (rank_ == 0) { - rootThread_ = std::thread(&MscclppBootstrap::Impl::bootstrapRoot, this, &listenSock_, uniqueId_.magic, nRanks_); + rootThread_ = std::thread(&MscclppBootstrap::Impl::bootstrapRoot, this); } } @@ -270,13 +270,13 @@ mscclppResult_t MscclppBootstrap::Impl::netInit(std::string ipPortPair) WARN("Invalid MSCCLPP_COMM_ID, please use format: : or []: or :"); return mscclppInvalidArgument; } - if (mscclppFindInterfaceMatchSubnet(this->bootstrapNetIfName, &this->bootstrapNetIfAddr, &remoteAddr, - MAX_IF_NAME_SIZE, 1) <= 0) { + if (mscclppFindInterfaceMatchSubnet(netIfName_, &netIfAddr_, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) { WARN("NET/Socket : No usable listening interface found"); return mscclppSystemError; } } else { - int ret = mscclppFindInterfaces(this->bootstrapNetIfName, &this->bootstrapNetIfAddr, MAX_IF_NAME_SIZE, 1); + int ret = + mscclppFindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1); if (ret <= 0) { WARN("Bootstrap : no socket interface found"); return mscclppInternalError; @@ -284,8 +284,8 @@ mscclppResult_t MscclppBootstrap::Impl::netInit(std::string ipPortPair) } char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2]; - std::sprintf(line, " %s:", bootstrapNetIfName); - mscclppSocketToString(&bootstrapNetIfAddr, line + strlen(line)); + std::sprintf(line, " %s:", netIfName_); + mscclppSocketToString(&netIfAddr_, line + strlen(line)); INFO(MSCCLPP_INIT, "Bootstrap : Using%s", line); return mscclppSuccess; } @@ -305,13 +305,13 @@ mscclppResult_t MscclppBootstrap::Impl::initialize() uint64_t magic = this->uniqueId_.magic; // Create socket for other ranks to contact me MSCCLPPCHECK( - mscclppSocketInit(&this->listenSock_, &bootstrapNetIfAddr, magic, mscclppSocketTypeBootstrap, this->abortFlag_)); + mscclppSocketInit(&this->listenSock_, &netIfAddr_, magic, mscclppSocketTypeBootstrap, this->abortFlag_)); MSCCLPPCHECK(mscclppSocketListen(&this->listenSock_)); MSCCLPPCHECK(mscclppSocketGetAddr(&this->listenSock_, &info.extAddressListen)); // Create socket for root to contact me MSCCLPPCHECK( - mscclppSocketInit(&listenSockRoot, &bootstrapNetIfAddr, magic, mscclppSocketTypeBootstrap, this->abortFlag_)); + mscclppSocketInit(&listenSockRoot, &netIfAddr_, magic, mscclppSocketTypeBootstrap, this->abortFlag_)); MSCCLPPCHECK(mscclppSocketListen(&listenSockRoot)); MSCCLPPCHECK(mscclppSocketGetAddr(&listenSockRoot, &info.extAddressListenRoot)); @@ -353,7 +353,7 @@ mscclppResult_t MscclppBootstrap::Impl::initialize() // proxy is aborted through a message; don't set abortFlag MSCCLPPCHECK(mscclppCalloc(&proxySocket, 1)); - MSCCLPPCHECK(mscclppSocketInit(proxySocket, &bootstrapNetIfAddr, magic, mscclppSocketTypeProxy, this->abortFlag_)); + MSCCLPPCHECK(mscclppSocketInit(proxySocket, &netIfAddr_, magic, mscclppSocketTypeProxy, this->abortFlag_)); MSCCLPPCHECK(mscclppSocketListen(proxySocket)); MSCCLPPCHECK(mscclppSocketGetAddr(proxySocket, &this->peerProxyAddresses_[rank_])); MSCCLPPCHECK(allGather(this->peerProxyAddresses_.data(), sizeof(union mscclppSocketAddress))); @@ -419,6 +419,22 @@ mscclppResult_t MscclppBootstrap::Impl::send(void* data, int size, int peer, int MSCCLPPCHECK(netSend(&sock, data, size)); MSCCLPPCHECK(mscclppSocketClose(&sock)); + return mscclppSuccess; +} + +mscclppResult_t MscclppBootstrap::Impl::recv(void* data, int size, int peer, int tag) +{ + return mscclppSuccess; +} + +mscclppResult_t MscclppBootstrap::Impl::barrier() +{ + return mscclppSuccess; +} + +mscclppResult_t MscclppBootstrap::Impl::close() +{ + return mscclppSuccess; } MscclppBootstrap::MscclppBootstrap(std::string ipPortPair, int rank, int nRanks) @@ -475,3 +491,561 @@ void MscclppBootstrap::Barrier() throw std::runtime_error("MscclppBootstrap::Barrier failed"); } } + +void MscclppBootstrap::Close() +{ + mscclppResult_t res = pimpl_->close(); + if (res != mscclppSuccess) { + throw std::runtime_error("MscclppBootstrap::Close failed"); + } +} + + + +// ------------------- Old bootstrap functions ------------------- +struct bootstrapRootArgs +{ + struct mscclppSocket* listenSock; + uint64_t magic; +}; + +/* Init functions */ +static char bootstrapNetIfName[MAX_IF_NAME_SIZE + 1]; +static union mscclppSocketAddress bootstrapNetIfAddr; +static int bootstrapNetInitDone = 0; +pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER; + +mscclppResult_t bootstrapNetInit(const char* ip_port_pair) +{ + if (bootstrapNetInitDone == 0) { + pthread_mutex_lock(&bootstrapNetLock); + if (bootstrapNetInitDone == 0) { + const char* env; + if (ip_port_pair) { + env = ip_port_pair; + } else { + env = getenv("MSCCLPP_COMM_ID"); + } + if (env) { + union mscclppSocketAddress remoteAddr; + if (mscclppSocketGetAddrFromString(&remoteAddr, env) != mscclppSuccess) { + WARN("Invalid MSCCLPP_COMM_ID, please use format: : or []: or :"); + return mscclppInvalidArgument; + } + if (mscclppFindInterfaceMatchSubnet(bootstrapNetIfName, &bootstrapNetIfAddr, &remoteAddr, MAX_IF_NAME_SIZE, + 1) <= 0) { + WARN("NET/Socket : No usable listening interface found"); + return mscclppSystemError; + } + } else { + int nIfs = mscclppFindInterfaces(bootstrapNetIfName, &bootstrapNetIfAddr, MAX_IF_NAME_SIZE, 1); + if (nIfs <= 0) { + WARN("Bootstrap : no socket interface found"); + return mscclppInternalError; + } + } + char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2]; + sprintf(line, " %s:", bootstrapNetIfName); + mscclppSocketToString(&bootstrapNetIfAddr, line + strlen(line)); + INFO(MSCCLPP_INIT, "Bootstrap : Using%s", line); + bootstrapNetInitDone = 1; + } + pthread_mutex_unlock(&bootstrapNetLock); + } + return mscclppSuccess; +} + +// Additional sync functions +static mscclppResult_t bootstrapNetSend(struct mscclppSocket* sock, void* data, int size) +{ + MSCCLPPCHECK(mscclppSocketSend(sock, &size, sizeof(int))); + MSCCLPPCHECK(mscclppSocketSend(sock, data, size)); + return mscclppSuccess; +} +static mscclppResult_t bootstrapNetRecv(struct mscclppSocket* sock, void* data, int size) +{ + int recvSize; + MSCCLPPCHECK(mscclppSocketRecv(sock, &recvSize, sizeof(int))); + if (recvSize > size) { + WARN("Message truncated : received %d bytes instead of %d", recvSize, size); + return mscclppInternalError; + } + MSCCLPPCHECK(mscclppSocketRecv(sock, data, std::min(recvSize, size))); + return mscclppSuccess; +} + +// struct extInfo +// { +// int rank; +// int nranks; +// union mscclppSocketAddress extAddressListenRoot; +// union mscclppSocketAddress extAddressListen; +// }; + +#include + +// static mscclppResult_t setFilesLimit() +// { +// struct rlimit filesLimit; +// SYSCHECK(getrlimit(RLIMIT_NOFILE, &filesLimit), "getrlimit"); +// filesLimit.rlim_cur = filesLimit.rlim_max; +// SYSCHECK(setrlimit(RLIMIT_NOFILE, &filesLimit), "setrlimit"); +// return mscclppSuccess; +// } + +static void* bootstrapRoot(void* rargs) +{ + struct bootstrapRootArgs* args = (struct bootstrapRootArgs*)rargs; + struct mscclppSocket* listenSock = args->listenSock; + uint64_t magic = args->magic; + mscclppResult_t res = mscclppSuccess; + int nranks = 0, c = 0; + struct extInfo info; + union mscclppSocketAddress* rankAddresses = NULL; + union mscclppSocketAddress* rankAddressesRoot = NULL; // for initial rank <-> root information exchange + union mscclppSocketAddress* zero = NULL; + MSCCLPPCHECKGOTO(mscclppCalloc(&zero, 1), res, out); + setFilesLimit(); + + TRACE(MSCCLPP_INIT, "BEGIN"); + /* Receive addresses from all ranks */ + do { + struct mscclppSocket sock; + MSCCLPPCHECKGOTO(mscclppSocketInit(&sock), res, out); + MSCCLPPCHECKGOTO(mscclppSocketAccept(&sock, listenSock), res, out); + MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, &info, sizeof(info)), res, out); + MSCCLPPCHECKGOTO(mscclppSocketClose(&sock), res, out); + + if (c == 0) { + nranks = info.nRanks; + MSCCLPPCHECKGOTO(mscclppCalloc(&rankAddresses, nranks), res, out); + MSCCLPPCHECKGOTO(mscclppCalloc(&rankAddressesRoot, nranks), res, out); + } + + if (nranks != info.nRanks) { + WARN("Bootstrap Root : mismatch in rank count from procs %d : %d", nranks, info.nRanks); + goto out; + } + + if (memcmp(zero, &rankAddressesRoot[info.rank], sizeof(union mscclppSocketAddress)) != 0) { + WARN("Bootstrap Root : rank %d of %d ranks has already checked in", info.rank, nranks); + goto out; + } + + // Save the connection handle for that rank + memcpy(rankAddressesRoot + info.rank, &info.extAddressListenRoot, sizeof(union mscclppSocketAddress)); + memcpy(rankAddresses + info.rank, &info.extAddressListen, sizeof(union mscclppSocketAddress)); + + ++c; + TRACE(MSCCLPP_INIT, "Received connect from rank %d total %d/%d", info.rank, c, nranks); + } while (c < nranks); + TRACE(MSCCLPP_INIT, "COLLECTED ALL %d HANDLES", nranks); + + // Send the connect handle for the next rank in the AllGather ring + for (int r = 0; r < nranks; ++r) { + int next = (r + 1) % nranks; + struct mscclppSocket sock; + MSCCLPPCHECKGOTO(mscclppSocketInit(&sock, rankAddressesRoot + r, magic, mscclppSocketTypeBootstrap), res, out); + MSCCLPPCHECKGOTO(mscclppSocketConnect(&sock), res, out); + MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, rankAddresses + next, sizeof(union mscclppSocketAddress)), res, out); + MSCCLPPCHECKGOTO(mscclppSocketClose(&sock), res, out); + } + TRACE(MSCCLPP_INIT, "SENT OUT ALL %d HANDLES", nranks); + +out: + if (listenSock != NULL) { + mscclppSocketClose(listenSock); + free(listenSock); + } + if (rankAddresses) + free(rankAddresses); + if (rankAddressesRoot) + free(rankAddressesRoot); + if (zero) + free(zero); + free(rargs); + + TRACE(MSCCLPP_INIT, "DONE"); + return NULL; +} + +mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle) +{ + struct mscclppSocket* listenSock; + struct bootstrapRootArgs* args; + pthread_t thread; + + MSCCLPPCHECK(mscclppCalloc(&listenSock, 1)); + MSCCLPPCHECK(mscclppSocketInit(listenSock, &handle->addr, handle->magic, mscclppSocketTypeBootstrap, NULL, 0)); + MSCCLPPCHECK(mscclppSocketListen(listenSock)); + MSCCLPPCHECK(mscclppSocketGetAddr(listenSock, &handle->addr)); + + MSCCLPPCHECK(mscclppCalloc(&args, 1)); + args->listenSock = listenSock; + args->magic = handle->magic; + NEQCHECK(pthread_create(&thread, NULL, bootstrapRoot, (void*)args), 0); + mscclppSetThreadName(thread, "MSCCLPP BootstrapR"); + NEQCHECK(pthread_detach(thread), 0); // will not be pthread_join()'d + return mscclppSuccess; +} + +// #include +// #include + +mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot, const char* ip_port_pair) +{ + memset(handle, 0, sizeof(mscclppBootstrapHandle)); + const char* env = NULL; + + if (ip_port_pair) { + env = ip_port_pair; + } else { + env = getenv("MSCCLPP_COMM_ID"); + } + if (env) { + handle->magic = 0xdeadbeef; + + INFO(MSCCLPP_ENV, "MSCCLPP_COMM_ID set by environment to %s", env); + if (mscclppSocketGetAddrFromString(&handle->addr, env) != mscclppSuccess) { + WARN("Invalid MSCCLPP_COMM_ID, please use format: : or []: or :"); + return mscclppInvalidArgument; + } + if (isRoot) + MSCCLPPCHECK(bootstrapCreateRoot(handle)); + } else { + MSCCLPPCHECK(getRandomData(&handle->magic, sizeof(handle->magic))); + memcpy(&handle->addr, &bootstrapNetIfAddr, sizeof(union mscclppSocketAddress)); + MSCCLPPCHECK(bootstrapCreateRoot(handle)); + } + // printf("addr = %s port = %d\n", inet_ntoa(handle->addr.sin.sin_addr), (int)ntohs(handle->addr.sin.sin_port)); + // printf("addr = %s\n", inet_ntoa((*(struct sockaddr_in*)&handle->addr.sa).sin_addr)); + + return mscclppSuccess; +} + +struct unexConn +{ + int peer; + int tag; + struct mscclppSocket sock; + struct unexConn* next; +}; + +struct bootstrapState +{ + struct mscclppSocket listenSock; + struct mscclppSocket ringRecvSocket; + struct mscclppSocket ringSendSocket; + union mscclppSocketAddress* peerCommAddresses; + union mscclppSocketAddress* peerProxyAddresses; + struct unexConn* unexpectedConnections; + int cudaDev; + int rank; + int nranks; + uint64_t magic; + volatile uint32_t* abortFlag; +}; + +mscclppResult_t bootstrapInit(struct mscclppBootstrapHandle* handle, struct mscclppComm* comm) +{ + int rank = comm->rank; + int nranks = comm->nRanks; + struct bootstrapState* state; + struct mscclppSocket* proxySocket; + mscclppSocketAddress nextAddr; + struct mscclppSocket sock, listenSockRoot; + struct extInfo info; + + MSCCLPPCHECK(mscclppCalloc(&state, 1)); + state->rank = rank; + state->nranks = nranks; + state->abortFlag = comm->abortFlag; + comm->bootstrap = state; + comm->magic = state->magic = handle->magic; + + TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank, nranks); + + info.rank = rank; + info.nRanks = nranks; + + // Create socket for other ranks to contact me + MSCCLPPCHECK(mscclppSocketInit(&state->listenSock, &bootstrapNetIfAddr, comm->magic, mscclppSocketTypeBootstrap, + comm->abortFlag)); + MSCCLPPCHECK(mscclppSocketListen(&state->listenSock)); + MSCCLPPCHECK(mscclppSocketGetAddr(&state->listenSock, &info.extAddressListen)); + + // Create socket for root to contact me + MSCCLPPCHECK( + mscclppSocketInit(&listenSockRoot, &bootstrapNetIfAddr, comm->magic, mscclppSocketTypeBootstrap, comm->abortFlag)); + MSCCLPPCHECK(mscclppSocketListen(&listenSockRoot)); + MSCCLPPCHECK(mscclppSocketGetAddr(&listenSockRoot, &info.extAddressListenRoot)); + + // stagger connection times to avoid an overload of the root + if (nranks > 128) { + long msec = rank; + struct timespec tv; + tv.tv_sec = msec / 1000; + tv.tv_nsec = 1000000 * (msec % 1000); + TRACE(MSCCLPP_INIT, "rank %d delaying connection to root by %ld msec", rank, msec); + (void)nanosleep(&tv, NULL); + } + + // send info on my listening socket to root + MSCCLPPCHECK(mscclppSocketInit(&sock, &handle->addr, comm->magic, mscclppSocketTypeBootstrap, comm->abortFlag)); + MSCCLPPCHECK(mscclppSocketConnect(&sock)); + MSCCLPPCHECK(bootstrapNetSend(&sock, &info, sizeof(info))); + MSCCLPPCHECK(mscclppSocketClose(&sock)); + + // get info on my "next" rank in the bootstrap ring from root + MSCCLPPCHECK(mscclppSocketInit(&sock)); + MSCCLPPCHECK(mscclppSocketAccept(&sock, &listenSockRoot)); + MSCCLPPCHECK(bootstrapNetRecv(&sock, &nextAddr, sizeof(union mscclppSocketAddress))); + MSCCLPPCHECK(mscclppSocketClose(&sock)); + MSCCLPPCHECK(mscclppSocketClose(&listenSockRoot)); + + MSCCLPPCHECK( + mscclppSocketInit(&state->ringSendSocket, &nextAddr, comm->magic, mscclppSocketTypeBootstrap, comm->abortFlag)); + MSCCLPPCHECK(mscclppSocketConnect(&state->ringSendSocket)); + // Accept the connect request from the previous rank in the AllGather ring + MSCCLPPCHECK(mscclppSocketInit(&state->ringRecvSocket)); + MSCCLPPCHECK(mscclppSocketAccept(&state->ringRecvSocket, &state->listenSock)); + + // AllGather all listen handlers + MSCCLPPCHECK(mscclppCalloc(&state->peerCommAddresses, nranks)); + MSCCLPPCHECK(mscclppSocketGetAddr(&state->listenSock, state->peerCommAddresses + rank)); + MSCCLPPCHECK(bootstrapAllGather(state, state->peerCommAddresses, sizeof(union mscclppSocketAddress))); + + // Create the service proxy + MSCCLPPCHECK(mscclppCalloc(&state->peerProxyAddresses, nranks)); + + // proxy is aborted through a message; don't set abortFlag + MSCCLPPCHECK(mscclppCalloc(&proxySocket, 1)); + MSCCLPPCHECK( + mscclppSocketInit(proxySocket, &bootstrapNetIfAddr, comm->magic, mscclppSocketTypeProxy, comm->abortFlag)); + MSCCLPPCHECK(mscclppSocketListen(proxySocket)); + MSCCLPPCHECK(mscclppSocketGetAddr(proxySocket, state->peerProxyAddresses + rank)); + MSCCLPPCHECK(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(union mscclppSocketAddress))); + // MSCCLPPCHECK(mscclppProxyInit(comm, proxySocket, state->peerProxyAddresses)); + + TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank, nranks); + + return mscclppSuccess; +} + +mscclppResult_t bootstrapAllGather(void* commState, void* allData, int size) +{ + struct bootstrapState* state = (struct bootstrapState*)commState; + char* data = (char*)allData; + int rank = state->rank; + int nranks = state->nranks; + + TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d", rank, nranks, size); + + /* Simple ring based AllGather + * At each step i receive data from (rank-i-1) from left + * and send previous step's data from (rank-i) to right + */ + for (int i = 0; i < nranks - 1; i++) { + size_t rslice = (rank - i - 1 + nranks) % nranks; + size_t sslice = (rank - i + nranks) % nranks; + + // Send slice to the right + MSCCLPPCHECK(bootstrapNetSend(&state->ringSendSocket, data + sslice * size, size)); + // Recv slice from the left + MSCCLPPCHECK(bootstrapNetRecv(&state->ringRecvSocket, data + rslice * size, size)); + } + + TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size); + return mscclppSuccess; +} + +mscclppResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size) +{ + mscclppResult_t ret = mscclppSuccess; + struct bootstrapState* state = (struct bootstrapState*)commState; + struct mscclppSocket sock; + + MSCCLPPCHECKGOTO(mscclppSocketInit(&sock, state->peerCommAddresses + peer, state->magic, mscclppSocketTypeBootstrap, + state->abortFlag), + ret, fail); + MSCCLPPCHECKGOTO(mscclppSocketConnect(&sock), ret, fail); + MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, &state->rank, sizeof(int)), ret, fail); + MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, &tag, sizeof(int)), ret, fail); + MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, data, size), ret, fail); + +exit: + MSCCLPPCHECK(mscclppSocketClose(&sock)); + return ret; +fail: + goto exit; +} + +mscclppResult_t bootstrapBarrier(void* commState, int* ranks, int rank, int nranks, int tag) +{ + if (nranks == 1) + return mscclppSuccess; + TRACE(MSCCLPP_INIT, "rank %d nranks %d tag %x - ENTER", rank, nranks, tag); + + /* Simple intra process barrier + * + * Based on the dissemination algorithm by Debra Hensgen, Raphael Finkel, and Udi Manbet, + * "Two Algorithms for Barrier Synchronization," International Journal of Parallel Programming, 17(1):1-17, 1988" + */ + int data[1]; + for (int mask = 1; mask < nranks; mask <<= 1) { + int src = (rank - mask + nranks) % nranks; + int dst = (rank + mask) % nranks; + MSCCLPPCHECK(bootstrapSend(commState, ranks[dst], tag, data, sizeof(data))); + MSCCLPPCHECK(bootstrapRecv(commState, ranks[src], tag, data, sizeof(data))); + } + + TRACE(MSCCLPP_INIT, "rank %d nranks %d tag %x - DONE", rank, nranks, tag); + return mscclppSuccess; +} + +mscclppResult_t bootstrapIntraNodeAllGather(void* commState, int* ranks, int rank, int nranks, void* allData, int size) +{ + if (nranks == 1) + return mscclppSuccess; + char* data = (char*)allData; + TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - ENTER", rank, nranks, size); + + for (int i = 1; i < nranks; i++) { + int src = (rank - i + nranks) % nranks; + int dst = (rank + i) % nranks; + MSCCLPPCHECK(bootstrapSend(commState, ranks[dst], /*tag=*/i, data + rank * size, size)); + MSCCLPPCHECK(bootstrapRecv(commState, ranks[src], /*tag=*/i, data + src * size, size)); + } + + TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size); + return mscclppSuccess; +} + +mscclppResult_t unexpectedEnqueue(struct bootstrapState* state, int peer, int tag, struct mscclppSocket* sock) +{ + // New unex + struct unexConn* unex; + MSCCLPPCHECK(mscclppCalloc(&unex, 1)); + unex->peer = peer; + unex->tag = tag; + memcpy(&unex->sock, sock, sizeof(struct mscclppSocket)); + + // Enqueue + struct unexConn* list = state->unexpectedConnections; + if (list == NULL) { + state->unexpectedConnections = unex; + return mscclppSuccess; + } + while (list->next) + list = list->next; + list->next = unex; + return mscclppSuccess; +} + +mscclppResult_t unexpectedDequeue(struct bootstrapState* state, int peer, int tag, struct mscclppSocket* sock, + int* found) +{ + struct unexConn* elem = state->unexpectedConnections; + struct unexConn* prev = NULL; + *found = 0; + while (elem) { + if (elem->peer == peer && elem->tag == tag) { + if (prev == NULL) { + state->unexpectedConnections = elem->next; + } else { + prev->next = elem->next; + } + memcpy(sock, &elem->sock, sizeof(struct mscclppSocket)); + free(elem); + *found = 1; + return mscclppSuccess; + } + prev = elem; + elem = elem->next; + } + return mscclppSuccess; +} + +static void unexpectedFree(struct bootstrapState* state) +{ + struct unexConn* elem = state->unexpectedConnections; + struct unexConn* prev = NULL; + + while (elem) { + prev = elem; + elem = elem->next; + free(prev); + } + return; +} + +// We can't know who we'll receive from, so we need to receive everything at once +mscclppResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size) +{ + mscclppResult_t ret = mscclppSuccess; + struct bootstrapState* state = (struct bootstrapState*)commState; + struct mscclppSocket sock; + int newPeer, newTag; + + // Search unexpected connections first + int found; + MSCCLPPCHECK(unexpectedDequeue(state, peer, tag, &sock, &found)); + if (found) { + MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, ((char*)data), size), ret, fail); + goto exit; + } + + // Then look for new connections + while (1) { + MSCCLPPCHECKGOTO(mscclppSocketInit(&sock), ret, fail); + MSCCLPPCHECKGOTO(mscclppSocketAccept(&sock, &state->listenSock), ret, fail); + MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, &newPeer, sizeof(int)), ret, fail); + MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, &newTag, sizeof(int)), ret, fail); + if (newPeer == peer && newTag == tag) { + MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, ((char*)data), size), ret, fail); + goto exit; + } + // Unexpected connection. Save for later. + MSCCLPPCHECKGOTO(unexpectedEnqueue(state, newPeer, newTag, &sock), ret, fail); + } +exit: + MSCCLPPCHECK(mscclppSocketClose(&sock)); + return ret; +fail: + goto exit; +} + +mscclppResult_t bootstrapClose(void* commState) +{ + struct bootstrapState* state = (struct bootstrapState*)commState; + if (state->unexpectedConnections != NULL) { + unexpectedFree(state); + if (*state->abortFlag == 0) { + WARN("Unexpected connections are not empty"); + return mscclppInternalError; + } + } + + MSCCLPPCHECK(mscclppSocketClose(&state->listenSock)); + MSCCLPPCHECK(mscclppSocketClose(&state->ringSendSocket)); + MSCCLPPCHECK(mscclppSocketClose(&state->ringRecvSocket)); + + free(state->peerCommAddresses); + free(state); + + return mscclppSuccess; +} + +mscclppResult_t bootstrapAbort(void* commState) +{ + struct bootstrapState* state = (struct bootstrapState*)commState; + if (commState == NULL) + return mscclppSuccess; + MSCCLPPCHECK(mscclppSocketClose(&state->listenSock)); + MSCCLPPCHECK(mscclppSocketClose(&state->ringSendSocket)); + MSCCLPPCHECK(mscclppSocketClose(&state->ringRecvSocket)); + free(state->peerCommAddresses); + free(state->peerProxyAddresses); + free(state); + return mscclppSuccess; +} \ No newline at end of file diff --git a/src/include/bootstrap.h b/src/include/bootstrap.h index afe6eca3..8eb7b939 100644 --- a/src/include/bootstrap.h +++ b/src/include/bootstrap.h @@ -18,11 +18,13 @@ class MscclppBootstrap : Bootstrap { public: MscclppBootstrap(std::string ipPortPair, int rank, int nRanks); MscclppBootstrap(mscclppBootstrapHandle handle, int rank, int nRanks); + ~MscclppBootstrap() = default; + void Initialize(); - void Send(void* data, int size, int peer, int tag); - void Recv(void* data, int size, int peer, int tag); - void AllGather(void* allData, int size); - void Barrier(); + void Send(void* data, int size, int peer, int tag) override; + void Recv(void* data, int size, int peer, int tag) override; + void AllGather(void* allData, int size) override; + void Barrier() override; void Close(); struct UniqueId; UniqueId GetUniqueId(); @@ -32,6 +34,7 @@ private: std::unique_ptr pimpl_; }; +mscclppResult_t bootstrapNetInit(const char* ip_port_pair = NULL); mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle); mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot = true, const char* ip_port_pair = NULL); diff --git a/src/include/comm.h b/src/include/comm.h index b76c8c4f..8275e0cb 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -44,7 +44,7 @@ struct mscclppComm struct mscclppDevConn devConns[MAXCONNECTIONS]; int nConns; - MscclppBootstrap bootstrap; + void* bootstrap; // Magic number for all network communication. Not a security key -- only goal is to detect mismatches. uint64_t magic; diff --git a/src/init.cc b/src/init.cc index 7c3b76b9..1850f68c 100644 --- a/src/init.cc +++ b/src/init.cc @@ -926,3 +926,6 @@ MSCCLPP_API mscclppResult_t mscclppSetBootstrapConnTimeout(int timeout) config->setBootstrapConnectionTimeoutConfig(timeout); return mscclppSuccess; } + +Bootstrap::~Bootstrap() { +}