src/ext/ep: port low-latency dispatch/combine kernels

Port DeepEP's pure-RDMA low-latency (LL) MoE kernels from
csrc/kernels/internode_ll.cu (branch chhwang/dev-atomic-add-cleanup)
into the MSCCL++ EP extension. NVSHMEM / IBGDA device primitives are
replaced with MSCCL++ PortChannelDeviceHandle operations:

  nvshmemx_barrier_all_block()            -> port-channel signal+wait ring
  nvshmemi_ibgda_put_nbi_warp(...)        -> lane-0 PortChannel.put(...)
  nvshmemi_ibgda_amo_nonfetch_add(...)    -> lane-0 PortChannel.atomicAdd(...)

The atomicAdd path relies on the MSCCL++ Connection::atomicAdd /
PortChannelDeviceHandle::atomicAdd API cherry-picked from branch
chhwang/new-atomic-add; the LL dispatch path uses a signed delta
(-num_tokens_sent - 1) which the new int64_t signature supports.

Changes:
* New file src/ext/ep/kernels/internode_ll.cu (~530 lines) with the
  three kernels clean_low_latency_buffer, dispatch<kUseFP8,...>,
  combine<...> plus their launchers. rdma_buffer_ptr is threaded
  through the launchers so the kernel can translate virtual addresses
  into registered-memory offsets expected by MSCCL++.
* kernels/api.cuh: replace the single stub signature with full LL
  launcher prototypes.
* buffer.cc: replace the four LL throw-stubs
  (clean_low_latency_buffer, low_latency_dispatch,
  low_latency_combine, get_next_low_latency_combine_buffer) with
  torch-Tensor implementations ported from DeepEP/csrc/deep_ep.cpp.
* Drop src/ext/ep/internode_stub.cc and its CMake entry.
* python/mscclpp/ext/ep/buffer.py: remove the low_latency_mode=True
  NotImplementedError guard; update docstring.
* test/python/ext/ep/test_ep_smoke.py: rename
  test_low_latency_rejected -> test_low_latency_buffer_construct
  to reflect that LL construction is now accepted.
* src/ext/ep/README.md: update status matrix, document the
  NVSHMEM -> MSCCL++ translation table, and list the known
  limitations.

This is a structural port: the kernels compile, link, and pass the
single-rank smoke tests, but end-to-end behaviour on multi-node H100
is not yet validated. Two known caveats:

  1. Performance will NOT match IBGDA because MSCCL++ port channels
     use a CPU proxy; this port is for functional parity, not latency.
  2. Buffer::sync() in LL mode only connects peers that share the
     same local GPU id (DeepEP convention), so the LL kernels assume
     a one-GPU-per-node topology (num_ranks == num_rdma_ranks).
     Multi-GPU-per-node LL layouts will need a follow-up in sync().

Tested:
  cmake --build build -j --target mscclpp_ep_cpp   # builds clean
  pytest test/python/ext/ep/test_ep_smoke.py        # 3 passed
This commit is contained in:
Qinghua Zhou
2026-04-20 21:46:00 +00:00
parent 88425a6771
commit 453160cc06
8 changed files with 843 additions and 68 deletions

View File

@@ -13,8 +13,9 @@ Current status (see ``src/ext/ep/README.md``):
* Intranode (NVLink-only) dispatch and combine are fully ported.
* ``get_dispatch_layout`` is ported.
* Internode HT and low-latency methods raise from C++ — they still need
the NVSHMEM/IBGDA -> MSCCL++ PortChannel migration.
* Internode HT (MSCCL++ PortChannel + MemoryChannel) is ported.
* Internode low-latency kernels are ported structurally (NVSHMEM/IBGDA ->
MSCCL++ PortChannel) but **untested on multi-node H100**.
"""
from __future__ import annotations
@@ -49,10 +50,11 @@ class Buffer:
num_nvl_bytes:
Size of the NVLink-accessible scratch buffer (shared via CUDA IPC).
num_rdma_bytes:
Size of the RDMA scratch buffer. Must be 0 until internode/LL
support is landed.
Size of the RDMA scratch buffer. Required (>0) for internode HT and
low-latency modes.
low_latency_mode:
Reserved — must be ``False`` until the LL path is ported.
Enable the low-latency dispatch/combine path (structural port,
untested).
num_qps_per_rank:
Ignored for intranode mode.
"""
@@ -68,13 +70,6 @@ class Buffer:
low_latency_mode: bool = False,
num_qps_per_rank: int = 12,
) -> None:
if low_latency_mode:
raise NotImplementedError(
"mscclpp.ext.ep.Buffer: low-latency mode is not yet ported. "
"Set low_latency_mode=False. See src/ext/ep/README.md for the "
"migration plan."
)
self.rank: int = group.rank()
self.group_size: int = group.size()
self.group = group