Simplify torch-integration tuning example

- Drop the multi_host_mnnvl-specific rsag fallback in _default_ar_config;
  fall through to default_allreduce_packet when NVLS is unavailable.
- Add SYMMETRIC_MEMORY env var so the tuning sweep can include the
  zero-copy NVLS / RSAG candidates without editing the source.
- Make _algo() raise on miss (direct dict lookup) and drop the
  defensive 'if a:' guards in _ar_candidates / _ag_candidates /
  _default_ar_config; merge existence checks into the platform
  conditions (self._nvls, self.symmetric_memory).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-05-06 21:14:36 +00:00
parent 905b23d9a8
commit 4a0d5b29d5

View File

@@ -99,17 +99,12 @@ class CustomizedComm:
self._time_buf = None
def _algo(self, collective: str, name: str):
return self._algos.get((collective, name))
return self._algos[(collective, name)]
def _default_ar_config(self):
"""Fallback allreduce config for barrier / timing sync."""
pkt = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and pkt:
return (pkt, 0, 0)
if self.multi_host_mnnvl:
rsag = self._algo("allreduce", "default_allreduce_rsag")
if rsag:
return (rsag, 0, 0)
if self._nvls:
return (self._algo("allreduce", "default_allreduce_nvls_packet"), 0, 0)
return (self._algo("allreduce", "default_allreduce_packet"), 0, 0)
# -- low-level execute --
@@ -157,7 +152,7 @@ class CustomizedComm:
def _barrier_internal(self):
a, nb, nt = self._default_ar_config()
self._exec_ar(self._barrier_tensor, a, nb, nt, sym=self.symmetric_memory)
self._exec_ar(self._barrier_tensor, a, nb, nt, sym=True)
# -- lazy tuning --
@@ -173,53 +168,33 @@ class CustomizedComm:
if self.multi_host_mnnvl:
if size <= 4 << 20:
if size <= 128 << 10:
a = self._algo("allreduce", "default_allreduce_allpair_packet")
if a:
out.append(a)
if size <= 64 << 10:
a = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and a:
out.append(a)
out.append(self._algo("allreduce", "default_allreduce_allpair_packet"))
if size <= 64 << 10 and self._nvls:
out.append(self._algo("allreduce", "default_allreduce_nvls_packet"))
if size > 128 << 10:
a = self._algo("allreduce", "default_allreduce_packet")
if a:
out.append(a)
out.append(self._algo("allreduce", "default_allreduce_packet"))
if size >= 512 << 10:
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
if self._nvls and self.symmetric_memory and a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_rsag")
if a:
out.append(a)
if self._nvls and self.symmetric_memory:
out.append(self._algo("allreduce", "default_allreduce_nvls_zero_copy"))
out.append(self._algo("allreduce", "default_allreduce_rsag"))
return out
if size <= 4 << 20:
a = self._algo("allreduce", "default_allreduce_packet")
if a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_allpair_packet")
if a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and a:
out.append(a)
out.append(self._algo("allreduce", "default_allreduce_packet"))
out.append(self._algo("allreduce", "default_allreduce_allpair_packet"))
if self._nvls:
out.append(self._algo("allreduce", "default_allreduce_nvls_packet"))
if size >= 512 << 10:
a = self._algo("allreduce", "default_allreduce_rsag_zero_copy")
if a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
if self._nvls and self.symmetric_memory and a:
out.append(a)
out.append(self._algo("allreduce", "default_allreduce_rsag_zero_copy"))
if self._nvls and self.symmetric_memory:
out.append(self._algo("allreduce", "default_allreduce_nvls_zero_copy"))
if torch.version.hip is not None:
a = self._algo("allreduce", "default_allreduce_fullmesh")
if a:
out.append(a)
out.append(self._algo("allreduce", "default_allreduce_fullmesh"))
return out
def _ag_candidates(self):
if self.multi_host_mnnvl:
return []
a = self._algo("allgather", "default_allgather_fullmesh2")
return [a] if a else []
return [self._algo("allgather", "default_allgather_fullmesh2")]
def _run_tune(self, collective, algo, buf, size, nb, nt):
"""Single tune invocation for either collective."""
@@ -474,11 +449,15 @@ def main():
accum_map = {"float32": mscclpp.DataType.float32, "float16": mscclpp.DataType.float16}
accum_str = os.environ.get("ACCUM_DTYPE")
accum_dtype = accum_map.get(accum_str) if accum_str else None
symmetric_memory = os.environ.get("SYMMETRIC_MEMORY", "0") == "1"
comm_group = init_dist()
cc = CustomizedComm(comm_group)
cc = CustomizedComm(comm_group, symmetric_memory=symmetric_memory)
print(f"rank {local} starting benchmarks with dtype={dtype} accum_dtype={accum_dtype}...")
print(
f"rank {local} starting benchmarks with dtype={dtype} "
f"accum_dtype={accum_dtype} symmetric_memory={symmetric_memory}..."
)
benchmark_allreduce(cc, dtype=dtype, accum_dtype=accum_dtype)
cc.barrier()
torch.cuda.synchronize()