mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
ext/ep: unfilter LL sync + add LL multirank test (intra-node WIP)
- Buffer::sync no longer drops non-same-GPU-id peers in low_latency_mode. DeepEP's original filter was safe because its LL path used NVSHMEM; this port drives LL via PortChannel so the kernel indexes port_channel_handles[local_expert*num_ranks + dst_rank] for every dst_rank. All peers now get a real memory/connection/semaphore/port channel entry. - Add test/python/ext/ep/test_low_latency_multirank.py (LL dispatch+combine functional round-trip, BF16 only). Works cross-node in DeepEP's 1-GPU-per-node topology. - Known limitation documented in src/ext/ep/README.md and the test docstring: intra-node 8-GPU LL currently hangs because every peer transfer routes through the CPU proxy over IB loopback between distinct HCAs on the same host, and (separately) CudaIpcConnection::atomicAdd is a 64-bit op which mis-aligns the 32-bit rdma_recv_count slots when used for same-node peers. Proper fix needs a mixed-transport LL variant (MemoryChannel for same-node, PortChannel for cross-node) or 64-bit counters.
This commit is contained in:
214
test/python/ext/ep/test_low_latency_multirank.py
Normal file
214
test/python/ext/ep/test_low_latency_multirank.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Multi-rank low-latency functional test for mscclpp_ep.
|
||||
|
||||
Launch with (intra-node, 8 GPUs):
|
||||
torchrun --nproc_per_node=8 test/python/ext/ep/test_low_latency_multirank.py
|
||||
|
||||
Launch with (2 nodes, 1 GPU per node -- DeepEP's recommended LL topology):
|
||||
# node 0:
|
||||
MASTER_ADDR=<master> MASTER_PORT=29600 NODE_RANK=0 \
|
||||
torchrun --nnodes=2 --nproc_per_node=1 --rdzv-backend=c10d \
|
||||
--rdzv-endpoint=<master>:29600 test/python/ext/ep/test_low_latency_multirank.py
|
||||
# node 1:
|
||||
MASTER_ADDR=<master> MASTER_PORT=29600 NODE_RANK=1 \
|
||||
torchrun --nnodes=2 --nproc_per_node=1 --rdzv-backend=c10d \
|
||||
--rdzv-endpoint=<master>:29600 test/python/ext/ep/test_low_latency_multirank.py
|
||||
|
||||
Exercises the LL dispatch + combine round-trip on a single node. The
|
||||
minimal correctness check:
|
||||
- dispatch: per-expert received token counts agree with an all-gathered
|
||||
reference computed from topk_idx across all ranks;
|
||||
- combine: the reconstructed x matches the analytical sum
|
||||
``x * sum(topk_weights, masked by topk_idx == -1)``.
|
||||
|
||||
Known limitation (see src/ext/ep/README.md): the LL kernels drive every
|
||||
peer via MSCCL++ PortChannel. Intra-node IB loopback between two HCAs on
|
||||
the same host (what an 8-GPU single-node launch exercises) currently hangs
|
||||
during dispatch; cross-node LL with one GPU per node works as designed.
|
||||
|
||||
Adapted from DeepEP/tests/test_low_latency.py stripped to the bare checks
|
||||
we need for an LL port smoke test. BF16-only (no FP8 check).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def init_dist():
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", rank))
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=f"tcp://{os.environ.get('MASTER_ADDR','127.0.0.1')}:{os.environ.get('MASTER_PORT','29500')}",
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
return rank, world_size, local_rank, dist.new_group(list(range(world_size)))
|
||||
|
||||
|
||||
def main():
|
||||
rank, num_ranks, local_rank, group = init_dist()
|
||||
from mscclpp.ext import ep
|
||||
|
||||
# Shrink the "bf16 precision" anchor to keep values small.
|
||||
rank_offset = 128
|
||||
assert num_ranks - rank_offset < 257, "too many ranks for bf16 precision anchor"
|
||||
|
||||
num_tokens = 64
|
||||
hidden = 7168 # LL kernels are compiled for a fixed set; see SWITCH_HIDDEN
|
||||
num_topk = 4
|
||||
num_experts = num_ranks * 4
|
||||
assert num_experts % num_ranks == 0
|
||||
num_local_experts = num_experts // num_ranks
|
||||
|
||||
torch.manual_seed(0xB3C4 + rank)
|
||||
random.seed(0xB3C4 + rank)
|
||||
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * (rank - rank_offset)
|
||||
# Encode the per-token index into the last 128 elements so the receiver
|
||||
# can verify which source token it is looking at.
|
||||
x[:, -128:] = (
|
||||
torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1)
|
||||
)
|
||||
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs() + 1
|
||||
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
|
||||
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device="cuda").abs()
|
||||
|
||||
# Randomly mask some positions
|
||||
for _ in range(min(10, num_tokens)):
|
||||
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
|
||||
|
||||
num_rdma_bytes = ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
num_tokens, hidden, num_ranks, num_experts
|
||||
)
|
||||
if rank == 0:
|
||||
print(
|
||||
f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} "
|
||||
f"num_experts={num_experts} num_topk={num_topk} "
|
||||
f"num_rdma_bytes={num_rdma_bytes}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
buf = ep.Buffer(
|
||||
group,
|
||||
num_nvl_bytes=0,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=max(1, num_experts // num_ranks),
|
||||
)
|
||||
print(
|
||||
f"[rank {rank}] Buffer created is_available={buf.is_available()} "
|
||||
f"is_internode={buf.is_internode_available()}",
|
||||
flush=True,
|
||||
)
|
||||
assert buf.is_available()
|
||||
|
||||
dist.barrier(group=group)
|
||||
torch.cuda.synchronize()
|
||||
print(f"[rank {rank}] pre-dispatch", flush=True)
|
||||
|
||||
# --- Dispatch ---
|
||||
# Return tuple (7 items):
|
||||
# packed_recv_x, packed_recv_x_scales (optional, FP8-only),
|
||||
# packed_recv_count, packed_recv_src_info, packed_recv_layout_range,
|
||||
# event, hook
|
||||
(
|
||||
packed_recv_x, _packed_recv_x_scales,
|
||||
packed_recv_count, packed_recv_src_info, packed_recv_layout_range,
|
||||
_event, recv_hook,
|
||||
) = buf.low_latency_dispatch(
|
||||
x, topk_idx, num_tokens, num_experts,
|
||||
False, False, True, # use_fp8, async, return_recv_hook
|
||||
)
|
||||
# Send phase launched on compute_stream; wait for local launch.
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier(group=group)
|
||||
print(f"[rank {rank}] dispatch-send done, calling hook", flush=True)
|
||||
recv_hook() # Recv phase.
|
||||
torch.cuda.synchronize()
|
||||
print(f"[rank {rank}] post-dispatch", flush=True)
|
||||
handle = (packed_recv_src_info, packed_recv_layout_range)
|
||||
# packed_recv_x: [num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]
|
||||
# packed_recv_count: [num_local_experts] int32
|
||||
|
||||
# Reference: gather all ranks' topk_idx and count expected tokens per expert.
|
||||
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device="cuda")
|
||||
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
|
||||
|
||||
int_mask = (1 << 32) - 1
|
||||
for i in range(num_local_experts):
|
||||
expert_id = rank * num_local_experts + i
|
||||
recv_count = int(packed_recv_count[i].item())
|
||||
expected_count = int((all_topk_idx == expert_id).sum().item())
|
||||
recv_layout_range = handle[1][i]
|
||||
layout_sum = int((recv_layout_range & int_mask).sum().item())
|
||||
assert recv_count == expected_count, (
|
||||
f"rank{rank} expert{expert_id}: recv_count={recv_count} != expected={expected_count}"
|
||||
)
|
||||
assert layout_sum == recv_count, (
|
||||
f"rank{rank} expert{expert_id}: layout range sum {layout_sum} != recv_count {recv_count}"
|
||||
)
|
||||
|
||||
if recv_count:
|
||||
recv_x = packed_recv_x[i, :recv_count]
|
||||
# All columns except the last 128 should share the value (src_rank - rank_offset)
|
||||
recv_x_lo = recv_x[:, :-128]
|
||||
amin = recv_x_lo.amin(dim=-1)
|
||||
amax = recv_x_lo.amax(dim=-1)
|
||||
assert torch.equal(amin, amax), f"rank{rank} expert{expert_id}: non-uniform recv block"
|
||||
|
||||
if rank == 0:
|
||||
print(f"[dispatch] OK (ranks={num_ranks})", flush=True)
|
||||
|
||||
# --- Combine ---
|
||||
# Simulate the downstream GEMM output = identity (bf16 copy) so combine
|
||||
# returns sum(x * weight) across experts.
|
||||
simulated_gemm_x = packed_recv_x.clone()
|
||||
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
|
||||
# Signature: (x, topk_idx, topk_weights, src_info, layout_range,
|
||||
# num_max_dispatch_tokens_per_rank, num_experts,
|
||||
# zero_copy, async, return_recv_hook, out)
|
||||
src_info, layout_range = handle[0], handle[1]
|
||||
combined_x, _event, _hook = buf.low_latency_combine(
|
||||
simulated_gemm_x, topk_idx, topk_weights,
|
||||
src_info, layout_range,
|
||||
num_tokens, num_experts,
|
||||
False, False, False, # zero_copy, async, return_recv_hook
|
||||
out,
|
||||
)
|
||||
|
||||
# Analytical expected: each token i, weighted sum over topk entries that
|
||||
# are not -1. Every expert returns the original x[i] (since simulated
|
||||
# gemm is identity), so the combine output should be
|
||||
# x[i] * sum(topk_weights[i, j] for j where topk_idx[i,j] != -1).
|
||||
weight_sum = topk_weights.masked_fill(topk_idx == -1, 0.0).sum(dim=1).view(-1, 1)
|
||||
expected = (x.float() * weight_sum).to(torch.bfloat16)
|
||||
diff = (combined_x.float() - expected.float()).abs().max().item()
|
||||
max_exp = expected.float().abs().max().item()
|
||||
print(
|
||||
f"[combine r{rank}] max|got-expected|={diff:.4e} max|expected|={max_exp:.4e}",
|
||||
flush=True,
|
||||
)
|
||||
assert torch.isnan(combined_x).any().item() is False
|
||||
assert diff < 1e-2, f"rank{rank}: LL combine mismatch diff={diff}"
|
||||
|
||||
dist.barrier(group=group)
|
||||
if rank == 0:
|
||||
print("PASS", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
finally:
|
||||
try:
|
||||
dist.destroy_process_group()
|
||||
except Exception:
|
||||
pass
|
||||
Reference in New Issue
Block a user