diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py index 6f8f097d..1d54cfa7 100644 --- a/examples/torch-integration/customized_comm_with_tuning.py +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -13,6 +13,7 @@ import struct import sys import traceback + def _get_bootstrap_world_size(): for name in ("WORLD_SIZE", "OMPI_COMM_WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS"): value = os.environ.get(name) @@ -22,13 +23,8 @@ def _get_bootstrap_world_size(): _bootstrap_world_size = _get_bootstrap_world_size() -if ( - _bootstrap_world_size - and _bootstrap_world_size > 1 - and "MSCCLPP_MNNVL_NRANKS_PER_NODE" not in os.environ - and os.environ.get("MSCCLPP_ENABLE_MNNVL", "1") != "0" -): - os.environ["MSCCLPP_MNNVL_NRANKS_PER_NODE"] = str(_bootstrap_world_size) +if _bootstrap_world_size and _bootstrap_world_size > 1 and "MSCCLPP_IPC_DOMAIN_NRANKS" not in os.environ: + os.environ["MSCCLPP_IPC_DOMAIN_NRANKS"] = str(_bootstrap_world_size) import torch import mscclpp @@ -140,11 +136,10 @@ class CustomizedComm: self.rank = comm.my_rank self.world_size = comm.nranks self.nranks_per_node = comm.nranks_per_node - self.mnnvl_domain = self.world_size > 1 and os.environ.get("MSCCLPP_MNNVL_NRANKS_PER_NODE") == str( - self.world_size - ) + nvlink_domain_nranks = int(os.environ.get("MSCCLPP_IPC_DOMAIN_NRANKS", "0")) + self.mnnvl_domain = self.world_size > 1 and nvlink_domain_nranks >= self.world_size self.multi_node = self.world_size > self.nranks_per_node and not self.mnnvl_domain - self.multi_host_mnnvl = self.mnnvl_domain and self.world_size > 1 + self.multi_host_mnnvl = self.mnnvl_domain and self.world_size > self.nranks_per_node self.symmetric_memory = symmetric_memory self._nvls = mscclpp.is_nvls_supported() diff --git a/include/mscclpp/env.hpp b/include/mscclpp/env.hpp index 09d364c3..0dd63ed7 100644 --- a/include/mscclpp/env.hpp +++ b/include/mscclpp/env.hpp @@ -119,11 +119,11 @@ class Env { /// Default is 0. Used when `EndpointConfig::Ib::gidIndex` is -1 (unspecified). const int ibGidIndex; - /// Env name: `MSCCLPP_MNNVL_NRANKS_PER_NODE`. Overrides the NVLink-domain size reported by the bootstrap. - /// This is intended for Multi-Node NVLink (MNNVL) deployments where a single CUDA IPC / NVLS domain spans - /// multiple hosts and should be treated as one collective peer group. - /// If unset or non-positive, the bootstrap falls back to physical-host-based detection. - const int mnnvlNranksPerNode; + /// Env name: `MSCCLPP_IPC_DOMAIN_NRANKS`. Number of ranks that share a single GPU-IPC-reachable peer + /// group (e.g. a Multi-Node NVLink fabric such as GB200 NVL72, or an AMD XGMI domain). This hint is + /// consumed only by the collective algorithms; it does not affect `Bootstrap::getNranksPerNode()` or + /// any other layer. If unset or non-positive, algorithms fall back to `bootstrap->getNranksPerNode()`. + const int ipcDomainNranks; private: Env(); diff --git a/src/core/bootstrap/bootstrap.cc b/src/core/bootstrap/bootstrap.cc index c84ef4c0..b3032e50 100644 --- a/src/core/bootstrap/bootstrap.cc +++ b/src/core/bootstrap/bootstrap.cc @@ -5,7 +5,6 @@ #include #include -#include #include #include #include @@ -434,10 +433,6 @@ void TcpBootstrap::Impl::establishConnections(int64_t timeoutSec) { int TcpBootstrap::Impl::getNranksPerNode() { if (nRanksPerNode_ > 0) return nRanksPerNode_; - if (env()->mnnvlNranksPerNode > 0) { - nRanksPerNode_ = env()->mnnvlNranksPerNode; - return nRanksPerNode_; - } int nRanksPerNode = 0; bool useIpv4 = peerCommAddresses_[rank_].sa.sa_family == AF_INET; for (int i = 0; i < nRanks_; i++) { diff --git a/src/core/env.cpp b/src/core/env.cpp index b46670d7..18d548b0 100644 --- a/src/core/env.cpp +++ b/src/core/env.cpp @@ -68,7 +68,7 @@ Env::Env() forceDisableNvls(readEnv("MSCCLPP_FORCE_DISABLE_NVLS", false)), forceDisableGdr(readEnv("MSCCLPP_FORCE_DISABLE_GDR", false)), ibGidIndex(readEnv("MSCCLPP_IB_GID_INDEX", 0)), - mnnvlNranksPerNode(readEnv("MSCCLPP_MNNVL_NRANKS_PER_NODE", 0)) {} + ipcDomainNranks(readEnv("MSCCLPP_IPC_DOMAIN_NRANKS", 0)) {} std::shared_ptr env() { static std::shared_ptr globalEnv = std::shared_ptr(new Env()); @@ -98,7 +98,7 @@ std::shared_ptr env() { logEnv("MSCCLPP_FORCE_DISABLE_NVLS", globalEnv->forceDisableNvls); logEnv("MSCCLPP_FORCE_DISABLE_GDR", globalEnv->forceDisableGdr); logEnv("MSCCLPP_IB_GID_INDEX", globalEnv->ibGidIndex); - logEnv("MSCCLPP_MNNVL_NRANKS_PER_NODE", globalEnv->mnnvlNranksPerNode); + logEnv("MSCCLPP_IPC_DOMAIN_NRANKS", globalEnv->ipcDomainNranks); } return globalEnv; } diff --git a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu index 9516ad78..690d0eb4 100644 --- a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu @@ -140,7 +140,7 @@ std::shared_ptr AllreduceAllpairPacket::initAllreduceContext(std::shared_p const int nChannelsPerConnection = maxBlockNum_; ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_); + ctx->nRanksPerNode = getIpcDomainNranks(comm); ctx->memorySemaphores = this->memorySemaphores_; ctx->registeredMemories = this->registeredMemories_; ctx->registeredMemories.pop_back(); // remove the local memory from previous context diff --git a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu index 21f71028..d331cc67 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu @@ -94,7 +94,7 @@ std::shared_ptr AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr< auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm); + ctx->nRanksPerNode = getIpcDomainNranks(comm); // setup channels ctx->switchChannels = this->switchChannels_; diff --git a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu index 25077004..36fcf860 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu @@ -183,7 +183,7 @@ std::shared_ptr AllreduceNvls::initAllreduceContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm); + ctx->nRanksPerNode = getIpcDomainNranks(comm); size_t sendBytes, recvBytes; CUdeviceptr sendBasePtr, recvBasePtr; diff --git a/src/ext/collectives/allreduce/allreduce_packet.cu b/src/ext/collectives/allreduce/allreduce_packet.cu index c195aefa..d631c35a 100644 --- a/src/ext/collectives/allreduce/allreduce_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_packet.cu @@ -264,7 +264,7 @@ std::shared_ptr AllreducePacket::initAllreduceContext(std::shared_ptrrank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_); + ctx->nRanksPerNode = getIpcDomainNranks(comm); ctx->memorySemaphores = this->memorySemaphores_; ctx->registeredMemories = this->registeredMemories_; ctx->registeredMemories.pop_back(); // remove the local memory from previous context diff --git a/src/ext/collectives/allreduce/allreduce_rsag.cu b/src/ext/collectives/allreduce/allreduce_rsag.cu index 7f9e6bfd..4c46bf9b 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag.cu @@ -203,7 +203,7 @@ std::shared_ptr AllreduceRsAg::initAllreduceContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_); + ctx->nRanksPerNode = getIpcDomainNranks(comm); ctx->memorySemaphores = this->scratchSemaphores_; ctx->registeredMemories = this->remoteScratchMemories_; diff --git a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu index a11da0f8..67eed6d3 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu @@ -169,7 +169,7 @@ CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_pt auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_); + ctx->nRanksPerNode = getIpcDomainNranks(comm); ctx->memorySemaphores = this->semaphores_; diff --git a/src/ext/collectives/collective_utils.cc b/src/ext/collectives/collective_utils.cc index 4d46c53b..de33009c 100644 --- a/src/ext/collectives/collective_utils.cc +++ b/src/ext/collectives/collective_utils.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -69,23 +70,12 @@ std::vector> setupMemoryS return memorySemaphores; } -int getCollectiveDomainNranksPerNode(std::shared_ptr comm, - const std::vector& connections) { - const int worldSize = comm->bootstrap()->getNranks(); - const int nRanksPerNode = comm->bootstrap()->getNranksPerNode(); - if (worldSize <= nRanksPerNode) { - return nRanksPerNode; +int getIpcDomainNranks(std::shared_ptr comm) { + const int envValue = mscclpp::env()->ipcDomainNranks; + if (envValue > 0) { + return envValue; } - const bool allPeersUseCudaIpc = - std::all_of(connections.begin(), connections.end(), - [](const auto& connection) { return connection.transport() == mscclpp::Transport::CudaIpc; }); - return allPeersUseCudaIpc ? worldSize : nRanksPerNode; -} - -int getCollectiveDomainNranksPerNode(std::shared_ptr comm) { - const int worldSize = comm->bootstrap()->getNranks(); - const int nRanksPerNode = comm->bootstrap()->getNranksPerNode(); - return worldSize > nRanksPerNode ? worldSize : nRanksPerNode; + return comm->bootstrap()->getNranksPerNode(); } std::shared_ptr> setupMemoryChannelDeviceHandles( diff --git a/src/ext/collectives/include/collective_utils.hpp b/src/ext/collectives/include/collective_utils.hpp index 38362a65..44a21402 100644 --- a/src/ext/collectives/include/collective_utils.hpp +++ b/src/ext/collectives/include/collective_utils.hpp @@ -50,8 +50,13 @@ std::vector setupMemoryChannels( std::vector setupConnections(std::shared_ptr comm); std::vector> setupMemorySemaphores( std::shared_ptr comm, const std::vector& connections, int nChannelsPerConnection); -int getCollectiveDomainNranksPerNode(std::shared_ptr comm, const std::vector& connections); -int getCollectiveDomainNranksPerNode(std::shared_ptr comm); + +/// Number of ranks that participate in the same GPU-IPC-reachable peer group (e.g. a single host or +/// a Multi-Node NVLink fabric, or an AMD XGMI domain). Returns the value of `MSCCLPP_IPC_DOMAIN_NRANKS` +/// if set to a positive value; otherwise falls back to `bootstrap->getNranksPerNode()`. This is +/// intentionally independent of `nRanksPerNode` so that algorithms can opt in to MNNVL-like behavior +/// without changing the meaning of bootstrap-level APIs. +int getIpcDomainNranks(std::shared_ptr comm); std::shared_ptr> setupMemoryChannelDeviceHandles( const std::vector& memoryChannels);