mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
testing connection setup
This commit is contained in:
@@ -81,10 +81,10 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connect(int remoteRank
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::connectionSetup() {
|
||||
for (auto& conn : pimpl->connections) {
|
||||
conn->startSetup(*this);
|
||||
conn->startSetup(pimpl->bootstrap_);
|
||||
}
|
||||
for (auto& conn : pimpl->connections) {
|
||||
conn->endSetup(*this);
|
||||
conn->endSetup(pimpl->bootstrap_);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ void CudaIpcConnection::flush() {
|
||||
|
||||
// IBConnection
|
||||
|
||||
IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank(remoteRank), tag(tag), transport_(transport), remoteTransport_(TransportNone) {
|
||||
IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank_(remoteRank), tag_(tag), transport_(transport), remoteTransport_(TransportNone) {
|
||||
qp = commImpl.getIbContext(transport)->createQp();
|
||||
}
|
||||
|
||||
@@ -114,15 +114,15 @@ void IBConnection::flush() {
|
||||
// npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
|
||||
}
|
||||
|
||||
void IBConnection::startSetup(Communicator& comm) {
|
||||
void IBConnection::startSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
|
||||
// TODO(chhwang): temporarily disabled to compile
|
||||
// comm.bootstrap().send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank, tag);
|
||||
bootstrap->send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank_, tag_);
|
||||
}
|
||||
|
||||
void IBConnection::endSetup(Communicator& comm) {
|
||||
void IBConnection::endSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
|
||||
IbQpInfo qpInfo;
|
||||
// TODO(chhwang): temporarily disabled to compile
|
||||
// comm.bootstrap().recv(&qpInfo, sizeof(qpInfo), remoteRank, tag);
|
||||
bootstrap->recv(&qpInfo, sizeof(qpInfo), remoteRank_, tag_);
|
||||
qp->rtr(qpInfo);
|
||||
qp->rts();
|
||||
}
|
||||
|
||||
@@ -263,7 +263,7 @@ int IbQp::pollCq()
|
||||
return ibv_poll_cq(reinterpret_cast<struct ibv_cq*>(this->cq), MSCCLPP_IB_CQ_POLL_NUM, reinterpret_cast<struct ibv_wc*>(this->wcs));
|
||||
}
|
||||
|
||||
const IbQpInfo& IbQp::getInfo() const
|
||||
IbQpInfo& IbQp::getInfo()
|
||||
{
|
||||
return this->info;
|
||||
}
|
||||
|
||||
@@ -12,8 +12,8 @@ namespace mscclpp {
|
||||
|
||||
class ConnectionBase : public Connection {
|
||||
public:
|
||||
virtual void startSetup(Communicator&) {};
|
||||
virtual void endSetup(Communicator&) {};
|
||||
virtual void startSetup(std::shared_ptr<BaseBootstrap> bootstrap) {};
|
||||
virtual void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) {};
|
||||
};
|
||||
|
||||
class CudaIpcConnection : public ConnectionBase {
|
||||
@@ -34,8 +34,8 @@ public:
|
||||
};
|
||||
|
||||
class IBConnection : public ConnectionBase {
|
||||
int remoteRank;
|
||||
int tag;
|
||||
int remoteRank_;
|
||||
int tag_;
|
||||
TransportFlags transport_;
|
||||
TransportFlags remoteTransport_;
|
||||
IbQp* qp;
|
||||
@@ -53,9 +53,9 @@ public:
|
||||
|
||||
void flush() override;
|
||||
|
||||
void startSetup(Communicator& comm) override;
|
||||
void startSetup(std::shared_ptr<BaseBootstrap> bootstrap) override;
|
||||
|
||||
void endSetup(Communicator& comm) override;
|
||||
void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) override;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -61,7 +61,7 @@ public:
|
||||
void postRecv(uint64_t wrId);
|
||||
int pollCq();
|
||||
|
||||
const IbQpInfo& getInfo() const;
|
||||
IbQpInfo& getInfo();
|
||||
const void* getWc(int idx) const;
|
||||
|
||||
private:
|
||||
|
||||
@@ -5,6 +5,20 @@
|
||||
#include <iostream>
|
||||
#include <mpi.h>
|
||||
|
||||
mscclpp::TransportFlags findIb(int localRank){
|
||||
mscclpp::TransportFlags IBs[] = {
|
||||
mscclpp::TransportIB0,
|
||||
mscclpp::TransportIB1,
|
||||
mscclpp::TransportIB2,
|
||||
mscclpp::TransportIB3,
|
||||
mscclpp::TransportIB4,
|
||||
mscclpp::TransportIB5,
|
||||
mscclpp::TransportIB6,
|
||||
mscclpp::TransportIB7
|
||||
};
|
||||
return IBs[localRank];
|
||||
}
|
||||
|
||||
void test_communicator(int rank, int worldSize, int nranksPerNode){
|
||||
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, worldSize);
|
||||
mscclpp::UniqueId id;
|
||||
@@ -16,12 +30,14 @@ void test_communicator(int rank, int worldSize, int nranksPerNode){
|
||||
auto communicator = std::make_shared<mscclpp::Communicator>(bootstrap);
|
||||
for (int i = 0; i < worldSize; i++){
|
||||
if (i != rank){
|
||||
if (i % nranksPerNode == rank % nranksPerNode)
|
||||
if (i % nranksPerNode == rank % nranksPerNode){
|
||||
auto connect = communicator->connect(i, 0, mscclpp::TransportCudaIpc);
|
||||
else
|
||||
auto connect = communicator->connect(i, 0, mscclpp::TransportAllIB);
|
||||
} else {
|
||||
auto connect = communicator->connect(i, 0, findIb(rank % nranksPerNode));
|
||||
}
|
||||
}
|
||||
}
|
||||
communicator->connectionSetup();
|
||||
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user