diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py index 060a0097..035c1dbb 100644 --- a/examples/torch-integration/customized_comm_with_tuning.py +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -2,11 +2,34 @@ # Licensed under the MIT License. # torchrun --nnodes=1 --nproc_per_node=8 examples/torch-integration/customized_comm_with_tuning.py +# mpirun -np 2 --hostfile python3 examples/torch-integration/customized_comm_with_tuning.py -import os +import gc +import fcntl import ipaddress +import os +import socket +import struct +import sys +import traceback + +def _get_bootstrap_world_size(): + for name in ("WORLD_SIZE", "OMPI_COMM_WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS"): + value = os.environ.get(name) + if value is not None: + return int(value) + return None + + +_bootstrap_world_size = _get_bootstrap_world_size() +if ( + _bootstrap_world_size + and _bootstrap_world_size > 1 + and "MSCCLPP_MNNVL_NRANKS_PER_NODE" not in os.environ + and os.environ.get("MSCCLPP_ENABLE_MNNVL", "1") != "0" +): + os.environ["MSCCLPP_MNNVL_NRANKS_PER_NODE"] = str(_bootstrap_world_size) -import netifaces as ni import torch import mscclpp import mscclpp.ext @@ -37,15 +60,44 @@ def _load_algorithms(scratch: torch.Tensor, rank: int): def _interfaces_for_ip(ip: str): target = ipaddress.ip_address(ip) - for iface in ni.interfaces(): - addrs = ni.ifaddresses(iface) - if ni.AF_INET in addrs: - for link in addrs[ni.AF_INET]: - if "addr" in link and ipaddress.ip_address(link["addr"]) == target: - return iface + for iface in os.listdir("/sys/class/net"): + try: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + req = struct.pack("256s", iface.encode()[:15]) + addr = socket.inet_ntoa(fcntl.ioctl(sock.fileno(), 0x8915, req)[20:24]) + except OSError: + continue + if ipaddress.ip_address(addr) == target: + return iface return None +def _resolve_interface(master_addr: str): + for env_name in ("MSCCLPP_INTERFACE", "MSCCLPP_SOCKET_IFNAME", "NCCL_SOCKET_IFNAME"): + value = os.environ.get(env_name) + if value: + iface = value.split(",")[0].strip() + if iface in os.listdir("/sys/class/net"): + return iface + raise ValueError(f"Interface {iface} from {env_name} does not exist") + return _interfaces_for_ip(master_addr) + + +def _get_env_int(*names: str, default=None): + for name in names: + value = os.environ.get(name) + if value is not None: + return int(value) + return default + + +def _running_under_mpi() -> bool: + return any( + name in os.environ + for name in ("OMPI_COMM_WORLD_RANK", "PMI_RANK", "PMIX_RANK", "MPI_LOCALRANKID", "SLURM_PROCID") + ) + + def _to_mscclpp_op(op) -> mscclpp.ReduceOp: if op == torch.distributed.ReduceOp.SUM: return mscclpp.ReduceOp.SUM @@ -76,6 +128,7 @@ class CustomizedComm: "default_allreduce_nvls_packet": 16, "default_allreduce_packet": 56, "default_allreduce_allpair_packet": 56, + "default_allreduce_rsag": 64, "default_allreduce_fullmesh": 64, "default_allgather_fullmesh2": 32, } @@ -84,6 +137,12 @@ class CustomizedComm: self.comm = comm self.rank = comm.my_rank self.world_size = comm.nranks + self.nranks_per_node = comm.nranks_per_node + self.mnnvl_domain = self.world_size > 1 and os.environ.get("MSCCLPP_MNNVL_NRANKS_PER_NODE") == str( + self.world_size + ) + self.multi_node = self.world_size > self.nranks_per_node and not self.mnnvl_domain + self.multi_host_mnnvl = self.mnnvl_domain and self.world_size > 1 self.symmetric_memory = symmetric_memory self._nvls = mscclpp.is_nvls_supported() @@ -106,6 +165,10 @@ class CustomizedComm: pkt = self._algo("allreduce", "default_allreduce_nvls_packet") if self._nvls and pkt: return (pkt, 0, 0) + if self.multi_node or self.multi_host_mnnvl: + rsag = self._algo("allreduce", "default_allreduce_rsag") + if rsag: + return (rsag, 0, 0) return (self._algo("allreduce", "default_allreduce_packet"), 0, 0) # -- low-level execute -- @@ -166,23 +229,48 @@ class CustomizedComm: def _ar_candidates(self, size: int): out = [] - if size <= 4 << 20: + if self.multi_host_mnnvl: + if size <= 4 << 20: + a = self._algo("allreduce", "default_allreduce_packet") + if a: + out.append(a) + a = self._algo("allreduce", "default_allreduce_nvls_packet") + if self._nvls and a: + out.append(a) + if size >= 512 << 10: + a = self._algo("allreduce", "default_allreduce_rsag") + if a: + out.append(a) + return out + if self.multi_node: a = self._algo("allreduce", "default_allreduce_nvls_packet") if self._nvls and a: out.append(a) a = self._algo("allreduce", "default_allreduce_packet") + if a: + out.append(a) + if size >= 512 << 10: + a = self._algo("allreduce", "default_allreduce_rsag") + if a: + out.append(a) + return out + if size <= 4 << 20: + a = self._algo("allreduce", "default_allreduce_packet") if a: out.append(a) a = self._algo("allreduce", "default_allreduce_allpair_packet") if a: out.append(a) - if size >= 512 << 10: - a = self._algo("allreduce", "default_allreduce_nvls_zero_copy") - if self._nvls and self.symmetric_memory and a: + a = self._algo("allreduce", "default_allreduce_nvls_packet") + if self._nvls and a: out.append(a) + if size >= 512 << 10: a = self._algo("allreduce", "default_allreduce_rsag_zero_copy") if a: out.append(a) + a = self._algo("allreduce", "default_allreduce_nvls_zero_copy") + if self._nvls and self.symmetric_memory and a: + out.append(a) if torch.version.hip is not None: a = self._algo("allreduce", "default_allreduce_fullmesh") if a: @@ -190,6 +278,8 @@ class CustomizedComm: return out def _ag_candidates(self): + if self.multi_node or self.multi_host_mnnvl: + return [] a = self._algo("allgather", "default_allgather_fullmesh2") return [a] if a else [] @@ -314,6 +404,8 @@ class CustomizedComm: ) def all_gather(self, output_tensor, input_tensor, stream=None): + if self.multi_node or self.multi_host_mnnvl: + raise RuntimeError("all_gather in this example currently supports only single-node runs") sz = _round_pow2(input_tensor.nbytes) if sz not in self._tune_cache["allgather"]: self._tune_size("allgather", sz) @@ -332,7 +424,11 @@ class CustomizedComm: # -- Benchmarks (standalone) -------------------------------------------------- -def _bench_sizes(low=5 * 1024, high=80 << 20): +def _bench_sizes(low=None, high=None): + if low is None: + low = _get_env_int("MSCCLPP_BENCH_LOW_SIZE", default=5 * 1024) + if high is None: + high = _get_env_int("MSCCLPP_BENCH_HIGH_SIZE", default=80 << 20) sizes, c = [], low while c <= high: sizes.append(c) @@ -433,13 +529,21 @@ def benchmark_allgather(comm: CustomizedComm, dtype=torch.float16, n_warmup=10, def init_dist() -> mscclpp.CommGroup: addr = os.environ.get("MSCCLPP_MASTER_ADDR") - if addr: - rank, world = int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]) - port = os.environ["MSCCLPP_MASTER_PORT"] - iface = _interfaces_for_ip(addr) + rank = _get_env_int("RANK", "OMPI_COMM_WORLD_RANK", "PMI_RANK", "SLURM_PROCID") + world = _get_env_int("WORLD_SIZE", "OMPI_COMM_WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS") + if addr and rank is not None and world is not None: + port = os.environ.get("MSCCLPP_MASTER_PORT", "29500") + iface = _resolve_interface(addr) if not iface: raise ValueError(f"No interface for {addr}") return mscclpp.CommGroup(interfaceIpPortTrio=f"{iface}:{addr}:{port}", rank=rank, size=world) + if _running_under_mpi(): + try: + from mpi4py import MPI + except ModuleNotFoundError as exc: + raise RuntimeError("mpi4py is required to launch this example with mpirun") from exc + + return mscclpp.CommGroup(mpi_comm=MPI.COMM_WORLD) import torch.distributed as dist dist.init_process_group(backend="gloo") @@ -447,7 +551,7 @@ def init_dist() -> mscclpp.CommGroup: def main(): - local = int(os.environ["LOCAL_RANK"]) + local = _get_env_int("LOCAL_RANK", "OMPI_COMM_WORLD_LOCAL_RANK", "MPI_LOCALRANKID", "SLURM_LOCALID", default=0) torch.cuda.set_device(local) dtype_str = os.environ.get("DTYPE", "float16") @@ -455,22 +559,48 @@ def main(): accum_map = {"float32": mscclpp.DataType.float32, "float16": mscclpp.DataType.float16} accum_str = os.environ.get("ACCUM_DTYPE") accum_dtype = accum_map.get(accum_str) if accum_str else None + n_warmup = _get_env_int("MSCCLPP_BENCH_WARMUP", default=10) + n_graph_launches = _get_env_int("MSCCLPP_BENCH_GRAPH_LAUNCHES", default=10) + n_iter = _get_env_int("MSCCLPP_BENCH_ITERS", default=100) comm_group = init_dist() cc = CustomizedComm(comm_group) print(f"rank {local} starting benchmarks with dtype={dtype} accum_dtype={accum_dtype}...") - benchmark_allreduce(cc, dtype=dtype, accum_dtype=accum_dtype) + benchmark_allreduce( + cc, + dtype=dtype, + accum_dtype=accum_dtype, + n_warmup=n_warmup, + n_graph_launches=n_graph_launches, + n_iter=n_iter, + ) cc.barrier() torch.cuda.synchronize() - benchmark_allgather(cc, dtype=dtype) - cc.barrier() - torch.cuda.synchronize() + if cc.multi_node or cc.multi_host_mnnvl: + if cc.rank == 0: + print("Skipping allgather benchmark on multi-node: this example's allgather path is single-node only.") + else: + benchmark_allgather(cc, dtype=dtype, n_warmup=n_warmup, n_graph_launches=n_graph_launches, n_iter=n_iter) + cc.barrier() + torch.cuda.synchronize() cc.destroy() + del cc + del comm_group + gc.collect() print(f"rank {local} completed successfully.") if __name__ == "__main__": - main() + exit_code = 0 + try: + main() + except Exception: + exit_code = 1 + traceback.print_exc() + finally: + sys.stdout.flush() + sys.stderr.flush() + os._exit(exit_code) diff --git a/include/mscclpp/env.hpp b/include/mscclpp/env.hpp index a6dd306b..09d364c3 100644 --- a/include/mscclpp/env.hpp +++ b/include/mscclpp/env.hpp @@ -119,6 +119,12 @@ class Env { /// Default is 0. Used when `EndpointConfig::Ib::gidIndex` is -1 (unspecified). const int ibGidIndex; + /// Env name: `MSCCLPP_MNNVL_NRANKS_PER_NODE`. Overrides the NVLink-domain size reported by the bootstrap. + /// This is intended for Multi-Node NVLink (MNNVL) deployments where a single CUDA IPC / NVLS domain spans + /// multiple hosts and should be treated as one collective peer group. + /// If unset or non-positive, the bootstrap falls back to physical-host-based detection. + const int mnnvlNranksPerNode; + private: Env(); diff --git a/src/core/bootstrap/bootstrap.cc b/src/core/bootstrap/bootstrap.cc index b3032e50..c84ef4c0 100644 --- a/src/core/bootstrap/bootstrap.cc +++ b/src/core/bootstrap/bootstrap.cc @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -433,6 +434,10 @@ void TcpBootstrap::Impl::establishConnections(int64_t timeoutSec) { int TcpBootstrap::Impl::getNranksPerNode() { if (nRanksPerNode_ > 0) return nRanksPerNode_; + if (env()->mnnvlNranksPerNode > 0) { + nRanksPerNode_ = env()->mnnvlNranksPerNode; + return nRanksPerNode_; + } int nRanksPerNode = 0; bool useIpv4 = peerCommAddresses_[rank_].sa.sa_family == AF_INET; for (int i = 0; i < nRanks_; i++) { diff --git a/src/core/env.cpp b/src/core/env.cpp index 7a42471b..b46670d7 100644 --- a/src/core/env.cpp +++ b/src/core/env.cpp @@ -67,7 +67,8 @@ Env::Env() ncclSymmetricMemory(readEnv("MSCCLPP_NCCL_SYMMETRIC_MEMORY", false)), forceDisableNvls(readEnv("MSCCLPP_FORCE_DISABLE_NVLS", false)), forceDisableGdr(readEnv("MSCCLPP_FORCE_DISABLE_GDR", false)), - ibGidIndex(readEnv("MSCCLPP_IB_GID_INDEX", 0)) {} + ibGidIndex(readEnv("MSCCLPP_IB_GID_INDEX", 0)), + mnnvlNranksPerNode(readEnv("MSCCLPP_MNNVL_NRANKS_PER_NODE", 0)) {} std::shared_ptr env() { static std::shared_ptr globalEnv = std::shared_ptr(new Env()); @@ -97,6 +98,7 @@ std::shared_ptr env() { logEnv("MSCCLPP_FORCE_DISABLE_NVLS", globalEnv->forceDisableNvls); logEnv("MSCCLPP_FORCE_DISABLE_GDR", globalEnv->forceDisableGdr); logEnv("MSCCLPP_IB_GID_INDEX", globalEnv->ibGidIndex); + logEnv("MSCCLPP_MNNVL_NRANKS_PER_NODE", globalEnv->mnnvlNranksPerNode); } return globalEnv; } diff --git a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu index 17bcfc33..9516ad78 100644 --- a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu @@ -17,9 +17,6 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand size_t channelDataOffset, size_t scratchBufferSize, int rank, int nRanksPerNode, int worldSize, size_t nelems, uint32_t numScratchBuff, void* flags, uint32_t flagSize) { - // This version of allreduce only works for single nodes - if (worldSize != nRanksPerNode) return; - if (sizeof(T) == 2 || sizeof(T) == 1) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int); const int nPeers = nRanksPerNode - 1; @@ -143,7 +140,7 @@ std::shared_ptr AllreduceAllpairPacket::initAllreduceContext(std::shared_p const int nChannelsPerConnection = maxBlockNum_; ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_); ctx->memorySemaphores = this->memorySemaphores_; ctx->registeredMemories = this->registeredMemories_; ctx->registeredMemories.pop_back(); // remove the local memory from previous context @@ -189,4 +186,4 @@ std::shared_ptr AllreduceAllpairPacket::build() { }); } } // namespace collective -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu index a616485e..21f71028 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu @@ -94,7 +94,7 @@ std::shared_ptr AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr< auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm); // setup channels ctx->switchChannels = this->switchChannels_; @@ -154,4 +154,4 @@ std::shared_ptr AllreduceNvlsPacket::build() { }); } } // namespace collective -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu index 735deb0a..25077004 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu @@ -183,7 +183,7 @@ std::shared_ptr AllreduceNvls::initAllreduceContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm); size_t sendBytes, recvBytes; CUdeviceptr sendBasePtr, recvBasePtr; diff --git a/src/ext/collectives/allreduce/allreduce_packet.cu b/src/ext/collectives/allreduce/allreduce_packet.cu index d39da408..c195aefa 100644 --- a/src/ext/collectives/allreduce/allreduce_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_packet.cu @@ -23,9 +23,6 @@ __global__ void __launch_bounds__(1024, 1) #else ) { #endif - // This version of allreduce only works for single nodes - if (worldSize != nRanksPerNode) return; - #if defined(ENABLE_NPKIT) extern __shared__ int4 NpkitSharedMem[]; NpKitEvent* event_buffer = (NpKitEvent*)((char*)NpkitSharedMem); @@ -267,7 +264,7 @@ std::shared_ptr AllreducePacket::initAllreduceContext(std::shared_ptrrank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_); ctx->memorySemaphores = this->memorySemaphores_; ctx->registeredMemories = this->registeredMemories_; ctx->registeredMemories.pop_back(); // remove the local memory from previous context @@ -313,4 +310,4 @@ std::shared_ptr AllreducePacket::build() { } } // namespace collective -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/ext/collectives/allreduce/allreduce_rsag.cu b/src/ext/collectives/allreduce/allreduce_rsag.cu index db471b93..f964b87e 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag.cu @@ -199,7 +199,7 @@ std::shared_ptr AllreduceRsAg::initAllreduceContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_); ctx->memorySemaphores = this->scratchSemaphores_; ctx->registeredMemories = this->remoteScratchMemories_; diff --git a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu index 42d86fc8..c4dea321 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu @@ -183,7 +183,7 @@ std::shared_ptr AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_pt auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_); ctx->memorySemaphores = this->semaphores_; diff --git a/src/ext/collectives/collective_utils.cc b/src/ext/collectives/collective_utils.cc index 016c4a5c..4d46c53b 100644 --- a/src/ext/collectives/collective_utils.cc +++ b/src/ext/collectives/collective_utils.cc @@ -69,6 +69,25 @@ std::vector> setupMemoryS return memorySemaphores; } +int getCollectiveDomainNranksPerNode(std::shared_ptr comm, + const std::vector& connections) { + const int worldSize = comm->bootstrap()->getNranks(); + const int nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + if (worldSize <= nRanksPerNode) { + return nRanksPerNode; + } + const bool allPeersUseCudaIpc = + std::all_of(connections.begin(), connections.end(), + [](const auto& connection) { return connection.transport() == mscclpp::Transport::CudaIpc; }); + return allPeersUseCudaIpc ? worldSize : nRanksPerNode; +} + +int getCollectiveDomainNranksPerNode(std::shared_ptr comm) { + const int worldSize = comm->bootstrap()->getNranks(); + const int nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + return worldSize > nRanksPerNode ? worldSize : nRanksPerNode; +} + std::shared_ptr> setupMemoryChannelDeviceHandles( const std::vector& memoryChannels) { std::vector> memoryChannelDeviceHandles; @@ -153,4 +172,4 @@ std::shared_ptr> setupBaseMemo } // namespace collective -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/ext/collectives/include/collective_utils.hpp b/src/ext/collectives/include/collective_utils.hpp index 638214dd..38362a65 100644 --- a/src/ext/collectives/include/collective_utils.hpp +++ b/src/ext/collectives/include/collective_utils.hpp @@ -50,6 +50,8 @@ std::vector setupMemoryChannels( std::vector setupConnections(std::shared_ptr comm); std::vector> setupMemorySemaphores( std::shared_ptr comm, const std::vector& connections, int nChannelsPerConnection); +int getCollectiveDomainNranksPerNode(std::shared_ptr comm, const std::vector& connections); +int getCollectiveDomainNranksPerNode(std::shared_ptr comm); std::shared_ptr> setupMemoryChannelDeviceHandles( const std::vector& memoryChannels); @@ -96,4 +98,4 @@ class AlgorithmCtx { } // namespace collective } // namespace mscclpp -#endif // MSCCLPP_EXT_COLLECTIVE_UTILS_HPP_ \ No newline at end of file +#endif // MSCCLPP_EXT_COLLECTIVE_UTILS_HPP_