host hashes in communicator

This commit is contained in:
Saeed Maleki
2023-04-27 19:17:19 +00:00
parent 8eda6369ee
commit aaa3f0e945
3 changed files with 22 additions and 2 deletions

View File

@@ -1,3 +1,5 @@
#include <sstream>
#include "mscclpp.hpp"
#include "communicator.hpp"
#include "host_connection.hpp"
@@ -12,7 +14,13 @@
namespace mscclpp {
Communicator::Impl::Impl(std::shared_ptr<BaseBootstrap> bootstrap) : bootstrap_(bootstrap) {}
Communicator::Impl::Impl(std::shared_ptr<BaseBootstrap> bootstrap) : bootstrap_(bootstrap) {
rankToHash_.resize(bootstrap->getNranks());
auto hostHash = getHostHash();
INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash);
rankToHash_[bootstrap->getRank()] = hostHash;
bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t));
}
Communicator::Impl::~Impl() {
ibContexts.clear();
@@ -67,11 +75,21 @@ RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportF
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connect(int remoteRank, int tag, TransportFlags transport) {
std::shared_ptr<ConnectionBase> conn;
if (transport | TransportCudaIpc) {
// sanity check: make sure the IPC connection is being made within a node
if (pimpl->rankToHash_[remoteRank] != pimpl->rankToHash_[pimpl->bootstrap_->getRank()]) {
std::stringstream ss;
ss << "Cuda IPC connection can only be made within a node: " << remoteRank << " != " << pimpl->bootstrap_->getRank();
throw std::runtime_error(ss.str());
}
auto cudaIpcConn = std::make_shared<CudaIpcConnection>();
conn = cudaIpcConn;
INFO(MSCCLPP_INIT, "Cuda IPC connection between %d(%lx) and %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],
remoteRank, pimpl->rankToHash_[remoteRank]);
} else if (transport | TransportAllIB) {
auto ibConn = std::make_shared<IBConnection>(remoteRank, tag, transport, *pimpl);
conn = ibConn;
INFO(MSCCLPP_INIT, "IB connection between %d(%lx) via %s and %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],
getIBDeviceName(transport).c_str(), remoteRank, pimpl->rankToHash_[remoteRank]);
} else {
throw std::runtime_error("Unsupported transport");
}

View File

@@ -18,6 +18,7 @@ struct Communicator::Impl {
std::vector<std::shared_ptr<ConnectionBase>> connections;
std::unordered_map<TransportFlags, std::unique_ptr<IbCtx>> ibContexts;
std::shared_ptr<BaseBootstrap> bootstrap_;
std::vector<uint64_t> rankToHash_;
Impl(std::shared_ptr<BaseBootstrap> bootstrap);

View File

@@ -30,7 +30,8 @@ void test_communicator(int rank, int worldSize, int nranksPerNode){
auto communicator = std::make_shared<mscclpp::Communicator>(bootstrap);
for (int i = 0; i < worldSize; i++){
if (i != rank){
if (i % nranksPerNode == rank % nranksPerNode){
if (i / nranksPerNode == rank / nranksPerNode){
printf("i %d rank %d nranksPerNode %d\n", i, rank, nranksPerNode);
auto connect = communicator->connect(i, 0, mscclpp::TransportCudaIpc);
} else {
auto connect = communicator->connect(i, 0, findIb(rank % nranksPerNode));