Enable MNNVL allreduce tuning

Add an MNNVL rank-domain override so MSCCL++ collectives can treat multi-host NVLink fabrics as a single CUDA IPC/NVLS peer group. Update packet, RSAG, and NVLS allreduce paths to use the collective domain size and teach the torch integration tuning example to select MNNVL-capable allreduce algorithms.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-04-28 05:38:59 +00:00
parent dd8b301a65
commit 893a08e69c
12 changed files with 199 additions and 41 deletions

View File

@@ -2,11 +2,34 @@
# Licensed under the MIT License.
# torchrun --nnodes=1 --nproc_per_node=8 examples/torch-integration/customized_comm_with_tuning.py
# mpirun -np 2 --hostfile <hostfile> python3 examples/torch-integration/customized_comm_with_tuning.py
import os
import gc
import fcntl
import ipaddress
import os
import socket
import struct
import sys
import traceback
def _get_bootstrap_world_size():
for name in ("WORLD_SIZE", "OMPI_COMM_WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS"):
value = os.environ.get(name)
if value is not None:
return int(value)
return None
_bootstrap_world_size = _get_bootstrap_world_size()
if (
_bootstrap_world_size
and _bootstrap_world_size > 1
and "MSCCLPP_MNNVL_NRANKS_PER_NODE" not in os.environ
and os.environ.get("MSCCLPP_ENABLE_MNNVL", "1") != "0"
):
os.environ["MSCCLPP_MNNVL_NRANKS_PER_NODE"] = str(_bootstrap_world_size)
import netifaces as ni
import torch
import mscclpp
import mscclpp.ext
@@ -37,15 +60,44 @@ def _load_algorithms(scratch: torch.Tensor, rank: int):
def _interfaces_for_ip(ip: str):
target = ipaddress.ip_address(ip)
for iface in ni.interfaces():
addrs = ni.ifaddresses(iface)
if ni.AF_INET in addrs:
for link in addrs[ni.AF_INET]:
if "addr" in link and ipaddress.ip_address(link["addr"]) == target:
return iface
for iface in os.listdir("/sys/class/net"):
try:
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
req = struct.pack("256s", iface.encode()[:15])
addr = socket.inet_ntoa(fcntl.ioctl(sock.fileno(), 0x8915, req)[20:24])
except OSError:
continue
if ipaddress.ip_address(addr) == target:
return iface
return None
def _resolve_interface(master_addr: str):
for env_name in ("MSCCLPP_INTERFACE", "MSCCLPP_SOCKET_IFNAME", "NCCL_SOCKET_IFNAME"):
value = os.environ.get(env_name)
if value:
iface = value.split(",")[0].strip()
if iface in os.listdir("/sys/class/net"):
return iface
raise ValueError(f"Interface {iface} from {env_name} does not exist")
return _interfaces_for_ip(master_addr)
def _get_env_int(*names: str, default=None):
for name in names:
value = os.environ.get(name)
if value is not None:
return int(value)
return default
def _running_under_mpi() -> bool:
return any(
name in os.environ
for name in ("OMPI_COMM_WORLD_RANK", "PMI_RANK", "PMIX_RANK", "MPI_LOCALRANKID", "SLURM_PROCID")
)
def _to_mscclpp_op(op) -> mscclpp.ReduceOp:
if op == torch.distributed.ReduceOp.SUM:
return mscclpp.ReduceOp.SUM
@@ -76,6 +128,7 @@ class CustomizedComm:
"default_allreduce_nvls_packet": 16,
"default_allreduce_packet": 56,
"default_allreduce_allpair_packet": 56,
"default_allreduce_rsag": 64,
"default_allreduce_fullmesh": 64,
"default_allgather_fullmesh2": 32,
}
@@ -84,6 +137,12 @@ class CustomizedComm:
self.comm = comm
self.rank = comm.my_rank
self.world_size = comm.nranks
self.nranks_per_node = comm.nranks_per_node
self.mnnvl_domain = self.world_size > 1 and os.environ.get("MSCCLPP_MNNVL_NRANKS_PER_NODE") == str(
self.world_size
)
self.multi_node = self.world_size > self.nranks_per_node and not self.mnnvl_domain
self.multi_host_mnnvl = self.mnnvl_domain and self.world_size > 1
self.symmetric_memory = symmetric_memory
self._nvls = mscclpp.is_nvls_supported()
@@ -106,6 +165,10 @@ class CustomizedComm:
pkt = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and pkt:
return (pkt, 0, 0)
if self.multi_node or self.multi_host_mnnvl:
rsag = self._algo("allreduce", "default_allreduce_rsag")
if rsag:
return (rsag, 0, 0)
return (self._algo("allreduce", "default_allreduce_packet"), 0, 0)
# -- low-level execute --
@@ -166,23 +229,48 @@ class CustomizedComm:
def _ar_candidates(self, size: int):
out = []
if size <= 4 << 20:
if self.multi_host_mnnvl:
if size <= 4 << 20:
a = self._algo("allreduce", "default_allreduce_packet")
if a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and a:
out.append(a)
if size >= 512 << 10:
a = self._algo("allreduce", "default_allreduce_rsag")
if a:
out.append(a)
return out
if self.multi_node:
a = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_packet")
if a:
out.append(a)
if size >= 512 << 10:
a = self._algo("allreduce", "default_allreduce_rsag")
if a:
out.append(a)
return out
if size <= 4 << 20:
a = self._algo("allreduce", "default_allreduce_packet")
if a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_allpair_packet")
if a:
out.append(a)
if size >= 512 << 10:
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
if self._nvls and self.symmetric_memory and a:
a = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and a:
out.append(a)
if size >= 512 << 10:
a = self._algo("allreduce", "default_allreduce_rsag_zero_copy")
if a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
if self._nvls and self.symmetric_memory and a:
out.append(a)
if torch.version.hip is not None:
a = self._algo("allreduce", "default_allreduce_fullmesh")
if a:
@@ -190,6 +278,8 @@ class CustomizedComm:
return out
def _ag_candidates(self):
if self.multi_node or self.multi_host_mnnvl:
return []
a = self._algo("allgather", "default_allgather_fullmesh2")
return [a] if a else []
@@ -314,6 +404,8 @@ class CustomizedComm:
)
def all_gather(self, output_tensor, input_tensor, stream=None):
if self.multi_node or self.multi_host_mnnvl:
raise RuntimeError("all_gather in this example currently supports only single-node runs")
sz = _round_pow2(input_tensor.nbytes)
if sz not in self._tune_cache["allgather"]:
self._tune_size("allgather", sz)
@@ -332,7 +424,11 @@ class CustomizedComm:
# -- Benchmarks (standalone) --------------------------------------------------
def _bench_sizes(low=5 * 1024, high=80 << 20):
def _bench_sizes(low=None, high=None):
if low is None:
low = _get_env_int("MSCCLPP_BENCH_LOW_SIZE", default=5 * 1024)
if high is None:
high = _get_env_int("MSCCLPP_BENCH_HIGH_SIZE", default=80 << 20)
sizes, c = [], low
while c <= high:
sizes.append(c)
@@ -433,13 +529,21 @@ def benchmark_allgather(comm: CustomizedComm, dtype=torch.float16, n_warmup=10,
def init_dist() -> mscclpp.CommGroup:
addr = os.environ.get("MSCCLPP_MASTER_ADDR")
if addr:
rank, world = int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"])
port = os.environ["MSCCLPP_MASTER_PORT"]
iface = _interfaces_for_ip(addr)
rank = _get_env_int("RANK", "OMPI_COMM_WORLD_RANK", "PMI_RANK", "SLURM_PROCID")
world = _get_env_int("WORLD_SIZE", "OMPI_COMM_WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS")
if addr and rank is not None and world is not None:
port = os.environ.get("MSCCLPP_MASTER_PORT", "29500")
iface = _resolve_interface(addr)
if not iface:
raise ValueError(f"No interface for {addr}")
return mscclpp.CommGroup(interfaceIpPortTrio=f"{iface}:{addr}:{port}", rank=rank, size=world)
if _running_under_mpi():
try:
from mpi4py import MPI
except ModuleNotFoundError as exc:
raise RuntimeError("mpi4py is required to launch this example with mpirun") from exc
return mscclpp.CommGroup(mpi_comm=MPI.COMM_WORLD)
import torch.distributed as dist
dist.init_process_group(backend="gloo")
@@ -447,7 +551,7 @@ def init_dist() -> mscclpp.CommGroup:
def main():
local = int(os.environ["LOCAL_RANK"])
local = _get_env_int("LOCAL_RANK", "OMPI_COMM_WORLD_LOCAL_RANK", "MPI_LOCALRANKID", "SLURM_LOCALID", default=0)
torch.cuda.set_device(local)
dtype_str = os.environ.get("DTYPE", "float16")
@@ -455,22 +559,48 @@ def main():
accum_map = {"float32": mscclpp.DataType.float32, "float16": mscclpp.DataType.float16}
accum_str = os.environ.get("ACCUM_DTYPE")
accum_dtype = accum_map.get(accum_str) if accum_str else None
n_warmup = _get_env_int("MSCCLPP_BENCH_WARMUP", default=10)
n_graph_launches = _get_env_int("MSCCLPP_BENCH_GRAPH_LAUNCHES", default=10)
n_iter = _get_env_int("MSCCLPP_BENCH_ITERS", default=100)
comm_group = init_dist()
cc = CustomizedComm(comm_group)
print(f"rank {local} starting benchmarks with dtype={dtype} accum_dtype={accum_dtype}...")
benchmark_allreduce(cc, dtype=dtype, accum_dtype=accum_dtype)
benchmark_allreduce(
cc,
dtype=dtype,
accum_dtype=accum_dtype,
n_warmup=n_warmup,
n_graph_launches=n_graph_launches,
n_iter=n_iter,
)
cc.barrier()
torch.cuda.synchronize()
benchmark_allgather(cc, dtype=dtype)
cc.barrier()
torch.cuda.synchronize()
if cc.multi_node or cc.multi_host_mnnvl:
if cc.rank == 0:
print("Skipping allgather benchmark on multi-node: this example's allgather path is single-node only.")
else:
benchmark_allgather(cc, dtype=dtype, n_warmup=n_warmup, n_graph_launches=n_graph_launches, n_iter=n_iter)
cc.barrier()
torch.cuda.synchronize()
cc.destroy()
del cc
del comm_group
gc.collect()
print(f"rank {local} completed successfully.")
if __name__ == "__main__":
main()
exit_code = 0
try:
main()
except Exception:
exit_code = 1
traceback.print_exc()
finally:
sys.stdout.flush()
sys.stderr.flush()
os._exit(exit_code)

View File

@@ -119,6 +119,12 @@ class Env {
/// Default is 0. Used when `EndpointConfig::Ib::gidIndex` is -1 (unspecified).
const int ibGidIndex;
/// Env name: `MSCCLPP_MNNVL_NRANKS_PER_NODE`. Overrides the NVLink-domain size reported by the bootstrap.
/// This is intended for Multi-Node NVLink (MNNVL) deployments where a single CUDA IPC / NVLS domain spans
/// multiple hosts and should be treated as one collective peer group.
/// If unset or non-positive, the bootstrap falls back to physical-host-based detection.
const int mnnvlNranksPerNode;
private:
Env();

View File

@@ -5,6 +5,7 @@
#include <cstring>
#include <mscclpp/core.hpp>
#include <mscclpp/env.hpp>
#include <mscclpp/errors.hpp>
#include <sstream>
#include <thread>
@@ -433,6 +434,10 @@ void TcpBootstrap::Impl::establishConnections(int64_t timeoutSec) {
int TcpBootstrap::Impl::getNranksPerNode() {
if (nRanksPerNode_ > 0) return nRanksPerNode_;
if (env()->mnnvlNranksPerNode > 0) {
nRanksPerNode_ = env()->mnnvlNranksPerNode;
return nRanksPerNode_;
}
int nRanksPerNode = 0;
bool useIpv4 = peerCommAddresses_[rank_].sa.sa_family == AF_INET;
for (int i = 0; i < nRanks_; i++) {

View File

@@ -67,7 +67,8 @@ Env::Env()
ncclSymmetricMemory(readEnv<bool>("MSCCLPP_NCCL_SYMMETRIC_MEMORY", false)),
forceDisableNvls(readEnv<bool>("MSCCLPP_FORCE_DISABLE_NVLS", false)),
forceDisableGdr(readEnv<bool>("MSCCLPP_FORCE_DISABLE_GDR", false)),
ibGidIndex(readEnv<int>("MSCCLPP_IB_GID_INDEX", 0)) {}
ibGidIndex(readEnv<int>("MSCCLPP_IB_GID_INDEX", 0)),
mnnvlNranksPerNode(readEnv<int>("MSCCLPP_MNNVL_NRANKS_PER_NODE", 0)) {}
std::shared_ptr<Env> env() {
static std::shared_ptr<Env> globalEnv = std::shared_ptr<Env>(new Env());
@@ -97,6 +98,7 @@ std::shared_ptr<Env> env() {
logEnv("MSCCLPP_FORCE_DISABLE_NVLS", globalEnv->forceDisableNvls);
logEnv("MSCCLPP_FORCE_DISABLE_GDR", globalEnv->forceDisableGdr);
logEnv("MSCCLPP_IB_GID_INDEX", globalEnv->ibGidIndex);
logEnv("MSCCLPP_MNNVL_NRANKS_PER_NODE", globalEnv->mnnvlNranksPerNode);
}
return globalEnv;
}

View File

@@ -17,9 +17,6 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand
size_t channelDataOffset, size_t scratchBufferSize, int rank, int nRanksPerNode,
int worldSize, size_t nelems, uint32_t numScratchBuff, void* flags,
uint32_t flagSize) {
// This version of allreduce only works for single nodes
if (worldSize != nRanksPerNode) return;
if (sizeof(T) == 2 || sizeof(T) == 1) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int);
const int nPeers = nRanksPerNode - 1;
@@ -143,7 +140,7 @@ std::shared_ptr<void> AllreduceAllpairPacket::initAllreduceContext(std::shared_p
const int nChannelsPerConnection = maxBlockNum_;
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_);
ctx->memorySemaphores = this->memorySemaphores_;
ctx->registeredMemories = this->registeredMemories_;
ctx->registeredMemories.pop_back(); // remove the local memory from previous context
@@ -189,4 +186,4 @@ std::shared_ptr<Algorithm> AllreduceAllpairPacket::build() {
});
}
} // namespace collective
} // namespace mscclpp
} // namespace mscclpp

View File

@@ -94,7 +94,7 @@ std::shared_ptr<void> AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr<
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm);
// setup channels
ctx->switchChannels = this->switchChannels_;
@@ -154,4 +154,4 @@ std::shared_ptr<mscclpp::Algorithm> AllreduceNvlsPacket::build() {
});
}
} // namespace collective
} // namespace mscclpp
} // namespace mscclpp

View File

@@ -183,7 +183,7 @@ std::shared_ptr<void> AllreduceNvls::initAllreduceContext(std::shared_ptr<mscclp
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm);
size_t sendBytes, recvBytes;
CUdeviceptr sendBasePtr, recvBasePtr;

View File

@@ -23,9 +23,6 @@ __global__ void __launch_bounds__(1024, 1)
#else
) {
#endif
// This version of allreduce only works for single nodes
if (worldSize != nRanksPerNode) return;
#if defined(ENABLE_NPKIT)
extern __shared__ int4 NpkitSharedMem[];
NpKitEvent* event_buffer = (NpKitEvent*)((char*)NpkitSharedMem);
@@ -267,7 +264,7 @@ std::shared_ptr<void> AllreducePacket::initAllreduceContext(std::shared_ptr<Comm
const int nChannelsPerConnection = maxBlockNum_;
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_);
ctx->memorySemaphores = this->memorySemaphores_;
ctx->registeredMemories = this->registeredMemories_;
ctx->registeredMemories.pop_back(); // remove the local memory from previous context
@@ -313,4 +310,4 @@ std::shared_ptr<Algorithm> AllreducePacket::build() {
}
} // namespace collective
} // namespace mscclpp
} // namespace mscclpp

View File

@@ -199,7 +199,7 @@ std::shared_ptr<void> AllreduceRsAg::initAllreduceContext(std::shared_ptr<Commun
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_);
ctx->memorySemaphores = this->scratchSemaphores_;
ctx->registeredMemories = this->remoteScratchMemories_;

View File

@@ -183,7 +183,7 @@ std::shared_ptr<void> AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_pt
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_);
ctx->memorySemaphores = this->semaphores_;

View File

@@ -69,6 +69,25 @@ std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> setupMemoryS
return memorySemaphores;
}
int getCollectiveDomainNranksPerNode(std::shared_ptr<mscclpp::Communicator> comm,
const std::vector<mscclpp::Connection>& connections) {
const int worldSize = comm->bootstrap()->getNranks();
const int nRanksPerNode = comm->bootstrap()->getNranksPerNode();
if (worldSize <= nRanksPerNode) {
return nRanksPerNode;
}
const bool allPeersUseCudaIpc =
std::all_of(connections.begin(), connections.end(),
[](const auto& connection) { return connection.transport() == mscclpp::Transport::CudaIpc; });
return allPeersUseCudaIpc ? worldSize : nRanksPerNode;
}
int getCollectiveDomainNranksPerNode(std::shared_ptr<mscclpp::Communicator> comm) {
const int worldSize = comm->bootstrap()->getNranks();
const int nRanksPerNode = comm->bootstrap()->getNranksPerNode();
return worldSize > nRanksPerNode ? worldSize : nRanksPerNode;
}
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::MemoryChannel>> setupMemoryChannelDeviceHandles(
const std::vector<mscclpp::MemoryChannel>& memoryChannels) {
std::vector<mscclpp::DeviceHandle<mscclpp::MemoryChannel>> memoryChannelDeviceHandles;
@@ -153,4 +172,4 @@ std::shared_ptr<mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>> setupBaseMemo
} // namespace collective
} // namespace mscclpp
} // namespace mscclpp

View File

@@ -50,6 +50,8 @@ std::vector<MemoryChannel> setupMemoryChannels(
std::vector<Connection> setupConnections(std::shared_ptr<Communicator> comm);
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> setupMemorySemaphores(
std::shared_ptr<Communicator> comm, const std::vector<Connection>& connections, int nChannelsPerConnection);
int getCollectiveDomainNranksPerNode(std::shared_ptr<Communicator> comm, const std::vector<Connection>& connections);
int getCollectiveDomainNranksPerNode(std::shared_ptr<Communicator> comm);
std::shared_ptr<DeviceHandle<MemoryChannel>> setupMemoryChannelDeviceHandles(
const std::vector<MemoryChannel>& memoryChannels);
@@ -96,4 +98,4 @@ class AlgorithmCtx {
} // namespace collective
} // namespace mscclpp
#endif // MSCCLPP_EXT_COLLECTIVE_UTILS_HPP_
#endif // MSCCLPP_EXT_COLLECTIVE_UTILS_HPP_