Files
mscclpp/test/python/ext/ep/test_low_latency_multirank.py
Qinghua Zhou 04ebba7563 ext/ep: GPU-initiated IBGDA path for low-latency dispatch/combine
Add a GPU-initiated RDMA WRITE path for the LL dispatch/combine kernels
based on mlx5dv direct verbs, alongside the existing IPC and host-FIFO
PortChannel paths. Selected at runtime via MSCCLPP_EP_USE_IBGDA when
num_rdma_ranks > 1.

Core (src/core, include/mscclpp):
  - New ibgda module (ibgda.{hpp,cc}, ibgda_device.cuh): per-peer mlx5
    QP/MR/CQ setup, device-side WQE writers (write_rdma_wqe,
    write_rdma_write_inl_wqe for 4B/8B), submit_requests / submit_no_db
    ring helpers, and a poller thread for send CQs.
  - ibgda_port_channel_device.{hpp,cuh}: thin port_put() wrapper over
    rdma_write with signal_cqe / ring_db flags so callers can issue
    UNSIGNALED batched WRs and ring the doorbell once at the tail.
  - mlx5dv_wrapper: expose extra symbols needed for direct WQE
    construction; minor connection.cc / proxy.cc / port_channel.cc
    plumbing to surface QP / MR handles and rkeys to the EP layer.

EP layer (src/ext/ep):
  - ibgda_setup.{hpp,cc}: build per-(local_expert, peer_rank) GpuQp
    handles, exchange remote MR addr/rkey via the bootstrap, own the
    CQ poller. h.dst is set to the per-peer remote_mrs index.
  - buffer.{hpp,cc}: gate IBGDA path with use_ibgda_path_ &&
    ibgda_setup_ != nullptr && !use_ipc; pass device_handles to the
    kernel launchers.
  - kernels/internode_ll.cu: 3-way DISPATCH_LAUNCH_CASE /
    COMBINE_LAUNCH_CASE (IPC / IBGDA / port-FIFO), templated on
    kIbgdaPath. Data PUTs are issued UNSIGNALED with ring_db=false;
    the trailing per-QP count write (dispatch) and flag write
    (combine) keep the defaults so each QP gets a single signaled
    WR that advances prod_idx past all queued data WRs and rings
    the doorbell once.

Test (test/python/ext/ep): extend test_low_latency_multirank.py with
env-driven config knobs (MSCCLPP_EP_LL_TOKENS / _HIDDEN / _TOPK /
_EXPERTS_PER_RANK) for sweeping the new path.
2026-05-07 05:14:15 +00:00

421 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Multi-rank low-latency functional test for mscclpp_ep.
Launch with (intra-node, 8 GPUs):
torchrun --nproc_per_node=8 test/python/ext/ep/test_low_latency_multirank.py
Launch with (2 nodes, 1 GPU per node -- DeepEP's recommended LL topology):
# node 0:
MASTER_ADDR=<master> MASTER_PORT=29600 NODE_RANK=0 \
torchrun --nnodes=2 --nproc_per_node=1 --rdzv-backend=c10d \
--rdzv-endpoint=<master>:29600 test/python/ext/ep/test_low_latency_multirank.py
# node 1:
MASTER_ADDR=<master> MASTER_PORT=29600 NODE_RANK=1 \
torchrun --nnodes=2 --nproc_per_node=1 --rdzv-backend=c10d \
--rdzv-endpoint=<master>:29600 test/python/ext/ep/test_low_latency_multirank.py
Exercises the LL dispatch + combine round-trip on a single node. The
minimal correctness check:
- dispatch: per-expert received token counts agree with an all-gathered
reference computed from topk_idx across all ranks;
- combine: the reconstructed x matches the analytical sum
``x * sum(topk_weights, masked by topk_idx == -1)``.
Known limitation (see src/ext/ep/README.md): the LL kernels drive every
peer via MSCCL++ PortChannel. Intra-node IB loopback between two HCAs on
the same host (what an 8-GPU single-node launch exercises) currently hangs
during dispatch; cross-node LL with one GPU per node works as designed.
Adapted from DeepEP/tests/test_low_latency.py stripped to the bare checks
we need for an LL port smoke test. BF16-only (no FP8 check).
"""
from __future__ import annotations
import os
import random
import sys
# Disable ProcessGroupNCCL's HeartbeatMonitor before importing torch.distributed.
# It runs in a background thread polling the TCPStore; under mpirun, rank 0
# (the store server) can exit before non-zero ranks finish teardown, producing
# noisy 'recvValue failed / Connection was likely closed' stack traces.
os.environ.setdefault("TORCH_NCCL_ENABLE_MONITORING", "0")
import ctypes
import psutil
import torch
import torch.distributed as dist
# Load libnuma for NUMA-aware memory binding (mirrors DeepEP/tests/utils.py).
try:
_libnuma = ctypes.CDLL("libnuma.so")
_libnuma.numa_available.restype = ctypes.c_int
_libnuma.numa_run_on_node.argtypes = [ctypes.c_int]
_libnuma.numa_set_preferred.argtypes = [ctypes.c_int]
except OSError:
_libnuma = None
def set_numa_affinity(local_rank: int):
cores_per_rank = 12
numa_node = local_rank // 4
core_start = local_rank * cores_per_rank
core_end = core_start + cores_per_rank
p = psutil.Process(os.getpid())
p.cpu_affinity(list(range(core_start, core_end)))
print(f"Rank {local_rank} numa node {numa_node} bound to cores {core_start}-{core_end - 1}")
# Bind memory to NUMA node
if _libnuma is not None and _libnuma.numa_available() != -1:
_libnuma.numa_set_preferred(numa_node)
print(f"Rank {local_rank}: CPU affinity → cores {core_start}-{core_end - 1}, memory NUMA → node {numa_node}")
else:
print(f"Rank {local_rank}: libnuma not available")
def init_dist():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", rank))
set_numa_affinity(local_rank)
torch.cuda.set_device(local_rank)
dist.init_process_group(
backend="nccl",
init_method=f"tcp://{os.environ.get('MASTER_ADDR','127.0.0.1')}:{os.environ.get('MASTER_PORT','29500')}",
world_size=world_size,
rank=rank,
)
return rank, world_size, local_rank, dist.new_group(list(range(world_size)))
def main():
rank, num_ranks, local_rank, group = init_dist()
from mscclpp.ext import ep
# Shrink the "bf16 precision" anchor to keep values small.
rank_offset = 128
assert num_ranks - rank_offset < 257, "too many ranks for bf16 precision anchor"
num_tokens = int(os.environ.get("MSCCLPP_EP_LL_TOKENS", "64"))
hidden = int(
os.environ.get("MSCCLPP_EP_LL_HIDDEN", "7168")
) # LL kernels are compiled for a fixed set; see SWITCH_HIDDEN
num_topk = int(os.environ.get("MSCCLPP_EP_LL_TOPK", "4"))
num_experts_per_rank = int(os.environ.get("MSCCLPP_EP_LL_EXPERTS_PER_RANK", "4"))
num_experts = num_ranks * num_experts_per_rank
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
torch.manual_seed(0xB3C4 + rank)
random.seed(0xB3C4 + rank)
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * (rank - rank_offset)
# Encode the per-token index into the last 128 elements so the receiver
# can verify which source token it is looking at.
x[:, -128:] = torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device="cuda").abs()
# Randomly mask some positions
for _ in range(min(10, num_tokens)):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
num_rdma_bytes = ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if rank == 0:
print(
f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} "
f"num_experts={num_experts} num_topk={num_topk} "
f"num_rdma_bytes={num_rdma_bytes}",
flush=True,
)
buf = ep.Buffer(
group,
num_nvl_bytes=0,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=max(1, num_experts // num_ranks),
)
print(
f"[rank {rank}] Buffer created is_available={buf.is_available()} "
f"is_internode={buf.is_internode_available()}",
flush=True,
)
assert buf.is_available()
dist.barrier(group=group)
torch.cuda.synchronize()
print(f"[rank {rank}] pre-dispatch", flush=True)
# --- Dispatch ---
# Return tuple (7 items):
# packed_recv_x, packed_recv_x_scales (optional, FP8-only),
# packed_recv_count, packed_recv_src_info, packed_recv_layout_range,
# event, hook
(
packed_recv_x,
_packed_recv_x_scales,
packed_recv_count,
packed_recv_src_info,
packed_recv_layout_range,
_event,
recv_hook,
) = buf.low_latency_dispatch(
x,
topk_idx,
num_tokens,
num_experts,
False,
False,
True, # use_fp8, async, return_recv_hook
)
# Send phase launched on compute_stream; wait for local launch.
torch.cuda.synchronize()
dist.barrier(group=group)
print(f"[rank {rank}] dispatch-send done, calling hook", flush=True)
recv_hook() # Recv phase.
torch.cuda.synchronize()
print(f"[rank {rank}] post-dispatch", flush=True)
handle = (packed_recv_src_info, packed_recv_layout_range)
# packed_recv_x: [num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]
# packed_recv_count: [num_local_experts] int32
# Reference: gather all ranks' topk_idx and count expected tokens per expert.
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device="cuda")
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
int_mask = (1 << 32) - 1
for i in range(num_local_experts):
expert_id = rank * num_local_experts + i
recv_count = int(packed_recv_count[i].item())
expected_count = int((all_topk_idx == expert_id).sum().item())
recv_layout_range = handle[1][i]
layout_sum = int((recv_layout_range & int_mask).sum().item())
assert (
recv_count == expected_count
), f"rank{rank} expert{expert_id}: recv_count={recv_count} != expected={expected_count}"
assert (
layout_sum == recv_count
), f"rank{rank} expert{expert_id}: layout range sum {layout_sum} != recv_count {recv_count}"
if recv_count:
recv_x = packed_recv_x[i, :recv_count]
# All columns except the last 128 should share the value (src_rank - rank_offset)
recv_x_lo = recv_x[:, :-128]
amin = recv_x_lo.amin(dim=-1)
amax = recv_x_lo.amax(dim=-1)
assert torch.equal(amin, amax), f"rank{rank} expert{expert_id}: non-uniform recv block"
if rank == 0:
print(f"[dispatch] OK (ranks={num_ranks})", flush=True)
# --- Combine ---
# Simulate the downstream GEMM output = identity (bf16 copy) so combine
# returns sum(x * weight) across experts.
simulated_gemm_x = packed_recv_x.clone()
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
# Signature: (x, topk_idx, topk_weights, src_info, layout_range,
# num_max_dispatch_tokens_per_rank, num_experts,
# zero_copy, async, return_recv_hook, out)
src_info, layout_range = handle[0], handle[1]
combined_x, _event, _hook = buf.low_latency_combine(
simulated_gemm_x,
topk_idx,
topk_weights,
src_info,
layout_range,
num_tokens,
num_experts,
False,
False,
False, # zero_copy, async, return_recv_hook
out,
)
# Analytical expected: each token i, weighted sum over topk entries that
# are not -1. Every expert returns the original x[i] (since simulated
# gemm is identity), so the combine output should be
# x[i] * sum(topk_weights[i, j] for j where topk_idx[i,j] != -1).
weight_sum = topk_weights.masked_fill(topk_idx == -1, 0.0).sum(dim=1).view(-1, 1)
expected = (x.float() * weight_sum).to(torch.bfloat16)
diff = (combined_x.float() - expected.float()).abs().max().item()
max_exp = expected.float().abs().max().item()
print(
f"[combine r{rank}] max|got-expected|={diff:.4e} max|expected|={max_exp:.4e}",
flush=True,
)
assert torch.isnan(combined_x).any().item() is False
assert diff < 1e-2, f"rank{rank}: LL combine mismatch diff={diff}"
dist.barrier(group=group)
if rank == 0:
print("PASS", flush=True)
# ------------------------------------------------------------------
# Optional benchmark (enable with MSCCLPP_EP_BENCH=1). Times dispatch
# and combine separately, reporting per-iter latency (max across ranks)
# and aggregate effective bandwidth (sum across ranks).
# ------------------------------------------------------------------
if os.environ.get("MSCCLPP_EP_BENCH", "0") != "1":
return
warmup = int(os.environ.get("MSCCLPP_EP_BENCH_WARMUP", "5"))
iters = int(os.environ.get("MSCCLPP_EP_BENCH_ITERS", "20"))
# Hoist dispatch's output tensors out of the timed loop. The largest
# (`packed_recv_x`, ~58 MB at 7K hidden) costs ~10us cumulative across
# the four torch::empty calls per iter; reusing them brings the bench
# in line with NCCL-EP `ep_bench` which preallocates output buffers.
num_local_experts = num_experts // num_ranks
bench_packed_recv_x = torch.empty(
(num_local_experts, num_ranks * num_tokens, hidden),
dtype=torch.bfloat16,
device="cuda",
)
bench_packed_recv_src_info = torch.empty(
(num_local_experts, num_ranks * num_tokens),
dtype=torch.int32,
device="cuda",
)
bench_packed_recv_layout_range = torch.empty(
(num_local_experts, num_ranks),
dtype=torch.int64,
device="cuda",
)
bench_packed_recv_count = torch.empty(
(num_local_experts,),
dtype=torch.int32,
device="cuda",
)
def _dispatch():
return buf.low_latency_dispatch(
x,
topk_idx,
num_tokens,
num_experts,
False,
False,
False, # use_fp8, async, return_recv_hook
bench_packed_recv_x,
None, # x_scales (FP8 only)
bench_packed_recv_src_info,
bench_packed_recv_layout_range,
bench_packed_recv_count,
)
# Hoist combine's output-tensor allocation out of the timed loop so the
# measurement reflects the kernel cost. (The original test also cloned the
# ~58 MB dispatch recv buffer on every iter, adding ~20 us of D2D memcpy
# to each combine sample and masking kernel-level changes.)
bench_out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
def _combine(dout, out_):
recv_x, _scales, _cnt, src_info_, layout_range_, _ev, _hk = dout
buf.low_latency_combine(
recv_x,
topk_idx,
topk_weights,
src_info_,
layout_range_,
num_tokens,
num_experts,
False,
False,
False,
out_,
)
for _ in range(warmup):
_combine(_dispatch(), bench_out)
torch.cuda.synchronize()
dist.barrier(group=group)
start_ev = torch.cuda.Event(enable_timing=True)
end_ev = torch.cuda.Event(enable_timing=True)
start_ev.record()
dout = None
for _ in range(iters):
dout = _dispatch()
end_ev.record()
torch.cuda.synchronize()
disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
recv_tokens = int(dout[2].sum().item()) # packed_recv_count summed over local experts
dist.barrier(group=group)
start_ev.record()
for _ in range(iters):
_combine(dout, bench_out)
end_ev.record()
torch.cuda.synchronize()
comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
# Dispatch payload: recv_tokens × hidden × bf16 (received on this rank).
# Combine payload: recv_tokens × hidden × bf16 as well -- each local expert
# sends one copy per dispatched token back to its owner, so the bytes on
# the wire match dispatch. Using num_tokens × hidden here would under-count
# the actual send payload by ~num_topk×.
disp_bytes = recv_tokens * hidden * 2
comb_bytes = recv_tokens * hidden * 2
# Reduce timings: report min/avg/max and base BW on AVG to match NCCL-EP's
# `ep_bench.cu` convention.
disp_min_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
disp_avg_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
disp_max_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
comb_min_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
comb_avg_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
comb_max_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
dist.all_reduce(disp_min_t, op=dist.ReduceOp.MIN, group=group)
dist.all_reduce(disp_avg_t, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(disp_max_t, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(comb_min_t, op=dist.ReduceOp.MIN, group=group)
dist.all_reduce(comb_avg_t, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(comb_max_t, op=dist.ReduceOp.MAX, group=group)
disp_avg_us = disp_avg_t.item() / num_ranks
comb_avg_us = comb_avg_t.item() / num_ranks
disp_bw_per_rank = disp_bytes / (disp_avg_us * 1e-6) / 1e9
comb_bw_per_rank = comb_bytes / (comb_avg_us * 1e-6) / 1e9
if rank == 0:
print(
f"[bench LL] num_ranks={num_ranks} tokens={num_tokens} hidden={hidden} "
f"num_experts={num_experts} num_topk={num_topk} warmup={warmup} iters={iters}",
flush=True,
)
print(
f" dispatch: avg={disp_avg_us:.1f}us min={disp_min_t.item():.1f}us max={disp_max_t.item():.1f}us "
f"per_rank_bw={disp_bw_per_rank:.2f} GB/s "
f"agg_bw={disp_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
flush=True,
)
print(
f" combine : avg={comb_avg_us:.1f}us min={comb_min_t.item():.1f}us max={comb_max_t.item():.1f}us "
f"per_rank_bw={comb_bw_per_rank:.2f} GB/s "
f"agg_bw={comb_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
flush=True,
)
if __name__ == "__main__":
try:
main()
finally:
# Ordered shutdown: barrier so every rank reaches teardown before the
# TCPStore server (rank 0) exits, then destroy the PG. Avoids noisy
# "recvValue failed / Connection was likely closed" stack traces from
# ProcessGroupNCCL's HeartbeatMonitor.
if dist.is_initialized():
try:
dist.barrier()
except Exception:
pass
try:
dist.destroy_process_group()
except Exception:
pass