mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 08:50:21 +00:00
Drop non-MNNVL multi_node regime from torch-integration example
The example is now MNNVL-only: a run is either single-host (everything fits in one node) or multi-host MNNVL (one cross-host NVLink domain). Plain multi-node-without-MNNVL had its own algorithm branch that this example will never exercise, so remove the multi_node flag and the intermediate mnnvl_domain variable. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -83,9 +83,7 @@ class CustomizedComm:
|
||||
self.world_size = comm.nranks
|
||||
self.nranks_per_node = comm.nranks_per_node
|
||||
nvlink_domain_nranks = int(os.environ.get("MSCCLPP_IPC_DOMAIN_NRANKS", "0"))
|
||||
self.mnnvl_domain = self.world_size > 1 and nvlink_domain_nranks >= self.world_size
|
||||
self.multi_node = self.world_size > self.nranks_per_node and not self.mnnvl_domain
|
||||
self.multi_host_mnnvl = self.mnnvl_domain and self.world_size > self.nranks_per_node
|
||||
self.multi_host_mnnvl = nvlink_domain_nranks >= self.world_size and self.world_size > self.nranks_per_node
|
||||
self.symmetric_memory = symmetric_memory
|
||||
self._nvls = mscclpp.is_nvls_supported()
|
||||
|
||||
@@ -108,7 +106,7 @@ class CustomizedComm:
|
||||
pkt = self._algo("allreduce", "default_allreduce_nvls_packet")
|
||||
if self._nvls and pkt:
|
||||
return (pkt, 0, 0)
|
||||
if self.multi_node or self.multi_host_mnnvl:
|
||||
if self.multi_host_mnnvl:
|
||||
rsag = self._algo("allreduce", "default_allreduce_rsag")
|
||||
if rsag:
|
||||
return (rsag, 0, 0)
|
||||
@@ -194,18 +192,6 @@ class CustomizedComm:
|
||||
if a:
|
||||
out.append(a)
|
||||
return out
|
||||
if self.multi_node:
|
||||
a = self._algo("allreduce", "default_allreduce_nvls_packet")
|
||||
if self._nvls and a:
|
||||
out.append(a)
|
||||
a = self._algo("allreduce", "default_allreduce_packet")
|
||||
if a:
|
||||
out.append(a)
|
||||
if size >= 512 << 10:
|
||||
a = self._algo("allreduce", "default_allreduce_rsag")
|
||||
if a:
|
||||
out.append(a)
|
||||
return out
|
||||
if size <= 4 << 20:
|
||||
a = self._algo("allreduce", "default_allreduce_packet")
|
||||
if a:
|
||||
@@ -230,7 +216,7 @@ class CustomizedComm:
|
||||
return out
|
||||
|
||||
def _ag_candidates(self):
|
||||
if self.multi_node or self.multi_host_mnnvl:
|
||||
if self.multi_host_mnnvl:
|
||||
return []
|
||||
a = self._algo("allgather", "default_allgather_fullmesh2")
|
||||
return [a] if a else []
|
||||
@@ -356,7 +342,7 @@ class CustomizedComm:
|
||||
)
|
||||
|
||||
def all_gather(self, output_tensor, input_tensor, stream=None):
|
||||
if self.multi_node or self.multi_host_mnnvl:
|
||||
if self.multi_host_mnnvl:
|
||||
raise RuntimeError("all_gather in this example currently supports only single-node runs")
|
||||
sz = _round_pow2(input_tensor.nbytes)
|
||||
if sz not in self._tune_cache["allgather"]:
|
||||
@@ -497,7 +483,7 @@ def main():
|
||||
cc.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if cc.multi_node or cc.multi_host_mnnvl:
|
||||
if cc.multi_host_mnnvl:
|
||||
if cc.rank == 0:
|
||||
print("Skipping allgather benchmark on multi-node: this example's allgather path is single-node only.")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user