diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py index 8d1efd53..6da9d713 100644 --- a/examples/torch-integration/customized_comm_with_tuning.py +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -83,9 +83,7 @@ class CustomizedComm: self.world_size = comm.nranks self.nranks_per_node = comm.nranks_per_node nvlink_domain_nranks = int(os.environ.get("MSCCLPP_IPC_DOMAIN_NRANKS", "0")) - self.mnnvl_domain = self.world_size > 1 and nvlink_domain_nranks >= self.world_size - self.multi_node = self.world_size > self.nranks_per_node and not self.mnnvl_domain - self.multi_host_mnnvl = self.mnnvl_domain and self.world_size > self.nranks_per_node + self.multi_host_mnnvl = nvlink_domain_nranks >= self.world_size and self.world_size > self.nranks_per_node self.symmetric_memory = symmetric_memory self._nvls = mscclpp.is_nvls_supported() @@ -108,7 +106,7 @@ class CustomizedComm: pkt = self._algo("allreduce", "default_allreduce_nvls_packet") if self._nvls and pkt: return (pkt, 0, 0) - if self.multi_node or self.multi_host_mnnvl: + if self.multi_host_mnnvl: rsag = self._algo("allreduce", "default_allreduce_rsag") if rsag: return (rsag, 0, 0) @@ -194,18 +192,6 @@ class CustomizedComm: if a: out.append(a) return out - if self.multi_node: - a = self._algo("allreduce", "default_allreduce_nvls_packet") - if self._nvls and a: - out.append(a) - a = self._algo("allreduce", "default_allreduce_packet") - if a: - out.append(a) - if size >= 512 << 10: - a = self._algo("allreduce", "default_allreduce_rsag") - if a: - out.append(a) - return out if size <= 4 << 20: a = self._algo("allreduce", "default_allreduce_packet") if a: @@ -230,7 +216,7 @@ class CustomizedComm: return out def _ag_candidates(self): - if self.multi_node or self.multi_host_mnnvl: + if self.multi_host_mnnvl: return [] a = self._algo("allgather", "default_allgather_fullmesh2") return [a] if a else [] @@ -356,7 +342,7 @@ class CustomizedComm: ) def all_gather(self, output_tensor, input_tensor, stream=None): - if self.multi_node or self.multi_host_mnnvl: + if self.multi_host_mnnvl: raise RuntimeError("all_gather in this example currently supports only single-node runs") sz = _round_pow2(input_tensor.nbytes) if sz not in self._tune_cache["allgather"]: @@ -497,7 +483,7 @@ def main(): cc.barrier() torch.cuda.synchronize() - if cc.multi_node or cc.multi_host_mnnvl: + if cc.multi_host_mnnvl: if cc.rank == 0: print("Skipping allgather benchmark on multi-node: this example's allgather path is single-node only.") else: