mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Add optional out_packed_recv_x / out_src_info / out_layout_range / out_count parameters to Buffer::low_latency_dispatch so callers can hoist the four recv-side allocations out of a hot loop, mirroring the existing out= path on low_latency_combine. The bench in test_low_latency_multirank.py preallocates these tensors once and passes them on every iter so the timed loop reflects kernel cost, not torch.empty + caching-allocator overhead.
358 lines
15 KiB
Python
358 lines
15 KiB
Python
"""Multi-rank low-latency functional test for mscclpp_ep.
|
||
|
||
Launch with (intra-node, 8 GPUs):
|
||
torchrun --nproc_per_node=8 test/python/ext/ep/test_low_latency_multirank.py
|
||
|
||
Launch with (2 nodes, 1 GPU per node -- DeepEP's recommended LL topology):
|
||
# node 0:
|
||
MASTER_ADDR=<master> MASTER_PORT=29600 NODE_RANK=0 \
|
||
torchrun --nnodes=2 --nproc_per_node=1 --rdzv-backend=c10d \
|
||
--rdzv-endpoint=<master>:29600 test/python/ext/ep/test_low_latency_multirank.py
|
||
# node 1:
|
||
MASTER_ADDR=<master> MASTER_PORT=29600 NODE_RANK=1 \
|
||
torchrun --nnodes=2 --nproc_per_node=1 --rdzv-backend=c10d \
|
||
--rdzv-endpoint=<master>:29600 test/python/ext/ep/test_low_latency_multirank.py
|
||
|
||
Exercises the LL dispatch + combine round-trip on a single node. The
|
||
minimal correctness check:
|
||
- dispatch: per-expert received token counts agree with an all-gathered
|
||
reference computed from topk_idx across all ranks;
|
||
- combine: the reconstructed x matches the analytical sum
|
||
``x * sum(topk_weights, masked by topk_idx == -1)``.
|
||
|
||
Known limitation (see src/ext/ep/README.md): the LL kernels drive every
|
||
peer via MSCCL++ PortChannel. Intra-node IB loopback between two HCAs on
|
||
the same host (what an 8-GPU single-node launch exercises) currently hangs
|
||
during dispatch; cross-node LL with one GPU per node works as designed.
|
||
|
||
Adapted from DeepEP/tests/test_low_latency.py stripped to the bare checks
|
||
we need for an LL port smoke test. BF16-only (no FP8 check).
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
import random
|
||
import sys
|
||
|
||
# Disable ProcessGroupNCCL's HeartbeatMonitor before importing torch.distributed.
|
||
# It runs in a background thread polling the TCPStore; under mpirun, rank 0
|
||
# (the store server) can exit before non-zero ranks finish teardown, producing
|
||
# noisy 'recvValue failed / Connection was likely closed' stack traces.
|
||
os.environ.setdefault("TORCH_NCCL_ENABLE_MONITORING", "0")
|
||
|
||
import torch
|
||
import torch.distributed as dist
|
||
|
||
|
||
def init_dist():
|
||
rank = int(os.environ["RANK"])
|
||
world_size = int(os.environ["WORLD_SIZE"])
|
||
local_rank = int(os.environ.get("LOCAL_RANK", rank))
|
||
torch.cuda.set_device(local_rank)
|
||
dist.init_process_group(
|
||
backend="nccl",
|
||
init_method=f"tcp://{os.environ.get('MASTER_ADDR','127.0.0.1')}:{os.environ.get('MASTER_PORT','29500')}",
|
||
world_size=world_size,
|
||
rank=rank,
|
||
)
|
||
return rank, world_size, local_rank, dist.new_group(list(range(world_size)))
|
||
|
||
|
||
def main():
|
||
rank, num_ranks, local_rank, group = init_dist()
|
||
from mscclpp.ext import ep
|
||
|
||
# Shrink the "bf16 precision" anchor to keep values small.
|
||
rank_offset = 128
|
||
assert num_ranks - rank_offset < 257, "too many ranks for bf16 precision anchor"
|
||
|
||
num_tokens = int(os.environ.get("MSCCLPP_EP_LL_TOKENS", "64"))
|
||
hidden = int(os.environ.get("MSCCLPP_EP_LL_HIDDEN", "7168")) # LL kernels are compiled for a fixed set; see SWITCH_HIDDEN
|
||
num_topk = int(os.environ.get("MSCCLPP_EP_LL_TOPK", "4"))
|
||
num_experts_per_rank = int(os.environ.get("MSCCLPP_EP_LL_EXPERTS_PER_RANK", "4"))
|
||
num_experts = num_ranks * num_experts_per_rank
|
||
assert num_experts % num_ranks == 0
|
||
num_local_experts = num_experts // num_ranks
|
||
|
||
torch.manual_seed(0xB3C4 + rank)
|
||
random.seed(0xB3C4 + rank)
|
||
|
||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * (rank - rank_offset)
|
||
# Encode the per-token index into the last 128 elements so the receiver
|
||
# can verify which source token it is looking at.
|
||
x[:, -128:] = (
|
||
torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1)
|
||
)
|
||
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs() + 1
|
||
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
|
||
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device="cuda").abs()
|
||
|
||
# Randomly mask some positions
|
||
for _ in range(min(10, num_tokens)):
|
||
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
|
||
|
||
num_rdma_bytes = ep.Buffer.get_low_latency_rdma_size_hint(
|
||
num_tokens, hidden, num_ranks, num_experts
|
||
)
|
||
if rank == 0:
|
||
print(
|
||
f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} "
|
||
f"num_experts={num_experts} num_topk={num_topk} "
|
||
f"num_rdma_bytes={num_rdma_bytes}",
|
||
flush=True,
|
||
)
|
||
|
||
buf = ep.Buffer(
|
||
group,
|
||
num_nvl_bytes=0,
|
||
num_rdma_bytes=num_rdma_bytes,
|
||
low_latency_mode=True,
|
||
num_qps_per_rank=max(1, num_experts // num_ranks),
|
||
)
|
||
print(
|
||
f"[rank {rank}] Buffer created is_available={buf.is_available()} "
|
||
f"is_internode={buf.is_internode_available()}",
|
||
flush=True,
|
||
)
|
||
assert buf.is_available()
|
||
|
||
dist.barrier(group=group)
|
||
torch.cuda.synchronize()
|
||
print(f"[rank {rank}] pre-dispatch", flush=True)
|
||
|
||
# --- Dispatch ---
|
||
# Return tuple (7 items):
|
||
# packed_recv_x, packed_recv_x_scales (optional, FP8-only),
|
||
# packed_recv_count, packed_recv_src_info, packed_recv_layout_range,
|
||
# event, hook
|
||
(
|
||
packed_recv_x, _packed_recv_x_scales,
|
||
packed_recv_count, packed_recv_src_info, packed_recv_layout_range,
|
||
_event, recv_hook,
|
||
) = buf.low_latency_dispatch(
|
||
x, topk_idx, num_tokens, num_experts,
|
||
False, False, True, # use_fp8, async, return_recv_hook
|
||
)
|
||
# Send phase launched on compute_stream; wait for local launch.
|
||
torch.cuda.synchronize()
|
||
dist.barrier(group=group)
|
||
print(f"[rank {rank}] dispatch-send done, calling hook", flush=True)
|
||
recv_hook() # Recv phase.
|
||
torch.cuda.synchronize()
|
||
print(f"[rank {rank}] post-dispatch", flush=True)
|
||
handle = (packed_recv_src_info, packed_recv_layout_range)
|
||
# packed_recv_x: [num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]
|
||
# packed_recv_count: [num_local_experts] int32
|
||
|
||
# Reference: gather all ranks' topk_idx and count expected tokens per expert.
|
||
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device="cuda")
|
||
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
|
||
|
||
int_mask = (1 << 32) - 1
|
||
for i in range(num_local_experts):
|
||
expert_id = rank * num_local_experts + i
|
||
recv_count = int(packed_recv_count[i].item())
|
||
expected_count = int((all_topk_idx == expert_id).sum().item())
|
||
recv_layout_range = handle[1][i]
|
||
layout_sum = int((recv_layout_range & int_mask).sum().item())
|
||
assert recv_count == expected_count, (
|
||
f"rank{rank} expert{expert_id}: recv_count={recv_count} != expected={expected_count}"
|
||
)
|
||
assert layout_sum == recv_count, (
|
||
f"rank{rank} expert{expert_id}: layout range sum {layout_sum} != recv_count {recv_count}"
|
||
)
|
||
|
||
if recv_count:
|
||
recv_x = packed_recv_x[i, :recv_count]
|
||
# All columns except the last 128 should share the value (src_rank - rank_offset)
|
||
recv_x_lo = recv_x[:, :-128]
|
||
amin = recv_x_lo.amin(dim=-1)
|
||
amax = recv_x_lo.amax(dim=-1)
|
||
assert torch.equal(amin, amax), f"rank{rank} expert{expert_id}: non-uniform recv block"
|
||
|
||
if rank == 0:
|
||
print(f"[dispatch] OK (ranks={num_ranks})", flush=True)
|
||
|
||
# --- Combine ---
|
||
# Simulate the downstream GEMM output = identity (bf16 copy) so combine
|
||
# returns sum(x * weight) across experts.
|
||
simulated_gemm_x = packed_recv_x.clone()
|
||
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
|
||
# Signature: (x, topk_idx, topk_weights, src_info, layout_range,
|
||
# num_max_dispatch_tokens_per_rank, num_experts,
|
||
# zero_copy, async, return_recv_hook, out)
|
||
src_info, layout_range = handle[0], handle[1]
|
||
combined_x, _event, _hook = buf.low_latency_combine(
|
||
simulated_gemm_x, topk_idx, topk_weights,
|
||
src_info, layout_range,
|
||
num_tokens, num_experts,
|
||
False, False, False, # zero_copy, async, return_recv_hook
|
||
out,
|
||
)
|
||
|
||
# Analytical expected: each token i, weighted sum over topk entries that
|
||
# are not -1. Every expert returns the original x[i] (since simulated
|
||
# gemm is identity), so the combine output should be
|
||
# x[i] * sum(topk_weights[i, j] for j where topk_idx[i,j] != -1).
|
||
weight_sum = topk_weights.masked_fill(topk_idx == -1, 0.0).sum(dim=1).view(-1, 1)
|
||
expected = (x.float() * weight_sum).to(torch.bfloat16)
|
||
diff = (combined_x.float() - expected.float()).abs().max().item()
|
||
max_exp = expected.float().abs().max().item()
|
||
print(
|
||
f"[combine r{rank}] max|got-expected|={diff:.4e} max|expected|={max_exp:.4e}",
|
||
flush=True,
|
||
)
|
||
assert torch.isnan(combined_x).any().item() is False
|
||
assert diff < 1e-2, f"rank{rank}: LL combine mismatch diff={diff}"
|
||
|
||
dist.barrier(group=group)
|
||
if rank == 0:
|
||
print("PASS", flush=True)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Optional benchmark (enable with MSCCLPP_EP_BENCH=1). Times dispatch
|
||
# and combine separately, reporting per-iter latency (max across ranks)
|
||
# and aggregate effective bandwidth (sum across ranks).
|
||
# ------------------------------------------------------------------
|
||
if os.environ.get("MSCCLPP_EP_BENCH", "0") != "1":
|
||
return
|
||
|
||
warmup = int(os.environ.get("MSCCLPP_EP_BENCH_WARMUP", "5"))
|
||
iters = int(os.environ.get("MSCCLPP_EP_BENCH_ITERS", "20"))
|
||
|
||
# Hoist dispatch's output tensors out of the timed loop. The largest
|
||
# (`packed_recv_x`, ~58 MB at 7K hidden) costs ~10us cumulative across
|
||
# the four torch::empty calls per iter; reusing them brings the bench
|
||
# in line with NCCL-EP `ep_bench` which preallocates output buffers.
|
||
num_local_experts = num_experts // num_ranks
|
||
bench_packed_recv_x = torch.empty(
|
||
(num_local_experts, num_ranks * num_tokens, hidden),
|
||
dtype=torch.bfloat16, device="cuda",
|
||
)
|
||
bench_packed_recv_src_info = torch.empty(
|
||
(num_local_experts, num_ranks * num_tokens),
|
||
dtype=torch.int32, device="cuda",
|
||
)
|
||
bench_packed_recv_layout_range = torch.empty(
|
||
(num_local_experts, num_ranks), dtype=torch.int64, device="cuda",
|
||
)
|
||
bench_packed_recv_count = torch.empty(
|
||
(num_local_experts,), dtype=torch.int32, device="cuda",
|
||
)
|
||
|
||
def _dispatch():
|
||
return buf.low_latency_dispatch(
|
||
x, topk_idx, num_tokens, num_experts,
|
||
False, False, False, # use_fp8, async, return_recv_hook
|
||
bench_packed_recv_x,
|
||
None, # x_scales (FP8 only)
|
||
bench_packed_recv_src_info,
|
||
bench_packed_recv_layout_range,
|
||
bench_packed_recv_count,
|
||
)
|
||
|
||
# Hoist combine's output-tensor allocation out of the timed loop so the
|
||
# measurement reflects the kernel cost. (The original test also cloned the
|
||
# ~58 MB dispatch recv buffer on every iter, adding ~20 us of D2D memcpy
|
||
# to each combine sample and masking kernel-level changes.)
|
||
bench_out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
|
||
|
||
def _combine(dout, out_):
|
||
(recv_x, _scales, _cnt, src_info_, layout_range_, _ev, _hk) = dout
|
||
buf.low_latency_combine(
|
||
recv_x, topk_idx, topk_weights,
|
||
src_info_, layout_range_,
|
||
num_tokens, num_experts,
|
||
False, False, False,
|
||
out_,
|
||
)
|
||
|
||
for _ in range(warmup):
|
||
_combine(_dispatch(), bench_out)
|
||
torch.cuda.synchronize()
|
||
dist.barrier(group=group)
|
||
|
||
start_ev = torch.cuda.Event(enable_timing=True)
|
||
end_ev = torch.cuda.Event(enable_timing=True)
|
||
start_ev.record()
|
||
dout = None
|
||
for _ in range(iters):
|
||
dout = _dispatch()
|
||
end_ev.record()
|
||
torch.cuda.synchronize()
|
||
disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
|
||
recv_tokens = int(dout[2].sum().item()) # packed_recv_count summed over local experts
|
||
|
||
dist.barrier(group=group)
|
||
start_ev.record()
|
||
for _ in range(iters):
|
||
_combine(dout, bench_out)
|
||
end_ev.record()
|
||
torch.cuda.synchronize()
|
||
comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
|
||
|
||
# Dispatch payload: recv_tokens × hidden × bf16 (received on this rank).
|
||
# Combine payload: recv_tokens × hidden × bf16 as well -- each local expert
|
||
# sends one copy per dispatched token back to its owner, so the bytes on
|
||
# the wire match dispatch. Using num_tokens × hidden here would under-count
|
||
# the actual send payload by ~num_topk×.
|
||
disp_bytes = recv_tokens * hidden * 2
|
||
comb_bytes = recv_tokens * hidden * 2
|
||
|
||
# Reduce timings: report min/avg/max and base BW on AVG to match NCCL-EP's
|
||
# `ep_bench.cu` convention.
|
||
disp_min_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
|
||
disp_avg_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
|
||
disp_max_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
|
||
comb_min_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
|
||
comb_avg_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
|
||
comb_max_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
|
||
dist.all_reduce(disp_min_t, op=dist.ReduceOp.MIN, group=group)
|
||
dist.all_reduce(disp_avg_t, op=dist.ReduceOp.SUM, group=group)
|
||
dist.all_reduce(disp_max_t, op=dist.ReduceOp.MAX, group=group)
|
||
dist.all_reduce(comb_min_t, op=dist.ReduceOp.MIN, group=group)
|
||
dist.all_reduce(comb_avg_t, op=dist.ReduceOp.SUM, group=group)
|
||
dist.all_reduce(comb_max_t, op=dist.ReduceOp.MAX, group=group)
|
||
disp_avg_us = disp_avg_t.item() / num_ranks
|
||
comb_avg_us = comb_avg_t.item() / num_ranks
|
||
disp_bw_per_rank = disp_bytes / (disp_avg_us * 1e-6) / 1e9
|
||
comb_bw_per_rank = comb_bytes / (comb_avg_us * 1e-6) / 1e9
|
||
if rank == 0:
|
||
print(
|
||
f"[bench LL] num_ranks={num_ranks} tokens={num_tokens} hidden={hidden} "
|
||
f"num_experts={num_experts} num_topk={num_topk} warmup={warmup} iters={iters}",
|
||
flush=True,
|
||
)
|
||
print(
|
||
f" dispatch: avg={disp_avg_us:.1f}us min={disp_min_t.item():.1f}us max={disp_max_t.item():.1f}us "
|
||
f"per_rank_bw={disp_bw_per_rank:.2f} GB/s "
|
||
f"agg_bw={disp_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
|
||
flush=True,
|
||
)
|
||
print(
|
||
f" combine : avg={comb_avg_us:.1f}us min={comb_min_t.item():.1f}us max={comb_max_t.item():.1f}us "
|
||
f"per_rank_bw={comb_bw_per_rank:.2f} GB/s "
|
||
f"agg_bw={comb_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
|
||
flush=True,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
main()
|
||
finally:
|
||
# Ordered shutdown: barrier so every rank reaches teardown before the
|
||
# TCPStore server (rank 0) exits, then destroy the PG. Avoids noisy
|
||
# "recvValue failed / Connection was likely closed" stack traces from
|
||
# ProcessGroupNCCL's HeartbeatMonitor.
|
||
if dist.is_initialized():
|
||
try:
|
||
dist.barrier()
|
||
except Exception:
|
||
pass
|
||
try:
|
||
dist.destroy_process_group()
|
||
except Exception:
|
||
pass
|