Files
mscclpp/test/python/ext/ep/test_ep_smoke.py
Qinghua Zhou 453160cc06 src/ext/ep: port low-latency dispatch/combine kernels
Port DeepEP's pure-RDMA low-latency (LL) MoE kernels from
csrc/kernels/internode_ll.cu (branch chhwang/dev-atomic-add-cleanup)
into the MSCCL++ EP extension. NVSHMEM / IBGDA device primitives are
replaced with MSCCL++ PortChannelDeviceHandle operations:

  nvshmemx_barrier_all_block()            -> port-channel signal+wait ring
  nvshmemi_ibgda_put_nbi_warp(...)        -> lane-0 PortChannel.put(...)
  nvshmemi_ibgda_amo_nonfetch_add(...)    -> lane-0 PortChannel.atomicAdd(...)

The atomicAdd path relies on the MSCCL++ Connection::atomicAdd /
PortChannelDeviceHandle::atomicAdd API cherry-picked from branch
chhwang/new-atomic-add; the LL dispatch path uses a signed delta
(-num_tokens_sent - 1) which the new int64_t signature supports.

Changes:
* New file src/ext/ep/kernels/internode_ll.cu (~530 lines) with the
  three kernels clean_low_latency_buffer, dispatch<kUseFP8,...>,
  combine<...> plus their launchers. rdma_buffer_ptr is threaded
  through the launchers so the kernel can translate virtual addresses
  into registered-memory offsets expected by MSCCL++.
* kernels/api.cuh: replace the single stub signature with full LL
  launcher prototypes.
* buffer.cc: replace the four LL throw-stubs
  (clean_low_latency_buffer, low_latency_dispatch,
  low_latency_combine, get_next_low_latency_combine_buffer) with
  torch-Tensor implementations ported from DeepEP/csrc/deep_ep.cpp.
* Drop src/ext/ep/internode_stub.cc and its CMake entry.
* python/mscclpp/ext/ep/buffer.py: remove the low_latency_mode=True
  NotImplementedError guard; update docstring.
* test/python/ext/ep/test_ep_smoke.py: rename
  test_low_latency_rejected -> test_low_latency_buffer_construct
  to reflect that LL construction is now accepted.
* src/ext/ep/README.md: update status matrix, document the
  NVSHMEM -> MSCCL++ translation table, and list the known
  limitations.

This is a structural port: the kernels compile, link, and pass the
single-rank smoke tests, but end-to-end behaviour on multi-node H100
is not yet validated. Two known caveats:

  1. Performance will NOT match IBGDA because MSCCL++ port channels
     use a CPU proxy; this port is for functional parity, not latency.
  2. Buffer::sync() in LL mode only connects peers that share the
     same local GPU id (DeepEP convention), so the LL kernels assume
     a one-GPU-per-node topology (num_ranks == num_rdma_ranks).
     Multi-GPU-per-node LL layouts will need a follow-up in sync().

Tested:
  cmake --build build -j --target mscclpp_ep_cpp   # builds clean
  pytest test/python/ext/ep/test_ep_smoke.py        # 3 passed
2026-04-20 21:46:00 +00:00

47 lines
1.6 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Smoke tests for the EP extension.
These tests only exercise single-rank / pure-Python code paths so they can
run in CI without multi-GPU resources. Multi-rank dispatch/combine tests
belong in ``test/python/ext/ep/test_intranode.py`` and are left as TODO
until the Python frontend is validated on H100.
Run with::
pytest -xvs test/python/ext/ep/test_ep_smoke.py
"""
from __future__ import annotations
import pytest
try:
import mscclpp_ep_cpp as _cpp # type: ignore[import-not-found]
except ImportError: # pragma: no cover
pytest.skip("mscclpp_ep_cpp is not built (set -DMSCCLPP_BUILD_EXT_EP=ON)", allow_module_level=True)
def test_config_roundtrip():
cfg = _cpp.Config(num_sms=20, num_max_nvl_chunked_send_tokens=6, num_max_nvl_chunked_recv_tokens=256,
num_max_rdma_chunked_send_tokens=6, num_max_rdma_chunked_recv_tokens=256)
hint = cfg.get_nvl_buffer_size_hint(7168 * 2, 8)
assert hint > 0
def test_low_latency_size_hint():
assert _cpp.get_low_latency_rdma_size_hint(128, 7168, 8, 256) > 0
def test_low_latency_buffer_construct():
# Low-latency kernels are structurally ported. At construction time the
# C++ Buffer must accept low_latency_mode=True; runtime requires a real
# multi-node setup (see tests in tests/test_low_latency.py when ported).
import torch
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
buf = _cpp.Buffer(rank=0, num_ranks=1, num_nvl_bytes=0, num_rdma_bytes=0, low_latency_mode=True)
assert not buf.is_available()