mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Decouple IPC-domain hint from bootstrap nRanksPerNode
Replace MSCCLPP_MNNVL_NRANKS_PER_NODE (which overrode TcpBootstrap and silently changed getNranksPerNode() for every consumer) with a single algorithm-level helper getIpcDomainNranks(comm) backed by a new MSCCLPP_IPC_DOMAIN_NRANKS env. The neutral IPC name covers both NVLink/ MNNVL on NV and XGMI on AMD. Bootstrap is unchanged and continues to report physical-host detection. Collapse the two getCollectiveDomainNranksPerNode overloads into one canonical helper and route all six allreduce algos (packet, allpair_packet, nvls_packet, nvls_zero_copy, rsag, rsag_zero_copy) through it. Update the standalone tuning example to use the new env name; drop the undeclared MSCCLPP_ENABLE_MNNVL gate; fix multi_host_mnnvl detection now that nranks_per_node is no longer overridden by the bootstrap. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -13,6 +13,7 @@ import struct
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
|
||||
def _get_bootstrap_world_size():
|
||||
for name in ("WORLD_SIZE", "OMPI_COMM_WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS"):
|
||||
value = os.environ.get(name)
|
||||
@@ -22,13 +23,8 @@ def _get_bootstrap_world_size():
|
||||
|
||||
|
||||
_bootstrap_world_size = _get_bootstrap_world_size()
|
||||
if (
|
||||
_bootstrap_world_size
|
||||
and _bootstrap_world_size > 1
|
||||
and "MSCCLPP_MNNVL_NRANKS_PER_NODE" not in os.environ
|
||||
and os.environ.get("MSCCLPP_ENABLE_MNNVL", "1") != "0"
|
||||
):
|
||||
os.environ["MSCCLPP_MNNVL_NRANKS_PER_NODE"] = str(_bootstrap_world_size)
|
||||
if _bootstrap_world_size and _bootstrap_world_size > 1 and "MSCCLPP_IPC_DOMAIN_NRANKS" not in os.environ:
|
||||
os.environ["MSCCLPP_IPC_DOMAIN_NRANKS"] = str(_bootstrap_world_size)
|
||||
|
||||
import torch
|
||||
import mscclpp
|
||||
@@ -140,11 +136,10 @@ class CustomizedComm:
|
||||
self.rank = comm.my_rank
|
||||
self.world_size = comm.nranks
|
||||
self.nranks_per_node = comm.nranks_per_node
|
||||
self.mnnvl_domain = self.world_size > 1 and os.environ.get("MSCCLPP_MNNVL_NRANKS_PER_NODE") == str(
|
||||
self.world_size
|
||||
)
|
||||
nvlink_domain_nranks = int(os.environ.get("MSCCLPP_IPC_DOMAIN_NRANKS", "0"))
|
||||
self.mnnvl_domain = self.world_size > 1 and nvlink_domain_nranks >= self.world_size
|
||||
self.multi_node = self.world_size > self.nranks_per_node and not self.mnnvl_domain
|
||||
self.multi_host_mnnvl = self.mnnvl_domain and self.world_size > 1
|
||||
self.multi_host_mnnvl = self.mnnvl_domain and self.world_size > self.nranks_per_node
|
||||
self.symmetric_memory = symmetric_memory
|
||||
self._nvls = mscclpp.is_nvls_supported()
|
||||
|
||||
|
||||
@@ -119,11 +119,11 @@ class Env {
|
||||
/// Default is 0. Used when `EndpointConfig::Ib::gidIndex` is -1 (unspecified).
|
||||
const int ibGidIndex;
|
||||
|
||||
/// Env name: `MSCCLPP_MNNVL_NRANKS_PER_NODE`. Overrides the NVLink-domain size reported by the bootstrap.
|
||||
/// This is intended for Multi-Node NVLink (MNNVL) deployments where a single CUDA IPC / NVLS domain spans
|
||||
/// multiple hosts and should be treated as one collective peer group.
|
||||
/// If unset or non-positive, the bootstrap falls back to physical-host-based detection.
|
||||
const int mnnvlNranksPerNode;
|
||||
/// Env name: `MSCCLPP_IPC_DOMAIN_NRANKS`. Number of ranks that share a single GPU-IPC-reachable peer
|
||||
/// group (e.g. a Multi-Node NVLink fabric such as GB200 NVL72, or an AMD XGMI domain). This hint is
|
||||
/// consumed only by the collective algorithms; it does not affect `Bootstrap::getNranksPerNode()` or
|
||||
/// any other layer. If unset or non-positive, algorithms fall back to `bootstrap->getNranksPerNode()`.
|
||||
const int ipcDomainNranks;
|
||||
|
||||
private:
|
||||
Env();
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
#include <cstring>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/env.hpp>
|
||||
#include <mscclpp/errors.hpp>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
@@ -434,10 +433,6 @@ void TcpBootstrap::Impl::establishConnections(int64_t timeoutSec) {
|
||||
|
||||
int TcpBootstrap::Impl::getNranksPerNode() {
|
||||
if (nRanksPerNode_ > 0) return nRanksPerNode_;
|
||||
if (env()->mnnvlNranksPerNode > 0) {
|
||||
nRanksPerNode_ = env()->mnnvlNranksPerNode;
|
||||
return nRanksPerNode_;
|
||||
}
|
||||
int nRanksPerNode = 0;
|
||||
bool useIpv4 = peerCommAddresses_[rank_].sa.sa_family == AF_INET;
|
||||
for (int i = 0; i < nRanks_; i++) {
|
||||
|
||||
@@ -68,7 +68,7 @@ Env::Env()
|
||||
forceDisableNvls(readEnv<bool>("MSCCLPP_FORCE_DISABLE_NVLS", false)),
|
||||
forceDisableGdr(readEnv<bool>("MSCCLPP_FORCE_DISABLE_GDR", false)),
|
||||
ibGidIndex(readEnv<int>("MSCCLPP_IB_GID_INDEX", 0)),
|
||||
mnnvlNranksPerNode(readEnv<int>("MSCCLPP_MNNVL_NRANKS_PER_NODE", 0)) {}
|
||||
ipcDomainNranks(readEnv<int>("MSCCLPP_IPC_DOMAIN_NRANKS", 0)) {}
|
||||
|
||||
std::shared_ptr<Env> env() {
|
||||
static std::shared_ptr<Env> globalEnv = std::shared_ptr<Env>(new Env());
|
||||
@@ -98,7 +98,7 @@ std::shared_ptr<Env> env() {
|
||||
logEnv("MSCCLPP_FORCE_DISABLE_NVLS", globalEnv->forceDisableNvls);
|
||||
logEnv("MSCCLPP_FORCE_DISABLE_GDR", globalEnv->forceDisableGdr);
|
||||
logEnv("MSCCLPP_IB_GID_INDEX", globalEnv->ibGidIndex);
|
||||
logEnv("MSCCLPP_MNNVL_NRANKS_PER_NODE", globalEnv->mnnvlNranksPerNode);
|
||||
logEnv("MSCCLPP_IPC_DOMAIN_NRANKS", globalEnv->ipcDomainNranks);
|
||||
}
|
||||
return globalEnv;
|
||||
}
|
||||
|
||||
@@ -140,7 +140,7 @@ std::shared_ptr<void> AllreduceAllpairPacket::initAllreduceContext(std::shared_p
|
||||
const int nChannelsPerConnection = maxBlockNum_;
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_);
|
||||
ctx->nRanksPerNode = getIpcDomainNranks(comm);
|
||||
ctx->memorySemaphores = this->memorySemaphores_;
|
||||
ctx->registeredMemories = this->registeredMemories_;
|
||||
ctx->registeredMemories.pop_back(); // remove the local memory from previous context
|
||||
|
||||
@@ -94,7 +94,7 @@ std::shared_ptr<void> AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr<
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm);
|
||||
ctx->nRanksPerNode = getIpcDomainNranks(comm);
|
||||
|
||||
// setup channels
|
||||
ctx->switchChannels = this->switchChannels_;
|
||||
|
||||
@@ -183,7 +183,7 @@ std::shared_ptr<void> AllreduceNvls::initAllreduceContext(std::shared_ptr<mscclp
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm);
|
||||
ctx->nRanksPerNode = getIpcDomainNranks(comm);
|
||||
|
||||
size_t sendBytes, recvBytes;
|
||||
CUdeviceptr sendBasePtr, recvBasePtr;
|
||||
|
||||
@@ -264,7 +264,7 @@ std::shared_ptr<void> AllreducePacket::initAllreduceContext(std::shared_ptr<Comm
|
||||
const int nChannelsPerConnection = maxBlockNum_;
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_);
|
||||
ctx->nRanksPerNode = getIpcDomainNranks(comm);
|
||||
ctx->memorySemaphores = this->memorySemaphores_;
|
||||
ctx->registeredMemories = this->registeredMemories_;
|
||||
ctx->registeredMemories.pop_back(); // remove the local memory from previous context
|
||||
|
||||
@@ -203,7 +203,7 @@ std::shared_ptr<void> AllreduceRsAg::initAllreduceContext(std::shared_ptr<Commun
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_);
|
||||
ctx->nRanksPerNode = getIpcDomainNranks(comm);
|
||||
|
||||
ctx->memorySemaphores = this->scratchSemaphores_;
|
||||
ctx->registeredMemories = this->remoteScratchMemories_;
|
||||
|
||||
@@ -169,7 +169,7 @@ CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr<void
|
||||
}
|
||||
|
||||
AlgorithmCtxKey AllreduceRsAgZeroCopy::generateAllreduceContextKey(const void* inputBuffer, void* outputBuffer,
|
||||
size_t size, DataType, bool symmetricMemory) {
|
||||
size_t size, DataType, bool symmetricMemory) {
|
||||
// For non-symmetric algorithms, we use both input and output buffer pointers in the key.
|
||||
if (symmetricMemory) {
|
||||
size_t inputBytes, outputBytes;
|
||||
@@ -186,7 +186,7 @@ std::shared_ptr<void> AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_pt
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = getCollectiveDomainNranksPerNode(comm, this->conns_);
|
||||
ctx->nRanksPerNode = getIpcDomainNranks(comm);
|
||||
|
||||
ctx->memorySemaphores = this->semaphores_;
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <algorithm>
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/env.hpp>
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
#include <mscclpp/switch_channel.hpp>
|
||||
|
||||
@@ -69,23 +70,12 @@ std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> setupMemoryS
|
||||
return memorySemaphores;
|
||||
}
|
||||
|
||||
int getCollectiveDomainNranksPerNode(std::shared_ptr<mscclpp::Communicator> comm,
|
||||
const std::vector<mscclpp::Connection>& connections) {
|
||||
const int worldSize = comm->bootstrap()->getNranks();
|
||||
const int nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
if (worldSize <= nRanksPerNode) {
|
||||
return nRanksPerNode;
|
||||
int getIpcDomainNranks(std::shared_ptr<mscclpp::Communicator> comm) {
|
||||
const int envValue = mscclpp::env()->ipcDomainNranks;
|
||||
if (envValue > 0) {
|
||||
return envValue;
|
||||
}
|
||||
const bool allPeersUseCudaIpc =
|
||||
std::all_of(connections.begin(), connections.end(),
|
||||
[](const auto& connection) { return connection.transport() == mscclpp::Transport::CudaIpc; });
|
||||
return allPeersUseCudaIpc ? worldSize : nRanksPerNode;
|
||||
}
|
||||
|
||||
int getCollectiveDomainNranksPerNode(std::shared_ptr<mscclpp::Communicator> comm) {
|
||||
const int worldSize = comm->bootstrap()->getNranks();
|
||||
const int nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
return worldSize > nRanksPerNode ? worldSize : nRanksPerNode;
|
||||
return comm->bootstrap()->getNranksPerNode();
|
||||
}
|
||||
|
||||
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::MemoryChannel>> setupMemoryChannelDeviceHandles(
|
||||
|
||||
@@ -50,8 +50,13 @@ std::vector<MemoryChannel> setupMemoryChannels(
|
||||
std::vector<Connection> setupConnections(std::shared_ptr<Communicator> comm);
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> setupMemorySemaphores(
|
||||
std::shared_ptr<Communicator> comm, const std::vector<Connection>& connections, int nChannelsPerConnection);
|
||||
int getCollectiveDomainNranksPerNode(std::shared_ptr<Communicator> comm, const std::vector<Connection>& connections);
|
||||
int getCollectiveDomainNranksPerNode(std::shared_ptr<Communicator> comm);
|
||||
|
||||
/// Number of ranks that participate in the same GPU-IPC-reachable peer group (e.g. a single host or
|
||||
/// a Multi-Node NVLink fabric, or an AMD XGMI domain). Returns the value of `MSCCLPP_IPC_DOMAIN_NRANKS`
|
||||
/// if set to a positive value; otherwise falls back to `bootstrap->getNranksPerNode()`. This is
|
||||
/// intentionally independent of `nRanksPerNode` so that algorithms can opt in to MNNVL-like behavior
|
||||
/// without changing the meaning of bootstrap-level APIs.
|
||||
int getIpcDomainNranks(std::shared_ptr<Communicator> comm);
|
||||
|
||||
std::shared_ptr<DeviceHandle<MemoryChannel>> setupMemoryChannelDeviceHandles(
|
||||
const std::vector<MemoryChannel>& memoryChannels);
|
||||
|
||||
Reference in New Issue
Block a user