mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-25 23:34:49 +00:00
temp solution
This commit is contained in:
@@ -97,8 +97,6 @@ class CustomizedComm:
|
||||
self.rank = comm.my_rank
|
||||
self.world_size = comm.nranks
|
||||
self.nranks_per_node = comm.nranks_per_node
|
||||
if comm.communicator.get_ipc_domain_n_ranks() == 0 and self.world_size > 1:
|
||||
comm.communicator.set_ipc_domain_n_ranks(self.world_size)
|
||||
self.ipc_domain_n_ranks = comm.communicator.get_ipc_domain_n_ranks()
|
||||
self.multi_host_mnnvl = self.ipc_domain_n_ranks >= self.world_size and self.world_size > self.nranks_per_node
|
||||
self.symmetric_memory = symmetric_memory
|
||||
@@ -433,8 +431,8 @@ def benchmark_allgather(comm: CustomizedComm, dtype=torch.float16, n_warmup=10,
|
||||
# -- Bootstrap & main ---------------------------------------------------------
|
||||
|
||||
|
||||
def init_dist() -> mscclpp.CommGroup:
|
||||
return mscclpp.CommGroup(mpi_comm=MPI.COMM_WORLD)
|
||||
def init_dist(ipc_domain_n_ranks: int = 0) -> mscclpp.CommGroup:
|
||||
return mscclpp.CommGroup(mpi_comm=MPI.COMM_WORLD, ipc_domain_n_ranks=ipc_domain_n_ranks)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -447,8 +445,9 @@ def main():
|
||||
accum_str = os.environ.get("ACCUM_DTYPE")
|
||||
accum_dtype = accum_map.get(accum_str) if accum_str else None
|
||||
symmetric_memory = os.environ.get("SYMMETRIC_MEMORY", "1") == "1"
|
||||
ipc_domain_n_ranks = int(os.environ.get("IPC_DOMAIN_NRANKS", "0"))
|
||||
|
||||
comm_group = init_dist()
|
||||
comm_group = init_dist(ipc_domain_n_ranks=ipc_domain_n_ranks)
|
||||
cc = CustomizedComm(comm_group, symmetric_memory=symmetric_memory)
|
||||
|
||||
print(
|
||||
|
||||
@@ -829,8 +829,8 @@ class Communicator {
|
||||
/// @param ipcDomainNranks Number of ranks in the communicator's IPC domain, or 0 to use the default.
|
||||
void setIpcDomainNranks(int ipcDomainNranks);
|
||||
|
||||
/// Get the IPC-domain rank count override for this communicator.
|
||||
/// @return The configured IPC-domain rank count, or 0 if the communicator uses `bootstrap()->getNranksPerNode()`.
|
||||
/// Get the effective IPC-domain rank count for this communicator.
|
||||
/// @return The configured IPC-domain rank count, or `bootstrap()->getNranksPerNode()` if no override is set.
|
||||
int getIpcDomainNranks() const;
|
||||
|
||||
/// Register a region of GPU memory for use in this communicator's context.
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
|
||||
#include "communicator.hpp"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "api.h"
|
||||
|
||||
namespace mscclpp {
|
||||
@@ -88,7 +86,9 @@ MSCCLPP_API_CPP void Communicator::setIpcDomainNranks(int ipcDomainNranks) {
|
||||
pimpl_->ipcDomainNranks_ = ipcDomainNranks;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Communicator::getIpcDomainNranks() const { return pimpl_->ipcDomainNranks_; }
|
||||
MSCCLPP_API_CPP int Communicator::getIpcDomainNranks() const {
|
||||
return (pimpl_->ipcDomainNranks_ > 0) ? pimpl_->ipcDomainNranks_ : pimpl_->bootstrap_->getNranksPerNode();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) {
|
||||
return context()->registerMemory(ptr, size, transports);
|
||||
|
||||
@@ -148,7 +148,7 @@ std::shared_ptr<void> AllgatherFullmesh::initAllgatherContext(std::shared_ptr<Co
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerNode();
|
||||
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
|
||||
|
||||
// setup semaphores
|
||||
ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection);
|
||||
|
||||
@@ -159,7 +159,7 @@ std::shared_ptr<void> AllgatherFullmesh2::initAllgatherContext(std::shared_ptr<m
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerNode();
|
||||
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
|
||||
|
||||
// setup semaphores
|
||||
ctx->memorySemaphores = this->memorySemaphores_;
|
||||
|
||||
@@ -140,7 +140,7 @@ std::shared_ptr<void> AllreduceAllpairPacket::initAllreduceContext(std::shared_p
|
||||
const int nChannelsPerConnection = maxBlockNum_;
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
|
||||
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
|
||||
ctx->memorySemaphores = this->memorySemaphores_;
|
||||
ctx->registeredMemories = this->registeredMemories_;
|
||||
ctx->registeredMemories.pop_back(); // remove the local memory from previous context
|
||||
|
||||
@@ -250,7 +250,7 @@ std::shared_ptr<void> AllreduceFullmesh::initAllreduceContext(std::shared_ptr<Co
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerNode();
|
||||
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
|
||||
|
||||
// setup semaphores
|
||||
ctx->memorySemaphores = this->outputSemaphores_;
|
||||
|
||||
@@ -177,7 +177,7 @@ struct NvlsBlockPipelineAdapter {
|
||||
|
||||
void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr<Communicator> comm) {
|
||||
nSwitchChannels_ = 8;
|
||||
ipcDomainNranks_ = getIpcDomainNranks(comm);
|
||||
ipcDomainNranks_ = comm->getIpcDomainNranks();
|
||||
// Per-peer channel allocation must hold up to 4 * ipcDomainNranks entries (see kernel).
|
||||
nBaseChannels_ = std::max(64, 4 * ipcDomainNranks_);
|
||||
this->conns_ = setupConnections(comm);
|
||||
@@ -224,7 +224,7 @@ std::shared_ptr<void> AllreduceNvlsBlockPipeline::initAllreduceContext(std::shar
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
|
||||
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
|
||||
|
||||
// setup channels
|
||||
ctx->switchChannels =
|
||||
|
||||
@@ -95,7 +95,7 @@ std::shared_ptr<void> AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr<
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
|
||||
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
|
||||
|
||||
// setup channels
|
||||
ctx->switchChannels = this->switchChannels_;
|
||||
|
||||
@@ -141,7 +141,7 @@ struct NvlsWarpPipelineAdapter {
|
||||
|
||||
void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr<Communicator> comm) {
|
||||
nSwitchChannels_ = NUM_NVLS_CONNECTION;
|
||||
ipcDomainNranks_ = getIpcDomainNranks(comm);
|
||||
ipcDomainNranks_ = comm->getIpcDomainNranks();
|
||||
// Per-peer channel allocation must hold 2 * nBlocks entries; default nBlocks = 4 * ipcDomainNranks.
|
||||
nBaseChannels_ = std::max(64, 8 * ipcDomainNranks_);
|
||||
this->conns_ = setupConnections(comm);
|
||||
@@ -188,7 +188,7 @@ std::shared_ptr<void> AllreduceNvlsWarpPipeline::initAllreduceContext(std::share
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
|
||||
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
|
||||
|
||||
// setup channels
|
||||
ctx->switchChannels =
|
||||
|
||||
@@ -99,7 +99,6 @@ void AllreduceNvls::initialize(std::shared_ptr<mscclpp::Communicator> comm) {
|
||||
MSCCLPP_CUDATHROW(cudaGetDeviceProperties(&deviceProp, device));
|
||||
computeCapabilityMajor_ = deviceProp.major;
|
||||
nSwitchChannels_ = 32;
|
||||
getIpcDomainNranks(comm);
|
||||
this->conns_ = setupConnections(comm);
|
||||
// setup semaphores
|
||||
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores =
|
||||
@@ -177,7 +176,7 @@ std::shared_ptr<void> AllreduceNvls::initAllreduceContext(std::shared_ptr<mscclp
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
|
||||
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
|
||||
|
||||
size_t sendBytes, recvBytes;
|
||||
CUdeviceptr sendBasePtr, recvBasePtr;
|
||||
|
||||
@@ -263,7 +263,7 @@ std::shared_ptr<void> AllreducePacket::initAllreduceContext(std::shared_ptr<Comm
|
||||
const int nChannelsPerConnection = maxBlockNum_;
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
|
||||
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
|
||||
ctx->memorySemaphores = this->memorySemaphores_;
|
||||
ctx->registeredMemories = this->registeredMemories_;
|
||||
ctx->registeredMemories.pop_back(); // remove the local memory from previous context
|
||||
|
||||
@@ -203,7 +203,7 @@ std::shared_ptr<void> AllreduceRsAg::initAllreduceContext(std::shared_ptr<Commun
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
|
||||
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
|
||||
|
||||
ctx->memorySemaphores = this->scratchSemaphores_;
|
||||
ctx->registeredMemories = this->remoteScratchMemories_;
|
||||
|
||||
@@ -306,7 +306,7 @@ std::shared_ptr<void> AllreduceRsAgPipeline::initAllreduceContext(std::shared_pt
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->ipcDomainNranks = comm->bootstrap()->getNranksPerNode();
|
||||
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
|
||||
|
||||
ctx->memorySemaphores = this->scratchSemaphores_;
|
||||
ctx->registeredMemories = this->remoteScratchMemories_;
|
||||
|
||||
@@ -200,7 +200,7 @@ std::shared_ptr<void> AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_pt
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->ipcDomainNranks = getIpcDomainNranks(comm);
|
||||
ctx->ipcDomainNranks = comm->getIpcDomainNranks();
|
||||
|
||||
ctx->memorySemaphores = this->semaphores_;
|
||||
|
||||
|
||||
@@ -72,26 +72,6 @@ std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> setupMemoryS
|
||||
return memorySemaphores;
|
||||
}
|
||||
|
||||
int getIpcDomainNranks(std::shared_ptr<Communicator> comm) {
|
||||
const int commValue = comm->getIpcDomainNranks();
|
||||
const int ipcDomainNranks = (commValue > 0) ? commValue : comm->bootstrap()->getNranksPerNode();
|
||||
const int worldSize = comm->bootstrap()->getNranks();
|
||||
const int rank = comm->bootstrap()->getRank();
|
||||
if (ipcDomainNranks < 2 || ipcDomainNranks > MAX_IPC_DOMAIN_NRANKS) {
|
||||
THROW(LogSubsys::ALGO, Error, ErrorCode::InvalidUsage, "ipcDomainNranks ", ipcDomainNranks,
|
||||
" is out of supported range [2, ", MAX_IPC_DOMAIN_NRANKS, "]");
|
||||
}
|
||||
if (worldSize != ipcDomainNranks) {
|
||||
THROW(LogSubsys::ALGO, Error, ErrorCode::InvalidUsage,
|
||||
"requires worldSize == ipcDomainNranks (got worldSize=", worldSize, ", ipcDomainNranks=", ipcDomainNranks,
|
||||
")");
|
||||
}
|
||||
if (rank < 0 || rank >= ipcDomainNranks) {
|
||||
THROW(LogSubsys::ALGO, Error, ErrorCode::InvalidUsage, "rank ", rank, " out of [0, ", ipcDomainNranks, ")");
|
||||
}
|
||||
return ipcDomainNranks;
|
||||
}
|
||||
|
||||
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::MemoryChannel>> setupMemoryChannelDeviceHandles(
|
||||
const std::vector<mscclpp::MemoryChannel>& memoryChannels) {
|
||||
std::vector<mscclpp::DeviceHandle<mscclpp::MemoryChannel>> memoryChannelDeviceHandles;
|
||||
|
||||
@@ -51,11 +51,6 @@ std::vector<Connection> setupConnections(std::shared_ptr<Communicator> comm);
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> setupMemorySemaphores(
|
||||
std::shared_ptr<Communicator> comm, const std::vector<Connection>& connections, int nChannelsPerConnection);
|
||||
|
||||
/// Returns the IPC-reachable peer-group size, validated to span the whole communicator and
|
||||
/// to be within `[2, MAX_IPC_DOMAIN_NRANKS]`. Reads the communicator's IPC-domain override
|
||||
/// if set; otherwise falls back to `bootstrap->getNranksPerNode()`. Throws `Error(InvalidUsage)` on violation.
|
||||
int getIpcDomainNranks(std::shared_ptr<Communicator> comm);
|
||||
|
||||
std::shared_ptr<DeviceHandle<MemoryChannel>> setupMemoryChannelDeviceHandles(
|
||||
const std::vector<MemoryChannel>& memoryChannels);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user