diff --git a/python/mscclpp/ext/ep/buffer.py b/python/mscclpp/ext/ep/buffer.py index 4a8ec011..238a95ef 100644 --- a/python/mscclpp/ext/ep/buffer.py +++ b/python/mscclpp/ext/ep/buffer.py @@ -90,22 +90,23 @@ class Buffer: dist.all_gather_object(ipc_handles, local_ipc_handle, group) root_unique_id: Optional[bytes] = None - # RDMA path is guarded above; still plumb the unique-id exchange so - # the code is ready to turn on once internode lands. - if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode: - if num_qps_per_rank <= 0: - raise ValueError("num_qps_per_rank must be > 0 for RDMA") + # MSCCL++ requires a bootstrapped Communicator even for pure-NVLink + # setups because `Buffer::sync()` uses `communicator->connect(ipc)` + # to build MemoryChannels. We always exchange a unique id. + if num_qps_per_rank <= 0: + raise ValueError("num_qps_per_rank must be > 0") - if self.rank == 0: - unique_id = self.runtime.create_unique_id() - root_unique_id = unique_id.bytes() - broadcast_list = [root_unique_id] - dist.broadcast_object_list(broadcast_list, src=0, group=group) - root_unique_id = broadcast_list[0] - assert root_unique_id is not None - self.runtime.connect(_cpp.UniqueId.from_bytes(root_unique_id)) + if self.rank == 0: + root_unique_id = self.runtime.create_unique_id() + broadcast_list = [root_unique_id] + dist.broadcast_object_list(broadcast_list, src=0, group=group) + root_unique_id = broadcast_list[0] + assert root_unique_id is not None + self.runtime.connect(root_unique_id) - self.runtime.sync(device_ids, ipc_handles, root_unique_id) + # sync() expects Sequence[bytearray | None] / bytearray | None. + ipc_handles_ba = [bytearray(h) if h is not None else None for h in ipc_handles] + self.runtime.sync(device_ids, ipc_handles_ba, bytearray(root_unique_id)) # ------------------------------------------------------------------ # Sanity helpers diff --git a/src/ext/ep/bindings.cpp b/src/ext/ep/bindings.cpp index 0a7018c8..e721e3d2 100644 --- a/src/ext/ep/bindings.cpp +++ b/src/ext/ep/bindings.cpp @@ -34,23 +34,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def(py::init<>()) .def("current_stream_wait", &mscclpp::ep::EventHandle::current_stream_wait); - // NOTE: `mscclpp::UniqueId` is the bootstrap id used for connecting the - // proxy service. We expose it as an opaque bytes-like object so Python can - // all-gather it across the user's process group. - py::class_(m, "UniqueId") - .def(py::init<>()) - .def("bytes", [](const mscclpp::UniqueId& self) { - return py::bytes(reinterpret_cast(self.data()), self.size()); - }) - .def_static("from_bytes", [](py::bytes data) { - auto s = std::string(data); - mscclpp::UniqueId uid; - if (s.size() != uid.size()) { - throw std::runtime_error("mscclpp.ep.UniqueId.from_bytes: size mismatch"); - } - std::memcpy(uid.data(), s.data(), s.size()); - return uid; - }); + // NOTE: `mscclpp::UniqueId` is `std::array`, which pybind11 + // implicitly converts to a Python list. We therefore avoid exposing it as + // a py::class_ and convert to/from `py::bytes` at the binding boundary. py::class_(m, "Buffer") .def(py::init(), py::arg("rank"), py::arg("num_ranks"), @@ -64,8 +50,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("get_local_ipc_handle", &mscclpp::ep::Buffer::get_local_ipc_handle) .def("get_local_nvshmem_unique_id", &mscclpp::ep::Buffer::get_local_nvshmem_unique_id) .def("get_local_buffer_tensor", &mscclpp::ep::Buffer::get_local_buffer_tensor) - .def("create_unique_id", &mscclpp::ep::Buffer::create_unique_id) - .def("connect", &mscclpp::ep::Buffer::connect) + .def("create_unique_id", + [](const mscclpp::ep::Buffer& self) { + auto uid = self.create_unique_id(); + return py::bytes(reinterpret_cast(uid.data()), uid.size()); + }) + .def("connect", + [](mscclpp::ep::Buffer& self, py::bytes data) { + std::string s = data; + mscclpp::UniqueId uid; + if (s.size() != uid.size()) { + throw std::runtime_error("mscclpp_ep_cpp.Buffer.connect: UniqueId size mismatch"); + } + std::memcpy(uid.data(), s.data(), s.size()); + self.connect(uid); + }) .def("sync", &mscclpp::ep::Buffer::sync) .def("get_dispatch_layout", &mscclpp::ep::Buffer::get_dispatch_layout) .def("intranode_dispatch", &mscclpp::ep::Buffer::intranode_dispatch) diff --git a/src/ext/ep/kernels/intranode_kernel.cu b/src/ext/ep/kernels/intranode_kernel.cu index b3f23442..b0960662 100644 --- a/src/ext/ep/kernels/intranode_kernel.cu +++ b/src/ext/ep/kernels/intranode_kernel.cu @@ -720,7 +720,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights, expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id); auto start_time = clock64(); - while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) { + while (expected_head >= 0 and channel_tail_idx[recv_lane_id] <= expected_head) { // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head); diff --git a/test/python/ext/ep/test_intranode_multirank.py b/test/python/ext/ep/test_intranode_multirank.py new file mode 100644 index 00000000..493d5689 --- /dev/null +++ b/test/python/ext/ep/test_intranode_multirank.py @@ -0,0 +1,179 @@ +"""Multi-rank intranode functional validation for mscclpp_ep. + +Launch with: + torchrun --nproc_per_node= test/python/ext/ep/test_intranode_multirank.py + +Tests that Buffer::sync() succeeds across N GPUs on a single node and that +a round-trip dispatch + combine preserves data (sum of top-k weighted copies). + +This is a minimal adaptation of DeepEP's tests/test_intranode.py stripped +to exercise only the code paths we've ported. +""" + +from __future__ import annotations + +import os +import sys + +import torch +import torch.distributed as dist + + +def init_dist(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ.get("LOCAL_RANK", rank)) + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="nccl", + init_method=f"tcp://{os.environ.get('MASTER_ADDR','127.0.0.1')}:{os.environ.get('MASTER_PORT','29500')}", + world_size=world_size, + rank=rank, + ) + return rank, world_size, local_rank, dist.new_group(list(range(world_size))) + + +def inplace_unique(x: torch.Tensor, num_slots: int): + assert x.dim() == 2 + mask = x < 0 + x_padded = x.masked_fill(mask, num_slots) + bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device) + bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded)) + bin_count = bin_count[:, :num_slots] + sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True) + sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1) + sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values + x[:, :].fill_(-1) + valid_len = min(num_slots, x.size(1)) + x[:, :valid_len] = sorted_bin_idx[:, :valid_len] + + +def main(): + rank, num_ranks, local_rank, group = init_dist() + from mscclpp.ext import ep + + # Small settings for functional check + num_tokens = 128 + hidden = 1024 + num_topk = min(4, num_ranks) + num_experts = num_ranks * 4 + + torch.manual_seed(0xA1B2 + rank) + + # Build topk layout that maps each token to num_topk distinct ranks/experts + scores = torch.randn((num_tokens, num_experts), device="cuda", dtype=torch.float32).abs() + 1 + topk_idx = torch.topk(scores, num_topk, dim=-1, sorted=False).indices + topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") + + rank_idx = topk_idx // (num_experts // num_ranks) + rank_idx.masked_fill_(topk_idx == -1, -1) + inplace_unique(rank_idx, num_ranks) + + # Expert / rank meta + num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda") + for i in range(num_experts): + num_tokens_per_expert[i] = (topk_idx == i).sum() + + num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda") + token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device="cuda") + for i in range(num_ranks): + num_tokens_per_rank[i] = (rank_idx == i).sum() + token_sel = (rank_idx == i).max(dim=-1).values + cnt = token_sel.sum().item() + tokens = torch.sort(token_sel.to(torch.int), descending=True).indices + tokens[:cnt] = torch.sort(tokens[:cnt]).values + token_idx_in_rank[i][tokens[:cnt]] = torch.arange(cnt, dtype=torch.long, device="cuda") + token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int) + is_token_in_rank = token_idx_in_rank >= 0 + + # Token payload = rank id (cast to bf16) so we can check correctness + x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * float(rank) + + # Allocate Buffer (intranode only: num_rdma_bytes=0) + cfg = ep.Config(20, 8, 256) + num_nvl_bytes = cfg.get_nvl_buffer_size_hint(hidden * x.element_size(), num_ranks) + if rank == 0: + print(f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} " + f"num_experts={num_experts} num_topk={num_topk} num_nvl_bytes={num_nvl_bytes}", + flush=True) + + print(f"[rank {rank}] creating Buffer", flush=True) + buf = ep.Buffer(group, num_nvl_bytes=num_nvl_bytes, num_rdma_bytes=0, low_latency_mode=False) + print(f"[rank {rank}] Buffer created is_available={buf.is_available()}", flush=True) + assert buf.is_available() + + # get_dispatch_layout sanity + ref_rank, _, ref_exp, ref_in_rank, _ = buf.runtime.get_dispatch_layout(topk_idx, num_experts, None, False, False) + assert torch.allclose(ref_rank, num_tokens_per_rank) + assert torch.allclose(ref_exp, num_tokens_per_expert) + assert torch.allclose(ref_in_rank, is_token_in_rank) + if rank == 0: + print("[layout] OK", flush=True) + + # Dispatch + (recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, + num_recv_tokens_per_expert_list, + rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, + send_head, _event) = buf.runtime.intranode_dispatch( + x, None, topk_idx, topk_weights, + num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, + 0, None, None, + 1, cfg, None, False, False, + ) + dist.barrier(group=group) + + # Validate received payloads: for each source rank i, the block of tokens + # we received from it should be filled with `i`. + assert recv_x.dim() == 2 and recv_x.size(1) == hidden + start = 0 + for src in range(num_ranks): + end = rank_prefix_matrix[src][rank].item() + block = recv_x[start:end] + if block.numel(): + actual = block.float().amin().item() + assert abs(actual - src) < 1e-3, ( + f"rank{rank}: block from src={src} has min={actual}, expected {src}" + ) + assert abs(block.float().amax().item() - src) < 1e-3 + start = end + if rank == 0: + print(f"[dispatch] OK (recv {recv_x.size(0)} tokens)", flush=True) + + # Combine (scatter-reduce back). Using recv_topk_weights=None path with + # dispatched tokens unchanged => every source rank should receive its + # contribution back, unweighted sum across topk copies. + handle_recv_src_idx = recv_src_idx + handle_rank_prefix_matrix = rank_prefix_matrix + handle_channel_prefix_matrix = recv_channel_prefix_matrix + + combined_x, combined_topk_weights, _ = buf.runtime.intranode_combine( + recv_x, recv_topk_weights, + handle_recv_src_idx, handle_rank_prefix_matrix, handle_channel_prefix_matrix, + send_head, cfg, None, False, False, + ) + + # Expected: we dispatched with x = rank * ones, so every destination r + # received the value `rank` for our token. On combine the destinations + # send that value back and we sum: combined[t] = rank * (#destinations). + num_dst = is_token_in_rank.sum(dim=1).to(torch.float32) + expected = num_dst * float(rank) + + got = combined_x.float().mean(dim=1) + diff = (got - expected).abs().max().item() + max_exp = expected.abs().max().item() + if rank == 0: + print(f"[combine] max|got-expected|={diff:.4e} max|expected|={max_exp:.4e}", flush=True) + assert diff < 1e-2, f"rank{rank}: combine mismatch max diff {diff}" + + dist.barrier(group=group) + if rank == 0: + print("PASS", flush=True) + + +if __name__ == "__main__": + try: + main() + except Exception: + import traceback + traceback.print_exc() + sys.exit(1)