mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Merge multinode branch
This commit is contained in:
@@ -62,6 +62,7 @@ void register_algorithm(nb::module_& m) {
|
||||
.def_prop_ro("buffer_mode", &Algorithm::bufferMode)
|
||||
.def_prop_ro("constraint", &Algorithm::constraint)
|
||||
.def_prop_ro("type", &Algorithm::type)
|
||||
.def("reset", &Algorithm::reset)
|
||||
.def(
|
||||
"execute",
|
||||
[](Algorithm& self, std::shared_ptr<Communicator> comm, uintptr_t input, uintptr_t output,
|
||||
|
||||
@@ -248,6 +248,10 @@ class MscclppAlltoAllV:
|
||||
# Fast path: skip GPU copies + bootstrap exchange if split sizes unchanged
|
||||
splits_key = (tuple(send_counts_bytes), tuple(recv_counts_bytes))
|
||||
if splits_key != self._cached_splits_key:
|
||||
# Clear cached contexts to free RegisteredMemory for old (possibly freed) tensors.
|
||||
# Without this, stale CUDA IPC handles accumulate and eventually SIGSEGV.
|
||||
if hasattr(self._algo, 'reset'):
|
||||
self._algo.reset()
|
||||
# Copy counts/displacements to GPU
|
||||
self._d_send_counts.copy_(torch.tensor(send_counts_bytes, dtype=torch.int64))
|
||||
self._d_send_displs.copy_(torch.tensor(send_displs_bytes, dtype=torch.int64))
|
||||
@@ -268,13 +272,16 @@ class MscclppAlltoAllV:
|
||||
stream = torch.cuda.current_stream()
|
||||
cuda_stream = stream.cuda_stream
|
||||
|
||||
# Use full buffer sizes (not actual data sizes) so the C++ context
|
||||
# key (input_ptr, output_ptr, inputSize, outputSize) is always the
|
||||
# same when using persistent buffers. This ensures only ONE context
|
||||
# is ever created, avoiding bootstrap TCP on every unique size combo.
|
||||
# The kernel uses per-peer sendCounts/recvCounts for actual data bounds.
|
||||
input_size = input.numel() * elem_size
|
||||
output_size = output.numel() * elem_size
|
||||
# Use the full underlying storage size (not just the view's active data)
|
||||
# for the context key, so that reusing views of the same tensor with
|
||||
# different split sizes doesn't create new contexts (which leak
|
||||
# RegisteredMemory for stale buffers).
|
||||
try:
|
||||
input_alloc_size = input.untyped_storage().size()
|
||||
output_alloc_size = output.untyped_storage().size()
|
||||
except Exception:
|
||||
input_alloc_size = input.nelement() * input.element_size()
|
||||
output_alloc_size = output.nelement() * output.element_size()
|
||||
|
||||
self._a2av_call_count += 1
|
||||
_cid = self._a2av_call_count
|
||||
@@ -297,8 +304,8 @@ class MscclppAlltoAllV:
|
||||
self._comm,
|
||||
input.data_ptr(),
|
||||
output.data_ptr(),
|
||||
input_size,
|
||||
output_size,
|
||||
input_alloc_size,
|
||||
output_alloc_size,
|
||||
_torch_dtype_to_mscclpp(dtype),
|
||||
ReduceOp.NOP,
|
||||
cuda_stream,
|
||||
|
||||
@@ -12,30 +12,113 @@ Usage:
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
import socket
|
||||
import struct
|
||||
import pickle
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
# Must init torch.distributed before importing mscclpp modules
|
||||
# to set rank/world_size environment variables
|
||||
|
||||
|
||||
def _get_routable_ip() -> str:
|
||||
"""Get a routable IP address for this host (not 127.0.0.1)."""
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.connect(("8.8.8.8", 80)) # doesn't actually send data
|
||||
ip = s.getsockname()[0]
|
||||
s.close()
|
||||
return ip
|
||||
except Exception:
|
||||
return socket.gethostbyname(socket.gethostname())
|
||||
|
||||
|
||||
def _tcp_broadcast_unique_id(unique_id_bytes: bytes, rank: int, world_size: int,
|
||||
master_addr: str, port: int = 18515) -> bytes:
|
||||
"""Broadcast UniqueId bytes from rank 0 to all other ranks via TCP."""
|
||||
if rank == 0:
|
||||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
server.bind(("", port))
|
||||
server.listen(world_size - 1)
|
||||
for _ in range(world_size - 1):
|
||||
conn, _ = server.accept()
|
||||
length = len(unique_id_bytes)
|
||||
conn.sendall(struct.pack("!I", length) + unique_id_bytes)
|
||||
conn.close()
|
||||
server.close()
|
||||
return unique_id_bytes
|
||||
else:
|
||||
# Retry connection to rank 0
|
||||
for attempt in range(120):
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.connect((master_addr, port))
|
||||
break
|
||||
except (ConnectionRefusedError, OSError):
|
||||
time.sleep(0.5)
|
||||
if attempt == 119:
|
||||
raise RuntimeError(f"Rank {rank}: failed to connect to {master_addr}:{port}")
|
||||
raw_len = b""
|
||||
while len(raw_len) < 4:
|
||||
raw_len += s.recv(4 - len(raw_len))
|
||||
length = struct.unpack("!I", raw_len)[0]
|
||||
data = b""
|
||||
while len(data) < length:
|
||||
data += s.recv(length - len(data))
|
||||
s.close()
|
||||
return data
|
||||
|
||||
|
||||
def main():
|
||||
# Get rank/world from MPI environment
|
||||
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("PMI_RANK", 0)))
|
||||
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", os.environ.get("PMI_SIZE", 1)))
|
||||
|
||||
# Set CUDA device
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count()))
|
||||
# Set CUDA device — prefer MPI-provided local rank to handle any rank mapping
|
||||
local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK",
|
||||
os.environ.get("MPI_LOCALRANKID",
|
||||
os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count()))))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# Initialize torch.distributed with NCCL (need MASTER_ADDR/PORT)
|
||||
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
|
||||
# Disable UCX in OpenMPI to avoid version mismatch crashes
|
||||
os.environ.setdefault("OMPI_MCA_pml", "ob1")
|
||||
os.environ.setdefault("OMPI_MCA_btl", "tcp,vader,self")
|
||||
|
||||
# Initialize torch.distributed — use NCCL when torch_fn benchmarks are needed,
|
||||
# otherwise gloo avoids IB configuration issues on some clusters.
|
||||
# Set ALLTOALLV_BACKEND=nccl to enable torch baseline comparison.
|
||||
backend = os.environ.get("ALLTOALLV_BACKEND", "gloo")
|
||||
# For multi-node: MASTER_ADDR must be set to rank 0's routable IP.
|
||||
# Single-node auto-detects; multi-node requires it from the launcher.
|
||||
if "MASTER_ADDR" not in os.environ:
|
||||
if rank == 0:
|
||||
os.environ["MASTER_ADDR"] = _get_routable_ip()
|
||||
else:
|
||||
# Check if we're single-node (all ranks on same host)
|
||||
n_gpus = torch.cuda.device_count()
|
||||
if world_size <= n_gpus:
|
||||
# Likely single-node – 127.0.0.1 works
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Rank {rank}: MASTER_ADDR not set for multi-node run "
|
||||
f"(world_size={world_size} > local GPUs={n_gpus}). "
|
||||
f"Set it in your launcher, e.g.:\n"
|
||||
f" mpirun -x MASTER_ADDR=<node0_ip> -x MASTER_PORT=29500 ..."
|
||||
)
|
||||
os.environ.setdefault("MASTER_PORT", "29500")
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size,
|
||||
device_id=torch.device(f"cuda:{local_rank}"))
|
||||
if backend == "nccl":
|
||||
# Don't use device_id= eager init — it triggers an immediate NCCL allreduce
|
||||
# that fails on some platforms (e.g. GB200 with NCCL 2.28.9).
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
||||
else:
|
||||
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
|
||||
if rank == 0:
|
||||
print(f"Testing MscclppAlltoAllV with {world_size} ranks")
|
||||
@@ -48,33 +131,51 @@ def main():
|
||||
UniqueId,
|
||||
)
|
||||
from mscclpp.ext.alltoallv_single import MscclppAlltoAllV
|
||||
import pickle
|
||||
|
||||
# Create mscclpp communicator with TcpBootstrap
|
||||
# Use torch.distributed to share the unique ID via pickle
|
||||
# Broadcast UniqueId via TCP sockets (avoids mpi4py/UCX issues)
|
||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||
uid_port = int(os.environ.get("MSCCLPP_UID_PORT", "18515"))
|
||||
bootstrap = TcpBootstrap(rank, world_size)
|
||||
|
||||
if rank == 0:
|
||||
unique_id = bootstrap.create_unique_id()
|
||||
# Serialize UniqueId via pickle and broadcast
|
||||
pickled = pickle.dumps(unique_id)
|
||||
id_tensor = torch.zeros(256, dtype=torch.uint8, device='cuda')
|
||||
id_tensor[:len(pickled)] = torch.tensor(list(pickled), dtype=torch.uint8)
|
||||
# Also send length
|
||||
len_tensor = torch.tensor([len(pickled)], dtype=torch.int64, device='cuda')
|
||||
uid_bytes = pickle.dumps(unique_id)
|
||||
else:
|
||||
id_tensor = torch.zeros(256, dtype=torch.uint8, device='cuda')
|
||||
len_tensor = torch.zeros(1, dtype=torch.int64, device='cuda')
|
||||
|
||||
dist.broadcast(len_tensor, src=0)
|
||||
dist.broadcast(id_tensor, src=0)
|
||||
uid_bytes = b""
|
||||
|
||||
uid_bytes = _tcp_broadcast_unique_id(uid_bytes, rank, world_size, master_addr, uid_port)
|
||||
if rank != 0:
|
||||
pickled_len = int(len_tensor.item())
|
||||
pickled = bytes(id_tensor[:pickled_len].cpu().tolist())
|
||||
unique_id = pickle.loads(pickled)
|
||||
unique_id = pickle.loads(uid_bytes)
|
||||
|
||||
bootstrap.initialize(unique_id)
|
||||
|
||||
# ── Multi-node diagnostics ─────────────────────────────────────────
|
||||
import subprocess, platform
|
||||
hostname = platform.node()
|
||||
n_ranks_per_node = bootstrap.get_n_ranks_per_node()
|
||||
is_multi_node = (world_size > n_ranks_per_node)
|
||||
|
||||
# Check IB device availability
|
||||
try:
|
||||
ib_out = subprocess.check_output(["ibv_devinfo", "-l"], stderr=subprocess.DEVNULL, timeout=5).decode().strip()
|
||||
ib_devices = [l.strip() for l in ib_out.splitlines() if l.strip() and "device" not in l.lower()]
|
||||
except Exception:
|
||||
ib_devices = []
|
||||
|
||||
if rank == 0:
|
||||
print(f" Hostname: {hostname}")
|
||||
print(f" nRanksPerNode: {n_ranks_per_node}, isMultiNode: {is_multi_node}")
|
||||
print(f" IB devices: {ib_devices if ib_devices else 'NONE FOUND'}")
|
||||
print(f" MSCCLPP_SOCKET_IFNAME: {os.environ.get('MSCCLPP_SOCKET_IFNAME', '<not set>')}")
|
||||
if is_multi_node and not ib_devices:
|
||||
print(f" WARNING: Multi-node detected but no IB devices! Cross-node will fail.")
|
||||
# Also print from rank n_ranks_per_node (first rank on node 1) for comparison
|
||||
if is_multi_node and rank == n_ranks_per_node:
|
||||
print(f" [Node 1] Hostname: {hostname}, rank={rank}")
|
||||
print(f" [Node 1] IB devices: {ib_devices if ib_devices else 'NONE FOUND'}")
|
||||
# ── End diagnostics ────────────────────────────────────────────────
|
||||
|
||||
comm = Communicator(bootstrap)
|
||||
|
||||
# Create MscclppAlltoAllV with existing communicator
|
||||
@@ -188,6 +289,21 @@ def main():
|
||||
input_split_sizes=in_splits,
|
||||
output_split_sizes=out_splits)
|
||||
|
||||
# Detect whether torch comparison is possible:
|
||||
# - Need NCCL backend (gloo doesn't support all_to_all_single)
|
||||
# - Need native NCCL (mscclpp shim doesn't implement all collectives)
|
||||
use_torch_baseline = (backend == "nccl")
|
||||
if use_torch_baseline:
|
||||
try:
|
||||
tiny_in = torch.zeros(world_size, dtype=torch.float32, device='cuda')
|
||||
tiny_out = torch.zeros(world_size, dtype=torch.float32, device='cuda')
|
||||
dist.all_to_all_single(tiny_out, tiny_in)
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
use_torch_baseline = False
|
||||
if rank == 0:
|
||||
print(" [INFO] torch all_to_all_single unavailable, skipping torch baseline")
|
||||
|
||||
def torch_fn(inp, out, in_splits, out_splits):
|
||||
dist.all_to_all_single(out, inp,
|
||||
output_split_sizes=out_splits,
|
||||
@@ -202,22 +318,32 @@ def main():
|
||||
|
||||
def print_header():
|
||||
if rank == 0:
|
||||
print(f" {'Avg Size':>10s} "
|
||||
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} "
|
||||
f"{'torch Lat':>10s} {'torch BW':>9s} "
|
||||
f"{'Speedup':>7s}")
|
||||
print(f" {'-'*10} "
|
||||
f"{'-'*12} {'-'*11} "
|
||||
f"{'-'*10} {'-'*9} "
|
||||
f"{'-'*7}")
|
||||
if use_torch_baseline:
|
||||
print(f" {'Avg Size':>10s} "
|
||||
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} "
|
||||
f"{'torch Lat':>10s} {'torch BW':>9s} "
|
||||
f"{'Speedup':>7s}")
|
||||
print(f" {'-'*10} "
|
||||
f"{'-'*12} {'-'*11} "
|
||||
f"{'-'*10} {'-'*9} "
|
||||
f"{'-'*7}")
|
||||
else:
|
||||
print(f" {'Avg Size':>10s} "
|
||||
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s}")
|
||||
print(f" {'-'*10} "
|
||||
f"{'-'*12} {'-'*11}")
|
||||
|
||||
def print_row(size_str, m_lat, m_bw, t_lat, t_bw):
|
||||
def print_row(size_str, m_lat, m_bw, t_lat=None, t_bw=None):
|
||||
if rank == 0:
|
||||
speedup = m_bw / t_bw if t_bw > 0 else float('inf')
|
||||
print(f" {size_str:>10s} "
|
||||
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB "
|
||||
f"{t_lat:>8.1f}us {t_bw:>7.2f}GB "
|
||||
f"{speedup:>6.2f}x")
|
||||
if t_bw is not None and t_bw > 0:
|
||||
speedup = m_bw / t_bw
|
||||
print(f" {size_str:>10s} "
|
||||
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB "
|
||||
f"{t_lat:>8.1f}us {t_bw:>7.2f}GB "
|
||||
f"{speedup:>6.2f}x")
|
||||
else:
|
||||
print(f" {size_str:>10s} "
|
||||
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB")
|
||||
|
||||
# ── Test 3: Synthetic variable-size sweep ─────────────────────────────
|
||||
if rank == 0:
|
||||
@@ -227,6 +353,13 @@ def main():
|
||||
msg_sizes = [1 << s for s in range(10, 28) if s % 2 == 0]
|
||||
msg_sizes.append(128 * 1024 * 1024)
|
||||
|
||||
# Pre-compute max split sizes across all sweep iterations to allocate
|
||||
# fixed-size tensors. Reusing the same tensors keeps the NativeAlgorithm
|
||||
# context key stable (same ptrs + sizes) and avoids the context cache
|
||||
# leak that causes SIGSEGV when stale RegisteredMemory accumulates.
|
||||
max_in_elems = 0
|
||||
max_out_elems = 0
|
||||
sweep_params = [] # (avg_msg_size, in_splits, out_splits)
|
||||
for avg_msg_size in msg_sizes:
|
||||
random.seed(12345)
|
||||
avg_elems = avg_msg_size // 4
|
||||
@@ -234,19 +367,41 @@ def main():
|
||||
for i in range(world_size):
|
||||
row = [max(1, int(avg_elems * (0.5 + random.random()))) for _ in range(world_size)]
|
||||
send_matrix.append(row)
|
||||
|
||||
in_splits = send_matrix[rank]
|
||||
out_splits = [send_matrix[j][rank] for j in range(world_size)]
|
||||
max_in_elems = max(max_in_elems, sum(in_splits))
|
||||
max_out_elems = max(max_out_elems, sum(out_splits))
|
||||
sweep_params.append((avg_msg_size, in_splits, out_splits))
|
||||
|
||||
inp = torch.randn(sum(in_splits), dtype=torch.float32, device='cuda')
|
||||
out = torch.empty(sum(out_splits), dtype=torch.float32, device='cuda')
|
||||
# Allocate once at max size
|
||||
inp = torch.randn(max_in_elems, dtype=torch.float32, device='cuda')
|
||||
out = torch.empty(max_out_elems, dtype=torch.float32, device='cuda')
|
||||
|
||||
for avg_msg_size, in_splits, out_splits in sweep_params:
|
||||
n_warmup = 3 if avg_msg_size >= 16 * 1024 * 1024 else 5
|
||||
n_iters = 5 if avg_msg_size >= 64 * 1024 * 1024 else (10 if avg_msg_size >= 4 * 1024 * 1024 else 20)
|
||||
|
||||
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
|
||||
# Use views into the fixed buffers (same data_ptr → same context key)
|
||||
inp_view = inp[:sum(in_splits)]
|
||||
out_view = out[:sum(out_splits)]
|
||||
|
||||
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters)
|
||||
if use_torch_baseline:
|
||||
try:
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters)
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
print(f" [WARN] torch baseline failed: {e}")
|
||||
print(f" [INFO] Disabling torch baseline for remaining sizes")
|
||||
use_torch_baseline = False
|
||||
try:
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw)
|
||||
else:
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw)
|
||||
|
||||
# ── Test 4: Real MoE workloads ───────────────────────────────────────
|
||||
# Token counts from real MoE training runs (rank 0's view, 8 GPUs).
|
||||
@@ -257,19 +412,30 @@ def main():
|
||||
# per rank so every rank has the same total send and each NVLink
|
||||
# carries a realistically imbalanced load.
|
||||
|
||||
# 10 workloads picked from 3M dispatch records in a real MoE training run,
|
||||
# covering the full imbalance spectrum from nearly uniform (1.05×) to
|
||||
# extremely skewed (10×). Each has 32768 total tokens → 167.8MB.
|
||||
MOE_WORKLOADS = [
|
||||
{
|
||||
"name": "MoE-A",
|
||||
# input_splits=[3976,3916,4497,4838,2888,3839,4355,4459]
|
||||
# total_send=167,772,160 total_recv=148,316,160
|
||||
"input_tokens": [3976, 3916, 4497, 4838, 2888, 3839, 4355, 4459],
|
||||
},
|
||||
{
|
||||
"name": "MoE-B",
|
||||
# input_splits=[3009,7161,2719,2766,3428,3010,6290,4385]
|
||||
# total_send=167,772,160 total_recv=163,722,240
|
||||
"input_tokens": [3009, 7161, 2719, 2766, 3428, 3010, 6290, 4385],
|
||||
},
|
||||
{"name": "MoE-A", # imbalance ≈ 1.05× (near-uniform)
|
||||
"input_tokens": [4122, 4115, 4000, 4200, 4126, 4046, 4035, 4124]},
|
||||
{"name": "MoE-B", # imbalance ≈ 1.20×
|
||||
"input_tokens": [3770, 4236, 3966, 4046, 4524, 4132, 3825, 4269]},
|
||||
{"name": "MoE-C", # imbalance ≈ 1.35×
|
||||
"input_tokens": [4142, 4489, 4563, 3380, 3957, 4133, 3958, 4146]},
|
||||
{"name": "MoE-D", # imbalance ≈ 1.50× (median)
|
||||
"input_tokens": [4232, 3697, 4619, 4788, 4420, 3192, 3971, 3849]},
|
||||
{"name": "MoE-E", # imbalance ≈ 1.75×
|
||||
"input_tokens": [4178, 3209, 4678, 5085, 3108, 3365, 5439, 3706]},
|
||||
{"name": "MoE-F", # imbalance ≈ 2.00×
|
||||
"input_tokens": [4582, 3903, 3949, 3727, 4823, 5106, 2553, 4125]},
|
||||
{"name": "MoE-G", # imbalance ≈ 2.50×
|
||||
"input_tokens": [4036, 4438, 4804, 6180, 2913, 2472, 4105, 3820]},
|
||||
{"name": "MoE-H", # imbalance ≈ 3.50×
|
||||
"input_tokens": [3152, 1722, 4406, 4027, 5365, 6027, 4895, 3174]},
|
||||
{"name": "MoE-I", # imbalance ≈ 5.00×
|
||||
"input_tokens": [4384, 4194, 7840, 3079, 3460, 3506, 1568, 4737]},
|
||||
{"name": "MoE-J", # imbalance ≈ 10.00× (extreme skew)
|
||||
"input_tokens": [2710, 7661, 3354, 4457, 4609, 766, 3423, 5788]},
|
||||
]
|
||||
ELEMS_PER_TOKEN = 2560 # 5120 bytes / 2 bytes-per-bfloat16
|
||||
|
||||
@@ -304,10 +470,23 @@ def main():
|
||||
n_warmup, n_iters = 5, 20
|
||||
|
||||
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
|
||||
avg_bytes = total_bytes // world_size
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
|
||||
if use_torch_baseline:
|
||||
try:
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
print(f" [WARN] torch baseline failed: {e}")
|
||||
print(f" [INFO] Disabling torch baseline for remaining workloads")
|
||||
use_torch_baseline = False
|
||||
try:
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw)
|
||||
else:
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw)
|
||||
else:
|
||||
if rank == 0:
|
||||
print("\n[Test 4] Skipped (real MoE workloads require exactly 8 ranks)")
|
||||
|
||||
Reference in New Issue
Block a user