Drop non-MNNVL multi_node regime from torch-integration example

The example is now MNNVL-only: a run is either single-host (everything
fits in one node) or multi-host MNNVL (one cross-host NVLink domain).
Plain multi-node-without-MNNVL had its own algorithm branch that this
example will never exercise, so remove the multi_node flag and the
intermediate mnnvl_domain variable.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-05-06 19:00:22 +00:00
parent 9aeeaf0f12
commit 905b23d9a8

View File

@@ -83,9 +83,7 @@ class CustomizedComm:
self.world_size = comm.nranks
self.nranks_per_node = comm.nranks_per_node
nvlink_domain_nranks = int(os.environ.get("MSCCLPP_IPC_DOMAIN_NRANKS", "0"))
self.mnnvl_domain = self.world_size > 1 and nvlink_domain_nranks >= self.world_size
self.multi_node = self.world_size > self.nranks_per_node and not self.mnnvl_domain
self.multi_host_mnnvl = self.mnnvl_domain and self.world_size > self.nranks_per_node
self.multi_host_mnnvl = nvlink_domain_nranks >= self.world_size and self.world_size > self.nranks_per_node
self.symmetric_memory = symmetric_memory
self._nvls = mscclpp.is_nvls_supported()
@@ -108,7 +106,7 @@ class CustomizedComm:
pkt = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and pkt:
return (pkt, 0, 0)
if self.multi_node or self.multi_host_mnnvl:
if self.multi_host_mnnvl:
rsag = self._algo("allreduce", "default_allreduce_rsag")
if rsag:
return (rsag, 0, 0)
@@ -194,18 +192,6 @@ class CustomizedComm:
if a:
out.append(a)
return out
if self.multi_node:
a = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_packet")
if a:
out.append(a)
if size >= 512 << 10:
a = self._algo("allreduce", "default_allreduce_rsag")
if a:
out.append(a)
return out
if size <= 4 << 20:
a = self._algo("allreduce", "default_allreduce_packet")
if a:
@@ -230,7 +216,7 @@ class CustomizedComm:
return out
def _ag_candidates(self):
if self.multi_node or self.multi_host_mnnvl:
if self.multi_host_mnnvl:
return []
a = self._algo("allgather", "default_allgather_fullmesh2")
return [a] if a else []
@@ -356,7 +342,7 @@ class CustomizedComm:
)
def all_gather(self, output_tensor, input_tensor, stream=None):
if self.multi_node or self.multi_host_mnnvl:
if self.multi_host_mnnvl:
raise RuntimeError("all_gather in this example currently supports only single-node runs")
sz = _round_pow2(input_tensor.nbytes)
if sz not in self._tune_cache["allgather"]:
@@ -497,7 +483,7 @@ def main():
cc.barrier()
torch.cuda.synchronize()
if cc.multi_node or cc.multi_host_mnnvl:
if cc.multi_host_mnnvl:
if cc.rank == 0:
print("Skipping allgather benchmark on multi-node: this example's allgather path is single-node only.")
else: