Commit Graph

3 Commits

Author SHA1 Message Date
Qinghua Zhou
a6af3a4454 ext/ep: fix multi-rank intranode dispatch+combine
Three issues blocked end-to-end intranode validation across multiple
ranks. This commit fixes them and adds a 2/4/8-rank functional test.

1. Combine receiver: OOB __shared__ read

   In the combine receiver warp, the wait loop evaluated
   `channel_tail_idx[recv_lane_id] <= expected_head` before the
   `expected_head >= 0` guard. `channel_tail_idx` is a shared array
   of size `kNumRanks`, but the loop runs on all 32 lanes of a warp,
   so lanes with `recv_lane_id >= kNumRanks` indexed out of bounds.
   compute-sanitizer reported "Invalid __shared__ read of size 4
   bytes" at combine<bf16,2,768>+0xdd0, surfaced asynchronously as
   cudaErrorIllegalAddress at the kernel launch site. Swap the
   operands so the rank-bounds check short-circuits the shared read.

2. Python bindings: UniqueId ABI

   `mscclpp::UniqueId` is a `std::array<uint8_t, N>` which pybind11
   auto-converts to a Python `list`, silently overriding any
   `py::class_<UniqueId>` wrapper. Expose `create_unique_id` /
   `connect` as lambdas that produce/consume `py::bytes` and memcpy
   into a local `UniqueId`. Also coerce `bytes`->`bytearray` at the
   Python call site for `sync()` whose signature expects
   `pybind11::bytearray`.

3. Python frontend: communicator required for NVL-only sync

   `Buffer::sync()` uses `communicator->connect(ipc_config, ...)` on
   the pure-NVLink path, so the communicator must be initialized
   even when `num_rdma_ranks == 1` and `low_latency_mode == False`.
   Always broadcast the unique id and call `runtime.connect()`
   before `sync()`.

Validation on a single H100x8 node via torchrun:
- 2 ranks: dispatch 195 tokens, combine diff=0
- 4 ranks: dispatch 371 tokens, combine diff=0
- 8 ranks: dispatch 456 tokens, combine diff=0

Test harness added at test/python/ext/ep/test_intranode_multirank.py.
2026-04-21 02:03:55 +00:00
Qinghua Zhou
453160cc06 src/ext/ep: port low-latency dispatch/combine kernels
Port DeepEP's pure-RDMA low-latency (LL) MoE kernels from
csrc/kernels/internode_ll.cu (branch chhwang/dev-atomic-add-cleanup)
into the MSCCL++ EP extension. NVSHMEM / IBGDA device primitives are
replaced with MSCCL++ PortChannelDeviceHandle operations:

  nvshmemx_barrier_all_block()            -> port-channel signal+wait ring
  nvshmemi_ibgda_put_nbi_warp(...)        -> lane-0 PortChannel.put(...)
  nvshmemi_ibgda_amo_nonfetch_add(...)    -> lane-0 PortChannel.atomicAdd(...)

The atomicAdd path relies on the MSCCL++ Connection::atomicAdd /
PortChannelDeviceHandle::atomicAdd API cherry-picked from branch
chhwang/new-atomic-add; the LL dispatch path uses a signed delta
(-num_tokens_sent - 1) which the new int64_t signature supports.

Changes:
* New file src/ext/ep/kernels/internode_ll.cu (~530 lines) with the
  three kernels clean_low_latency_buffer, dispatch<kUseFP8,...>,
  combine<...> plus their launchers. rdma_buffer_ptr is threaded
  through the launchers so the kernel can translate virtual addresses
  into registered-memory offsets expected by MSCCL++.
* kernels/api.cuh: replace the single stub signature with full LL
  launcher prototypes.
* buffer.cc: replace the four LL throw-stubs
  (clean_low_latency_buffer, low_latency_dispatch,
  low_latency_combine, get_next_low_latency_combine_buffer) with
  torch-Tensor implementations ported from DeepEP/csrc/deep_ep.cpp.
* Drop src/ext/ep/internode_stub.cc and its CMake entry.
* python/mscclpp/ext/ep/buffer.py: remove the low_latency_mode=True
  NotImplementedError guard; update docstring.
* test/python/ext/ep/test_ep_smoke.py: rename
  test_low_latency_rejected -> test_low_latency_buffer_construct
  to reflect that LL construction is now accepted.
* src/ext/ep/README.md: update status matrix, document the
  NVSHMEM -> MSCCL++ translation table, and list the known
  limitations.

This is a structural port: the kernels compile, link, and pass the
single-rank smoke tests, but end-to-end behaviour on multi-node H100
is not yet validated. Two known caveats:

  1. Performance will NOT match IBGDA because MSCCL++ port channels
     use a CPU proxy; this port is for functional parity, not latency.
  2. Buffer::sync() in LL mode only connects peers that share the
     same local GPU id (DeepEP convention), so the LL kernels assume
     a one-GPU-per-node topology (num_ranks == num_rdma_ranks).
     Multi-GPU-per-node LL layouts will need a follow-up in sync().

Tested:
  cmake --build build -j --target mscclpp_ep_cpp   # builds clean
  pytest test/python/ext/ep/test_ep_smoke.py        # 3 passed
2026-04-20 21:46:00 +00:00
Qinghua Zhou
88425a6771 Add Expert-Parallel (MoE dispatch/combine) extension under src/ext/ep
Port DeepEP's high-throughput MoE dispatch/combine kernels onto MSCCL++
as an optional build target `mscclpp_ep_cpp`, gated by -DMSCCLPP_BUILD_EXT_EP
(OFF by default). Sources are lifted from DeepEP branch
`chhwang/dev-atomic-add-cleanup` and rebased onto upstream MSCCL++ APIs;
the NVSHMEM / IBGDA dependencies are replaced with `PortChannel` +
`MemoryChannel` + the new `Connection::atomicAdd` primitive.

Scope
-----
Intranode (NVLink-only):
  * `Buffer` ctor/dtor: cudaMalloc nvl workspace, export IPC handle,
    allocate FIFO + peer-pointer tables, start `ProxyService`.
  * `sync()`: import peer IPC handles, upload peer pointer table,
    build `MemoryDevice2DeviceSemaphore` + `MemoryChannel` per peer.
  * `get_dispatch_layout`, `intranode_dispatch`, `intranode_combine`
    ported verbatim (torch::Tensor ABI preserved).

Internode HT (NVLink + RDMA):
  * `sync()` RDMA branch: cudaMalloc RDMA buffer + `bootstrap->barrier()`
    (replacing NVSHMEM symmetric-heap allocation); register with
    `all_transport`, exchange via `sendMemory`/`recvMemory`, build 12 IB
    QPs/peer + 16 semaphores/peer + 16 port channels/peer.
  * Full `internode.cu` port (notify_dispatch / dispatch / cached_notify
    / combine / get_dispatch_layout). The 4 raw `ChannelTrigger` atomic
    sites are rewritten to call the new
    `PortChannelDeviceHandle::atomicAdd(offset, value)` API; the single
    `nvshmem_fence()` is replaced with `__threadfence_system()` (remote
    visibility guaranteed by the subsequent port-channel barrier).
  * `internode_dispatch` / `internode_combine` host code ported, with
    the torch tensor marshalling and CPU spin-wait on mapped counters.

Low-latency (pure RDMA):
  * Not ported. `low_latency_dispatch`, `low_latency_combine`,
    `clean_low_latency_buffer`, `get_next_low_latency_combine_buffer`
    throw `std::runtime_error`; the Python frontend refuses to
    construct a Buffer with `low_latency_mode=True`.

Python layer
------------
* New pybind11 + libtorch Python extension `mscclpp_ep_cpp` (separate
  from the nanobind `_mscclpp` because the EP ABI carries
  `torch::Tensor` / `at::cuda::CUDAStream`).
* `mscclpp.ext.ep.Buffer` mirrors `deep_ep.Buffer`; exchanges device
  IDs, IPC handles and the bootstrap UniqueId over the user's
  `torch.distributed` process group before calling `sync()`.
* `mscclpp.ext` auto-imports `ep` if the extension is built.

Build
-----
* `src/ext/ep/CMakeLists.txt`: finds Python + Torch; warns and skips if
  `CMAKE_PREFIX_PATH` doesn't point at `torch.utils.cmake_prefix_path`.
  Falls back to Torch's bundled pybind11 if a standalone pybind11 is not
  installed. Links `libtorch_python` explicitly (without it, `import
  mscclpp_ep_cpp` fails with `undefined symbol: THPDtypeType`).
* Top-level `CMakeLists.txt` exposes the `MSCCLPP_BUILD_EXT_EP` option
  (default OFF).

Tests
-----
* `test/python/ext/ep/test_ep_smoke.py`: skipped if the extension isn't
  built. Covers Config round-trip, low-latency size hint, and the LL
  construction guard. Multi-rank functional tests still to do on H100.

Notes
-----
* Builds against the preceding "atomic add" commit which adds
  `Connection::atomicAdd` and `PortChannelDeviceHandle::atomicAdd` to
  upstream MSCCL++.
* Intranode path verified end-to-end (build + import + smoke tests).
* Internode HT is code-complete but requires real IB hardware to
  validate; see `src/ext/ep/README.md` for the detailed port plan and
  remaining LL migration.
2026-04-20 20:15:23 +00:00