testing connection setup

This commit is contained in:
Saeed Maleki
2023-04-27 06:08:35 +00:00
parent 4d7a4a25db
commit 8eda6369ee
6 changed files with 34 additions and 18 deletions

View File

@@ -81,10 +81,10 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connect(int remoteRank
MSCCLPP_API_CPP void Communicator::connectionSetup() {
for (auto& conn : pimpl->connections) {
conn->startSetup(*this);
conn->startSetup(pimpl->bootstrap_);
}
for (auto& conn : pimpl->connections) {
conn->endSetup(*this);
conn->endSetup(pimpl->bootstrap_);
}
}

View File

@@ -54,7 +54,7 @@ void CudaIpcConnection::flush() {
// IBConnection
IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank(remoteRank), tag(tag), transport_(transport), remoteTransport_(TransportNone) {
IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank_(remoteRank), tag_(tag), transport_(transport), remoteTransport_(TransportNone) {
qp = commImpl.getIbContext(transport)->createQp();
}
@@ -114,15 +114,15 @@ void IBConnection::flush() {
// npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
}
void IBConnection::startSetup(Communicator& comm) {
void IBConnection::startSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
// TODO(chhwang): temporarily disabled to compile
// comm.bootstrap().send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank, tag);
bootstrap->send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank_, tag_);
}
void IBConnection::endSetup(Communicator& comm) {
void IBConnection::endSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
IbQpInfo qpInfo;
// TODO(chhwang): temporarily disabled to compile
// comm.bootstrap().recv(&qpInfo, sizeof(qpInfo), remoteRank, tag);
bootstrap->recv(&qpInfo, sizeof(qpInfo), remoteRank_, tag_);
qp->rtr(qpInfo);
qp->rts();
}

View File

@@ -263,7 +263,7 @@ int IbQp::pollCq()
return ibv_poll_cq(reinterpret_cast<struct ibv_cq*>(this->cq), MSCCLPP_IB_CQ_POLL_NUM, reinterpret_cast<struct ibv_wc*>(this->wcs));
}
const IbQpInfo& IbQp::getInfo() const
IbQpInfo& IbQp::getInfo()
{
return this->info;
}

View File

@@ -12,8 +12,8 @@ namespace mscclpp {
class ConnectionBase : public Connection {
public:
virtual void startSetup(Communicator&) {};
virtual void endSetup(Communicator&) {};
virtual void startSetup(std::shared_ptr<BaseBootstrap> bootstrap) {};
virtual void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) {};
};
class CudaIpcConnection : public ConnectionBase {
@@ -34,8 +34,8 @@ public:
};
class IBConnection : public ConnectionBase {
int remoteRank;
int tag;
int remoteRank_;
int tag_;
TransportFlags transport_;
TransportFlags remoteTransport_;
IbQp* qp;
@@ -53,9 +53,9 @@ public:
void flush() override;
void startSetup(Communicator& comm) override;
void startSetup(std::shared_ptr<BaseBootstrap> bootstrap) override;
void endSetup(Communicator& comm) override;
void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) override;
};
} // namespace mscclpp

View File

@@ -61,7 +61,7 @@ public:
void postRecv(uint64_t wrId);
int pollCq();
const IbQpInfo& getInfo() const;
IbQpInfo& getInfo();
const void* getWc(int idx) const;
private:

View File

@@ -5,6 +5,20 @@
#include <iostream>
#include <mpi.h>
mscclpp::TransportFlags findIb(int localRank){
mscclpp::TransportFlags IBs[] = {
mscclpp::TransportIB0,
mscclpp::TransportIB1,
mscclpp::TransportIB2,
mscclpp::TransportIB3,
mscclpp::TransportIB4,
mscclpp::TransportIB5,
mscclpp::TransportIB6,
mscclpp::TransportIB7
};
return IBs[localRank];
}
void test_communicator(int rank, int worldSize, int nranksPerNode){
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, worldSize);
mscclpp::UniqueId id;
@@ -16,12 +30,14 @@ void test_communicator(int rank, int worldSize, int nranksPerNode){
auto communicator = std::make_shared<mscclpp::Communicator>(bootstrap);
for (int i = 0; i < worldSize; i++){
if (i != rank){
if (i % nranksPerNode == rank % nranksPerNode)
if (i % nranksPerNode == rank % nranksPerNode){
auto connect = communicator->connect(i, 0, mscclpp::TransportCudaIpc);
else
auto connect = communicator->connect(i, 0, mscclpp::TransportAllIB);
} else {
auto connect = communicator->connect(i, 0, findIb(rank % nranksPerNode));
}
}
}
communicator->connectionSetup();
if (bootstrap->getRank() == 0)
std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl;