diff --git a/CMakeLists.txt b/CMakeLists.txt index 49154e0b..3f9bf8e0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -206,6 +206,7 @@ if(MSCCLPP_USE_CUDA) else() set(GPU_LIBRARIES CUDA::cudart CUDA::cuda_driver) endif() + list(APPEND GPU_LIBRARIES CUDA::nvml) else() set(CMAKE_HIP_STANDARD 17) set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Wall -Wextra") diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py index d0da8c68..6cef88fe 100644 --- a/examples/torch-integration/customized_comm_with_tuning.py +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -97,7 +97,7 @@ class CustomizedComm: self.rank = comm.my_rank self.world_size = comm.nranks self.nranks_per_node = comm.nranks_per_node - self.ipc_domain_n_ranks = comm.communicator.get_ipc_domain_n_ranks() + self.ipc_domain_n_ranks = comm.ipc_domain_n_ranks self.multi_host_mnnvl = self.ipc_domain_n_ranks >= self.world_size and self.world_size > self.nranks_per_node self.symmetric_memory = symmetric_memory self._nvls = mscclpp.is_nvls_supported() @@ -431,8 +431,8 @@ def benchmark_allgather(comm: CustomizedComm, dtype=torch.float16, n_warmup=10, # -- Bootstrap & main --------------------------------------------------------- -def init_dist(ipc_domain_n_ranks: int = 0) -> mscclpp.CommGroup: - return mscclpp.CommGroup(mpi_comm=MPI.COMM_WORLD, ipc_domain_n_ranks=ipc_domain_n_ranks) +def init_dist() -> mscclpp.CommGroup: + return mscclpp.CommGroup(mpi_comm=MPI.COMM_WORLD) def main(): @@ -445,9 +445,8 @@ def main(): accum_str = os.environ.get("ACCUM_DTYPE") accum_dtype = accum_map.get(accum_str) if accum_str else None symmetric_memory = os.environ.get("SYMMETRIC_MEMORY", "1") == "1" - ipc_domain_n_ranks = int(os.environ.get("IPC_DOMAIN_NRANKS", "0")) - comm_group = init_dist(ipc_domain_n_ranks=ipc_domain_n_ranks) + comm_group = init_dist() cc = CustomizedComm(comm_group, symmetric_memory=symmetric_memory) print( diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 832323ad..4c14f1ee 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -46,6 +46,10 @@ class Bootstrap { /// @return The total number of ranks per node. virtual int getNranksPerNode() const = 0; + /// Return the number of ranks in this rank's GPU IPC domain. + /// @return The number of ranks in the GPU IPC domain. + virtual int getNranksPerIpcDomain() const; + /// Send arbitrary data to another process. /// /// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size, @@ -144,6 +148,9 @@ class TcpBootstrap : public Bootstrap { /// Return the total number of ranks per node. int getNranksPerNode() const override; + /// Return the number of ranks in this rank's GPU IPC domain. + int getNranksPerIpcDomain() const override; + /// Send arbitrary data to another process. /// /// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size, @@ -821,18 +828,6 @@ class Communicator { /// @return The context held by this communicator. std::shared_ptr context(); - /// Set the IPC-domain rank count for collective algorithms using this communicator. - /// - /// The value describes how many ranks are in one GPU-IPC-reachable peer group, such as a Multi-Node NVLink - /// fabric. Set to 0 to use the default `bootstrap()->getNranksPerNode()` value. - /// - /// @param ipcDomainNranks Number of ranks in the communicator's IPC domain, or 0 to use the default. - void setIpcDomainNranks(int ipcDomainNranks); - - /// Get the effective IPC-domain rank count for this communicator. - /// @return The configured IPC-domain rank count, or `bootstrap()->getNranksPerNode()` if no override is set. - int getIpcDomainNranks() const; - /// Register a region of GPU memory for use in this communicator's context. /// /// @param ptr Base pointer to the memory. diff --git a/python/csrc/core_py.cpp b/python/csrc/core_py.cpp index d748c6a0..7e9af6c1 100644 --- a/python/csrc/core_py.cpp +++ b/python/csrc/core_py.cpp @@ -56,6 +56,7 @@ void register_core(nb::module_& m) { .def("get_rank", &Bootstrap::getRank) .def("get_n_ranks", &Bootstrap::getNranks) .def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode) + .def("get_n_ranks_per_ipc_domain", &Bootstrap::getNranksPerIpcDomain) .def( "send", [](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) { @@ -282,8 +283,6 @@ void register_core(nb::module_& m) { nb::arg("context") = nullptr) .def("bootstrap", &Communicator::bootstrap) .def("context", &Communicator::context) - .def("set_ipc_domain_n_ranks", &Communicator::setIpcDomainNranks, nb::arg("n_ranks")) - .def("get_ipc_domain_n_ranks", &Communicator::getIpcDomainNranks) .def( "register_memory", [](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) { diff --git a/python/mscclpp/_core/comm.py b/python/mscclpp/_core/comm.py index f1940eae..875e07f1 100644 --- a/python/mscclpp/_core/comm.py +++ b/python/mscclpp/_core/comm.py @@ -35,7 +35,6 @@ class CommGroup: interfaceIpPortTrio: str = "", rank: int = None, size: int = None, - ipc_domain_n_ranks: int = 0, ): if interfaceIpPortTrio == "" and (mpi_comm is not None or torch_group is not None): uniq_id = None @@ -71,11 +70,10 @@ class CommGroup: else: raise RuntimeError("Either the interface or mpi_group need to be specified") self.communicator = CppCommunicator(self.bootstrap) - self.communicator.set_ipc_domain_n_ranks(ipc_domain_n_ranks) self.my_rank = self.bootstrap.get_rank() self.nranks = self.bootstrap.get_n_ranks() self.nranks_per_node = self.bootstrap.get_n_ranks_per_node() - self.ipc_domain_n_ranks = self.communicator.get_ipc_domain_n_ranks() + self.ipc_domain_n_ranks = self.bootstrap.get_n_ranks_per_ipc_domain() def barrier(self): self.bootstrap.barrier() diff --git a/src/core/bootstrap/bootstrap.cc b/src/core/bootstrap/bootstrap.cc index b3032e50..a5835751 100644 --- a/src/core/bootstrap/bootstrap.cc +++ b/src/core/bootstrap/bootstrap.cc @@ -50,6 +50,8 @@ MSCCLPP_API_CPP void Bootstrap::groupBarrier(const std::vector& ranks) { } } +MSCCLPP_API_CPP int Bootstrap::getNranksPerIpcDomain() const { return getNranksPerNode(); } + MSCCLPP_API_CPP void Bootstrap::send(const std::vector& data, int peer, int tag) { size_t size = data.size(); send((void*)&size, sizeof(size_t), peer, tag); @@ -83,6 +85,7 @@ class TcpBootstrap::Impl { int getRank(); int getNranks(); int getNranksPerNode(); + int getNranksPerIpcDomain(); void allGather(void* allData, int size); void broadcast(void* data, int size, int root); void send(void* data, int size, int peer, int tag); @@ -95,6 +98,7 @@ class TcpBootstrap::Impl { int rank_; int nRanks_; int nRanksPerNode_; + int nRanksPerIpcDomain_; bool netInitialized; std::unique_ptr listenSockRoot_; std::unique_ptr listenSock_; @@ -148,6 +152,7 @@ TcpBootstrap::Impl::Impl(int rank, int nRanks) : rank_(rank), nRanks_(nRanks), nRanksPerNode_(0), + nRanksPerIpcDomain_(0), netInitialized(false), peerCommAddresses_(nRanks, SocketAddress()), barrierArr_(nRanks, 0), @@ -451,6 +456,22 @@ int TcpBootstrap::Impl::getNranksPerNode() { return nRanksPerNode_; } +int TcpBootstrap::Impl::getNranksPerIpcDomain() { + if (nRanksPerIpcDomain_ > 0) return nRanksPerIpcDomain_; + std::vector ipcDomainHashes(nRanks_); + ipcDomainHashes[rank_] = getIpcDomainHash(); + allGather(ipcDomainHashes.data(), sizeof(uint64_t)); + + int nRanksPerIpcDomain = 0; + for (int i = 0; i < nRanks_; ++i) { + if (ipcDomainHashes[i] == ipcDomainHashes[rank_]) { + ++nRanksPerIpcDomain; + } + } + nRanksPerIpcDomain_ = nRanksPerIpcDomain; + return nRanksPerIpcDomain_; +} + void TcpBootstrap::Impl::allGather(void* allData, int size) { char* data = static_cast(allData); int rank = rank_; @@ -592,6 +613,8 @@ MSCCLPP_API_CPP int TcpBootstrap::getNranks() const { return pimpl_->getNranks() MSCCLPP_API_CPP int TcpBootstrap::getNranksPerNode() const { return pimpl_->getNranksPerNode(); } +MSCCLPP_API_CPP int TcpBootstrap::getNranksPerIpcDomain() const { return pimpl_->getNranksPerIpcDomain(); } + MSCCLPP_API_CPP void TcpBootstrap::send(void* data, int size, int peer, int tag) { pimpl_->send(data, size, peer, tag); } diff --git a/src/core/communicator.cc b/src/core/communicator.cc index 9bbbff3b..41e46bc5 100644 --- a/src/core/communicator.cc +++ b/src/core/communicator.cc @@ -79,17 +79,6 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::bootstrap() { return pi MSCCLPP_API_CPP std::shared_ptr Communicator::context() { return pimpl_->context_; } -MSCCLPP_API_CPP void Communicator::setIpcDomainNranks(int ipcDomainNranks) { - if (ipcDomainNranks < 0) { - throw Error("ipcDomainNranks must be non-negative", ErrorCode::InvalidUsage); - } - pimpl_->ipcDomainNranks_ = ipcDomainNranks; -} - -MSCCLPP_API_CPP int Communicator::getIpcDomainNranks() const { - return (pimpl_->ipcDomainNranks_ > 0) ? pimpl_->ipcDomainNranks_ : pimpl_->bootstrap_->getNranksPerNode(); -} - MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) { return context()->registerMemory(ptr, size, transports); } diff --git a/src/core/include/communicator.hpp b/src/core/include/communicator.hpp index b9f519b9..333cc982 100644 --- a/src/core/include/communicator.hpp +++ b/src/core/include/communicator.hpp @@ -60,8 +60,6 @@ struct Communicator::Impl { std::shared_ptr bootstrap_; std::shared_ptr context_; std::unordered_map connectionInfos_; - int ipcDomainNranks_ = 0; - // Temporary storage for the latest RecvItem of each {remoteRank, tag} pair. // The RecvItem is removed when it finishes or when getLastRecvItem observes that it is ready. std::unordered_map, std::shared_ptr, PairHash> lastRecvItems_; diff --git a/src/core/include/utils_internal.hpp b/src/core/include/utils_internal.hpp index c5c67e26..c6934194 100644 --- a/src/core/include/utils_internal.hpp +++ b/src/core/include/utils_internal.hpp @@ -37,6 +37,7 @@ int64_t busIdToInt64(const std::string busId); uint64_t getHash(const char* string, int n); uint64_t getHostHash(); uint64_t getPidHash(); +uint64_t getIpcDomainHash(); void getRandomData(void* buffer, size_t bytes); struct netIf { diff --git a/src/core/utils_internal.cc b/src/core/utils_internal.cc index 8cc55430..2e620b66 100644 --- a/src/core/utils_internal.cc +++ b/src/core/utils_internal.cc @@ -6,6 +6,10 @@ #include #include +#if defined(MSCCLPP_USE_CUDA) +#include +#endif + #include #include #include @@ -175,6 +179,79 @@ uint64_t getPidHash(void) { return *pidHash; } +#if defined(MSCCLPP_USE_CUDA) && defined(NVML_GPU_FABRIC_UUID_LEN) +namespace { + +class NvmlState { + public: + NvmlState() : initialized_(nvmlInit_v2() == NVML_SUCCESS) {} + + ~NvmlState() { + if (initialized_) { + (void)nvmlShutdown(); + } + } + + bool isInitialized() const { return initialized_; } + + private: + bool initialized_ = false; +}; + +uint64_t getFabricHash(const nvmlGpuFabricInfo_t& fabricInfo) { + char hashData[NVML_GPU_FABRIC_UUID_LEN + sizeof(fabricInfo.cliqueId)]; + std::memcpy(hashData, fabricInfo.clusterUuid, NVML_GPU_FABRIC_UUID_LEN); + std::memcpy(hashData + NVML_GPU_FABRIC_UUID_LEN, &fabricInfo.cliqueId, sizeof(fabricInfo.cliqueId)); + return getHash(hashData, sizeof(hashData)); +} + +bool tryGetNvmlIpcDomainHash(uint64_t& ipcDomainHash) { + // Use the current CUDA device; callers must set the rank's device before querying. + int deviceId; + if (cudaGetDevice(&deviceId) != cudaSuccess) { + return false; + } + + char pciBusId[] = "00000000:00:00.0"; + if (cudaDeviceGetPCIBusId(pciBusId, sizeof(pciBusId), deviceId) != cudaSuccess) { + return false; + } + + static NvmlState nvml; + if (!nvml.isInitialized()) { + return false; + } + + nvmlDevice_t nvmlDevice; + if (nvmlDeviceGetHandleByPciBusId_v2(pciBusId, &nvmlDevice) != NVML_SUCCESS) { + return false; + } + + nvmlGpuFabricInfo_t fabricInfo = {}; + if (nvmlDeviceGetGpuFabricInfo(nvmlDevice, &fabricInfo) != NVML_SUCCESS) { + return false; + } + if (fabricInfo.state != NVML_GPU_FABRIC_STATE_COMPLETED || fabricInfo.status != NVML_SUCCESS) { + return false; + } + + ipcDomainHash = getFabricHash(fabricInfo); + return true; +} + +} // namespace +#endif + +uint64_t getIpcDomainHash(void) { +#if defined(MSCCLPP_USE_CUDA) && defined(NVML_GPU_FABRIC_UUID_LEN) + uint64_t ipcDomainHash; + if (tryGetNvmlIpcDomainHash(ipcDomainHash)) { + return ipcDomainHash; + } +#endif + return getHostHash(); +} + int parseStringList(const char* string, netIf* ifList, int maxList) { if (!string) return 0; diff --git a/src/ext/collectives/allgather/allgather_fullmesh.cu b/src/ext/collectives/allgather/allgather_fullmesh.cu index 8b5cf3b7..84dd4d47 100644 --- a/src/ext/collectives/allgather/allgather_fullmesh.cu +++ b/src/ext/collectives/allgather/allgather_fullmesh.cu @@ -148,7 +148,7 @@ std::shared_ptr AllgatherFullmesh::initAllgatherContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->ipcDomainNranks = comm->getIpcDomainNranks(); + ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerIpcDomain(); // setup semaphores ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection); diff --git a/src/ext/collectives/allgather/allgather_fullmesh_2.cu b/src/ext/collectives/allgather/allgather_fullmesh_2.cu index de9d9384..5a353922 100644 --- a/src/ext/collectives/allgather/allgather_fullmesh_2.cu +++ b/src/ext/collectives/allgather/allgather_fullmesh_2.cu @@ -159,7 +159,7 @@ std::shared_ptr AllgatherFullmesh2::initAllgatherContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->ipcDomainNranks = comm->getIpcDomainNranks(); + ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerIpcDomain(); // setup semaphores ctx->memorySemaphores = this->memorySemaphores_; diff --git a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu index 6c4f972f..29ef2055 100644 --- a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu @@ -140,7 +140,7 @@ std::shared_ptr AllreduceAllpairPacket::initAllreduceContext(std::shared_p const int nChannelsPerConnection = maxBlockNum_; ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->ipcDomainNranks = comm->getIpcDomainNranks(); + ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerIpcDomain(); ctx->memorySemaphores = this->memorySemaphores_; ctx->registeredMemories = this->registeredMemories_; ctx->registeredMemories.pop_back(); // remove the local memory from previous context diff --git a/src/ext/collectives/allreduce/allreduce_fullmesh.cu b/src/ext/collectives/allreduce/allreduce_fullmesh.cu index a5427070..b158f817 100644 --- a/src/ext/collectives/allreduce/allreduce_fullmesh.cu +++ b/src/ext/collectives/allreduce/allreduce_fullmesh.cu @@ -250,7 +250,7 @@ std::shared_ptr AllreduceFullmesh::initAllreduceContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->ipcDomainNranks = comm->getIpcDomainNranks(); + ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerIpcDomain(); // setup semaphores ctx->memorySemaphores = this->outputSemaphores_; diff --git a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu index 07418f74..890e50f5 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu @@ -177,7 +177,7 @@ struct NvlsBlockPipelineAdapter { void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr comm) { nSwitchChannels_ = 8; - ipcDomainNranks_ = comm->getIpcDomainNranks(); + ipcDomainNranks_ = comm->bootstrap()->getNranksPerIpcDomain(); // Per-peer channel allocation must hold up to 4 * ipcDomainNranks entries (see kernel). nBaseChannels_ = std::max(64, 4 * ipcDomainNranks_); this->conns_ = setupConnections(comm); @@ -224,7 +224,7 @@ std::shared_ptr AllreduceNvlsBlockPipeline::initAllreduceContext(std::shar auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->ipcDomainNranks = comm->getIpcDomainNranks(); + ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerIpcDomain(); // setup channels ctx->switchChannels = diff --git a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu index cb9ad17e..e8ecfb73 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu @@ -95,7 +95,7 @@ std::shared_ptr AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr< auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->ipcDomainNranks = comm->getIpcDomainNranks(); + ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerIpcDomain(); // setup channels ctx->switchChannels = this->switchChannels_; diff --git a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu index a0669294..68efc2ab 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu @@ -141,7 +141,7 @@ struct NvlsWarpPipelineAdapter { void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr comm) { nSwitchChannels_ = NUM_NVLS_CONNECTION; - ipcDomainNranks_ = comm->getIpcDomainNranks(); + ipcDomainNranks_ = comm->bootstrap()->getNranksPerIpcDomain(); // Per-peer channel allocation must hold 2 * nBlocks entries; default nBlocks = 4 * ipcDomainNranks. nBaseChannels_ = std::max(64, 8 * ipcDomainNranks_); this->conns_ = setupConnections(comm); @@ -188,7 +188,7 @@ std::shared_ptr AllreduceNvlsWarpPipeline::initAllreduceContext(std::share auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->ipcDomainNranks = comm->getIpcDomainNranks(); + ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerIpcDomain(); // setup channels ctx->switchChannels = diff --git a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu index 36095e73..a6f699b2 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu @@ -176,7 +176,7 @@ std::shared_ptr AllreduceNvls::initAllreduceContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->ipcDomainNranks = comm->getIpcDomainNranks(); + ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerIpcDomain(); size_t sendBytes, recvBytes; CUdeviceptr sendBasePtr, recvBasePtr; diff --git a/src/ext/collectives/allreduce/allreduce_packet.cu b/src/ext/collectives/allreduce/allreduce_packet.cu index f88389dc..a0bc0e26 100644 --- a/src/ext/collectives/allreduce/allreduce_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_packet.cu @@ -263,7 +263,7 @@ std::shared_ptr AllreducePacket::initAllreduceContext(std::shared_ptrrank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->ipcDomainNranks = comm->getIpcDomainNranks(); + ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerIpcDomain(); ctx->memorySemaphores = this->memorySemaphores_; ctx->registeredMemories = this->registeredMemories_; ctx->registeredMemories.pop_back(); // remove the local memory from previous context diff --git a/src/ext/collectives/allreduce/allreduce_rsag.cu b/src/ext/collectives/allreduce/allreduce_rsag.cu index 43ff5610..22e3a4ee 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag.cu @@ -203,7 +203,7 @@ std::shared_ptr AllreduceRsAg::initAllreduceContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->ipcDomainNranks = comm->getIpcDomainNranks(); + ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerIpcDomain(); ctx->memorySemaphores = this->scratchSemaphores_; ctx->registeredMemories = this->remoteScratchMemories_; diff --git a/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu b/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu index 1e59c7e4..bedf15c5 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu @@ -306,7 +306,7 @@ std::shared_ptr AllreduceRsAgPipeline::initAllreduceContext(std::shared_pt auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->ipcDomainNranks = comm->getIpcDomainNranks(); + ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerIpcDomain(); ctx->memorySemaphores = this->scratchSemaphores_; ctx->registeredMemories = this->remoteScratchMemories_; diff --git a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu index f8d61279..10d3a35c 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu @@ -200,7 +200,7 @@ std::shared_ptr AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_pt auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->ipcDomainNranks = comm->getIpcDomainNranks(); + ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerIpcDomain(); ctx->memorySemaphores = this->semaphores_; diff --git a/test/mp_unit/bootstrap_tests.cc b/test/mp_unit/bootstrap_tests.cc index c28087a4..eb6985a8 100644 --- a/test/mp_unit/bootstrap_tests.cc +++ b/test/mp_unit/bootstrap_tests.cc @@ -127,6 +127,7 @@ class MPIBootstrap : public mscclpp::Bootstrap { MPI_Comm_size(shmcomm, &shmrank); return shmrank; } + int getNranksPerIpcDomain() const override { return getNranksPerNode(); } void allGather(void* sendbuf, int size) override { MPI_Allgather(MPI_IN_PLACE, 0, MPI_BYTE, sendbuf, size, MPI_BYTE, MPI_COMM_WORLD); }