Files
mscclpp/examples/torch-integration/customized_comm_with_tuning.py
Binyang Li 865c2bc795 Optimize MNNVL allreduce without symmetric memory
Run the tuning example with symmetric memory disabled, make allreduce tuning use the same symmetric-memory mode as execution, and narrow the MNNVL small-message candidate set to avoid slower packet/NVLS choices. Increase packet and RSAG channel parallelism so non-symmetric CUDA-IPC paths can use 112-block packet and 128-block RSAG configs.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-28 07:55:52 +00:00

625 lines
23 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# torchrun --nnodes=1 --nproc_per_node=8 examples/torch-integration/customized_comm_with_tuning.py
# mpirun -np 2 --hostfile <hostfile> python3 examples/torch-integration/customized_comm_with_tuning.py
import gc
import fcntl
import ipaddress
import os
import socket
import struct
import sys
import traceback
def _get_bootstrap_world_size():
for name in ("WORLD_SIZE", "OMPI_COMM_WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS"):
value = os.environ.get(name)
if value is not None:
return int(value)
return None
_bootstrap_world_size = _get_bootstrap_world_size()
if (
_bootstrap_world_size
and _bootstrap_world_size > 1
and "MSCCLPP_MNNVL_NRANKS_PER_NODE" not in os.environ
and os.environ.get("MSCCLPP_ENABLE_MNNVL", "1") != "0"
):
os.environ["MSCCLPP_MNNVL_NRANKS_PER_NODE"] = str(_bootstrap_world_size)
import torch
import mscclpp
import mscclpp.ext
import mscclpp.utils as mscclpp_utils
# -- Helpers ------------------------------------------------------------------
def _make_tensor(size_bytes: int, dtype: torch.dtype) -> torch.Tensor:
"""Allocate a tensor backed by RawGpuBuffer (symmetric memory)."""
# PyTorch's from_dlpack does not support certain float8 DLPack type codes.
# Work around by importing as uint8 and reinterpreting via .view().
_DLPACK_UNSUPPORTED = (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)
if dtype in _DLPACK_UNSUPPORTED:
dlpack = mscclpp.RawGpuBuffer(size_bytes).to_dlpack(data_type=str(torch.uint8))
return torch.utils.dlpack.from_dlpack(dlpack).view(dtype)
dlpack = mscclpp.RawGpuBuffer(size_bytes).to_dlpack(data_type=str(dtype))
return torch.utils.dlpack.from_dlpack(dlpack)
def _load_algorithms(scratch: torch.Tensor, rank: int):
return mscclpp.ext.AlgorithmCollectionBuilder().build_default_algorithms(
scratch_buffer=scratch.data_ptr(),
scratch_buffer_size=scratch.nbytes,
rank=rank,
)
def _interfaces_for_ip(ip: str):
target = ipaddress.ip_address(ip)
for iface in os.listdir("/sys/class/net"):
try:
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
req = struct.pack("256s", iface.encode()[:15])
addr = socket.inet_ntoa(fcntl.ioctl(sock.fileno(), 0x8915, req)[20:24])
except OSError:
continue
if ipaddress.ip_address(addr) == target:
return iface
return None
def _resolve_interface(master_addr: str):
for env_name in ("MSCCLPP_INTERFACE", "MSCCLPP_SOCKET_IFNAME", "NCCL_SOCKET_IFNAME"):
value = os.environ.get(env_name)
if value:
iface = value.split(",")[0].strip()
if iface in os.listdir("/sys/class/net"):
return iface
raise ValueError(f"Interface {iface} from {env_name} does not exist")
return _interfaces_for_ip(master_addr)
def _get_env_int(*names: str, default=None):
for name in names:
value = os.environ.get(name)
if value is not None:
return int(value)
return default
def _running_under_mpi() -> bool:
return any(
name in os.environ
for name in ("OMPI_COMM_WORLD_RANK", "PMI_RANK", "PMIX_RANK", "MPI_LOCALRANKID", "SLURM_PROCID")
)
def _to_mscclpp_op(op) -> mscclpp.ReduceOp:
if op == torch.distributed.ReduceOp.SUM:
return mscclpp.ReduceOp.SUM
if op == torch.distributed.ReduceOp.MIN:
return mscclpp.ReduceOp.MIN
raise ValueError(f"unsupported op: {op}")
def _round_pow2(size: int) -> int:
"""Round up to next power-of-2, clamped to [1024, 256 MB]."""
size = max(size, 1024)
size = min(size, 256 << 20)
return 1 << (size - 1).bit_length()
# -- CustomizedComm -----------------------------------------------------------
class CustomizedComm:
"""Exposes all_reduce, all_gather, barrier with lazy per-size tuning."""
_TUNE_N_WARMUP = 5
_TUNE_N_GRAPH_LAUNCHES = 10
_TUNE_N_OPS_PER_GRAPH = 100
_CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 64, 112, 128]
_CANDIDATE_NTHREADS = [512, 768, 1024]
_NBLOCKS_LIMIT = {
"default_allreduce_nvls_packet": 16,
"default_allreduce_nvls_zero_copy": 32,
"default_allreduce_packet": 112,
"default_allreduce_allpair_packet": 56,
"default_allreduce_rsag": 128,
"default_allreduce_rsag_zero_copy": 64,
"default_allreduce_fullmesh": 64,
"default_allgather_fullmesh2": 32,
}
def __init__(self, comm: mscclpp.CommGroup, symmetric_memory: bool = False):
self.comm = comm
self.rank = comm.my_rank
self.world_size = comm.nranks
self.nranks_per_node = comm.nranks_per_node
self.mnnvl_domain = self.world_size > 1 and os.environ.get("MSCCLPP_MNNVL_NRANKS_PER_NODE") == str(
self.world_size
)
self.multi_node = self.world_size > self.nranks_per_node and not self.mnnvl_domain
self.multi_host_mnnvl = self.mnnvl_domain and self.world_size > 1
self.symmetric_memory = symmetric_memory
self._nvls = mscclpp.is_nvls_supported()
self._scratch = _make_tensor(1 << 27, torch.float16)
self._barrier_tensor = _make_tensor(4096, torch.float32)
algos = _load_algorithms(self._scratch, self.rank)
self._algos = {(a.collective, a.name): a for a in algos}
# {collective: {rounded_size: (algo, nblocks, nthreads)}}
self._tune_cache: dict[str, dict[int, tuple]] = {"allreduce": {}, "allgather": {}}
self._tune_buf = None
self._time_buf = None
def _algo(self, collective: str, name: str):
return self._algos.get((collective, name))
def _nblocks_limit(self, algo_name: str, size: int) -> int:
if algo_name == "default_allreduce_packet" and size < (1 << 20):
return 56
return self._NBLOCKS_LIMIT.get(algo_name, 128)
def _default_ar_config(self):
"""Fallback allreduce config for barrier / timing sync."""
pkt = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and pkt:
return (pkt, 0, 0)
if self.multi_node or self.multi_host_mnnvl:
rsag = self._algo("allreduce", "default_allreduce_rsag")
if rsag:
return (rsag, 0, 0)
return (self._algo("allreduce", "default_allreduce_packet"), 0, 0)
# -- low-level execute --
def _exec_ar(self, tensor, algo, nb, nt, op=mscclpp.ReduceOp.SUM, stream=None, accum_dtype=None, sym=True):
s = stream.cuda_stream if stream else torch.cuda.current_stream().cuda_stream
ret = algo.execute(
comm=self.comm.communicator,
input_buffer=tensor.data_ptr(),
output_buffer=tensor.data_ptr(),
input_size=tensor.nbytes,
output_size=tensor.nbytes,
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype),
op=op,
stream=s,
nblocks=nb,
nthreads_per_block=nt,
symmetric_memory=sym,
accum_dtype=accum_dtype,
)
if ret != 0:
print(f"Rank {self.rank}: {algo.name} failed ({ret})")
return ret
def _exec_ag(self, inp, out, algo, nb, nt, stream=None, sym=None):
if sym is None:
sym = self.symmetric_memory
s = stream.cuda_stream if stream else torch.cuda.current_stream().cuda_stream
ret = algo.execute(
comm=self.comm.communicator,
input_buffer=inp.data_ptr(),
output_buffer=out.data_ptr(),
input_size=inp.nbytes,
output_size=out.nbytes,
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(inp.dtype),
op=mscclpp.ReduceOp.NOP,
stream=s,
nblocks=nb,
nthreads_per_block=nt,
symmetric_memory=sym,
)
if ret != 0:
print(f"Rank {self.rank}: AG {algo.name} failed ({ret})")
return ret
def _barrier_internal(self):
a, nb, nt = self._default_ar_config()
self._exec_ar(self._barrier_tensor, a, nb, nt, sym=self.symmetric_memory)
# -- lazy tuning --
def _ensure_tune_bufs(self):
if self._tune_buf is None:
self._tune_buf = _make_tensor(1 << 27, torch.float16)
self._tune_buf.normal_()
self._time_buf = _make_tensor(4096, torch.float32)
return self._tune_buf
def _ar_candidates(self, size: int):
out = []
if self.multi_host_mnnvl:
if size <= 4 << 20:
a = self._algo("allreduce", "default_allreduce_allpair_packet")
if a:
out.append(a)
if size <= 64 << 10:
a = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and a:
out.append(a)
if size > 128 << 10:
a = self._algo("allreduce", "default_allreduce_packet")
if a:
out.append(a)
if size >= 512 << 10:
a = self._algo("allreduce", "default_allreduce_rsag_zero_copy")
if self.symmetric_memory and a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
if self._nvls and self.symmetric_memory and a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_rsag")
if a:
out.append(a)
return out
if self.multi_node:
a = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_packet")
if a:
out.append(a)
if size >= 512 << 10:
a = self._algo("allreduce", "default_allreduce_rsag")
if a:
out.append(a)
return out
if size <= 4 << 20:
a = self._algo("allreduce", "default_allreduce_packet")
if a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_allpair_packet")
if a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and a:
out.append(a)
if size >= 512 << 10:
a = self._algo("allreduce", "default_allreduce_rsag_zero_copy")
if a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
if self._nvls and self.symmetric_memory and a:
out.append(a)
if torch.version.hip is not None:
a = self._algo("allreduce", "default_allreduce_fullmesh")
if a:
out.append(a)
return out
def _ag_candidates(self):
if self.multi_node or self.multi_host_mnnvl:
return []
a = self._algo("allgather", "default_allgather_fullmesh2")
return [a] if a else []
def _run_tune(self, collective, algo, buf, size, nb, nt):
"""Single tune invocation for either collective."""
if collective == "allreduce":
return algo.execute(
comm=self.comm.communicator,
input_buffer=buf.data_ptr(),
output_buffer=buf.data_ptr(),
input_size=size,
output_size=size,
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(buf.dtype),
op=mscclpp.ReduceOp.SUM,
stream=torch.cuda.current_stream().cuda_stream,
nblocks=nb,
nthreads_per_block=nt,
symmetric_memory=self.symmetric_memory,
)
else:
total = size * self.world_size
out_ptr = buf.data_ptr()
return algo.execute(
comm=self.comm.communicator,
input_buffer=out_ptr + self.rank * size,
output_buffer=out_ptr,
input_size=size,
output_size=total,
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(buf.dtype),
op=mscclpp.ReduceOp.NOP,
stream=torch.cuda.current_stream().cuda_stream,
nblocks=nb,
nthreads_per_block=nt,
symmetric_memory=False,
)
def _tune_size(self, collective: str, target_size: int):
"""Auto-tune one (collective, target_size) pair and cache result."""
buf = self._ensure_tune_bufs()
cands = self._ar_candidates(target_size) if collective == "allreduce" else self._ag_candidates()
best_time, best_cfg = float("inf"), None
used = set()
run = lambda a, nb, nt: self._run_tune(collective, a, buf, target_size, nb, nt)
for algo in cands:
nb_limit = self._nblocks_limit(algo.name, target_size)
for nb in self._CANDIDATE_NBLOCKS:
if nb > nb_limit:
continue
for nt in self._CANDIDATE_NTHREADS:
# Feasibility — sync result across ranks so all agree
ret = run(algo, nb, nt)
torch.cuda.synchronize()
self._time_buf[0] = float(ret)
self._exec_ar(self._time_buf[:1], *self._default_ar_config(), sym=self.symmetric_memory)
if self._time_buf[0].item() != 0:
continue
used.add(algo)
# Warmup
for _ in range(self._TUNE_N_WARMUP):
run(algo, nb, nt)
# CUDA-graph timed benchmark
cs = torch.cuda.Stream()
cs.wait_stream(torch.cuda.current_stream())
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=cs):
for _ in range(self._TUNE_N_OPS_PER_GRAPH):
run(algo, nb, nt)
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
start.record(cs)
with torch.cuda.stream(cs):
for _ in range(self._TUNE_N_GRAPH_LAUNCHES):
g.replay()
end.record(cs)
end.synchronize()
elapsed = start.elapsed_time(end)
# Cross-rank timing sync
self._time_buf.fill_(elapsed)
torch.cuda.current_stream().wait_stream(cs)
self._exec_ar(self._time_buf, *self._default_ar_config(), sym=self.symmetric_memory)
avg = self._time_buf[self.rank].item() / self.world_size
if avg < best_time:
best_time, best_cfg = avg, (algo, nb, nt)
if best_cfg:
self._tune_cache[collective][target_size] = best_cfg
if self.rank == 0:
n = self._TUNE_N_GRAPH_LAUNCHES * self._TUNE_N_OPS_PER_GRAPH
print(
f"[tune] {collective} size={target_size}: {best_cfg[0].name} "
f"nb={best_cfg[1]} nt={best_cfg[2]} time={best_time / n * 1000:.2f}us",
flush=True,
)
else:
fb = (
self._default_ar_config()
if collective == "allreduce"
else ((self._ag_candidates()[0], 32, 512) if self._ag_candidates() else None)
)
self._tune_cache[collective][target_size] = fb
torch.cuda.synchronize()
self._barrier_internal()
for a in used:
a.reset()
# -- public API --
def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, stream=None, accum_dtype=None):
sz = _round_pow2(tensor.nbytes)
if sz not in self._tune_cache["allreduce"]:
self._tune_size("allreduce", sz)
a, nb, nt = self._tune_cache["allreduce"][sz]
self._exec_ar(
tensor, a, nb, nt, op=_to_mscclpp_op(op), stream=stream, accum_dtype=accum_dtype, sym=self.symmetric_memory
)
def all_gather(self, output_tensor, input_tensor, stream=None):
if self.multi_node or self.multi_host_mnnvl:
raise RuntimeError("all_gather in this example currently supports only single-node runs")
sz = _round_pow2(input_tensor.nbytes)
if sz not in self._tune_cache["allgather"]:
self._tune_size("allgather", sz)
a, nb, nt = self._tune_cache["allgather"][sz]
self._exec_ag(input_tensor, output_tensor, a, nb, nt, stream=stream, sym=self.symmetric_memory)
def barrier(self):
self._barrier_internal()
def destroy(self):
self._algos.clear()
self._tune_cache = {"allreduce": {}, "allgather": {}}
self._tune_buf = self._time_buf = self._barrier_tensor = self._scratch = self.comm = None
# -- Benchmarks (standalone) --------------------------------------------------
def _bench_sizes(low=None, high=None):
if low is None:
low = _get_env_int("MSCCLPP_BENCH_LOW_SIZE", default=5 * 1024)
if high is None:
high = _get_env_int("MSCCLPP_BENCH_HIGH_SIZE", default=80 << 20)
sizes, c = [], low
while c <= high:
sizes.append(c)
c *= 2
return sizes
def benchmark_allreduce(
comm: CustomizedComm, dtype=torch.float16, accum_dtype=None, n_warmup=10, n_graph_launches=10, n_iter=100
):
sizes = _bench_sizes()
if comm.rank == 0:
print(f"\n{'='*60}\nAllreduce Benchmark\n{'='*60}")
print(f"{'Nelements':<18} {'Size(B)':<18} {'Time(us)':<18} {'AlgoBW(GB/s)':<18}")
cs = torch.cuda.Stream()
buf = _make_tensor(1 << 27, dtype)
buf.normal_() if dtype in (torch.float16, torch.float32, torch.bfloat16) else buf.fill_(0)
for size in sizes:
nelems = size // buf.element_size()
t = buf[: size // buf.element_size()]
comm.all_reduce(t, accum_dtype=accum_dtype)
torch.cuda.synchronize()
cs.wait_stream(torch.cuda.current_stream())
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=cs):
for _ in range(n_iter):
comm.all_reduce(t, accum_dtype=accum_dtype)
with torch.cuda.stream(cs):
for _ in range(n_warmup):
g.replay()
comm.barrier()
cs.synchronize()
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
s.record(cs)
with torch.cuda.stream(cs):
for _ in range(n_graph_launches):
g.replay()
e.record(cs)
e.synchronize()
ms = s.elapsed_time(e) / (n_graph_launches * n_iter)
if comm.rank == 0:
print(f"{nelems:<18} {size:<18} {ms*1000:<18.2f} {size/(ms*1e-3)/1e9:<18.2f}")
def benchmark_allgather(comm: CustomizedComm, dtype=torch.float16, n_warmup=10, n_graph_launches=10, n_iter=100):
sizes = _bench_sizes()
if comm.rank == 0:
print(f"\n{'='*60}\nAllgather Benchmark\n{'='*60}")
print(f"{'PerRank(B)':<18} {'Total(B)':<18} {'Time(us)':<18} {'AlgoBW(GB/s)':<18}")
cs = torch.cuda.Stream()
buf = _make_tensor(1 << 27, dtype)
buf.normal_() if dtype in (torch.float16, torch.float32, torch.bfloat16) else buf.fill_(0)
for prs in sizes:
total = prs * comm.world_size
if total > buf.nbytes:
break
nt = total // buf.element_size()
npr = prs // buf.element_size()
out = buf[:nt]
inp = out[comm.rank * npr : (comm.rank + 1) * npr]
comm.all_gather(out, inp)
torch.cuda.synchronize()
cs.wait_stream(torch.cuda.current_stream())
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=cs):
for _ in range(n_iter):
comm.all_gather(out, inp)
with torch.cuda.stream(cs):
for _ in range(n_warmup):
g.replay()
comm.barrier()
cs.synchronize()
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
s.record(cs)
with torch.cuda.stream(cs):
for _ in range(n_graph_launches):
g.replay()
e.record(cs)
e.synchronize()
ms = s.elapsed_time(e) / (n_graph_launches * n_iter)
if comm.rank == 0:
print(f"{prs:<18} {total:<18} {ms*1000:<18.2f} {total/(ms*1e-3)/1e9:<18.2f}")
# -- Bootstrap & main ---------------------------------------------------------
def init_dist() -> mscclpp.CommGroup:
addr = os.environ.get("MSCCLPP_MASTER_ADDR")
rank = _get_env_int("RANK", "OMPI_COMM_WORLD_RANK", "PMI_RANK", "SLURM_PROCID")
world = _get_env_int("WORLD_SIZE", "OMPI_COMM_WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS")
if addr and rank is not None and world is not None:
port = os.environ.get("MSCCLPP_MASTER_PORT", "29500")
iface = _resolve_interface(addr)
if not iface:
raise ValueError(f"No interface for {addr}")
return mscclpp.CommGroup(interfaceIpPortTrio=f"{iface}:{addr}:{port}", rank=rank, size=world)
if _running_under_mpi():
try:
from mpi4py import MPI
except ModuleNotFoundError as exc:
raise RuntimeError("mpi4py is required to launch this example with mpirun") from exc
return mscclpp.CommGroup(mpi_comm=MPI.COMM_WORLD)
import torch.distributed as dist
dist.init_process_group(backend="gloo")
return mscclpp.CommGroup(torch_group=dist.group.WORLD)
def main():
local = _get_env_int("LOCAL_RANK", "OMPI_COMM_WORLD_LOCAL_RANK", "MPI_LOCALRANKID", "SLURM_LOCALID", default=0)
torch.cuda.set_device(local)
dtype_str = os.environ.get("DTYPE", "float16")
dtype = getattr(torch, dtype_str, torch.float16)
accum_map = {"float32": mscclpp.DataType.float32, "float16": mscclpp.DataType.float16}
accum_str = os.environ.get("ACCUM_DTYPE")
accum_dtype = accum_map.get(accum_str) if accum_str else None
n_warmup = _get_env_int("MSCCLPP_BENCH_WARMUP", default=10)
n_graph_launches = _get_env_int("MSCCLPP_BENCH_GRAPH_LAUNCHES", default=10)
n_iter = _get_env_int("MSCCLPP_BENCH_ITERS", default=100)
comm_group = init_dist()
cc = CustomizedComm(comm_group, symmetric_memory=False)
print(f"rank {local} starting benchmarks with dtype={dtype} accum_dtype={accum_dtype}...")
benchmark_allreduce(
cc,
dtype=dtype,
accum_dtype=accum_dtype,
n_warmup=n_warmup,
n_graph_launches=n_graph_launches,
n_iter=n_iter,
)
cc.barrier()
torch.cuda.synchronize()
if cc.multi_node or cc.multi_host_mnnvl:
if cc.rank == 0:
print("Skipping allgather benchmark on multi-node: this example's allgather path is single-node only.")
else:
benchmark_allgather(cc, dtype=dtype, n_warmup=n_warmup, n_graph_launches=n_graph_launches, n_iter=n_iter)
cc.barrier()
torch.cuda.synchronize()
cc.destroy()
del cc
del comm_group
gc.collect()
print(f"rank {local} completed successfully.")
if __name__ == "__main__":
exit_code = 0
try:
main()
except Exception:
exit_code = 1
traceback.print_exc()
finally:
sys.stdout.flush()
sys.stderr.flush()
os._exit(exit_code)