mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Run the tuning example with symmetric memory disabled, make allreduce tuning use the same symmetric-memory mode as execution, and narrow the MNNVL small-message candidate set to avoid slower packet/NVLS choices. Increase packet and RSAG channel parallelism so non-symmetric CUDA-IPC paths can use 112-block packet and 128-block RSAG configs. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
625 lines
23 KiB
Python
625 lines
23 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
# torchrun --nnodes=1 --nproc_per_node=8 examples/torch-integration/customized_comm_with_tuning.py
|
|
# mpirun -np 2 --hostfile <hostfile> python3 examples/torch-integration/customized_comm_with_tuning.py
|
|
|
|
import gc
|
|
import fcntl
|
|
import ipaddress
|
|
import os
|
|
import socket
|
|
import struct
|
|
import sys
|
|
import traceback
|
|
|
|
def _get_bootstrap_world_size():
|
|
for name in ("WORLD_SIZE", "OMPI_COMM_WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS"):
|
|
value = os.environ.get(name)
|
|
if value is not None:
|
|
return int(value)
|
|
return None
|
|
|
|
|
|
_bootstrap_world_size = _get_bootstrap_world_size()
|
|
if (
|
|
_bootstrap_world_size
|
|
and _bootstrap_world_size > 1
|
|
and "MSCCLPP_MNNVL_NRANKS_PER_NODE" not in os.environ
|
|
and os.environ.get("MSCCLPP_ENABLE_MNNVL", "1") != "0"
|
|
):
|
|
os.environ["MSCCLPP_MNNVL_NRANKS_PER_NODE"] = str(_bootstrap_world_size)
|
|
|
|
import torch
|
|
import mscclpp
|
|
import mscclpp.ext
|
|
import mscclpp.utils as mscclpp_utils
|
|
|
|
# -- Helpers ------------------------------------------------------------------
|
|
|
|
|
|
def _make_tensor(size_bytes: int, dtype: torch.dtype) -> torch.Tensor:
|
|
"""Allocate a tensor backed by RawGpuBuffer (symmetric memory)."""
|
|
# PyTorch's from_dlpack does not support certain float8 DLPack type codes.
|
|
# Work around by importing as uint8 and reinterpreting via .view().
|
|
_DLPACK_UNSUPPORTED = (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)
|
|
if dtype in _DLPACK_UNSUPPORTED:
|
|
dlpack = mscclpp.RawGpuBuffer(size_bytes).to_dlpack(data_type=str(torch.uint8))
|
|
return torch.utils.dlpack.from_dlpack(dlpack).view(dtype)
|
|
dlpack = mscclpp.RawGpuBuffer(size_bytes).to_dlpack(data_type=str(dtype))
|
|
return torch.utils.dlpack.from_dlpack(dlpack)
|
|
|
|
|
|
def _load_algorithms(scratch: torch.Tensor, rank: int):
|
|
return mscclpp.ext.AlgorithmCollectionBuilder().build_default_algorithms(
|
|
scratch_buffer=scratch.data_ptr(),
|
|
scratch_buffer_size=scratch.nbytes,
|
|
rank=rank,
|
|
)
|
|
|
|
|
|
def _interfaces_for_ip(ip: str):
|
|
target = ipaddress.ip_address(ip)
|
|
for iface in os.listdir("/sys/class/net"):
|
|
try:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
|
|
req = struct.pack("256s", iface.encode()[:15])
|
|
addr = socket.inet_ntoa(fcntl.ioctl(sock.fileno(), 0x8915, req)[20:24])
|
|
except OSError:
|
|
continue
|
|
if ipaddress.ip_address(addr) == target:
|
|
return iface
|
|
return None
|
|
|
|
|
|
def _resolve_interface(master_addr: str):
|
|
for env_name in ("MSCCLPP_INTERFACE", "MSCCLPP_SOCKET_IFNAME", "NCCL_SOCKET_IFNAME"):
|
|
value = os.environ.get(env_name)
|
|
if value:
|
|
iface = value.split(",")[0].strip()
|
|
if iface in os.listdir("/sys/class/net"):
|
|
return iface
|
|
raise ValueError(f"Interface {iface} from {env_name} does not exist")
|
|
return _interfaces_for_ip(master_addr)
|
|
|
|
|
|
def _get_env_int(*names: str, default=None):
|
|
for name in names:
|
|
value = os.environ.get(name)
|
|
if value is not None:
|
|
return int(value)
|
|
return default
|
|
|
|
|
|
def _running_under_mpi() -> bool:
|
|
return any(
|
|
name in os.environ
|
|
for name in ("OMPI_COMM_WORLD_RANK", "PMI_RANK", "PMIX_RANK", "MPI_LOCALRANKID", "SLURM_PROCID")
|
|
)
|
|
|
|
|
|
def _to_mscclpp_op(op) -> mscclpp.ReduceOp:
|
|
if op == torch.distributed.ReduceOp.SUM:
|
|
return mscclpp.ReduceOp.SUM
|
|
if op == torch.distributed.ReduceOp.MIN:
|
|
return mscclpp.ReduceOp.MIN
|
|
raise ValueError(f"unsupported op: {op}")
|
|
|
|
|
|
def _round_pow2(size: int) -> int:
|
|
"""Round up to next power-of-2, clamped to [1024, 256 MB]."""
|
|
size = max(size, 1024)
|
|
size = min(size, 256 << 20)
|
|
return 1 << (size - 1).bit_length()
|
|
|
|
|
|
# -- CustomizedComm -----------------------------------------------------------
|
|
|
|
|
|
class CustomizedComm:
|
|
"""Exposes all_reduce, all_gather, barrier with lazy per-size tuning."""
|
|
|
|
_TUNE_N_WARMUP = 5
|
|
_TUNE_N_GRAPH_LAUNCHES = 10
|
|
_TUNE_N_OPS_PER_GRAPH = 100
|
|
_CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 64, 112, 128]
|
|
_CANDIDATE_NTHREADS = [512, 768, 1024]
|
|
_NBLOCKS_LIMIT = {
|
|
"default_allreduce_nvls_packet": 16,
|
|
"default_allreduce_nvls_zero_copy": 32,
|
|
"default_allreduce_packet": 112,
|
|
"default_allreduce_allpair_packet": 56,
|
|
"default_allreduce_rsag": 128,
|
|
"default_allreduce_rsag_zero_copy": 64,
|
|
"default_allreduce_fullmesh": 64,
|
|
"default_allgather_fullmesh2": 32,
|
|
}
|
|
|
|
def __init__(self, comm: mscclpp.CommGroup, symmetric_memory: bool = False):
|
|
self.comm = comm
|
|
self.rank = comm.my_rank
|
|
self.world_size = comm.nranks
|
|
self.nranks_per_node = comm.nranks_per_node
|
|
self.mnnvl_domain = self.world_size > 1 and os.environ.get("MSCCLPP_MNNVL_NRANKS_PER_NODE") == str(
|
|
self.world_size
|
|
)
|
|
self.multi_node = self.world_size > self.nranks_per_node and not self.mnnvl_domain
|
|
self.multi_host_mnnvl = self.mnnvl_domain and self.world_size > 1
|
|
self.symmetric_memory = symmetric_memory
|
|
self._nvls = mscclpp.is_nvls_supported()
|
|
|
|
self._scratch = _make_tensor(1 << 27, torch.float16)
|
|
self._barrier_tensor = _make_tensor(4096, torch.float32)
|
|
|
|
algos = _load_algorithms(self._scratch, self.rank)
|
|
self._algos = {(a.collective, a.name): a for a in algos}
|
|
|
|
# {collective: {rounded_size: (algo, nblocks, nthreads)}}
|
|
self._tune_cache: dict[str, dict[int, tuple]] = {"allreduce": {}, "allgather": {}}
|
|
self._tune_buf = None
|
|
self._time_buf = None
|
|
|
|
def _algo(self, collective: str, name: str):
|
|
return self._algos.get((collective, name))
|
|
|
|
def _nblocks_limit(self, algo_name: str, size: int) -> int:
|
|
if algo_name == "default_allreduce_packet" and size < (1 << 20):
|
|
return 56
|
|
return self._NBLOCKS_LIMIT.get(algo_name, 128)
|
|
|
|
def _default_ar_config(self):
|
|
"""Fallback allreduce config for barrier / timing sync."""
|
|
pkt = self._algo("allreduce", "default_allreduce_nvls_packet")
|
|
if self._nvls and pkt:
|
|
return (pkt, 0, 0)
|
|
if self.multi_node or self.multi_host_mnnvl:
|
|
rsag = self._algo("allreduce", "default_allreduce_rsag")
|
|
if rsag:
|
|
return (rsag, 0, 0)
|
|
return (self._algo("allreduce", "default_allreduce_packet"), 0, 0)
|
|
|
|
# -- low-level execute --
|
|
|
|
def _exec_ar(self, tensor, algo, nb, nt, op=mscclpp.ReduceOp.SUM, stream=None, accum_dtype=None, sym=True):
|
|
s = stream.cuda_stream if stream else torch.cuda.current_stream().cuda_stream
|
|
ret = algo.execute(
|
|
comm=self.comm.communicator,
|
|
input_buffer=tensor.data_ptr(),
|
|
output_buffer=tensor.data_ptr(),
|
|
input_size=tensor.nbytes,
|
|
output_size=tensor.nbytes,
|
|
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype),
|
|
op=op,
|
|
stream=s,
|
|
nblocks=nb,
|
|
nthreads_per_block=nt,
|
|
symmetric_memory=sym,
|
|
accum_dtype=accum_dtype,
|
|
)
|
|
if ret != 0:
|
|
print(f"Rank {self.rank}: {algo.name} failed ({ret})")
|
|
return ret
|
|
|
|
def _exec_ag(self, inp, out, algo, nb, nt, stream=None, sym=None):
|
|
if sym is None:
|
|
sym = self.symmetric_memory
|
|
s = stream.cuda_stream if stream else torch.cuda.current_stream().cuda_stream
|
|
ret = algo.execute(
|
|
comm=self.comm.communicator,
|
|
input_buffer=inp.data_ptr(),
|
|
output_buffer=out.data_ptr(),
|
|
input_size=inp.nbytes,
|
|
output_size=out.nbytes,
|
|
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(inp.dtype),
|
|
op=mscclpp.ReduceOp.NOP,
|
|
stream=s,
|
|
nblocks=nb,
|
|
nthreads_per_block=nt,
|
|
symmetric_memory=sym,
|
|
)
|
|
if ret != 0:
|
|
print(f"Rank {self.rank}: AG {algo.name} failed ({ret})")
|
|
return ret
|
|
|
|
def _barrier_internal(self):
|
|
a, nb, nt = self._default_ar_config()
|
|
self._exec_ar(self._barrier_tensor, a, nb, nt, sym=self.symmetric_memory)
|
|
|
|
# -- lazy tuning --
|
|
|
|
def _ensure_tune_bufs(self):
|
|
if self._tune_buf is None:
|
|
self._tune_buf = _make_tensor(1 << 27, torch.float16)
|
|
self._tune_buf.normal_()
|
|
self._time_buf = _make_tensor(4096, torch.float32)
|
|
return self._tune_buf
|
|
|
|
def _ar_candidates(self, size: int):
|
|
out = []
|
|
if self.multi_host_mnnvl:
|
|
if size <= 4 << 20:
|
|
a = self._algo("allreduce", "default_allreduce_allpair_packet")
|
|
if a:
|
|
out.append(a)
|
|
if size <= 64 << 10:
|
|
a = self._algo("allreduce", "default_allreduce_nvls_packet")
|
|
if self._nvls and a:
|
|
out.append(a)
|
|
if size > 128 << 10:
|
|
a = self._algo("allreduce", "default_allreduce_packet")
|
|
if a:
|
|
out.append(a)
|
|
if size >= 512 << 10:
|
|
a = self._algo("allreduce", "default_allreduce_rsag_zero_copy")
|
|
if self.symmetric_memory and a:
|
|
out.append(a)
|
|
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
|
|
if self._nvls and self.symmetric_memory and a:
|
|
out.append(a)
|
|
a = self._algo("allreduce", "default_allreduce_rsag")
|
|
if a:
|
|
out.append(a)
|
|
return out
|
|
if self.multi_node:
|
|
a = self._algo("allreduce", "default_allreduce_nvls_packet")
|
|
if self._nvls and a:
|
|
out.append(a)
|
|
a = self._algo("allreduce", "default_allreduce_packet")
|
|
if a:
|
|
out.append(a)
|
|
if size >= 512 << 10:
|
|
a = self._algo("allreduce", "default_allreduce_rsag")
|
|
if a:
|
|
out.append(a)
|
|
return out
|
|
if size <= 4 << 20:
|
|
a = self._algo("allreduce", "default_allreduce_packet")
|
|
if a:
|
|
out.append(a)
|
|
a = self._algo("allreduce", "default_allreduce_allpair_packet")
|
|
if a:
|
|
out.append(a)
|
|
a = self._algo("allreduce", "default_allreduce_nvls_packet")
|
|
if self._nvls and a:
|
|
out.append(a)
|
|
if size >= 512 << 10:
|
|
a = self._algo("allreduce", "default_allreduce_rsag_zero_copy")
|
|
if a:
|
|
out.append(a)
|
|
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
|
|
if self._nvls and self.symmetric_memory and a:
|
|
out.append(a)
|
|
if torch.version.hip is not None:
|
|
a = self._algo("allreduce", "default_allreduce_fullmesh")
|
|
if a:
|
|
out.append(a)
|
|
return out
|
|
|
|
def _ag_candidates(self):
|
|
if self.multi_node or self.multi_host_mnnvl:
|
|
return []
|
|
a = self._algo("allgather", "default_allgather_fullmesh2")
|
|
return [a] if a else []
|
|
|
|
def _run_tune(self, collective, algo, buf, size, nb, nt):
|
|
"""Single tune invocation for either collective."""
|
|
if collective == "allreduce":
|
|
return algo.execute(
|
|
comm=self.comm.communicator,
|
|
input_buffer=buf.data_ptr(),
|
|
output_buffer=buf.data_ptr(),
|
|
input_size=size,
|
|
output_size=size,
|
|
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(buf.dtype),
|
|
op=mscclpp.ReduceOp.SUM,
|
|
stream=torch.cuda.current_stream().cuda_stream,
|
|
nblocks=nb,
|
|
nthreads_per_block=nt,
|
|
symmetric_memory=self.symmetric_memory,
|
|
)
|
|
else:
|
|
total = size * self.world_size
|
|
out_ptr = buf.data_ptr()
|
|
return algo.execute(
|
|
comm=self.comm.communicator,
|
|
input_buffer=out_ptr + self.rank * size,
|
|
output_buffer=out_ptr,
|
|
input_size=size,
|
|
output_size=total,
|
|
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(buf.dtype),
|
|
op=mscclpp.ReduceOp.NOP,
|
|
stream=torch.cuda.current_stream().cuda_stream,
|
|
nblocks=nb,
|
|
nthreads_per_block=nt,
|
|
symmetric_memory=False,
|
|
)
|
|
|
|
def _tune_size(self, collective: str, target_size: int):
|
|
"""Auto-tune one (collective, target_size) pair and cache result."""
|
|
buf = self._ensure_tune_bufs()
|
|
cands = self._ar_candidates(target_size) if collective == "allreduce" else self._ag_candidates()
|
|
|
|
best_time, best_cfg = float("inf"), None
|
|
used = set()
|
|
run = lambda a, nb, nt: self._run_tune(collective, a, buf, target_size, nb, nt)
|
|
|
|
for algo in cands:
|
|
nb_limit = self._nblocks_limit(algo.name, target_size)
|
|
for nb in self._CANDIDATE_NBLOCKS:
|
|
if nb > nb_limit:
|
|
continue
|
|
for nt in self._CANDIDATE_NTHREADS:
|
|
# Feasibility — sync result across ranks so all agree
|
|
ret = run(algo, nb, nt)
|
|
torch.cuda.synchronize()
|
|
self._time_buf[0] = float(ret)
|
|
self._exec_ar(self._time_buf[:1], *self._default_ar_config(), sym=self.symmetric_memory)
|
|
if self._time_buf[0].item() != 0:
|
|
continue
|
|
used.add(algo)
|
|
|
|
# Warmup
|
|
for _ in range(self._TUNE_N_WARMUP):
|
|
run(algo, nb, nt)
|
|
|
|
# CUDA-graph timed benchmark
|
|
cs = torch.cuda.Stream()
|
|
cs.wait_stream(torch.cuda.current_stream())
|
|
g = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(g, stream=cs):
|
|
for _ in range(self._TUNE_N_OPS_PER_GRAPH):
|
|
run(algo, nb, nt)
|
|
|
|
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
|
start.record(cs)
|
|
with torch.cuda.stream(cs):
|
|
for _ in range(self._TUNE_N_GRAPH_LAUNCHES):
|
|
g.replay()
|
|
end.record(cs)
|
|
end.synchronize()
|
|
elapsed = start.elapsed_time(end)
|
|
|
|
# Cross-rank timing sync
|
|
self._time_buf.fill_(elapsed)
|
|
torch.cuda.current_stream().wait_stream(cs)
|
|
self._exec_ar(self._time_buf, *self._default_ar_config(), sym=self.symmetric_memory)
|
|
avg = self._time_buf[self.rank].item() / self.world_size
|
|
|
|
if avg < best_time:
|
|
best_time, best_cfg = avg, (algo, nb, nt)
|
|
|
|
if best_cfg:
|
|
self._tune_cache[collective][target_size] = best_cfg
|
|
if self.rank == 0:
|
|
n = self._TUNE_N_GRAPH_LAUNCHES * self._TUNE_N_OPS_PER_GRAPH
|
|
print(
|
|
f"[tune] {collective} size={target_size}: {best_cfg[0].name} "
|
|
f"nb={best_cfg[1]} nt={best_cfg[2]} time={best_time / n * 1000:.2f}us",
|
|
flush=True,
|
|
)
|
|
else:
|
|
fb = (
|
|
self._default_ar_config()
|
|
if collective == "allreduce"
|
|
else ((self._ag_candidates()[0], 32, 512) if self._ag_candidates() else None)
|
|
)
|
|
self._tune_cache[collective][target_size] = fb
|
|
|
|
torch.cuda.synchronize()
|
|
self._barrier_internal()
|
|
for a in used:
|
|
a.reset()
|
|
|
|
# -- public API --
|
|
|
|
def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, stream=None, accum_dtype=None):
|
|
sz = _round_pow2(tensor.nbytes)
|
|
if sz not in self._tune_cache["allreduce"]:
|
|
self._tune_size("allreduce", sz)
|
|
a, nb, nt = self._tune_cache["allreduce"][sz]
|
|
self._exec_ar(
|
|
tensor, a, nb, nt, op=_to_mscclpp_op(op), stream=stream, accum_dtype=accum_dtype, sym=self.symmetric_memory
|
|
)
|
|
|
|
def all_gather(self, output_tensor, input_tensor, stream=None):
|
|
if self.multi_node or self.multi_host_mnnvl:
|
|
raise RuntimeError("all_gather in this example currently supports only single-node runs")
|
|
sz = _round_pow2(input_tensor.nbytes)
|
|
if sz not in self._tune_cache["allgather"]:
|
|
self._tune_size("allgather", sz)
|
|
a, nb, nt = self._tune_cache["allgather"][sz]
|
|
self._exec_ag(input_tensor, output_tensor, a, nb, nt, stream=stream, sym=self.symmetric_memory)
|
|
|
|
def barrier(self):
|
|
self._barrier_internal()
|
|
|
|
def destroy(self):
|
|
self._algos.clear()
|
|
self._tune_cache = {"allreduce": {}, "allgather": {}}
|
|
self._tune_buf = self._time_buf = self._barrier_tensor = self._scratch = self.comm = None
|
|
|
|
|
|
# -- Benchmarks (standalone) --------------------------------------------------
|
|
|
|
|
|
def _bench_sizes(low=None, high=None):
|
|
if low is None:
|
|
low = _get_env_int("MSCCLPP_BENCH_LOW_SIZE", default=5 * 1024)
|
|
if high is None:
|
|
high = _get_env_int("MSCCLPP_BENCH_HIGH_SIZE", default=80 << 20)
|
|
sizes, c = [], low
|
|
while c <= high:
|
|
sizes.append(c)
|
|
c *= 2
|
|
return sizes
|
|
|
|
|
|
def benchmark_allreduce(
|
|
comm: CustomizedComm, dtype=torch.float16, accum_dtype=None, n_warmup=10, n_graph_launches=10, n_iter=100
|
|
):
|
|
sizes = _bench_sizes()
|
|
if comm.rank == 0:
|
|
print(f"\n{'='*60}\nAllreduce Benchmark\n{'='*60}")
|
|
print(f"{'Nelements':<18} {'Size(B)':<18} {'Time(us)':<18} {'AlgoBW(GB/s)':<18}")
|
|
|
|
cs = torch.cuda.Stream()
|
|
buf = _make_tensor(1 << 27, dtype)
|
|
buf.normal_() if dtype in (torch.float16, torch.float32, torch.bfloat16) else buf.fill_(0)
|
|
|
|
for size in sizes:
|
|
nelems = size // buf.element_size()
|
|
t = buf[: size // buf.element_size()]
|
|
comm.all_reduce(t, accum_dtype=accum_dtype)
|
|
torch.cuda.synchronize()
|
|
|
|
cs.wait_stream(torch.cuda.current_stream())
|
|
g = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(g, stream=cs):
|
|
for _ in range(n_iter):
|
|
comm.all_reduce(t, accum_dtype=accum_dtype)
|
|
with torch.cuda.stream(cs):
|
|
for _ in range(n_warmup):
|
|
g.replay()
|
|
comm.barrier()
|
|
cs.synchronize()
|
|
|
|
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
|
s.record(cs)
|
|
with torch.cuda.stream(cs):
|
|
for _ in range(n_graph_launches):
|
|
g.replay()
|
|
e.record(cs)
|
|
e.synchronize()
|
|
|
|
ms = s.elapsed_time(e) / (n_graph_launches * n_iter)
|
|
if comm.rank == 0:
|
|
print(f"{nelems:<18} {size:<18} {ms*1000:<18.2f} {size/(ms*1e-3)/1e9:<18.2f}")
|
|
|
|
|
|
def benchmark_allgather(comm: CustomizedComm, dtype=torch.float16, n_warmup=10, n_graph_launches=10, n_iter=100):
|
|
sizes = _bench_sizes()
|
|
if comm.rank == 0:
|
|
print(f"\n{'='*60}\nAllgather Benchmark\n{'='*60}")
|
|
print(f"{'PerRank(B)':<18} {'Total(B)':<18} {'Time(us)':<18} {'AlgoBW(GB/s)':<18}")
|
|
|
|
cs = torch.cuda.Stream()
|
|
buf = _make_tensor(1 << 27, dtype)
|
|
buf.normal_() if dtype in (torch.float16, torch.float32, torch.bfloat16) else buf.fill_(0)
|
|
|
|
for prs in sizes:
|
|
total = prs * comm.world_size
|
|
if total > buf.nbytes:
|
|
break
|
|
nt = total // buf.element_size()
|
|
npr = prs // buf.element_size()
|
|
out = buf[:nt]
|
|
inp = out[comm.rank * npr : (comm.rank + 1) * npr]
|
|
|
|
comm.all_gather(out, inp)
|
|
torch.cuda.synchronize()
|
|
|
|
cs.wait_stream(torch.cuda.current_stream())
|
|
g = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(g, stream=cs):
|
|
for _ in range(n_iter):
|
|
comm.all_gather(out, inp)
|
|
with torch.cuda.stream(cs):
|
|
for _ in range(n_warmup):
|
|
g.replay()
|
|
comm.barrier()
|
|
cs.synchronize()
|
|
|
|
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
|
s.record(cs)
|
|
with torch.cuda.stream(cs):
|
|
for _ in range(n_graph_launches):
|
|
g.replay()
|
|
e.record(cs)
|
|
e.synchronize()
|
|
|
|
ms = s.elapsed_time(e) / (n_graph_launches * n_iter)
|
|
if comm.rank == 0:
|
|
print(f"{prs:<18} {total:<18} {ms*1000:<18.2f} {total/(ms*1e-3)/1e9:<18.2f}")
|
|
|
|
|
|
# -- Bootstrap & main ---------------------------------------------------------
|
|
|
|
|
|
def init_dist() -> mscclpp.CommGroup:
|
|
addr = os.environ.get("MSCCLPP_MASTER_ADDR")
|
|
rank = _get_env_int("RANK", "OMPI_COMM_WORLD_RANK", "PMI_RANK", "SLURM_PROCID")
|
|
world = _get_env_int("WORLD_SIZE", "OMPI_COMM_WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS")
|
|
if addr and rank is not None and world is not None:
|
|
port = os.environ.get("MSCCLPP_MASTER_PORT", "29500")
|
|
iface = _resolve_interface(addr)
|
|
if not iface:
|
|
raise ValueError(f"No interface for {addr}")
|
|
return mscclpp.CommGroup(interfaceIpPortTrio=f"{iface}:{addr}:{port}", rank=rank, size=world)
|
|
if _running_under_mpi():
|
|
try:
|
|
from mpi4py import MPI
|
|
except ModuleNotFoundError as exc:
|
|
raise RuntimeError("mpi4py is required to launch this example with mpirun") from exc
|
|
|
|
return mscclpp.CommGroup(mpi_comm=MPI.COMM_WORLD)
|
|
import torch.distributed as dist
|
|
|
|
dist.init_process_group(backend="gloo")
|
|
return mscclpp.CommGroup(torch_group=dist.group.WORLD)
|
|
|
|
|
|
def main():
|
|
local = _get_env_int("LOCAL_RANK", "OMPI_COMM_WORLD_LOCAL_RANK", "MPI_LOCALRANKID", "SLURM_LOCALID", default=0)
|
|
torch.cuda.set_device(local)
|
|
|
|
dtype_str = os.environ.get("DTYPE", "float16")
|
|
dtype = getattr(torch, dtype_str, torch.float16)
|
|
accum_map = {"float32": mscclpp.DataType.float32, "float16": mscclpp.DataType.float16}
|
|
accum_str = os.environ.get("ACCUM_DTYPE")
|
|
accum_dtype = accum_map.get(accum_str) if accum_str else None
|
|
n_warmup = _get_env_int("MSCCLPP_BENCH_WARMUP", default=10)
|
|
n_graph_launches = _get_env_int("MSCCLPP_BENCH_GRAPH_LAUNCHES", default=10)
|
|
n_iter = _get_env_int("MSCCLPP_BENCH_ITERS", default=100)
|
|
|
|
comm_group = init_dist()
|
|
cc = CustomizedComm(comm_group, symmetric_memory=False)
|
|
|
|
print(f"rank {local} starting benchmarks with dtype={dtype} accum_dtype={accum_dtype}...")
|
|
benchmark_allreduce(
|
|
cc,
|
|
dtype=dtype,
|
|
accum_dtype=accum_dtype,
|
|
n_warmup=n_warmup,
|
|
n_graph_launches=n_graph_launches,
|
|
n_iter=n_iter,
|
|
)
|
|
cc.barrier()
|
|
torch.cuda.synchronize()
|
|
|
|
if cc.multi_node or cc.multi_host_mnnvl:
|
|
if cc.rank == 0:
|
|
print("Skipping allgather benchmark on multi-node: this example's allgather path is single-node only.")
|
|
else:
|
|
benchmark_allgather(cc, dtype=dtype, n_warmup=n_warmup, n_graph_launches=n_graph_launches, n_iter=n_iter)
|
|
cc.barrier()
|
|
torch.cuda.synchronize()
|
|
|
|
cc.destroy()
|
|
del cc
|
|
del comm_group
|
|
gc.collect()
|
|
print(f"rank {local} completed successfully.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
exit_code = 0
|
|
try:
|
|
main()
|
|
except Exception:
|
|
exit_code = 1
|
|
traceback.print_exc()
|
|
finally:
|
|
sys.stdout.flush()
|
|
sys.stderr.flush()
|
|
os._exit(exit_code)
|