diff --git a/src/ext/ep/README.md b/src/ext/ep/README.md index b5184ebd..9ba5e315 100644 --- a/src/ext/ep/README.md +++ b/src/ext/ep/README.md @@ -15,12 +15,13 @@ targeting: |------------------------------------|---------------------------------------------| | `Buffer` construction + IPC + sync | ✅ ported (NVLink + RDMA) | | `get_dispatch_layout` | ✅ ported | -| `intranode_dispatch` (NVLink) | ✅ validated (8 ranks, 1 node) | -| `intranode_combine` (NVLink) | ✅ validated (8 ranks, 1 node) | -| `internode_dispatch` (NVLink+RDMA) | ✅ validated (16 ranks, 2×H100×8) | -| `internode_combine` (NVLink+RDMA) | ✅ validated (16 ranks, 2×H100×8) | -| `low_latency_dispatch` (RDMA+IPC) | ✅ validated (8 ranks intra-node; 16 ranks 2×H100) | -| `low_latency_combine` (RDMA+IPC) | ✅ validated (8 ranks intra-node; 16 ranks 2×H100) | +| `intranode_dispatch` (NVLink) | ✅ validated (8 ranks H100; 4 ranks GB200) | +| `intranode_combine` (NVLink) | ✅ validated (8 ranks H100; 4 ranks GB200) | +| `internode_dispatch` (NVLink+RDMA) | ✅ validated (16 ranks 2×H100×8; 64 ranks 16×GB200) | +| `internode_combine` (NVLink+RDMA) | ✅ validated (16 ranks 2×H100×8; 64 ranks 16×GB200) | +| `low_latency_dispatch` (RDMA+IPC) | ✅ validated (8 ranks H100; 16 ranks 2×H100×8; 64 ranks 16×GB200 via NVLS fabric IPC) | +| `low_latency_combine` (RDMA+IPC) | ✅ validated (8 ranks H100; 16 ranks 2×H100×8; 64 ranks 16×GB200 via NVLS fabric IPC) | +| GB200 NVLS multimem fast path | ✅ runtime-gated by `mscclpp::isNvlsSupported()` | | Multi-`ProxyService` sharding | ✅ env-tunable, arch-aware default | | `Connection::atomicAdd` API | ✅ cherry-picked into mscclpp | | Python frontend `mscclpp.ext.ep` | ✅ wraps HT + LL paths | @@ -31,20 +32,45 @@ Infiniband using [`test/python/ext/ep/test_internode_multirank.py`](../../../tes All 16 ranks complete dispatch followed by combine with exact (zero-diff) recovery of the per-rank token payloads. +On Azure GB200 NVL72 (4 GPUs / NUMA host, CX-7 RoCE), HT and LL were +validated at 16 nodes × 4 GPUs = **64 ranks** with HIDDEN=7168, +tokens=4096, experts=256, top-k=8: + +- HT internode: dispatch ~**2 006 GB/s** agg, combine ~**2 011 GB/s** agg + (`NVL_SEND=8 NVL_RECV=256 RDMA_SEND=8 RDMA_RECV=32`). +- LL internode: dispatch ~**16 817 GB/s** agg (~262 GB/s per rank), + combine ~**21 148 GB/s** agg (defaults). + +The GB200 path bypasses Azure CX-7 RoCE's broken `IBV_ATOMIC_*` by +routing peer pointers through cuMem fabric IPC over the NVL72 fabric +(`nvidia-imex`) and emitting NVLink-SHARP `multimem.*` atomics from the +kernels. The legacy RDMA-atomic PortChannel path is retained as a +fallback when `mscclpp::isNvlsSupported()` returns `false`. + The low-latency (LL) path uses a mixed transport: `MemoryChannel` (CUDA IPC) for same-node peers and `PortChannel` (CPU proxy + IB verbs) for -remote peers. The DeepEP LL kernels were translated as follows: +remote peers. On GB200 NVL72, cross-node peers are *also* reached via +CUDA-IPC — peer pointers are exchanged as cuMem fabric handles over +`nvidia-imex` and the kernels emit NVLink-SHARP `multimem.*` atomics +directly on the NVL72 fabric, bypassing CX-7 RoCE's broken IB atomics. +The DeepEP LL kernels were translated as follows: | DeepEP / IBGDA | MSCCL++ replacement | |------------------------------------------|------------------------------------------------------------------| | `nvshmemx_barrier_all_block()` | signal + wait ring across per-peer channel handles | | `nvshmemi_ibgda_put_nbi_warp(...)` (intra-node) | `MemoryChannelDeviceHandle::put` (CUDA IPC, no proxy) | -| `nvshmemi_ibgda_put_nbi_warp(...)` (inter-node) | lane-0 `PortChannelDeviceHandle::put(dst_off, src_off, n)` | -| `nvshmemi_ibgda_amo_nonfetch_add(...)` | lane-0 `atomicAdd` on the corresponding channel handle | +| `nvshmemi_ibgda_put_nbi_warp(...)` (inter-node, NVLS) | direct `st.global` on imported cuMem-fabric peer base (GB200) | +| `nvshmemi_ibgda_put_nbi_warp(...)` (inter-node, fallback) | lane-0 `PortChannelDeviceHandle::put(dst_off, src_off, n)` | +| `nvshmemi_ibgda_amo_nonfetch_add(...)` (NVLS) | `multimem.red.add.u64` on NVL72 multicast counter (GB200) | +| `nvshmemi_ibgda_amo_nonfetch_add(...)` (fallback) | lane-0 `atomicAdd` on the corresponding channel handle | LL was validated on: - 8 ranks × 1 H100 node (NVLink + CUDA-IPC fast path). - 16 ranks × 2 H100×8 nodes (mixed CUDA-IPC intra-node + IB inter-node). +- 4 ranks × 1 Azure GB200 NVL72 node (NVLink + CUDA-IPC). +- 64 ranks × 16 Azure GB200 NVL72 nodes (intra-node CUDA-IPC + cross-node + cuMem fabric IPC via `nvidia-imex` + NVLS `multimem.*` atomics); + dispatch ~16.8 TB/s agg, combine ~21.1 TB/s agg. ### `num_proxy_services` / proxy sharding @@ -96,11 +122,116 @@ from mscclpp.ext import ep buf = ep.Buffer(group, num_nvl_bytes=..., num_rdma_bytes=...) ``` +### Build-time CMake options + +| Variable | Default | Meaning | +|---------------------------------------|---------|---------------------------------------------------------------| +| `MSCCLPP_BUILD_EXT_EP` | `OFF` | Build the EP extension at all | +| `MSCCLPP_EP_NUM_MAX_NVL_PEERS` | `8` | Compile-time `NUM_MAX_NVL_PEERS` — set to `4` for GB200 NVL72 | +| `MSCCLPP_EP_KERNEL_DEBUG_TIMEOUT` | `OFF` | Use a short ~10s kernel spin timeout (default is ~100s) | + +### Azure GB200 (NVL72, 4 GPUs / NUMA host) + +GB200 NVL72 nodes expose **4 GPUs per NUMA host** (not 8 like HGX +H100), and the cross-node atomic fast-path uses NVLink-SHARP +multicast (`multimem.*` PTX) routed over the NVL72 fabric via +nvidia-imex instead of broken IB atomics on Azure CX-7 RoCE. Two +build-time settings are required: + +```bash +# Option 1: plain CMake. +cmake -S . -B build \ + -DMSCCLPP_BUILD_EXT_EP=ON \ + -DMSCCLPP_EP_NUM_MAX_NVL_PEERS=4 \ + -DCMAKE_PREFIX_PATH="$(python -c 'import torch; print(torch.utils.cmake_prefix_path)')" +cmake --build build -j + +# Option 2: wheel-based install (pyproject.toml already sets +# MSCCLPP_BUILD_EXT_EP=ON). +CMAKE_PREFIX_PATH="$(python -c 'import torch; print(torch.utils.cmake_prefix_path)')" \ +python3 -m pip install --no-build-isolation \ + --config-settings=cmake.define.MSCCLPP_EP_NUM_MAX_NVL_PEERS=4 \ + . +``` + +Add `-DMSCCLPP_EP_KERNEL_DEBUG_TIMEOUT=ON` only when triaging hangs (it +shortens the kernel-side spin timeout from ~100s to ~10s). + +Runtime prerequisites on GB200: + +- CUDA Toolkit ≥ 12.5 (the `cuCtxCreate` proxy-context path uses the + 4-arg signature added in 12.5; older toolkits compile against the + 3-arg fallback automatically). +- Driver ≥ 555 with nvidia-imex configured so cuMem fabric handles + (`POSIX_FD | FABRIC`) can be exchanged across nodes. +- NVLink-SHARP / multicast support enabled (`nvidia-smi mig … --imex` + reachable). `mscclpp::isNvlsSupported()` must return `true` at + Buffer construction; otherwise the kernels fall back to the legacy + PortChannel + RDMA path (and on Azure CX-7 RoCE the broken IB + atomics will hang). +- `MSCCLPP_EP_LOCAL_WORLD_SIZE` partitions ranks into NUMA hosts; it + defaults to the build-time `NUM_MAX_NVL_PEERS`, so a GB200 build + (`-DMSCCLPP_EP_NUM_MAX_NVL_PEERS=4`) auto-uses 4 and **does not + require** setting this env var. Only set `MSCCLPP_EP_LOCAL_WORLD_SIZE=4` + if you are running on GB200 against a stock build that still has + `NUM_MAX_NVL_PEERS=8` (otherwise host code mis-classifies cross-node + peers as local and `cudaIpcOpenMemHandle` fails). +- RT priority is required by NCCL/glibc. On each node: + + ```bash + sudo tee -a /etc/security/limits.conf > /dev/null <<'EOF' + * soft rtprio 99 + * hard rtprio 99 + EOF + # Re-login so `ulimit -r` reports 99. + ``` +- `nvidia-imex` must be active on every node with an identical + `/etc/nvidia-imex/nodes_config.cfg` listing all node IPs. Verify: + + ```bash + sudo systemctl status nvidia-imex + sudo cat /etc/nvidia-imex/nodes_config.cfg + ls /dev/nvidia-caps-imex-channels/ # channel0 must exist + ``` + +GB200 runtime env (export on every node before launching the tests): + +```bash +export NCCL_IB_DISABLE=1 # mscclpp owns IB +export NCCL_MNNVL_ENABLE=0 # Azure GB200 is NOT one MNNVL fabric across nodes +export MSCCLPP_HCA_DEVICES=mlx5_0,mlx5_1,mlx5_2,mlx5_3 # avoid mlx5_bond_0 (PORT_DOWN) +# Bootstrap NIC (only if auto-detect picks wrong) — on Azure GB200: +export NCCL_SOCKET_IFNAME=enP22p1s0f1 +export MSCCLPP_SOCKET_IFNAME=$NCCL_SOCKET_IFNAME +export GLOO_SOCKET_IFNAME=$NCCL_SOCKET_IFNAME +``` + +Runtime knobs (env vars, exposed by `test_intranode_multirank.py` / +`test_internode_multirank.py` — defaults below are the test-script +defaults, **not** the `ep.Config(...)` constructor defaults): + +| Variable | Maps to (`ep.Config` field) | Default | Notes | +|---------------------------|--------------------------------------|--------:|-----------------------------------------| +| `MSCCLPP_EP_NUM_SMS` | `num_sms` | `152` | `20` on the intranode test. Try `64` on GB200 intranode for `dispatch` BW. | +| `MSCCLPP_EP_NVL_SEND` | `num_max_nvl_chunked_send_tokens` | `8` | Must be `<` `MSCCLPP_EP_NVL_RECV`. | +| `MSCCLPP_EP_NVL_RECV` | `num_max_nvl_chunked_recv_tokens` | `256` | Scales NVL ring buffer linearly. | +| `MSCCLPP_EP_RDMA_SEND` | `num_max_rdma_chunked_send_tokens` | `16` | Internode only. | +| `MSCCLPP_EP_RDMA_RECV` | `num_max_rdma_chunked_recv_tokens` | `128` | Scale **down** as `num_rdma_ranks` grows (4n→128, 8n→64, 16n→32) to keep the RDMA buffer under the 2 GiB `INT_MAX` limit. | + +Validated 16-node (64-rank) configs on Azure GB200 NVL72 (HIDDEN=7168, +tokens=4096, experts=256, topk=8): + +- HT internode: `NVL_SEND=8 NVL_RECV=256 RDMA_SEND=8 RDMA_RECV=32` → + dispatch ~**2 006 GB/s** agg, combine ~**2 011 GB/s** agg. +- LL internode: defaults → dispatch ~**16 817 GB/s** agg + (262 GB/s per rank), combine ~**21 148 GB/s** agg. + ## Layout ``` src/ext/ep/ ├── CMakeLists.txt — builds mscclpp_ep_cpp (Torch + pybind11) +├── README.md — this file ├── buffer.hpp / buffer.cc — host-side Buffer, sync(), dispatch/combine ├── config.hpp / event.hpp — Config, EventHandle ├── bindings.cpp — PYBIND11_MODULE definition @@ -114,24 +245,57 @@ src/ext/ep/ ├── runtime.cu — intranode::barrier launcher ├── intranode_kernel.cu — intranode dispatch/combine kernels ├── internode.cu — internode HT dispatch/combine + layout - └── internode_ll.cu — internode LL dispatch/combine (structural) + │ (incl. NVLS multimem fast path for GB200) + └── internode_ll.cu — internode LL dispatch/combine python/mscclpp/ext/ep/ ├── __init__.py — reexports Buffer / Config / EventHandle └── buffer.py — torch.distributed-aware frontend test/python/ext/ep/ -├── test_ep_smoke.py — size-hint + rejection smoke test -├── test_intranode_multirank.py — NVLink HT dispatch+combine, 8 ranks -├── test_internode_multirank.py — HT dispatch+combine, 16 ranks (2×8) -└── test_low_latency_multirank.py — LL dispatch+combine (intra-node + cross-node) +├── test_intranode_multirank.py — intranode HT dispatch+combine +├── test_internode_multirank.py — internode HT dispatch+combine +└── test_low_latency_multirank.py — LL dispatch+combine ``` ## Running the tests +### Test prerequisites + +The Python tests are launched through `torchrun` / `mpirun` and require +PyTorch + a few support packages in the active environment. A minimal +install (matches the GB200 reference setup): + +```bash +# Conda env (any Python >= 3.10). Use the appropriate Miniconda variant +# for the host arch (`aarch64` shown; use `x86_64` on x86 clusters). +wget -O /tmp/Miniconda3-latest-Linux-aarch64.sh \ + https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh +bash /tmp/Miniconda3-latest-Linux-aarch64.sh -b -p $HOME/miniconda3 +source $HOME/miniconda3/etc/profile.d/conda.sh +conda create -n torch python=3.14 -y +conda activate torch + +# Runtime libs used by the tests / launcher. +conda install -c conda-forge -y cupy mpi4py pybind11 blake3 sortedcontainers + +# PyTorch (pulls a matching cuda-toolkit + NCCL). +pip3 install torch + +# mscclpp build deps (used by `pip install .` of this repo). +pip install scikit-build-core nanobind setuptools_scm +``` + +Then build the EP extension (see [Build](#build)) — `pip install .` from +the repo root installs `mscclpp_ep_cpp.so` into the active env so the +test scripts can `from mscclpp.ext import ep`. + Intranode (single node, 8 GPUs) — HT: ```bash +MSCCLPP_EP_BENCH=1 \ +MSCCLPP_EP_BENCH_TOKENS=4096 MSCCLPP_EP_BENCH_HIDDEN=7168 \ +MSCCLPP_EP_BENCH_EXPERTS=256 MSCCLPP_EP_BENCH_TOPK=8 \ torchrun --nnodes=1 --nproc_per_node=8 \ test/python/ext/ep/test_intranode_multirank.py ``` @@ -139,6 +303,9 @@ torchrun --nnodes=1 --nproc_per_node=8 \ Intranode LL (single node, 8 GPUs): ```bash +MSCCLPP_EP_BENCH=1 \ +MSCCLPP_EP_BENCH_TOKENS=128 MSCCLPP_EP_BENCH_HIDDEN=7168 \ +MSCCLPP_EP_BENCH_EXPERTS=256 MSCCLPP_EP_BENCH_TOPK=8 \ torchrun --nnodes=1 --nproc_per_node=8 \ test/python/ext/ep/test_low_latency_multirank.py ``` @@ -147,27 +314,39 @@ Internode HT (2 nodes × 8 GPUs), torchrun: ```bash # node 0 (master) -NCCL_SOCKET_IFNAME=eth0 MSCCLPP_SOCKET_IFNAME=eth0 GLOO_SOCKET_IFNAME=eth0 \ +MSCCLPP_EP_BENCH=1 \ +MSCCLPP_EP_BENCH_TOKENS=4096 MSCCLPP_EP_BENCH_HIDDEN=7168 \ +MSCCLPP_EP_BENCH_EXPERTS=256 MSCCLPP_EP_BENCH_TOPK=8 \ torchrun --nnodes=2 --nproc_per_node=8 --node_rank=0 \ --master_addr= --master_port=29600 \ test/python/ext/ep/test_internode_multirank.py # node 1 (worker) -NCCL_SOCKET_IFNAME=eth0 MSCCLPP_SOCKET_IFNAME=eth0 GLOO_SOCKET_IFNAME=eth0 \ +MSCCLPP_EP_BENCH=1 \ +MSCCLPP_EP_BENCH_TOKENS=4096 MSCCLPP_EP_BENCH_HIDDEN=7168 \ +MSCCLPP_EP_BENCH_EXPERTS=256 MSCCLPP_EP_BENCH_TOPK=8 \ torchrun --nnodes=2 --nproc_per_node=8 --node_rank=1 \ --master_addr= --master_port=29600 \ test/python/ext/ep/test_internode_multirank.py ``` -Internode HT/LL via mpirun (matches the NCCL-EP launch convention with -NUMA binding and an explicit topology file): +If the bootstrap NIC is mis-detected (e.g. multi-homed hosts), pin +it explicitly: + +```bash +export NCCL_SOCKET_IFNAME= +export MSCCLPP_SOCKET_IFNAME=$NCCL_SOCKET_IFNAME +export GLOO_SOCKET_IFNAME=$NCCL_SOCKET_IFNAME +``` + +Internode HT via mpirun (NCCL-EP convention with NUMA binding): ```bash mpirun -np 16 --allow-run-as-root --hostfile \ - --mca pml ob1 --mca btl tcp,vader,self --mca btl_tcp_if_include eth0 \ --bind-to numa \ - -x NCCL_SOCKET_IFNAME=eth0 -x MSCCLPP_SOCKET_IFNAME=eth0 -x GLOO_SOCKET_IFNAME=eth0 \ - -x NCCL_IB_DISABLE=0 -x NCCL_TOPO_FILE= \ + -x MSCCLPP_EP_BENCH=1 \ + -x MSCCLPP_EP_BENCH_TOKENS=4096 -x MSCCLPP_EP_BENCH_HIDDEN=7168 \ + -x MSCCLPP_EP_BENCH_EXPERTS=256 -x MSCCLPP_EP_BENCH_TOPK=8 \ -x MASTER_ADDR= -x MASTER_PORT=29600 \ bash -c 'export RANK=$OMPI_COMM_WORLD_RANK \ WORLD_SIZE=$OMPI_COMM_WORLD_SIZE \ @@ -175,6 +354,34 @@ mpirun -np 16 --allow-run-as-root --hostfile \ exec python3 test/python/ext/ep/test_internode_multirank.py' ``` +Internode LL via mpirun — same launch wrapper, swap the test script: + +```bash +mpirun -np 16 --allow-run-as-root --hostfile \ + --bind-to numa \ + -x MSCCLPP_EP_BENCH=1 \ + -x MSCCLPP_EP_BENCH_TOKENS=128 -x MSCCLPP_EP_BENCH_HIDDEN=7168 \ + -x MSCCLPP_EP_BENCH_EXPERTS=256 -x MSCCLPP_EP_BENCH_TOPK=8 \ + -x MASTER_ADDR= -x MASTER_PORT=29600 \ + bash -c 'export RANK=$OMPI_COMM_WORLD_RANK \ + WORLD_SIZE=$OMPI_COMM_WORLD_SIZE \ + LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK; \ + exec python3 test/python/ext/ep/test_low_latency_multirank.py' +``` + +Add `-x NCCL_SOCKET_IFNAME= -x MSCCLPP_SOCKET_IFNAME= +-x GLOO_SOCKET_IFNAME=` to the `mpirun` lines above only if the +default bootstrap NIC is wrong. `NCCL_IB_DISABLE` / `NCCL_TOPO_FILE` +are not required — EP traffic goes through mscclpp, not NCCL. + +If Open MPI's own bootstrap misbehaves (e.g. UCX is mis-configured or +the host is multi-homed), force its control channel onto plain TCP +over the local management NIC by adding +`--mca pml ob1 --mca btl tcp,vader,self --mca btl_tcp_if_include ` +to `mpirun`. `` is the management NIC (e.g. `enP22p1s0f1` on +Azure GB200, `eno1`/`bond0` elsewhere) — it is **not** `eth0` on +GB200. + ### Benchmark mode All three multirank tests double as micro-benchmarks when @@ -216,6 +423,8 @@ Env knobs: `memory_channel_handles_device_ptr`, with port channels ordered by peer rank (so kernel-side indexing by `peer_rank` is consistent). - [x] Validated on 2×H100×8 with `test_internode_multirank.py`. +- [x] Validated on 16×GB200 NVL72 (64 ranks, Azure CX-7 RoCE) — see + Phase 5. ### Phase 3 — Low-Latency (RDMA + CUDA-IPC) — DONE @@ -247,3 +456,34 @@ reference dispatch/combine. - [x] `test_low_latency_multirank.py` — LL round-trip validated intra-node (8 ranks) and cross-node (2×H100×8). - [x] In-tree micro-benchmark harness (`MSCCLPP_EP_BENCH=1`) reporting min/avg/max + BW@avg, aligned with NCCL-EP `ep_bench`. - [ ] Throughput benchmarks against DeepEP upstream. + +### Phase 5 — Azure GB200 NVL72 port — DONE + +GB200 (4 GPUs / NUMA host, CX-7 RoCE) needed three independent fixes +on top of the H100 baseline: + +- [x] `NUM_MAX_NVL_PEERS` made CMake-configurable + (`-DMSCCLPP_EP_NUM_MAX_NVL_PEERS=4`); runtime + `MSCCLPP_EP_LOCAL_WORLD_SIZE` defaults to the compile-time value. +- [x] Portability fixes for older toolchains/arches: multimem PTX + guarded by `__CUDA_ARCH__ >= 900`, `cuCtxCreate` 4-arg form + version-guarded, `NUM_TIMEOUT_CYCLES` restored, CMake install + destination dual-mode. +- [x] **LL bypass for broken CX-7 IB atomics** (Proposal A): + `Buffer::rdma_buffer_ptr` allocated via + `mscclpp::detail::gpuCallocPhysical` (POSIX_FD | FABRIC handles); + LL IPC fast-path gate lifted from `num_rdma_ranks==1` to + `low_latency_mode`; peer bases resolved through + `RegisteredMemory::data()` so mscclpp `CudaIpc` transport imports + cross-node cuMem fabric handles via `nvidia-imex`. +- [x] **NVLS multimem fast path for HT atomics** (Proposal B): + runtime-gated by `mscclpp::isNvlsSupported()`. Cross-node + `port_channel.signal/wait` + `putWithSignal` replaced by + `multimem.red.add.u64` on NVL72 multicast counters; legacy IB + path retained as fallback. +- [x] Validated on 16 × Azure GB200 NVL72 (64 ranks), HIDDEN=7168, + tokens=4096, experts=256, top-k=8: + - HT: dispatch ~2 006 GB/s agg, combine ~2 011 GB/s agg + (`NVL_SEND=8 NVL_RECV=256 RDMA_SEND=8 RDMA_RECV=32`). + - LL: dispatch ~16 817 GB/s agg (~262 GB/s/rank), + combine ~21 148 GB/s agg (defaults).