mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 08:50:21 +00:00
Simplify torch-integration tuning example
- Drop the multi_host_mnnvl-specific rsag fallback in _default_ar_config; fall through to default_allreduce_packet when NVLS is unavailable. - Add SYMMETRIC_MEMORY env var so the tuning sweep can include the zero-copy NVLS / RSAG candidates without editing the source. - Make _algo() raise on miss (direct dict lookup) and drop the defensive 'if a:' guards in _ar_candidates / _ag_candidates / _default_ar_config; merge existence checks into the platform conditions (self._nvls, self.symmetric_memory). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -99,17 +99,12 @@ class CustomizedComm:
|
||||
self._time_buf = None
|
||||
|
||||
def _algo(self, collective: str, name: str):
|
||||
return self._algos.get((collective, name))
|
||||
return self._algos[(collective, name)]
|
||||
|
||||
def _default_ar_config(self):
|
||||
"""Fallback allreduce config for barrier / timing sync."""
|
||||
pkt = self._algo("allreduce", "default_allreduce_nvls_packet")
|
||||
if self._nvls and pkt:
|
||||
return (pkt, 0, 0)
|
||||
if self.multi_host_mnnvl:
|
||||
rsag = self._algo("allreduce", "default_allreduce_rsag")
|
||||
if rsag:
|
||||
return (rsag, 0, 0)
|
||||
if self._nvls:
|
||||
return (self._algo("allreduce", "default_allreduce_nvls_packet"), 0, 0)
|
||||
return (self._algo("allreduce", "default_allreduce_packet"), 0, 0)
|
||||
|
||||
# -- low-level execute --
|
||||
@@ -157,7 +152,7 @@ class CustomizedComm:
|
||||
|
||||
def _barrier_internal(self):
|
||||
a, nb, nt = self._default_ar_config()
|
||||
self._exec_ar(self._barrier_tensor, a, nb, nt, sym=self.symmetric_memory)
|
||||
self._exec_ar(self._barrier_tensor, a, nb, nt, sym=True)
|
||||
|
||||
# -- lazy tuning --
|
||||
|
||||
@@ -173,53 +168,33 @@ class CustomizedComm:
|
||||
if self.multi_host_mnnvl:
|
||||
if size <= 4 << 20:
|
||||
if size <= 128 << 10:
|
||||
a = self._algo("allreduce", "default_allreduce_allpair_packet")
|
||||
if a:
|
||||
out.append(a)
|
||||
if size <= 64 << 10:
|
||||
a = self._algo("allreduce", "default_allreduce_nvls_packet")
|
||||
if self._nvls and a:
|
||||
out.append(a)
|
||||
out.append(self._algo("allreduce", "default_allreduce_allpair_packet"))
|
||||
if size <= 64 << 10 and self._nvls:
|
||||
out.append(self._algo("allreduce", "default_allreduce_nvls_packet"))
|
||||
if size > 128 << 10:
|
||||
a = self._algo("allreduce", "default_allreduce_packet")
|
||||
if a:
|
||||
out.append(a)
|
||||
out.append(self._algo("allreduce", "default_allreduce_packet"))
|
||||
if size >= 512 << 10:
|
||||
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
|
||||
if self._nvls and self.symmetric_memory and a:
|
||||
out.append(a)
|
||||
a = self._algo("allreduce", "default_allreduce_rsag")
|
||||
if a:
|
||||
out.append(a)
|
||||
if self._nvls and self.symmetric_memory:
|
||||
out.append(self._algo("allreduce", "default_allreduce_nvls_zero_copy"))
|
||||
out.append(self._algo("allreduce", "default_allreduce_rsag"))
|
||||
return out
|
||||
if size <= 4 << 20:
|
||||
a = self._algo("allreduce", "default_allreduce_packet")
|
||||
if a:
|
||||
out.append(a)
|
||||
a = self._algo("allreduce", "default_allreduce_allpair_packet")
|
||||
if a:
|
||||
out.append(a)
|
||||
a = self._algo("allreduce", "default_allreduce_nvls_packet")
|
||||
if self._nvls and a:
|
||||
out.append(a)
|
||||
out.append(self._algo("allreduce", "default_allreduce_packet"))
|
||||
out.append(self._algo("allreduce", "default_allreduce_allpair_packet"))
|
||||
if self._nvls:
|
||||
out.append(self._algo("allreduce", "default_allreduce_nvls_packet"))
|
||||
if size >= 512 << 10:
|
||||
a = self._algo("allreduce", "default_allreduce_rsag_zero_copy")
|
||||
if a:
|
||||
out.append(a)
|
||||
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
|
||||
if self._nvls and self.symmetric_memory and a:
|
||||
out.append(a)
|
||||
out.append(self._algo("allreduce", "default_allreduce_rsag_zero_copy"))
|
||||
if self._nvls and self.symmetric_memory:
|
||||
out.append(self._algo("allreduce", "default_allreduce_nvls_zero_copy"))
|
||||
if torch.version.hip is not None:
|
||||
a = self._algo("allreduce", "default_allreduce_fullmesh")
|
||||
if a:
|
||||
out.append(a)
|
||||
out.append(self._algo("allreduce", "default_allreduce_fullmesh"))
|
||||
return out
|
||||
|
||||
def _ag_candidates(self):
|
||||
if self.multi_host_mnnvl:
|
||||
return []
|
||||
a = self._algo("allgather", "default_allgather_fullmesh2")
|
||||
return [a] if a else []
|
||||
return [self._algo("allgather", "default_allgather_fullmesh2")]
|
||||
|
||||
def _run_tune(self, collective, algo, buf, size, nb, nt):
|
||||
"""Single tune invocation for either collective."""
|
||||
@@ -474,11 +449,15 @@ def main():
|
||||
accum_map = {"float32": mscclpp.DataType.float32, "float16": mscclpp.DataType.float16}
|
||||
accum_str = os.environ.get("ACCUM_DTYPE")
|
||||
accum_dtype = accum_map.get(accum_str) if accum_str else None
|
||||
symmetric_memory = os.environ.get("SYMMETRIC_MEMORY", "0") == "1"
|
||||
|
||||
comm_group = init_dist()
|
||||
cc = CustomizedComm(comm_group)
|
||||
cc = CustomizedComm(comm_group, symmetric_memory=symmetric_memory)
|
||||
|
||||
print(f"rank {local} starting benchmarks with dtype={dtype} accum_dtype={accum_dtype}...")
|
||||
print(
|
||||
f"rank {local} starting benchmarks with dtype={dtype} "
|
||||
f"accum_dtype={accum_dtype} symmetric_memory={symmetric_memory}..."
|
||||
)
|
||||
benchmark_allreduce(cc, dtype=dtype, accum_dtype=accum_dtype)
|
||||
cc.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
Reference in New Issue
Block a user