diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py index 44a5c9c1..1243ca91 100644 --- a/examples/torch-integration/customized_comm_with_tuning.py +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -8,10 +8,6 @@ import os from mpi4py import MPI -_world_size = MPI.COMM_WORLD.Get_size() -if _world_size > 1 and "MSCCLPP_IPC_DOMAIN_NRANKS" not in os.environ: - os.environ["MSCCLPP_IPC_DOMAIN_NRANKS"] = str(_world_size) - import torch import mscclpp import mscclpp.ext @@ -101,8 +97,10 @@ class CustomizedComm: self.rank = comm.my_rank self.world_size = comm.nranks self.nranks_per_node = comm.nranks_per_node - nvlink_domain_nranks = int(os.environ.get("MSCCLPP_IPC_DOMAIN_NRANKS", "0")) - self.multi_host_mnnvl = nvlink_domain_nranks >= self.world_size and self.world_size > self.nranks_per_node + if comm.communicator.get_ipc_domain_n_ranks() == 0 and self.world_size > 1: + comm.communicator.set_ipc_domain_n_ranks(self.world_size) + self.ipc_domain_n_ranks = comm.communicator.get_ipc_domain_n_ranks() + self.multi_host_mnnvl = self.ipc_domain_n_ranks >= self.world_size and self.world_size > self.nranks_per_node self.symmetric_memory = symmetric_memory self._nvls = mscclpp.is_nvls_supported() diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 45b56bcc..481f1d3c 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -821,6 +821,18 @@ class Communicator { /// @return The context held by this communicator. std::shared_ptr context(); + /// Set the IPC-domain rank count for collective algorithms using this communicator. + /// + /// The value describes how many ranks are in one GPU-IPC-reachable peer group, such as a Multi-Node NVLink + /// fabric. Set to 0 to use the default `bootstrap()->getNranksPerNode()` value. + /// + /// @param ipcDomainNranks Number of ranks in the communicator's IPC domain, or 0 to use the default. + void setIpcDomainNranks(int ipcDomainNranks); + + /// Get the IPC-domain rank count override for this communicator. + /// @return The configured IPC-domain rank count, or 0 if the communicator uses `bootstrap()->getNranksPerNode()`. + int getIpcDomainNranks() const; + /// Register a region of GPU memory for use in this communicator's context. /// /// @param ptr Base pointer to the memory. diff --git a/include/mscclpp/env.hpp b/include/mscclpp/env.hpp index 0dd63ed7..a6dd306b 100644 --- a/include/mscclpp/env.hpp +++ b/include/mscclpp/env.hpp @@ -119,12 +119,6 @@ class Env { /// Default is 0. Used when `EndpointConfig::Ib::gidIndex` is -1 (unspecified). const int ibGidIndex; - /// Env name: `MSCCLPP_IPC_DOMAIN_NRANKS`. Number of ranks that share a single GPU-IPC-reachable peer - /// group (e.g. a Multi-Node NVLink fabric such as GB200 NVL72, or an AMD XGMI domain). This hint is - /// consumed only by the collective algorithms; it does not affect `Bootstrap::getNranksPerNode()` or - /// any other layer. If unset or non-positive, algorithms fall back to `bootstrap->getNranksPerNode()`. - const int ipcDomainNranks; - private: Env(); diff --git a/python/csrc/core_py.cpp b/python/csrc/core_py.cpp index a94f9863..d748c6a0 100644 --- a/python/csrc/core_py.cpp +++ b/python/csrc/core_py.cpp @@ -282,6 +282,8 @@ void register_core(nb::module_& m) { nb::arg("context") = nullptr) .def("bootstrap", &Communicator::bootstrap) .def("context", &Communicator::context) + .def("set_ipc_domain_n_ranks", &Communicator::setIpcDomainNranks, nb::arg("n_ranks")) + .def("get_ipc_domain_n_ranks", &Communicator::getIpcDomainNranks) .def( "register_memory", [](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) { diff --git a/python/mscclpp/_core/comm.py b/python/mscclpp/_core/comm.py index d42349dd..f1940eae 100644 --- a/python/mscclpp/_core/comm.py +++ b/python/mscclpp/_core/comm.py @@ -35,6 +35,7 @@ class CommGroup: interfaceIpPortTrio: str = "", rank: int = None, size: int = None, + ipc_domain_n_ranks: int = 0, ): if interfaceIpPortTrio == "" and (mpi_comm is not None or torch_group is not None): uniq_id = None @@ -70,9 +71,11 @@ class CommGroup: else: raise RuntimeError("Either the interface or mpi_group need to be specified") self.communicator = CppCommunicator(self.bootstrap) + self.communicator.set_ipc_domain_n_ranks(ipc_domain_n_ranks) self.my_rank = self.bootstrap.get_rank() self.nranks = self.bootstrap.get_n_ranks() self.nranks_per_node = self.bootstrap.get_n_ranks_per_node() + self.ipc_domain_n_ranks = self.communicator.get_ipc_domain_n_ranks() def barrier(self): self.bootstrap.barrier() diff --git a/src/core/communicator.cc b/src/core/communicator.cc index 1ca029d6..2272175e 100644 --- a/src/core/communicator.cc +++ b/src/core/communicator.cc @@ -81,6 +81,15 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::bootstrap() { return pi MSCCLPP_API_CPP std::shared_ptr Communicator::context() { return pimpl_->context_; } +MSCCLPP_API_CPP void Communicator::setIpcDomainNranks(int ipcDomainNranks) { + if (ipcDomainNranks < 0) { + throw Error("ipcDomainNranks must be non-negative", ErrorCode::InvalidUsage); + } + pimpl_->ipcDomainNranks_ = ipcDomainNranks; +} + +MSCCLPP_API_CPP int Communicator::getIpcDomainNranks() const { return pimpl_->ipcDomainNranks_; } + MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) { return context()->registerMemory(ptr, size, transports); } diff --git a/src/core/env.cpp b/src/core/env.cpp index 18d548b0..7a42471b 100644 --- a/src/core/env.cpp +++ b/src/core/env.cpp @@ -67,8 +67,7 @@ Env::Env() ncclSymmetricMemory(readEnv("MSCCLPP_NCCL_SYMMETRIC_MEMORY", false)), forceDisableNvls(readEnv("MSCCLPP_FORCE_DISABLE_NVLS", false)), forceDisableGdr(readEnv("MSCCLPP_FORCE_DISABLE_GDR", false)), - ibGidIndex(readEnv("MSCCLPP_IB_GID_INDEX", 0)), - ipcDomainNranks(readEnv("MSCCLPP_IPC_DOMAIN_NRANKS", 0)) {} + ibGidIndex(readEnv("MSCCLPP_IB_GID_INDEX", 0)) {} std::shared_ptr env() { static std::shared_ptr globalEnv = std::shared_ptr(new Env()); @@ -98,7 +97,6 @@ std::shared_ptr env() { logEnv("MSCCLPP_FORCE_DISABLE_NVLS", globalEnv->forceDisableNvls); logEnv("MSCCLPP_FORCE_DISABLE_GDR", globalEnv->forceDisableGdr); logEnv("MSCCLPP_IB_GID_INDEX", globalEnv->ibGidIndex); - logEnv("MSCCLPP_IPC_DOMAIN_NRANKS", globalEnv->ipcDomainNranks); } return globalEnv; } diff --git a/src/core/include/communicator.hpp b/src/core/include/communicator.hpp index f15e20f7..b9f519b9 100644 --- a/src/core/include/communicator.hpp +++ b/src/core/include/communicator.hpp @@ -60,6 +60,7 @@ struct Communicator::Impl { std::shared_ptr bootstrap_; std::shared_ptr context_; std::unordered_map connectionInfos_; + int ipcDomainNranks_ = 0; // Temporary storage for the latest RecvItem of each {remoteRank, tag} pair. // The RecvItem is removed when it finishes or when getLastRecvItem observes that it is ready. diff --git a/src/ext/collectives/collective_utils.cc b/src/ext/collectives/collective_utils.cc index 6acfd7ce..192fac8d 100644 --- a/src/ext/collectives/collective_utils.cc +++ b/src/ext/collectives/collective_utils.cc @@ -6,7 +6,6 @@ #include #include #include -#include #include #include #include @@ -73,23 +72,22 @@ std::vector> setupMemoryS return memorySemaphores; } -int getIpcDomainNranks(std::shared_ptr comm) { - const int envValue = mscclpp::env()->ipcDomainNranks; - const int ipcDomainNranks = (envValue > 0) ? envValue : comm->bootstrap()->getNranksPerNode(); +int getIpcDomainNranks(std::shared_ptr comm) { + const int commValue = comm->getIpcDomainNranks(); + const int ipcDomainNranks = (commValue > 0) ? commValue : comm->bootstrap()->getNranksPerNode(); const int worldSize = comm->bootstrap()->getNranks(); const int rank = comm->bootstrap()->getRank(); if (ipcDomainNranks < 2 || ipcDomainNranks > MAX_IPC_DOMAIN_NRANKS) { - THROW(mscclpp::LogSubsys::ALGO, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage, "ipcDomainNranks ", - ipcDomainNranks, " is out of supported range [2, ", MAX_IPC_DOMAIN_NRANKS, "]"); + THROW(LogSubsys::ALGO, Error, ErrorCode::InvalidUsage, "ipcDomainNranks ", ipcDomainNranks, + " is out of supported range [2, ", MAX_IPC_DOMAIN_NRANKS, "]"); } if (worldSize != ipcDomainNranks) { - THROW(mscclpp::LogSubsys::ALGO, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage, + THROW(LogSubsys::ALGO, Error, ErrorCode::InvalidUsage, "requires worldSize == ipcDomainNranks (got worldSize=", worldSize, ", ipcDomainNranks=", ipcDomainNranks, ")"); } if (rank < 0 || rank >= ipcDomainNranks) { - THROW(mscclpp::LogSubsys::ALGO, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage, "rank ", rank, " out of [0, ", - ipcDomainNranks, ")"); + THROW(LogSubsys::ALGO, Error, ErrorCode::InvalidUsage, "rank ", rank, " out of [0, ", ipcDomainNranks, ")"); } return ipcDomainNranks; } diff --git a/src/ext/collectives/include/collective_utils.hpp b/src/ext/collectives/include/collective_utils.hpp index 280a6332..217c7f55 100644 --- a/src/ext/collectives/include/collective_utils.hpp +++ b/src/ext/collectives/include/collective_utils.hpp @@ -52,9 +52,8 @@ std::vector> setupMemorySemaphores std::shared_ptr comm, const std::vector& connections, int nChannelsPerConnection); /// Returns the IPC-reachable peer-group size, validated to span the whole communicator and -/// to be within `[2, MAX_IPC_DOMAIN_NRANKS]`. Reads `MSCCLPP_IPC_DOMAIN_NRANKS` if set to a -/// positive value; otherwise falls back to `bootstrap->getNranksPerNode()`. Throws -/// `Error(InvalidUsage)` on violation. +/// to be within `[2, MAX_IPC_DOMAIN_NRANKS]`. Reads the communicator's IPC-domain override +/// if set; otherwise falls back to `bootstrap->getNranksPerNode()`. Throws `Error(InvalidUsage)` on violation. int getIpcDomainNranks(std::shared_ptr comm); std::shared_ptr> setupMemoryChannelDeviceHandles(