temp solution

This commit is contained in:
Binyang Li
2026-05-15 23:15:40 +00:00
parent dbebde2b58
commit 93b43547cc
17 changed files with 23 additions and 50 deletions

View File

@@ -97,8 +97,6 @@ class CustomizedComm:
self.rank = comm.my_rank
self.world_size = comm.nranks
self.nranks_per_node = comm.nranks_per_node
if comm.communicator.get_ipc_domain_n_ranks() == 0 and self.world_size > 1:
comm.communicator.set_ipc_domain_n_ranks(self.world_size)
self.ipc_domain_n_ranks = comm.communicator.get_ipc_domain_n_ranks()
self.multi_host_mnnvl = self.ipc_domain_n_ranks >= self.world_size and self.world_size > self.nranks_per_node
self.symmetric_memory = symmetric_memory
@@ -433,8 +431,8 @@ def benchmark_allgather(comm: CustomizedComm, dtype=torch.float16, n_warmup=10,
# -- Bootstrap & main ---------------------------------------------------------
def init_dist() -> mscclpp.CommGroup:
return mscclpp.CommGroup(mpi_comm=MPI.COMM_WORLD)
def init_dist(ipc_domain_n_ranks: int = 0) -> mscclpp.CommGroup:
return mscclpp.CommGroup(mpi_comm=MPI.COMM_WORLD, ipc_domain_n_ranks=ipc_domain_n_ranks)
def main():
@@ -447,8 +445,9 @@ def main():
accum_str = os.environ.get("ACCUM_DTYPE")
accum_dtype = accum_map.get(accum_str) if accum_str else None
symmetric_memory = os.environ.get("SYMMETRIC_MEMORY", "1") == "1"
ipc_domain_n_ranks = int(os.environ.get("IPC_DOMAIN_NRANKS", "0"))
comm_group = init_dist()
comm_group = init_dist(ipc_domain_n_ranks=ipc_domain_n_ranks)
cc = CustomizedComm(comm_group, symmetric_memory=symmetric_memory)
print(

View File

@@ -829,8 +829,8 @@ class Communicator {
/// @param ipcDomainNranks Number of ranks in the communicator's IPC domain, or 0 to use the default.
void setIpcDomainNranks(int ipcDomainNranks);
/// Get the IPC-domain rank count override for this communicator.
/// @return The configured IPC-domain rank count, or 0 if the communicator uses `bootstrap()->getNranksPerNode()`.
/// Get the effective IPC-domain rank count for this communicator.
/// @return The configured IPC-domain rank count, or `bootstrap()->getNranksPerNode()` if no override is set.
int getIpcDomainNranks() const;
/// Register a region of GPU memory for use in this communicator's context.

View File

@@ -3,8 +3,6 @@
#include "communicator.hpp"
#include <utility>
#include "api.h"
namespace mscclpp {
@@ -88,7 +86,9 @@ MSCCLPP_API_CPP void Communicator::setIpcDomainNranks(int ipcDomainNranks) {
pimpl_->ipcDomainNranks_ = ipcDomainNranks;
}
MSCCLPP_API_CPP int Communicator::getIpcDomainNranks() const { return pimpl_->ipcDomainNranks_; }
MSCCLPP_API_CPP int Communicator::getIpcDomainNranks() const {
return (pimpl_->ipcDomainNranks_ > 0) ? pimpl_->ipcDomainNranks_ : pimpl_->bootstrap_->getNranksPerNode();
}
MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) {
return context()->registerMemory(ptr, size, transports);

View File

@@ -148,7 +148,7 @@ std::shared_ptr<void> AllgatherFullmesh::initAllgatherContext(std::shared_ptr<Co
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerNode();
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
// setup semaphores
ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection);

View File

@@ -159,7 +159,7 @@ std::shared_ptr<void> AllgatherFullmesh2::initAllgatherContext(std::shared_ptr<m
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerNode();
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
// setup semaphores
ctx->memorySemaphores = this->memorySemaphores_;

View File

@@ -140,7 +140,7 @@ std::shared_ptr<void> AllreduceAllpairPacket::initAllreduceContext(std::shared_p
const int nChannelsPerConnection = maxBlockNum_;
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
ctx->memorySemaphores = this->memorySemaphores_;
ctx->registeredMemories = this->registeredMemories_;
ctx->registeredMemories.pop_back(); // remove the local memory from previous context

View File

@@ -250,7 +250,7 @@ std::shared_ptr<void> AllreduceFullmesh::initAllreduceContext(std::shared_ptr<Co
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerNode();
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
// setup semaphores
ctx->memorySemaphores = this->outputSemaphores_;

View File

@@ -177,7 +177,7 @@ struct NvlsBlockPipelineAdapter {
void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr<Communicator> comm) {
nSwitchChannels_ = 8;
ipcDomainNranks_ = getIpcDomainNranks(comm);
ipcDomainNranks_ = comm->getIpcDomainNranks();
// Per-peer channel allocation must hold up to 4 * ipcDomainNranks entries (see kernel).
nBaseChannels_ = std::max(64, 4 * ipcDomainNranks_);
this->conns_ = setupConnections(comm);
@@ -224,7 +224,7 @@ std::shared_ptr<void> AllreduceNvlsBlockPipeline::initAllreduceContext(std::shar
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
// setup channels
ctx->switchChannels =

View File

@@ -95,7 +95,7 @@ std::shared_ptr<void> AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr<
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
// setup channels
ctx->switchChannels = this->switchChannels_;

View File

@@ -141,7 +141,7 @@ struct NvlsWarpPipelineAdapter {
void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr<Communicator> comm) {
nSwitchChannels_ = NUM_NVLS_CONNECTION;
ipcDomainNranks_ = getIpcDomainNranks(comm);
ipcDomainNranks_ = comm->getIpcDomainNranks();
// Per-peer channel allocation must hold 2 * nBlocks entries; default nBlocks = 4 * ipcDomainNranks.
nBaseChannels_ = std::max(64, 8 * ipcDomainNranks_);
this->conns_ = setupConnections(comm);
@@ -188,7 +188,7 @@ std::shared_ptr<void> AllreduceNvlsWarpPipeline::initAllreduceContext(std::share
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
// setup channels
ctx->switchChannels =

View File

@@ -99,7 +99,6 @@ void AllreduceNvls::initialize(std::shared_ptr<mscclpp::Communicator> comm) {
MSCCLPP_CUDATHROW(cudaGetDeviceProperties(&deviceProp, device));
computeCapabilityMajor_ = deviceProp.major;
nSwitchChannels_ = 32;
getIpcDomainNranks(comm);
this->conns_ = setupConnections(comm);
// setup semaphores
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores =
@@ -177,7 +176,7 @@ std::shared_ptr<void> AllreduceNvls::initAllreduceContext(std::shared_ptr<mscclp
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
size_t sendBytes, recvBytes;
CUdeviceptr sendBasePtr, recvBasePtr;

View File

@@ -263,7 +263,7 @@ std::shared_ptr<void> AllreducePacket::initAllreduceContext(std::shared_ptr<Comm
const int nChannelsPerConnection = maxBlockNum_;
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
ctx->memorySemaphores = this->memorySemaphores_;
ctx->registeredMemories = this->registeredMemories_;
ctx->registeredMemories.pop_back(); // remove the local memory from previous context

View File

@@ -203,7 +203,7 @@ std::shared_ptr<void> AllreduceRsAg::initAllreduceContext(std::shared_ptr<Commun
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
ctx->memorySemaphores = this->scratchSemaphores_;
ctx->registeredMemories = this->remoteScratchMemories_;

View File

@@ -306,7 +306,7 @@ std::shared_ptr<void> AllreduceRsAgPipeline::initAllreduceContext(std::shared_pt
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerNode();
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
ctx->memorySemaphores = this->scratchSemaphores_;
ctx->registeredMemories = this->remoteScratchMemories_;

View File

@@ -200,7 +200,7 @@ std::shared_ptr<void> AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_pt
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
ctx->memorySemaphores = this->semaphores_;

View File

@@ -72,26 +72,6 @@ std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> setupMemoryS
return memorySemaphores;
}
int getIpcDomainNranks(std::shared_ptr<Communicator> comm) {
const int commValue = comm->getIpcDomainNranks();
const int ipcDomainNranks = (commValue > 0) ? commValue : comm->bootstrap()->getNranksPerNode();
const int worldSize = comm->bootstrap()->getNranks();
const int rank = comm->bootstrap()->getRank();
if (ipcDomainNranks < 2 || ipcDomainNranks > MAX_IPC_DOMAIN_NRANKS) {
THROW(LogSubsys::ALGO, Error, ErrorCode::InvalidUsage, "ipcDomainNranks ", ipcDomainNranks,
" is out of supported range [2, ", MAX_IPC_DOMAIN_NRANKS, "]");
}
if (worldSize != ipcDomainNranks) {
THROW(LogSubsys::ALGO, Error, ErrorCode::InvalidUsage,
"requires worldSize == ipcDomainNranks (got worldSize=", worldSize, ", ipcDomainNranks=", ipcDomainNranks,
")");
}
if (rank < 0 || rank >= ipcDomainNranks) {
THROW(LogSubsys::ALGO, Error, ErrorCode::InvalidUsage, "rank ", rank, " out of [0, ", ipcDomainNranks, ")");
}
return ipcDomainNranks;
}
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::MemoryChannel>> setupMemoryChannelDeviceHandles(
const std::vector<mscclpp::MemoryChannel>& memoryChannels) {
std::vector<mscclpp::DeviceHandle<mscclpp::MemoryChannel>> memoryChannelDeviceHandles;

View File

@@ -51,11 +51,6 @@ std::vector<Connection> setupConnections(std::shared_ptr<Communicator> comm);
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> setupMemorySemaphores(
std::shared_ptr<Communicator> comm, const std::vector<Connection>& connections, int nChannelsPerConnection);
/// Returns the IPC-reachable peer-group size, validated to span the whole communicator and
/// to be within `[2, MAX_IPC_DOMAIN_NRANKS]`. Reads the communicator's IPC-domain override
/// if set; otherwise falls back to `bootstrap->getNranksPerNode()`. Throws `Error(InvalidUsage)` on violation.
int getIpcDomainNranks(std::shared_ptr<Communicator> comm);
std::shared_ptr<DeviceHandle<MemoryChannel>> setupMemoryChannelDeviceHandles(
const std::vector<MemoryChannel>& memoryChannels);