mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-24 06:44:40 +00:00
host hashes in communicator
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include "communicator.hpp"
|
||||
#include "host_connection.hpp"
|
||||
@@ -12,7 +14,13 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
Communicator::Impl::Impl(std::shared_ptr<BaseBootstrap> bootstrap) : bootstrap_(bootstrap) {}
|
||||
Communicator::Impl::Impl(std::shared_ptr<BaseBootstrap> bootstrap) : bootstrap_(bootstrap) {
|
||||
rankToHash_.resize(bootstrap->getNranks());
|
||||
auto hostHash = getHostHash();
|
||||
INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash);
|
||||
rankToHash_[bootstrap->getRank()] = hostHash;
|
||||
bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t));
|
||||
}
|
||||
|
||||
Communicator::Impl::~Impl() {
|
||||
ibContexts.clear();
|
||||
@@ -67,11 +75,21 @@ RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportF
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connect(int remoteRank, int tag, TransportFlags transport) {
|
||||
std::shared_ptr<ConnectionBase> conn;
|
||||
if (transport | TransportCudaIpc) {
|
||||
// sanity check: make sure the IPC connection is being made within a node
|
||||
if (pimpl->rankToHash_[remoteRank] != pimpl->rankToHash_[pimpl->bootstrap_->getRank()]) {
|
||||
std::stringstream ss;
|
||||
ss << "Cuda IPC connection can only be made within a node: " << remoteRank << " != " << pimpl->bootstrap_->getRank();
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
auto cudaIpcConn = std::make_shared<CudaIpcConnection>();
|
||||
conn = cudaIpcConn;
|
||||
INFO(MSCCLPP_INIT, "Cuda IPC connection between %d(%lx) and %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],
|
||||
remoteRank, pimpl->rankToHash_[remoteRank]);
|
||||
} else if (transport | TransportAllIB) {
|
||||
auto ibConn = std::make_shared<IBConnection>(remoteRank, tag, transport, *pimpl);
|
||||
conn = ibConn;
|
||||
INFO(MSCCLPP_INIT, "IB connection between %d(%lx) via %s and %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],
|
||||
getIBDeviceName(transport).c_str(), remoteRank, pimpl->rankToHash_[remoteRank]);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported transport");
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ struct Communicator::Impl {
|
||||
std::vector<std::shared_ptr<ConnectionBase>> connections;
|
||||
std::unordered_map<TransportFlags, std::unique_ptr<IbCtx>> ibContexts;
|
||||
std::shared_ptr<BaseBootstrap> bootstrap_;
|
||||
std::vector<uint64_t> rankToHash_;
|
||||
|
||||
Impl(std::shared_ptr<BaseBootstrap> bootstrap);
|
||||
|
||||
|
||||
@@ -30,7 +30,8 @@ void test_communicator(int rank, int worldSize, int nranksPerNode){
|
||||
auto communicator = std::make_shared<mscclpp::Communicator>(bootstrap);
|
||||
for (int i = 0; i < worldSize; i++){
|
||||
if (i != rank){
|
||||
if (i % nranksPerNode == rank % nranksPerNode){
|
||||
if (i / nranksPerNode == rank / nranksPerNode){
|
||||
printf("i %d rank %d nranksPerNode %d\n", i, rank, nranksPerNode);
|
||||
auto connect = communicator->connect(i, 0, mscclpp::TransportCudaIpc);
|
||||
} else {
|
||||
auto connect = communicator->connect(i, 0, findIb(rank % nranksPerNode));
|
||||
|
||||
Reference in New Issue
Block a user