mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 08:50:21 +00:00
Fold validateIpcDomainSpansWorld into getIpcDomainNranks
getIpcDomainNranks now performs the range / world-size / rank checks itself and throws on violation, so the separate validateIpcDomainSpansWorld helper is unnecessary. Update the 3 NVLS callsites (block_pipeline, warp_pipeline, nvls_zero_copy) to call getIpcDomainNranks directly. The non-NVLS callers also pick up the strict validation, which is fine because they are only invoked in single-host or multi-host MNNVL scenarios where worldSize == ipcDomainNranks (the NCCL adapter's multi-node path returns nullptr, falling back to NCCL/RCCL). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -178,7 +178,7 @@ struct NvlsBlockPipelineAdapter {
|
||||
|
||||
void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr<Communicator> comm) {
|
||||
nSwitchChannels_ = 8;
|
||||
ipcDomainNranks_ = validateIpcDomainSpansWorld(comm);
|
||||
ipcDomainNranks_ = getIpcDomainNranks(comm);
|
||||
// Block-pipeline device-side semaphore indices grow as 6 * ipcDomainNranks (see kernel).
|
||||
if (6 * ipcDomainNranks_ > NUM_SEMAPHORES) {
|
||||
throw Error("AllreduceNvlsBlockPipeline: ipcDomainNranks " + std::to_string(ipcDomainNranks_) +
|
||||
|
||||
@@ -142,7 +142,7 @@ struct NvlsWarpPipelineAdapter {
|
||||
|
||||
void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr<Communicator> comm) {
|
||||
nSwitchChannels_ = NUM_NVLS_CONNECTION;
|
||||
ipcDomainNranks_ = validateIpcDomainSpansWorld(comm);
|
||||
ipcDomainNranks_ = getIpcDomainNranks(comm);
|
||||
// The warp-pipeline kernel addresses 2 * nPeers entries per block in `memoryChannels`,
|
||||
// so per-peer base channel allocation must be at least `2 * nBlocks`. Default
|
||||
// nBlocks = 4 * ipcDomainNranks (see allreduceKernelFunc), so size accordingly.
|
||||
|
||||
@@ -105,7 +105,7 @@ void AllreduceNvls::initialize(std::shared_ptr<mscclpp::Communicator> comm) {
|
||||
MSCCLPP_CUDATHROW(cudaGetDeviceProperties(&deviceProp, device));
|
||||
computeCapabilityMajor_ = deviceProp.major;
|
||||
nSwitchChannels_ = 32;
|
||||
validateIpcDomainSpansWorld(comm);
|
||||
getIpcDomainNranks(comm);
|
||||
this->conns_ = setupConnections(comm);
|
||||
// setup semaphores
|
||||
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores =
|
||||
|
||||
@@ -75,14 +75,7 @@ std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> setupMemoryS
|
||||
|
||||
int getIpcDomainNranks(std::shared_ptr<mscclpp::Communicator> comm) {
|
||||
const int envValue = mscclpp::env()->ipcDomainNranks;
|
||||
if (envValue > 0) {
|
||||
return envValue;
|
||||
}
|
||||
return comm->bootstrap()->getNranksPerNode();
|
||||
}
|
||||
|
||||
int validateIpcDomainSpansWorld(std::shared_ptr<mscclpp::Communicator> comm) {
|
||||
const int ipcDomainNranks = getIpcDomainNranks(comm);
|
||||
const int ipcDomainNranks = (envValue > 0) ? envValue : comm->bootstrap()->getNranksPerNode();
|
||||
const int worldSize = comm->bootstrap()->getNranks();
|
||||
const int rank = comm->bootstrap()->getRank();
|
||||
if (ipcDomainNranks < 2 || ipcDomainNranks > MAX_IPC_DOMAIN_NRANKS) {
|
||||
|
||||
@@ -51,18 +51,11 @@ std::vector<Connection> setupConnections(std::shared_ptr<Communicator> comm);
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> setupMemorySemaphores(
|
||||
std::shared_ptr<Communicator> comm, const std::vector<Connection>& connections, int nChannelsPerConnection);
|
||||
|
||||
/// Number of ranks that participate in the same GPU-IPC-reachable peer group (e.g. a single host or
|
||||
/// a Multi-Node NVLink fabric, or an AMD XGMI domain). Returns the value of `MSCCLPP_IPC_DOMAIN_NRANKS`
|
||||
/// if set to a positive value; otherwise falls back to `bootstrap->getNranksPerNode()`. This is
|
||||
/// intentionally independent of `Bootstrap::getNranksPerNode()` so that algorithms can opt in to
|
||||
/// MNNVL-like behavior without changing the meaning of bootstrap-level APIs.
|
||||
int getIpcDomainNranks(std::shared_ptr<Communicator> comm);
|
||||
|
||||
/// Validates that the IPC domain spans the whole communicator and that the local rank fits within
|
||||
/// the supported `[2, MAX_IPC_DOMAIN_NRANKS]` range. Used by NVLS allreduce algorithms whose
|
||||
/// multicast group spans the whole communicator. Returns the validated `ipcDomainNranks`; throws
|
||||
/// Returns the IPC-reachable peer-group size, validated to span the whole communicator and
|
||||
/// to be within `[2, MAX_IPC_DOMAIN_NRANKS]`. Reads `MSCCLPP_IPC_DOMAIN_NRANKS` if set to a
|
||||
/// positive value; otherwise falls back to `bootstrap->getNranksPerNode()`. Throws
|
||||
/// `Error(InvalidUsage)` on violation.
|
||||
int validateIpcDomainSpansWorld(std::shared_ptr<Communicator> comm);
|
||||
int getIpcDomainNranks(std::shared_ptr<Communicator> comm);
|
||||
|
||||
std::shared_ptr<DeviceHandle<MemoryChannel>> setupMemoryChannelDeviceHandles(
|
||||
const std::vector<MemoryChannel>& memoryChannels);
|
||||
|
||||
Reference in New Issue
Block a user