From f0c6ac081f23425e3a91c1493a1f4c7f40909600 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 6 May 2026 21:49:48 +0000 Subject: [PATCH] Fold validateIpcDomainSpansWorld into getIpcDomainNranks getIpcDomainNranks now performs the range / world-size / rank checks itself and throws on violation, so the separate validateIpcDomainSpansWorld helper is unnecessary. Update the 3 NVLS callsites (block_pipeline, warp_pipeline, nvls_zero_copy) to call getIpcDomainNranks directly. The non-NVLS callers also pick up the strict validation, which is fine because they are only invoked in single-host or multi-host MNNVL scenarios where worldSize == ipcDomainNranks (the NCCL adapter's multi-node path returns nullptr, falling back to NCCL/RCCL). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../allreduce/allreduce_nvls_block_pipeline.cu | 2 +- .../allreduce/allreduce_nvls_warp_pipeline.cu | 2 +- .../allreduce/allreduce_nvls_zero_copy.cu | 2 +- src/ext/collectives/collective_utils.cc | 9 +-------- src/ext/collectives/include/collective_utils.hpp | 15 ++++----------- 5 files changed, 8 insertions(+), 22 deletions(-) diff --git a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu index 8c4a1e23..f5c0d2f8 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu @@ -178,7 +178,7 @@ struct NvlsBlockPipelineAdapter { void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr comm) { nSwitchChannels_ = 8; - ipcDomainNranks_ = validateIpcDomainSpansWorld(comm); + ipcDomainNranks_ = getIpcDomainNranks(comm); // Block-pipeline device-side semaphore indices grow as 6 * ipcDomainNranks (see kernel). if (6 * ipcDomainNranks_ > NUM_SEMAPHORES) { throw Error("AllreduceNvlsBlockPipeline: ipcDomainNranks " + std::to_string(ipcDomainNranks_) + diff --git a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu index 950c287b..02b899aa 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu @@ -142,7 +142,7 @@ struct NvlsWarpPipelineAdapter { void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr comm) { nSwitchChannels_ = NUM_NVLS_CONNECTION; - ipcDomainNranks_ = validateIpcDomainSpansWorld(comm); + ipcDomainNranks_ = getIpcDomainNranks(comm); // The warp-pipeline kernel addresses 2 * nPeers entries per block in `memoryChannels`, // so per-peer base channel allocation must be at least `2 * nBlocks`. Default // nBlocks = 4 * ipcDomainNranks (see allreduceKernelFunc), so size accordingly. diff --git a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu index 6ab0cd63..115a229a 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu @@ -105,7 +105,7 @@ void AllreduceNvls::initialize(std::shared_ptr comm) { MSCCLPP_CUDATHROW(cudaGetDeviceProperties(&deviceProp, device)); computeCapabilityMajor_ = deviceProp.major; nSwitchChannels_ = 32; - validateIpcDomainSpansWorld(comm); + getIpcDomainNranks(comm); this->conns_ = setupConnections(comm); // setup semaphores std::vector> memorySemaphores = diff --git a/src/ext/collectives/collective_utils.cc b/src/ext/collectives/collective_utils.cc index e4eb7142..6acfd7ce 100644 --- a/src/ext/collectives/collective_utils.cc +++ b/src/ext/collectives/collective_utils.cc @@ -75,14 +75,7 @@ std::vector> setupMemoryS int getIpcDomainNranks(std::shared_ptr comm) { const int envValue = mscclpp::env()->ipcDomainNranks; - if (envValue > 0) { - return envValue; - } - return comm->bootstrap()->getNranksPerNode(); -} - -int validateIpcDomainSpansWorld(std::shared_ptr comm) { - const int ipcDomainNranks = getIpcDomainNranks(comm); + const int ipcDomainNranks = (envValue > 0) ? envValue : comm->bootstrap()->getNranksPerNode(); const int worldSize = comm->bootstrap()->getNranks(); const int rank = comm->bootstrap()->getRank(); if (ipcDomainNranks < 2 || ipcDomainNranks > MAX_IPC_DOMAIN_NRANKS) { diff --git a/src/ext/collectives/include/collective_utils.hpp b/src/ext/collectives/include/collective_utils.hpp index 6b0c6ab4..280a6332 100644 --- a/src/ext/collectives/include/collective_utils.hpp +++ b/src/ext/collectives/include/collective_utils.hpp @@ -51,18 +51,11 @@ std::vector setupConnections(std::shared_ptr comm); std::vector> setupMemorySemaphores( std::shared_ptr comm, const std::vector& connections, int nChannelsPerConnection); -/// Number of ranks that participate in the same GPU-IPC-reachable peer group (e.g. a single host or -/// a Multi-Node NVLink fabric, or an AMD XGMI domain). Returns the value of `MSCCLPP_IPC_DOMAIN_NRANKS` -/// if set to a positive value; otherwise falls back to `bootstrap->getNranksPerNode()`. This is -/// intentionally independent of `Bootstrap::getNranksPerNode()` so that algorithms can opt in to -/// MNNVL-like behavior without changing the meaning of bootstrap-level APIs. -int getIpcDomainNranks(std::shared_ptr comm); - -/// Validates that the IPC domain spans the whole communicator and that the local rank fits within -/// the supported `[2, MAX_IPC_DOMAIN_NRANKS]` range. Used by NVLS allreduce algorithms whose -/// multicast group spans the whole communicator. Returns the validated `ipcDomainNranks`; throws +/// Returns the IPC-reachable peer-group size, validated to span the whole communicator and +/// to be within `[2, MAX_IPC_DOMAIN_NRANKS]`. Reads `MSCCLPP_IPC_DOMAIN_NRANKS` if set to a +/// positive value; otherwise falls back to `bootstrap->getNranksPerNode()`. Throws /// `Error(InvalidUsage)` on violation. -int validateIpcDomainSpansWorld(std::shared_ptr comm); +int getIpcDomainNranks(std::shared_ptr comm); std::shared_ptr> setupMemoryChannelDeviceHandles( const std::vector& memoryChannels);