diff --git a/src/communicator.cc b/src/communicator.cc index c24b0c5e..7e1348e8 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -81,10 +81,10 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank MSCCLPP_API_CPP void Communicator::connectionSetup() { for (auto& conn : pimpl->connections) { - conn->startSetup(*this); + conn->startSetup(pimpl->bootstrap_); } for (auto& conn : pimpl->connections) { - conn->endSetup(*this); + conn->endSetup(pimpl->bootstrap_); } } diff --git a/src/connection.cc b/src/connection.cc index 1e21694c..fc653c2a 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -54,7 +54,7 @@ void CudaIpcConnection::flush() { // IBConnection -IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank(remoteRank), tag(tag), transport_(transport), remoteTransport_(TransportNone) { +IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank_(remoteRank), tag_(tag), transport_(transport), remoteTransport_(TransportNone) { qp = commImpl.getIbContext(transport)->createQp(); } @@ -114,15 +114,15 @@ void IBConnection::flush() { // npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT); } -void IBConnection::startSetup(Communicator& comm) { +void IBConnection::startSetup(std::shared_ptr bootstrap) { // TODO(chhwang): temporarily disabled to compile - // comm.bootstrap().send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank, tag); + bootstrap->send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank_, tag_); } -void IBConnection::endSetup(Communicator& comm) { +void IBConnection::endSetup(std::shared_ptr bootstrap) { IbQpInfo qpInfo; // TODO(chhwang): temporarily disabled to compile - // comm.bootstrap().recv(&qpInfo, sizeof(qpInfo), remoteRank, tag); + bootstrap->recv(&qpInfo, sizeof(qpInfo), remoteRank_, tag_); qp->rtr(qpInfo); qp->rts(); } diff --git a/src/ib.cc b/src/ib.cc index 4dc0285b..fe3334a3 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -263,7 +263,7 @@ int IbQp::pollCq() return ibv_poll_cq(reinterpret_cast(this->cq), MSCCLPP_IB_CQ_POLL_NUM, reinterpret_cast(this->wcs)); } -const IbQpInfo& IbQp::getInfo() const +IbQpInfo& IbQp::getInfo() { return this->info; } diff --git a/src/include/connection.hpp b/src/include/connection.hpp index dcf21362..132726f7 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -12,8 +12,8 @@ namespace mscclpp { class ConnectionBase : public Connection { public: - virtual void startSetup(Communicator&) {}; - virtual void endSetup(Communicator&) {}; + virtual void startSetup(std::shared_ptr bootstrap) {}; + virtual void endSetup(std::shared_ptr bootstrap) {}; }; class CudaIpcConnection : public ConnectionBase { @@ -34,8 +34,8 @@ public: }; class IBConnection : public ConnectionBase { - int remoteRank; - int tag; + int remoteRank_; + int tag_; TransportFlags transport_; TransportFlags remoteTransport_; IbQp* qp; @@ -53,9 +53,9 @@ public: void flush() override; - void startSetup(Communicator& comm) override; + void startSetup(std::shared_ptr bootstrap) override; - void endSetup(Communicator& comm) override; + void endSetup(std::shared_ptr bootstrap) override; }; } // namespace mscclpp diff --git a/src/include/ib.hpp b/src/include/ib.hpp index d04b75bd..b1baeb75 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -61,7 +61,7 @@ public: void postRecv(uint64_t wrId); int pollCq(); - const IbQpInfo& getInfo() const; + IbQpInfo& getInfo(); const void* getWc(int idx) const; private: diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index fc3a72e8..d3fe15b0 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -5,6 +5,20 @@ #include #include +mscclpp::TransportFlags findIb(int localRank){ + mscclpp::TransportFlags IBs[] = { + mscclpp::TransportIB0, + mscclpp::TransportIB1, + mscclpp::TransportIB2, + mscclpp::TransportIB3, + mscclpp::TransportIB4, + mscclpp::TransportIB5, + mscclpp::TransportIB6, + mscclpp::TransportIB7 + }; + return IBs[localRank]; +} + void test_communicator(int rank, int worldSize, int nranksPerNode){ auto bootstrap = std::make_shared(rank, worldSize); mscclpp::UniqueId id; @@ -16,12 +30,14 @@ void test_communicator(int rank, int worldSize, int nranksPerNode){ auto communicator = std::make_shared(bootstrap); for (int i = 0; i < worldSize; i++){ if (i != rank){ - if (i % nranksPerNode == rank % nranksPerNode) + if (i % nranksPerNode == rank % nranksPerNode){ auto connect = communicator->connect(i, 0, mscclpp::TransportCudaIpc); - else - auto connect = communicator->connect(i, 0, mscclpp::TransportAllIB); + } else { + auto connect = communicator->connect(i, 0, findIb(rank % nranksPerNode)); + } } } + communicator->connectionSetup(); if (bootstrap->getRank() == 0) std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl;