mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
ext/ep: fix multi-rank intranode dispatch+combine
Three issues blocked end-to-end intranode validation across multiple ranks. This commit fixes them and adds a 2/4/8-rank functional test. 1. Combine receiver: OOB __shared__ read In the combine receiver warp, the wait loop evaluated `channel_tail_idx[recv_lane_id] <= expected_head` before the `expected_head >= 0` guard. `channel_tail_idx` is a shared array of size `kNumRanks`, but the loop runs on all 32 lanes of a warp, so lanes with `recv_lane_id >= kNumRanks` indexed out of bounds. compute-sanitizer reported "Invalid __shared__ read of size 4 bytes" at combine<bf16,2,768>+0xdd0, surfaced asynchronously as cudaErrorIllegalAddress at the kernel launch site. Swap the operands so the rank-bounds check short-circuits the shared read. 2. Python bindings: UniqueId ABI `mscclpp::UniqueId` is a `std::array<uint8_t, N>` which pybind11 auto-converts to a Python `list`, silently overriding any `py::class_<UniqueId>` wrapper. Expose `create_unique_id` / `connect` as lambdas that produce/consume `py::bytes` and memcpy into a local `UniqueId`. Also coerce `bytes`->`bytearray` at the Python call site for `sync()` whose signature expects `pybind11::bytearray`. 3. Python frontend: communicator required for NVL-only sync `Buffer::sync()` uses `communicator->connect(ipc_config, ...)` on the pure-NVLink path, so the communicator must be initialized even when `num_rdma_ranks == 1` and `low_latency_mode == False`. Always broadcast the unique id and call `runtime.connect()` before `sync()`. Validation on a single H100x8 node via torchrun: - 2 ranks: dispatch 195 tokens, combine diff=0 - 4 ranks: dispatch 371 tokens, combine diff=0 - 8 ranks: dispatch 456 tokens, combine diff=0 Test harness added at test/python/ext/ep/test_intranode_multirank.py.
This commit is contained in:
@@ -90,22 +90,23 @@ class Buffer:
|
||||
dist.all_gather_object(ipc_handles, local_ipc_handle, group)
|
||||
|
||||
root_unique_id: Optional[bytes] = None
|
||||
# RDMA path is guarded above; still plumb the unique-id exchange so
|
||||
# the code is ready to turn on once internode lands.
|
||||
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
|
||||
if num_qps_per_rank <= 0:
|
||||
raise ValueError("num_qps_per_rank must be > 0 for RDMA")
|
||||
# MSCCL++ requires a bootstrapped Communicator even for pure-NVLink
|
||||
# setups because `Buffer::sync()` uses `communicator->connect(ipc)`
|
||||
# to build MemoryChannels. We always exchange a unique id.
|
||||
if num_qps_per_rank <= 0:
|
||||
raise ValueError("num_qps_per_rank must be > 0")
|
||||
|
||||
if self.rank == 0:
|
||||
unique_id = self.runtime.create_unique_id()
|
||||
root_unique_id = unique_id.bytes()
|
||||
broadcast_list = [root_unique_id]
|
||||
dist.broadcast_object_list(broadcast_list, src=0, group=group)
|
||||
root_unique_id = broadcast_list[0]
|
||||
assert root_unique_id is not None
|
||||
self.runtime.connect(_cpp.UniqueId.from_bytes(root_unique_id))
|
||||
if self.rank == 0:
|
||||
root_unique_id = self.runtime.create_unique_id()
|
||||
broadcast_list = [root_unique_id]
|
||||
dist.broadcast_object_list(broadcast_list, src=0, group=group)
|
||||
root_unique_id = broadcast_list[0]
|
||||
assert root_unique_id is not None
|
||||
self.runtime.connect(root_unique_id)
|
||||
|
||||
self.runtime.sync(device_ids, ipc_handles, root_unique_id)
|
||||
# sync() expects Sequence[bytearray | None] / bytearray | None.
|
||||
ipc_handles_ba = [bytearray(h) if h is not None else None for h in ipc_handles]
|
||||
self.runtime.sync(device_ids, ipc_handles_ba, bytearray(root_unique_id))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sanity helpers
|
||||
|
||||
Reference in New Issue
Block a user