diff --git a/src/bootstrap/bootstrap.cc b/src/bootstrap/bootstrap.cc index e54b38e8..b12afb61 100644 --- a/src/bootstrap/bootstrap.cc +++ b/src/bootstrap/bootstrap.cc @@ -59,20 +59,24 @@ struct extInfo class MscclppBootstrap::Impl { public: - Impl(std::string ipPortPair, int rank, int nRanks, const UniqueId uniqueId); + Impl(int rank, int nRanks); ~Impl(); - mscclppResult_t initialize(); + void initialize(const UniqueId uniqueId); + void initialize(std::string ipPortPair); + mscclppResult_t establishConnections(); + UniqueId getUniqueId(); mscclppResult_t allGather(void* allData, int size); mscclppResult_t send(void* data, int size, int peer, int tag); mscclppResult_t recv(void* data, int size, int peer, int tag); mscclppResult_t barrier(); mscclppResult_t close(); - static UniqueId uniqueId_; + UniqueId uniqueId_; private: int rank_; int nRanks_; + bool netInitialized; mscclppSocket listenSock_; mscclppSocket ringRecvSocket_; mscclppSocket ringSendSocket_; @@ -95,37 +99,67 @@ private: mscclppResult_t netInit(std::string ipPortPair); }; -UniqueId MscclppBootstrap::Impl::uniqueId_; +// UniqueId MscclppBootstrap::Impl::uniqueId_; -MscclppBootstrap::Impl::Impl(std::string ipPortPair, int rank, int nRanks, const UniqueId uniqueId) - : rank_(rank), nRanks_(nRanks), peerCommAddresses_(nRanks, mscclppSocketAddress()), +MscclppBootstrap::Impl::Impl(int rank, int nRanks) + : rank_(rank), nRanks_(nRanks), netInitialized(false), peerCommAddresses_(nRanks, mscclppSocketAddress()), peerProxyAddresses_(nRanks, mscclppSocketAddress()), abortFlag_(nullptr) +{ +} + +UniqueId MscclppBootstrap::Impl::getUniqueId() +{ + UniqueId uniqueId; + auto ret = netInit(""); + if (ret != mscclppSuccess) { + throw std::runtime_error("Failed to initialize network"); + } + ret = getRandomData(&uniqueId.magic, sizeof(uniqueId.magic)); + if (ret != mscclppSuccess) { + throw std::runtime_error("getting random data failed"); + } + std::memcpy(&uniqueId.addr, &netIfAddr_, sizeof(union mscclppSocketAddress)); + + return uniqueId; +} + +void MscclppBootstrap::Impl::initialize(const UniqueId uniqueId) +{ + int ret = netInit(""); + if (ret != mscclppSuccess) { + throw std::runtime_error("Failed to initialize network"); + } + + uniqueId_.magic = uniqueId.magic; + uniqueId_.addr = uniqueId.addr; + + if (rank_ == 0) { + rootThread_ = std::thread(&MscclppBootstrap::Impl::bootstrapRoot, this); + } + + ret = establishConnections(); + if (ret != mscclppSuccess) { + throw std::runtime_error("Failed to establish connections"); + } +} + +void MscclppBootstrap::Impl::initialize(std::string ipPortPair) { int ret = netInit(ipPortPair); if (ret != mscclppSuccess) { throw std::runtime_error("Failed to initialize network"); } - UniqueId zeroId; - std::memset(&zeroId, 0, sizeof(UniqueId)); - if (std::memcmp(&uniqueId, &zeroId, sizeof(UniqueId)) != 0) { - uniqueId_.magic = uniqueId.magic; - uniqueId_.addr = uniqueId.addr; - return; - } - - if (!ipPortPair.empty()) { - uniqueId_.magic = 0xdeadbeef; - } else { - mscclppResult_t ret = getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic)); - if (ret != mscclppSuccess) { - throw std::runtime_error("getting random data failed"); - } - } + uniqueId_.magic = 0xdeadbeef; std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(union mscclppSocketAddress)); if (rank_ == 0) { rootThread_ = std::thread(&MscclppBootstrap::Impl::bootstrapRoot, this); } + + ret = establishConnections(); + if (ret != mscclppSuccess) { + throw std::runtime_error("Failed to establish connections"); + } } MscclppBootstrap::Impl::~Impl() @@ -145,33 +179,39 @@ mscclppResult_t MscclppBootstrap::Impl::getRemoteAddresses(mscclppSocket* listen mscclppResult_t res = mscclppSuccess; mscclppSocketAddress zero; + printf("hh 0\n"); std::memset(&zero, 0, sizeof(mscclppSocketAddress)); res = mscclppSocketInit(&sock); if (res != mscclppSuccess) { WARN("Bootstrap Root : mscclppSocketInit failed"); return res; } + printf("hh 1\n"); res = mscclppSocketAccept(&sock, listenSock); if (res != mscclppSuccess) { WARN("Bootstrap Root : mscclppSocketAccept failed"); return res; } + printf("hh 2\n"); res = netRecv(&sock, &info, sizeof(info)); if (res != mscclppSuccess) { WARN("Bootstrap Root : netRecv failed"); return res; } + printf("hh 3\n"); res = mscclppSocketClose(&sock); if (res != mscclppSuccess) { WARN("Bootstrap Root : mscclppSocketClose failed"); return res; } + printf("hh 4\n"); if (this->nRanks_ != info.nRanks) { WARN("Bootstrap Root : mismatch in rank count from procs %d : %d", this->nRanks_, info.nRanks); return res; } + printf("hh 5\n"); if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(mscclppSocketAddress)) != 0) { WARN("Bootstrap Root : rank %d of %d ranks has already checked in", info.rank, this->nRanks_); return res; @@ -216,6 +256,7 @@ mscclppResult_t MscclppBootstrap::Impl::sendHandleToPeer(int peer, mscclppResult_t MscclppBootstrap::Impl::bootstrapRoot() { + printf("I am here0 magic %x\n", uniqueId_.magic); mscclppResult_t res = mscclppSuccess; int numCollected = 0; std::vector rankAddresses(this->nRanks_, mscclppSocketAddress()); @@ -226,16 +267,20 @@ mscclppResult_t MscclppBootstrap::Impl::bootstrapRoot() std::memset(rankAddressesRoot.data(), 0, sizeof(mscclppSocketAddress) * this->nRanks_); setFilesLimit(); + printf("I am here1 %x\n", uniqueId_.magic); mscclppSocket listenSock; MSCCLPPCHECK( mscclppSocketInit(&listenSock, &uniqueId_.addr, uniqueId_.magic, mscclppSocketTypeBootstrap, nullptr, 0)); MSCCLPPCHECK(mscclppSocketListen(&listenSock)); + printf("I am here2\n"); TRACE(MSCCLPP_INIT, "BEGIN"); + printf("I am here3\n"); /* Receive addresses from all ranks */ do { int rank; res = getRemoteAddresses(&listenSock, rankAddresses, rankAddressesRoot, rank); + printf("I am here4\n"); if (res != mscclppSuccess) { WARN("Bootstrap Root : getRemoteAddresses failed"); break; @@ -262,6 +307,8 @@ mscclppResult_t MscclppBootstrap::Impl::bootstrapRoot() mscclppResult_t MscclppBootstrap::Impl::netInit(std::string ipPortPair) { + if (netInitialized) + return mscclppSuccess; if (!ipPortPair.empty()) { union mscclppSocketAddress remoteAddr; if (mscclppSocketGetAddrFromString(&remoteAddr, ipPortPair.c_str()) != mscclppSuccess) { @@ -284,17 +331,18 @@ mscclppResult_t MscclppBootstrap::Impl::netInit(std::string ipPortPair) std::sprintf(line, " %s:", netIfName_); mscclppSocketToString(&netIfAddr_, line + strlen(line)); INFO(MSCCLPP_INIT, "Bootstrap : Using%s", line); + netInitialized = true; return mscclppSuccess; } -mscclppResult_t MscclppBootstrap::Impl::initialize() +mscclppResult_t MscclppBootstrap::Impl::establishConnections() { mscclppSocket* proxySocket; mscclppSocketAddress nextAddr; mscclppSocket sock, listenSockRoot; extInfo info; - TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank, nranks); + TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank_, nRanks_); info.rank = this->rank_; info.nRanks = this->nRanks_; @@ -322,11 +370,21 @@ mscclppResult_t MscclppBootstrap::Impl::initialize() randomSleep(this->rank_); } + + char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2]; + std::sprintf(line, " %s:", netIfName_); + mscclppSocketToString(&this->uniqueId_.addr, line + strlen(line)); + + printf("tt 1 %s\n", line); // send info on my listening socket to root MSCCLPPCHECK(mscclppSocketInit(&sock, &this->uniqueId_.addr, magic, mscclppSocketTypeBootstrap, this->abortFlag_)); + printf("tt 2\n"); MSCCLPPCHECK(mscclppSocketConnect(&sock)); + printf("tt 3\n"); MSCCLPPCHECK(netSend(&sock, &info, sizeof(info))); + printf("tt 4\n"); MSCCLPPCHECK(mscclppSocketClose(&sock)); + printf("tt 5\n"); // get info on my "next" rank in the bootstrap ring from root MSCCLPPCHECK(mscclppSocketInit(&sock)); @@ -353,7 +411,7 @@ mscclppResult_t MscclppBootstrap::Impl::initialize() MSCCLPPCHECK(mscclppSocketGetAddr(proxySocket, &this->peerProxyAddresses_[rank_])); MSCCLPPCHECK(allGather(this->peerProxyAddresses_.data(), sizeof(union mscclppSocketAddress))); - TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank, nranks); + TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank_, nRanks_); return mscclppSuccess; } @@ -380,7 +438,7 @@ mscclppResult_t MscclppBootstrap::Impl::allGather(void* allData, int size) MSCCLPPCHECK(netRecv(&this->ringRecvSocket_, data + rSlice * size, size)); } - TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size); + TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nRanks, size); return mscclppSuccess; } @@ -432,23 +490,15 @@ mscclppResult_t MscclppBootstrap::Impl::close() return mscclppSuccess; } -MscclppBootstrap::MscclppBootstrap(std::string ipPortPair, int rank, int nRanks) +MscclppBootstrap::MscclppBootstrap(int rank, int nRanks) { - UniqueId uniqueId; - std::memset(&uniqueId, 0, sizeof(uniqueId)); // pimpl_ = std::make_unique(ipPortPair, rank, nRanks, uniqueId); - pimpl_ = new Impl(ipPortPair, rank, nRanks, uniqueId); -} - -MscclppBootstrap::MscclppBootstrap(UniqueId uniqueId, int rank, int nRanks) -{ - pimpl_ = new Impl("", rank, nRanks, uniqueId); - // pimpl_ = std::make_unique("", rank, nRanks, uniqueId); + pimpl_ = new Impl(rank, nRanks); } UniqueId MscclppBootstrap::GetUniqueId() { - return Impl::uniqueId_; + return pimpl_->getUniqueId(); } void MscclppBootstrap::Send(void* data, int size, int peer, int tag) @@ -475,12 +525,14 @@ void MscclppBootstrap::AllGather(void* allData, int size) } } -void MscclppBootstrap::Initialize() +void MscclppBootstrap::Initialize(const UniqueId uniqueId) { - mscclppResult_t res = pimpl_->initialize(); - if (res != mscclppSuccess) { - throw std::runtime_error("MscclppBootstrap::Initialize failed"); - } + pimpl_->initialize(uniqueId); +} + +void MscclppBootstrap::Initialize(std::string ipPortPair) +{ + pimpl_->initialize(ipPortPair); } void MscclppBootstrap::Barrier() diff --git a/src/include/bootstrap.h b/src/include/bootstrap.h index 175981e4..a76c99b7 100644 --- a/src/include/bootstrap.h +++ b/src/include/bootstrap.h @@ -17,13 +17,13 @@ static_assert(sizeof(UniqueId) <= sizeof(mscclppUniqueId), class __attribute__((visibility("default"))) MscclppBootstrap : public Bootstrap { public: - MscclppBootstrap(std::string ipPortPair, int rank, int nRanks); - MscclppBootstrap(UniqueId uniqueId, int rank, int nRanks); + MscclppBootstrap(int rank, int nRanks); ~MscclppBootstrap() override = default; - static UniqueId GetUniqueId(); + UniqueId GetUniqueId(); - void Initialize(); + void Initialize(const UniqueId uniqueId); + void Initialize(std::string ipPortPair); void Send(void* data, int size, int peer, int tag) override; void Recv(void* data, int size, int peer, int tag) override; void AllGather(void* allData, int size) override; diff --git a/tests/bootstrap_test_cpp.cc b/tests/bootstrap_test_cpp.cc index 34137577..8e3b1e87 100644 --- a/tests/bootstrap_test_cpp.cc +++ b/tests/bootstrap_test_cpp.cc @@ -11,7 +11,13 @@ int main() MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &worldSize); - std::shared_ptr bootstrap(new MscclppBootstrap("", rank, worldSize)); + std::shared_ptr bootstrap(new MscclppBootstrap(rank, worldSize)); + // bootstrap->Initialize("costsim-dev-00000A:50000"); + UniqueId id; + if (rank == 0) + id = bootstrap->GetUniqueId(); + MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); + bootstrap->Initialize(id); // need to call initialization first MPI_Finalize();