Files
mscclpp/test/python/ext/ep/test_low_latency_multirank.py
Qinghua Zhou fdf7d579dc ext/ep: optional preallocated outputs for low_latency_dispatch
Add optional out_packed_recv_x / out_src_info / out_layout_range /
out_count parameters to Buffer::low_latency_dispatch so callers can
hoist the four recv-side allocations out of a hot loop, mirroring the
existing out= path on low_latency_combine.

The bench in test_low_latency_multirank.py preallocates these tensors
once and passes them on every iter so the timed loop reflects kernel
cost, not torch.empty + caching-allocator overhead.
2026-04-30 18:45:44 +00:00

358 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Multi-rank low-latency functional test for mscclpp_ep.
Launch with (intra-node, 8 GPUs):
torchrun --nproc_per_node=8 test/python/ext/ep/test_low_latency_multirank.py
Launch with (2 nodes, 1 GPU per node -- DeepEP's recommended LL topology):
# node 0:
MASTER_ADDR=<master> MASTER_PORT=29600 NODE_RANK=0 \
torchrun --nnodes=2 --nproc_per_node=1 --rdzv-backend=c10d \
--rdzv-endpoint=<master>:29600 test/python/ext/ep/test_low_latency_multirank.py
# node 1:
MASTER_ADDR=<master> MASTER_PORT=29600 NODE_RANK=1 \
torchrun --nnodes=2 --nproc_per_node=1 --rdzv-backend=c10d \
--rdzv-endpoint=<master>:29600 test/python/ext/ep/test_low_latency_multirank.py
Exercises the LL dispatch + combine round-trip on a single node. The
minimal correctness check:
- dispatch: per-expert received token counts agree with an all-gathered
reference computed from topk_idx across all ranks;
- combine: the reconstructed x matches the analytical sum
``x * sum(topk_weights, masked by topk_idx == -1)``.
Known limitation (see src/ext/ep/README.md): the LL kernels drive every
peer via MSCCL++ PortChannel. Intra-node IB loopback between two HCAs on
the same host (what an 8-GPU single-node launch exercises) currently hangs
during dispatch; cross-node LL with one GPU per node works as designed.
Adapted from DeepEP/tests/test_low_latency.py stripped to the bare checks
we need for an LL port smoke test. BF16-only (no FP8 check).
"""
from __future__ import annotations
import os
import random
import sys
# Disable ProcessGroupNCCL's HeartbeatMonitor before importing torch.distributed.
# It runs in a background thread polling the TCPStore; under mpirun, rank 0
# (the store server) can exit before non-zero ranks finish teardown, producing
# noisy 'recvValue failed / Connection was likely closed' stack traces.
os.environ.setdefault("TORCH_NCCL_ENABLE_MONITORING", "0")
import torch
import torch.distributed as dist
def init_dist():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", rank))
torch.cuda.set_device(local_rank)
dist.init_process_group(
backend="nccl",
init_method=f"tcp://{os.environ.get('MASTER_ADDR','127.0.0.1')}:{os.environ.get('MASTER_PORT','29500')}",
world_size=world_size,
rank=rank,
)
return rank, world_size, local_rank, dist.new_group(list(range(world_size)))
def main():
rank, num_ranks, local_rank, group = init_dist()
from mscclpp.ext import ep
# Shrink the "bf16 precision" anchor to keep values small.
rank_offset = 128
assert num_ranks - rank_offset < 257, "too many ranks for bf16 precision anchor"
num_tokens = int(os.environ.get("MSCCLPP_EP_LL_TOKENS", "64"))
hidden = int(os.environ.get("MSCCLPP_EP_LL_HIDDEN", "7168")) # LL kernels are compiled for a fixed set; see SWITCH_HIDDEN
num_topk = int(os.environ.get("MSCCLPP_EP_LL_TOPK", "4"))
num_experts_per_rank = int(os.environ.get("MSCCLPP_EP_LL_EXPERTS_PER_RANK", "4"))
num_experts = num_ranks * num_experts_per_rank
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
torch.manual_seed(0xB3C4 + rank)
random.seed(0xB3C4 + rank)
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * (rank - rank_offset)
# Encode the per-token index into the last 128 elements so the receiver
# can verify which source token it is looking at.
x[:, -128:] = (
torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1)
)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device="cuda").abs()
# Randomly mask some positions
for _ in range(min(10, num_tokens)):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
num_rdma_bytes = ep.Buffer.get_low_latency_rdma_size_hint(
num_tokens, hidden, num_ranks, num_experts
)
if rank == 0:
print(
f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} "
f"num_experts={num_experts} num_topk={num_topk} "
f"num_rdma_bytes={num_rdma_bytes}",
flush=True,
)
buf = ep.Buffer(
group,
num_nvl_bytes=0,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=max(1, num_experts // num_ranks),
)
print(
f"[rank {rank}] Buffer created is_available={buf.is_available()} "
f"is_internode={buf.is_internode_available()}",
flush=True,
)
assert buf.is_available()
dist.barrier(group=group)
torch.cuda.synchronize()
print(f"[rank {rank}] pre-dispatch", flush=True)
# --- Dispatch ---
# Return tuple (7 items):
# packed_recv_x, packed_recv_x_scales (optional, FP8-only),
# packed_recv_count, packed_recv_src_info, packed_recv_layout_range,
# event, hook
(
packed_recv_x, _packed_recv_x_scales,
packed_recv_count, packed_recv_src_info, packed_recv_layout_range,
_event, recv_hook,
) = buf.low_latency_dispatch(
x, topk_idx, num_tokens, num_experts,
False, False, True, # use_fp8, async, return_recv_hook
)
# Send phase launched on compute_stream; wait for local launch.
torch.cuda.synchronize()
dist.barrier(group=group)
print(f"[rank {rank}] dispatch-send done, calling hook", flush=True)
recv_hook() # Recv phase.
torch.cuda.synchronize()
print(f"[rank {rank}] post-dispatch", flush=True)
handle = (packed_recv_src_info, packed_recv_layout_range)
# packed_recv_x: [num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]
# packed_recv_count: [num_local_experts] int32
# Reference: gather all ranks' topk_idx and count expected tokens per expert.
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device="cuda")
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
int_mask = (1 << 32) - 1
for i in range(num_local_experts):
expert_id = rank * num_local_experts + i
recv_count = int(packed_recv_count[i].item())
expected_count = int((all_topk_idx == expert_id).sum().item())
recv_layout_range = handle[1][i]
layout_sum = int((recv_layout_range & int_mask).sum().item())
assert recv_count == expected_count, (
f"rank{rank} expert{expert_id}: recv_count={recv_count} != expected={expected_count}"
)
assert layout_sum == recv_count, (
f"rank{rank} expert{expert_id}: layout range sum {layout_sum} != recv_count {recv_count}"
)
if recv_count:
recv_x = packed_recv_x[i, :recv_count]
# All columns except the last 128 should share the value (src_rank - rank_offset)
recv_x_lo = recv_x[:, :-128]
amin = recv_x_lo.amin(dim=-1)
amax = recv_x_lo.amax(dim=-1)
assert torch.equal(amin, amax), f"rank{rank} expert{expert_id}: non-uniform recv block"
if rank == 0:
print(f"[dispatch] OK (ranks={num_ranks})", flush=True)
# --- Combine ---
# Simulate the downstream GEMM output = identity (bf16 copy) so combine
# returns sum(x * weight) across experts.
simulated_gemm_x = packed_recv_x.clone()
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
# Signature: (x, topk_idx, topk_weights, src_info, layout_range,
# num_max_dispatch_tokens_per_rank, num_experts,
# zero_copy, async, return_recv_hook, out)
src_info, layout_range = handle[0], handle[1]
combined_x, _event, _hook = buf.low_latency_combine(
simulated_gemm_x, topk_idx, topk_weights,
src_info, layout_range,
num_tokens, num_experts,
False, False, False, # zero_copy, async, return_recv_hook
out,
)
# Analytical expected: each token i, weighted sum over topk entries that
# are not -1. Every expert returns the original x[i] (since simulated
# gemm is identity), so the combine output should be
# x[i] * sum(topk_weights[i, j] for j where topk_idx[i,j] != -1).
weight_sum = topk_weights.masked_fill(topk_idx == -1, 0.0).sum(dim=1).view(-1, 1)
expected = (x.float() * weight_sum).to(torch.bfloat16)
diff = (combined_x.float() - expected.float()).abs().max().item()
max_exp = expected.float().abs().max().item()
print(
f"[combine r{rank}] max|got-expected|={diff:.4e} max|expected|={max_exp:.4e}",
flush=True,
)
assert torch.isnan(combined_x).any().item() is False
assert diff < 1e-2, f"rank{rank}: LL combine mismatch diff={diff}"
dist.barrier(group=group)
if rank == 0:
print("PASS", flush=True)
# ------------------------------------------------------------------
# Optional benchmark (enable with MSCCLPP_EP_BENCH=1). Times dispatch
# and combine separately, reporting per-iter latency (max across ranks)
# and aggregate effective bandwidth (sum across ranks).
# ------------------------------------------------------------------
if os.environ.get("MSCCLPP_EP_BENCH", "0") != "1":
return
warmup = int(os.environ.get("MSCCLPP_EP_BENCH_WARMUP", "5"))
iters = int(os.environ.get("MSCCLPP_EP_BENCH_ITERS", "20"))
# Hoist dispatch's output tensors out of the timed loop. The largest
# (`packed_recv_x`, ~58 MB at 7K hidden) costs ~10us cumulative across
# the four torch::empty calls per iter; reusing them brings the bench
# in line with NCCL-EP `ep_bench` which preallocates output buffers.
num_local_experts = num_experts // num_ranks
bench_packed_recv_x = torch.empty(
(num_local_experts, num_ranks * num_tokens, hidden),
dtype=torch.bfloat16, device="cuda",
)
bench_packed_recv_src_info = torch.empty(
(num_local_experts, num_ranks * num_tokens),
dtype=torch.int32, device="cuda",
)
bench_packed_recv_layout_range = torch.empty(
(num_local_experts, num_ranks), dtype=torch.int64, device="cuda",
)
bench_packed_recv_count = torch.empty(
(num_local_experts,), dtype=torch.int32, device="cuda",
)
def _dispatch():
return buf.low_latency_dispatch(
x, topk_idx, num_tokens, num_experts,
False, False, False, # use_fp8, async, return_recv_hook
bench_packed_recv_x,
None, # x_scales (FP8 only)
bench_packed_recv_src_info,
bench_packed_recv_layout_range,
bench_packed_recv_count,
)
# Hoist combine's output-tensor allocation out of the timed loop so the
# measurement reflects the kernel cost. (The original test also cloned the
# ~58 MB dispatch recv buffer on every iter, adding ~20 us of D2D memcpy
# to each combine sample and masking kernel-level changes.)
bench_out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
def _combine(dout, out_):
(recv_x, _scales, _cnt, src_info_, layout_range_, _ev, _hk) = dout
buf.low_latency_combine(
recv_x, topk_idx, topk_weights,
src_info_, layout_range_,
num_tokens, num_experts,
False, False, False,
out_,
)
for _ in range(warmup):
_combine(_dispatch(), bench_out)
torch.cuda.synchronize()
dist.barrier(group=group)
start_ev = torch.cuda.Event(enable_timing=True)
end_ev = torch.cuda.Event(enable_timing=True)
start_ev.record()
dout = None
for _ in range(iters):
dout = _dispatch()
end_ev.record()
torch.cuda.synchronize()
disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
recv_tokens = int(dout[2].sum().item()) # packed_recv_count summed over local experts
dist.barrier(group=group)
start_ev.record()
for _ in range(iters):
_combine(dout, bench_out)
end_ev.record()
torch.cuda.synchronize()
comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
# Dispatch payload: recv_tokens × hidden × bf16 (received on this rank).
# Combine payload: recv_tokens × hidden × bf16 as well -- each local expert
# sends one copy per dispatched token back to its owner, so the bytes on
# the wire match dispatch. Using num_tokens × hidden here would under-count
# the actual send payload by ~num_topk×.
disp_bytes = recv_tokens * hidden * 2
comb_bytes = recv_tokens * hidden * 2
# Reduce timings: report min/avg/max and base BW on AVG to match NCCL-EP's
# `ep_bench.cu` convention.
disp_min_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
disp_avg_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
disp_max_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
comb_min_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
comb_avg_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
comb_max_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
dist.all_reduce(disp_min_t, op=dist.ReduceOp.MIN, group=group)
dist.all_reduce(disp_avg_t, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(disp_max_t, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(comb_min_t, op=dist.ReduceOp.MIN, group=group)
dist.all_reduce(comb_avg_t, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(comb_max_t, op=dist.ReduceOp.MAX, group=group)
disp_avg_us = disp_avg_t.item() / num_ranks
comb_avg_us = comb_avg_t.item() / num_ranks
disp_bw_per_rank = disp_bytes / (disp_avg_us * 1e-6) / 1e9
comb_bw_per_rank = comb_bytes / (comb_avg_us * 1e-6) / 1e9
if rank == 0:
print(
f"[bench LL] num_ranks={num_ranks} tokens={num_tokens} hidden={hidden} "
f"num_experts={num_experts} num_topk={num_topk} warmup={warmup} iters={iters}",
flush=True,
)
print(
f" dispatch: avg={disp_avg_us:.1f}us min={disp_min_t.item():.1f}us max={disp_max_t.item():.1f}us "
f"per_rank_bw={disp_bw_per_rank:.2f} GB/s "
f"agg_bw={disp_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
flush=True,
)
print(
f" combine : avg={comb_avg_us:.1f}us min={comb_min_t.item():.1f}us max={comb_max_t.item():.1f}us "
f"per_rank_bw={comb_bw_per_rank:.2f} GB/s "
f"agg_bw={comb_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
flush=True,
)
if __name__ == "__main__":
try:
main()
finally:
# Ordered shutdown: barrier so every rank reaches teardown before the
# TCPStore server (rank 0) exits, then destroy the PG. Avoids noisy
# "recvValue failed / Connection was likely closed" stack traces from
# ProcessGroupNCCL's HeartbeatMonitor.
if dist.is_initialized():
try:
dist.barrier()
except Exception:
pass
try:
dist.destroy_process_group()
except Exception:
pass