From 4a0d5b29d509b00268a64f6e0a5b4db602e8cb46 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 6 May 2026 21:14:36 +0000 Subject: [PATCH] Simplify torch-integration tuning example - Drop the multi_host_mnnvl-specific rsag fallback in _default_ar_config; fall through to default_allreduce_packet when NVLS is unavailable. - Add SYMMETRIC_MEMORY env var so the tuning sweep can include the zero-copy NVLS / RSAG candidates without editing the source. - Make _algo() raise on miss (direct dict lookup) and drop the defensive 'if a:' guards in _ar_candidates / _ag_candidates / _default_ar_config; merge existence checks into the platform conditions (self._nvls, self.symmetric_memory). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../customized_comm_with_tuning.py | 73 +++++++------------ 1 file changed, 26 insertions(+), 47 deletions(-) diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py index 6da9d713..18fdd6f1 100644 --- a/examples/torch-integration/customized_comm_with_tuning.py +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -99,17 +99,12 @@ class CustomizedComm: self._time_buf = None def _algo(self, collective: str, name: str): - return self._algos.get((collective, name)) + return self._algos[(collective, name)] def _default_ar_config(self): """Fallback allreduce config for barrier / timing sync.""" - pkt = self._algo("allreduce", "default_allreduce_nvls_packet") - if self._nvls and pkt: - return (pkt, 0, 0) - if self.multi_host_mnnvl: - rsag = self._algo("allreduce", "default_allreduce_rsag") - if rsag: - return (rsag, 0, 0) + if self._nvls: + return (self._algo("allreduce", "default_allreduce_nvls_packet"), 0, 0) return (self._algo("allreduce", "default_allreduce_packet"), 0, 0) # -- low-level execute -- @@ -157,7 +152,7 @@ class CustomizedComm: def _barrier_internal(self): a, nb, nt = self._default_ar_config() - self._exec_ar(self._barrier_tensor, a, nb, nt, sym=self.symmetric_memory) + self._exec_ar(self._barrier_tensor, a, nb, nt, sym=True) # -- lazy tuning -- @@ -173,53 +168,33 @@ class CustomizedComm: if self.multi_host_mnnvl: if size <= 4 << 20: if size <= 128 << 10: - a = self._algo("allreduce", "default_allreduce_allpair_packet") - if a: - out.append(a) - if size <= 64 << 10: - a = self._algo("allreduce", "default_allreduce_nvls_packet") - if self._nvls and a: - out.append(a) + out.append(self._algo("allreduce", "default_allreduce_allpair_packet")) + if size <= 64 << 10 and self._nvls: + out.append(self._algo("allreduce", "default_allreduce_nvls_packet")) if size > 128 << 10: - a = self._algo("allreduce", "default_allreduce_packet") - if a: - out.append(a) + out.append(self._algo("allreduce", "default_allreduce_packet")) if size >= 512 << 10: - a = self._algo("allreduce", "default_allreduce_nvls_zero_copy") - if self._nvls and self.symmetric_memory and a: - out.append(a) - a = self._algo("allreduce", "default_allreduce_rsag") - if a: - out.append(a) + if self._nvls and self.symmetric_memory: + out.append(self._algo("allreduce", "default_allreduce_nvls_zero_copy")) + out.append(self._algo("allreduce", "default_allreduce_rsag")) return out if size <= 4 << 20: - a = self._algo("allreduce", "default_allreduce_packet") - if a: - out.append(a) - a = self._algo("allreduce", "default_allreduce_allpair_packet") - if a: - out.append(a) - a = self._algo("allreduce", "default_allreduce_nvls_packet") - if self._nvls and a: - out.append(a) + out.append(self._algo("allreduce", "default_allreduce_packet")) + out.append(self._algo("allreduce", "default_allreduce_allpair_packet")) + if self._nvls: + out.append(self._algo("allreduce", "default_allreduce_nvls_packet")) if size >= 512 << 10: - a = self._algo("allreduce", "default_allreduce_rsag_zero_copy") - if a: - out.append(a) - a = self._algo("allreduce", "default_allreduce_nvls_zero_copy") - if self._nvls and self.symmetric_memory and a: - out.append(a) + out.append(self._algo("allreduce", "default_allreduce_rsag_zero_copy")) + if self._nvls and self.symmetric_memory: + out.append(self._algo("allreduce", "default_allreduce_nvls_zero_copy")) if torch.version.hip is not None: - a = self._algo("allreduce", "default_allreduce_fullmesh") - if a: - out.append(a) + out.append(self._algo("allreduce", "default_allreduce_fullmesh")) return out def _ag_candidates(self): if self.multi_host_mnnvl: return [] - a = self._algo("allgather", "default_allgather_fullmesh2") - return [a] if a else [] + return [self._algo("allgather", "default_allgather_fullmesh2")] def _run_tune(self, collective, algo, buf, size, nb, nt): """Single tune invocation for either collective.""" @@ -474,11 +449,15 @@ def main(): accum_map = {"float32": mscclpp.DataType.float32, "float16": mscclpp.DataType.float16} accum_str = os.environ.get("ACCUM_DTYPE") accum_dtype = accum_map.get(accum_str) if accum_str else None + symmetric_memory = os.environ.get("SYMMETRIC_MEMORY", "0") == "1" comm_group = init_dist() - cc = CustomizedComm(comm_group) + cc = CustomizedComm(comm_group, symmetric_memory=symmetric_memory) - print(f"rank {local} starting benchmarks with dtype={dtype} accum_dtype={accum_dtype}...") + print( + f"rank {local} starting benchmarks with dtype={dtype} " + f"accum_dtype={accum_dtype} symmetric_memory={symmetric_memory}..." + ) benchmark_allreduce(cc, dtype=dtype, accum_dtype=accum_dtype) cc.barrier() torch.cuda.synchronize()