diff --git a/src/ext/ep/buffer.cc b/src/ext/ep/buffer.cc index 1544c93b..a299d582 100644 --- a/src/ext/ep/buffer.cc +++ b/src/ext/ep/buffer.cc @@ -241,7 +241,6 @@ void Buffer::sync(const std::vector &device_ids, } for (int i = 0; i < num_nvl_ranks; ++i) { if (i == nvl_rank) continue; - auto r = i + rdma_rank * num_nvl_ranks; auto sema = std::make_shared(*communicator, connections[i]); memory_channels.emplace_back(sema, remote_mem_futures[i].get(), buffer_mem); } @@ -282,17 +281,21 @@ void Buffer::sync(const std::vector &device_ids, memory_ids[rank] = proxy_service->addMemory(local_rdma_buffer_mem); // Send local memory to other ranks. If low_latency_mode == true, only send to ranks with the same GPU ID. + // Use tag=1 to disambiguate from the NVL phase's tag=0 traffic with same-node peers. + constexpr int kRdmaTag = 1; for (int r = 0; r < num_ranks; ++r) { if (r == rank) continue; if (low_latency_mode && ((r % NUM_MAX_NVL_PEERS) != (rank % NUM_MAX_NVL_PEERS))) continue; - communicator->sendMemory(local_rdma_buffer_mem, r, 0); + communicator->sendMemory(local_rdma_buffer_mem, r, kRdmaTag); } // Receive remote memory from other ranks. for (int r = 0; r < num_ranks; ++r) { if (r == rank) continue; if (low_latency_mode && ((r % NUM_MAX_NVL_PEERS) != (rank % NUM_MAX_NVL_PEERS))) continue; - memory_ids[r] = proxy_service->addMemory(communicator->recvMemory(r, 0).get()); + auto f = communicator->recvMemory(r, kRdmaTag); + auto mem = f.get(); + memory_ids[r] = proxy_service->addMemory(std::move(mem)); } // Rank -> vector of connections @@ -301,7 +304,7 @@ void Buffer::sync(const std::vector &device_ids, const mscclpp::EndpointConfig ib_cfg(ib_transport); // Self connection for local memory (CUDA IPC). - connections[rank].emplace_back(communicator->connect(ipc_cfg, rank, 0).get()); + connections[rank].emplace_back(communicator->connect(ipc_cfg, rank, kRdmaTag).get()); // Remote IB connections (multi-QP per peer). const int num_ib_connections_per_rank = 12; // #QPs per rank (mirrors DeepEP). @@ -310,16 +313,21 @@ void Buffer::sync(const std::vector &device_ids, std::vector> futures; futures.reserve(num_ib_connections_per_rank); for (int i = 0; i < num_ib_connections_per_rank; ++i) { - futures.emplace_back(communicator->connect(ib_cfg, r, 0)); + futures.emplace_back(communicator->connect(ib_cfg, r, kRdmaTag)); } for (auto& f : futures) connections[r].emplace_back(f.get()); } - // Rank -> vector of semaphore IDs + // Rank -> vector of semaphore IDs. Iterate peers in sorted rank order so + // semaphore pairings between nodes line up deterministically. std::unordered_map> sema_ids; const int num_semaphores_per_rank = 16; for (int i = 0; i < num_semaphores_per_rank; ++i) { - for (auto& [r, conns] : connections) { + for (int r = 0; r < num_ranks; ++r) { + if (low_latency_mode && ((r % NUM_MAX_NVL_PEERS) != (rank % NUM_MAX_NVL_PEERS))) continue; + auto conn_it = connections.find(r); + EP_HOST_ASSERT(conn_it != connections.end()); + auto& conns = conn_it->second; auto& conn = conns[i % conns.size()]; auto sema_id = proxy_service->buildAndAddSemaphore(*communicator, conn); sema_ids[r].emplace_back(sema_id); @@ -327,10 +335,23 @@ void Buffer::sync(const std::vector &device_ids, } // Create port channels + device handles. + // + // The kernels index `port_channel_handles[channel_id * num_ranks + peer_rank]` + // where peer_rank is a GLOBAL rank in [0..num_ranks). So the outer stride must + // be num_ranks with peers in ascending rank order. Iterating `memory_ids` (an + // `unordered_map`) yields hash order and would misroute signals, deadlocking. const int num_port_channels_per_rank = num_semaphores_per_rank; std::vector port_channel_handles; for (int i = 0; i < num_port_channels_per_rank; ++i) { - for (auto& [r, memory_id] : memory_ids) { + for (int r = 0; r < num_ranks; ++r) { + if (low_latency_mode && ((r % NUM_MAX_NVL_PEERS) != (rank % NUM_MAX_NVL_PEERS))) { + // Not connected in LL mode; push a default handle as placeholder. + port_channel_handles.emplace_back(mscclpp::PortChannelDeviceHandle{}); + continue; + } + auto mem_it = memory_ids.find(r); + EP_HOST_ASSERT(mem_it != memory_ids.end()); + auto memory_id = mem_it->second; auto sema_id = sema_ids[r][i % sema_ids[r].size()]; auto port_channel = proxy_service->portChannel(sema_id, memory_id, memory_ids[rank]); port_channels.emplace_back(std::move(port_channel)); diff --git a/test/python/ext/ep/test_internode_multirank.py b/test/python/ext/ep/test_internode_multirank.py new file mode 100644 index 00000000..9ac9639c --- /dev/null +++ b/test/python/ext/ep/test_internode_multirank.py @@ -0,0 +1,200 @@ +"""Multi-rank internode (HT) functional validation for mscclpp_ep. + +Launch on each node with (example: 2 nodes x 8 GPUs = 16 ranks): + + # on master (NODE_RANK=0): + MASTER_ADDR= MASTER_PORT=29600 NODE_RANK=0 \ + torchrun --nnodes=2 --nproc_per_node=8 \ + --rdzv-backend=c10d --rdzv-endpoint=:29600 \ + test/python/ext/ep/test_internode_multirank.py + + # on worker (NODE_RANK=1): + MASTER_ADDR= MASTER_PORT=29600 NODE_RANK=1 \ + torchrun --nnodes=2 --nproc_per_node=8 \ + --rdzv-backend=c10d --rdzv-endpoint=:29600 \ + test/python/ext/ep/test_internode_multirank.py + +Round-trip dispatch + combine using internode HT kernels across nodes. +""" + +from __future__ import annotations + +import os +import sys + +import torch +import torch.distributed as dist + + +def init_dist(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ.get("LOCAL_RANK", rank % 8)) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + return rank, world_size, local_rank, dist.new_group(list(range(world_size))) + + +def inplace_unique(x: torch.Tensor, num_slots: int): + assert x.dim() == 2 + mask = x < 0 + x_padded = x.masked_fill(mask, num_slots) + bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device) + bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded)) + bin_count = bin_count[:, :num_slots] + sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True) + sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1) + sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values + x[:, :].fill_(-1) + valid_len = min(num_slots, x.size(1)) + x[:, :valid_len] = sorted_bin_idx[:, :valid_len] + + +def main(): + rank, num_ranks, local_rank, group = init_dist() + from mscclpp.ext import ep + + NUM_MAX_NVL_PEERS = 8 + assert num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS, \ + f"expected >1 node with 8 GPUs each, got num_ranks={num_ranks}" + num_nodes = num_ranks // NUM_MAX_NVL_PEERS + num_local_ranks = NUM_MAX_NVL_PEERS + + # Small settings for functional check + num_tokens = 128 + hidden = 1024 + num_topk = min(4, num_ranks) + num_experts = (num_ranks * 4) # multiple of num_ranks + + torch.manual_seed(0xA1B2 + rank) + + scores = torch.randn((num_tokens, num_experts), device="cuda", dtype=torch.float32).abs() + 1 + topk_idx = torch.topk(scores, num_topk, dim=-1, sorted=False).indices + topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") + + rank_idx = topk_idx // (num_experts // num_ranks) + rank_idx.masked_fill_(topk_idx == -1, -1) + inplace_unique(rank_idx, num_ranks) + + rdma_rank_idx = rank_idx // num_local_ranks + rdma_rank_idx.masked_fill_(rank_idx == -1, -1) + inplace_unique(rdma_rank_idx, num_nodes) + + num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda") + for i in range(num_experts): + num_tokens_per_expert[i] = (topk_idx == i).sum() + + num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda") + num_tokens_per_rdma_rank = torch.empty((num_nodes,), dtype=torch.int, device="cuda") + token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device="cuda") + for i in range(num_ranks): + num_tokens_per_rank[i] = (rank_idx == i).sum() + token_sel = (rank_idx == i).max(dim=-1).values + cnt = token_sel.sum().item() + tokens = torch.sort(token_sel.to(torch.int), descending=True).indices + tokens[:cnt] = torch.sort(tokens[:cnt]).values + token_idx_in_rank[i][tokens[:cnt]] = torch.arange(cnt, dtype=torch.long, device="cuda") + for i in range(num_nodes): + num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum() + token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int) + is_token_in_rank = token_idx_in_rank >= 0 + + x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * float(rank) + + # Buffer config for internode HT: needs num_rdma_bytes > 0. + cfg = ep.Config(20, 8, 256, 16, 128) + num_nvl_bytes = cfg.get_nvl_buffer_size_hint(hidden * x.element_size(), num_ranks) + num_rdma_bytes = cfg.get_rdma_buffer_size_hint(hidden * x.element_size(), num_ranks) + if rank == 0: + print(f"[cfg] num_nodes={num_nodes} num_ranks={num_ranks} num_tokens={num_tokens} " + f"hidden={hidden} num_experts={num_experts} num_topk={num_topk} " + f"num_nvl_bytes={num_nvl_bytes} num_rdma_bytes={num_rdma_bytes}", + flush=True) + + print(f"[rank {rank}] creating Buffer", flush=True) + buf = ep.Buffer(group, num_nvl_bytes=num_nvl_bytes, num_rdma_bytes=num_rdma_bytes, low_latency_mode=False) + print(f"[rank {rank}] Buffer created is_available={buf.is_available()} " + f"is_internode={buf.is_internode_available()}", flush=True) + assert buf.is_available() and buf.is_internode_available() + + ref_rank, ref_rdma_rank, ref_exp, ref_in_rank, _ = \ + buf.runtime.get_dispatch_layout(topk_idx, num_experts, None, False, False) + assert torch.allclose(ref_rank, num_tokens_per_rank) + assert torch.allclose(ref_rdma_rank, num_tokens_per_rdma_rank) + assert torch.allclose(ref_exp, num_tokens_per_expert) + assert torch.allclose(ref_in_rank, is_token_in_rank) + if rank == 0: + print("[layout] OK", flush=True) + dist.barrier(group=group) + + # internode_dispatch signature (non-cached mode): + # (x, x_scales, topk_idx, topk_weights, + # num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert, + # cached_num_recv_tokens=0, cached_num_rdma_recv_tokens=0, + # cached_rdma_channel_prefix_matrix=None, cached_recv_rdma_rank_prefix_sum=None, + # cached_gbl_channel_prefix_matrix=None, cached_recv_gbl_rank_prefix_sum=None, + # expert_alignment, config, previous_event, async, allocate_on_comm_stream) + (recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, + num_recv_tokens_per_expert_list, + rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, + recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, + recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, + recv_src_meta, send_rdma_head, send_nvl_head, _event) = buf.runtime.internode_dispatch( + x, None, topk_idx, topk_weights, + num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert, + 0, 0, + None, None, None, None, + 1, cfg, None, False, False, + ) + dist.barrier(group=group) + + # Validate recv buffer: for each source rank i, the block carries value i. + assert recv_x.dim() == 2 and recv_x.size(1) == hidden + start = 0 + for src in range(num_ranks): + end = recv_gbl_rank_prefix_sum[src].item() + block = recv_x[start:end] + if block.numel(): + lo = block.float().amin().item() + hi = block.float().amax().item() + assert abs(lo - src) < 1e-3 and abs(hi - src) < 1e-3, ( + f"rank{rank}: block from src={src} has range=[{lo}, {hi}], expected {src}" + ) + start = end + if rank == 0: + print(f"[dispatch] OK (recv {recv_x.size(0)} tokens)", flush=True) + + # internode_combine signature: + # (x, topk_weights, + # src_meta, is_combined_token_in_rank, + # rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, + # combined_rdma_head, combined_nvl_head, config, previous_event, async, allocate_on_comm_stream) + combined_x, combined_topk_weights, _ = buf.runtime.internode_combine( + recv_x, recv_topk_weights, + recv_src_meta, is_token_in_rank, + rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, + send_rdma_head, send_nvl_head, + cfg, None, False, False, + ) + + num_dst = is_token_in_rank.sum(dim=1).to(torch.float32) + expected = num_dst * float(rank) + got = combined_x.float().mean(dim=1) + diff = (got - expected).abs().max().item() + max_exp = expected.abs().max().item() + if rank == 0: + print(f"[combine] max|got-expected|={diff:.4e} max|expected|={max_exp:.4e}", flush=True) + assert diff < 1e-2, f"rank{rank}: combine mismatch max diff {diff}" + + dist.barrier(group=group) + if rank == 0: + print("PASS", flush=True) + + +if __name__ == "__main__": + try: + main() + except Exception: + import traceback + traceback.print_exc() + sys.exit(1)