diff --git a/CMakeLists.txt b/CMakeLists.txt index ef8b785a..0b62523b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,7 @@ option(MSCCLPP_BUILD_TESTS "Build tests" OFF) option(MSCCLPP_BUILD_PYTHON_BINDINGS "Build Python bindings" ON) option(MSCCLPP_BUILD_EXT_NCCL "Build NCCL interfaces" ON) option(MSCCLPP_BUILD_EXT_COLLECTIVES "Build collective algorithms" ON) +option(MSCCLPP_BUILD_EXT_EP "Build Expert-Parallel (MoE dispatch/combine) extension" OFF) option(MSCCLPP_USE_CUDA "Use NVIDIA/CUDA." OFF) option(MSCCLPP_USE_ROCM "Use AMD/ROCm." OFF) option(MSCCLPP_USE_IB "Use InfiniBand." ON) diff --git a/python/mscclpp/ext/__init__.py b/python/mscclpp/ext/__init__.py index 08a96ecd..4c8aef5b 100644 --- a/python/mscclpp/ext/__init__.py +++ b/python/mscclpp/ext/__init__.py @@ -2,3 +2,9 @@ # Licensed under the MIT license. from .algorithm_collection_builder import * + +try: + from . import ep # noqa: F401 +except ImportError: + # EP extension not built; leave `mscclpp.ext.ep` undefined. + pass diff --git a/python/mscclpp/ext/ep/__init__.py b/python/mscclpp/ext/ep/__init__.py new file mode 100644 index 00000000..1db824e0 --- /dev/null +++ b/python/mscclpp/ext/ep/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""MSCCL++ Expert-Parallel (MoE dispatch/combine) extension. + +See ``src/ext/ep/README.md`` in the repository for migration status. The +``Buffer`` class mirrors :class:`deep_ep.Buffer` and currently supports +intranode (NVLink-only) dispatch/combine. Internode HT and low-latency +paths raise until the NVSHMEM -> MSCCL++ port is completed. +""" + +from .buffer import Buffer, Config, EventHandle # noqa: F401 + +__all__ = ["Buffer", "Config", "EventHandle"] diff --git a/python/mscclpp/ext/ep/buffer.py b/python/mscclpp/ext/ep/buffer.py new file mode 100644 index 00000000..b50538a1 --- /dev/null +++ b/python/mscclpp/ext/ep/buffer.py @@ -0,0 +1,189 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +# Portions adapted from DeepEP (https://github.com/deepseek-ai/DeepEP), +# branch ``chhwang/dev-atomic-add-cleanup``. Licensed under the MIT License. +"""Python frontend for the MSCCL++ Expert-Parallel extension. + +This is a thin wrapper around the pybind11 extension ``mscclpp_ep_cpp``. +The shape of :class:`Buffer` mirrors :class:`deep_ep.Buffer` so existing +DeepEP users can port with minimal changes. + +Current status (see ``src/ext/ep/README.md``): + +* Intranode (NVLink-only) dispatch and combine are fully ported. +* ``get_dispatch_layout`` is ported. +* Internode HT and low-latency methods raise from C++ — they still need + the NVSHMEM/IBGDA -> MSCCL++ PortChannel migration. +""" + +from __future__ import annotations + +import os +from typing import List, Optional, Tuple + +import torch +import torch.distributed as dist + +try: + import mscclpp_ep_cpp as _cpp # type: ignore[import-not-found] +except ImportError as exc: # pragma: no cover + raise ImportError( + "mscclpp_ep_cpp is not available. Build mscclpp with " + "-DMSCCLPP_BUILD_EXT_EP=ON (and ensure PyTorch's CMake prefix is on " + "CMAKE_PREFIX_PATH) or install via `pip install` after the build." + ) from exc + +Config = _cpp.Config +EventHandle = _cpp.EventHandle + + +class Buffer: + """Core expert-parallel (EP) communication buffer. + + Parameters + ---------- + group: + The ``torch.distributed`` process group. Used only for out-of-band + exchange of IPC handles and the MSCCL++ unique id. + num_nvl_bytes: + Size of the NVLink-accessible scratch buffer (shared via CUDA IPC). + num_rdma_bytes: + Size of the RDMA scratch buffer. Must be 0 until internode/LL + support is landed. + low_latency_mode: + Reserved — must be ``False`` until the LL path is ported. + num_qps_per_rank: + Ignored for intranode mode. + """ + + #: Default number of SMs reserved for comms kernels. Matches DeepEP. + num_sms: int = 20 + + def __init__( + self, + group: dist.ProcessGroup, + num_nvl_bytes: int = 0, + num_rdma_bytes: int = 0, + low_latency_mode: bool = False, + num_qps_per_rank: int = 12, + ) -> None: + if low_latency_mode: + raise NotImplementedError( + "mscclpp.ext.ep.Buffer: low-latency mode is not yet ported. " + "Set low_latency_mode=False. See src/ext/ep/README.md for the " + "migration plan." + ) + + self.rank: int = group.rank() + self.group_size: int = group.size() + self.group = group + self.num_nvl_bytes = num_nvl_bytes + self.num_rdma_bytes = num_rdma_bytes + self.low_latency_mode = low_latency_mode + self.num_qps_per_rank = num_qps_per_rank + + self.runtime = _cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode) + + # Exchange device IDs + IPC handles + (for RDMA) the MSCCL++ unique id. + device_ids: List[Optional[int]] = [None] * self.group_size + local_device_id = self.runtime.get_local_device_id() + dist.all_gather_object(device_ids, local_device_id, group) + + ipc_handles: List[Optional[bytes]] = [None] * self.group_size + local_ipc_handle = self.runtime.get_local_ipc_handle() + dist.all_gather_object(ipc_handles, local_ipc_handle, group) + + root_unique_id: Optional[bytes] = None + # RDMA path is guarded above; still plumb the unique-id exchange so + # the code is ready to turn on once internode lands. + if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode: + if num_qps_per_rank <= 0: + raise ValueError("num_qps_per_rank must be > 0 for RDMA") + + if self.rank == 0: + unique_id = self.runtime.create_unique_id() + root_unique_id = unique_id.bytes() + broadcast_list = [root_unique_id] + dist.broadcast_object_list(broadcast_list, src=0, group=group) + root_unique_id = broadcast_list[0] + assert root_unique_id is not None + self.runtime.connect(_cpp.UniqueId.from_bytes(root_unique_id)) + + self.runtime.sync(device_ids, ipc_handles, root_unique_id) + + # ------------------------------------------------------------------ + # Sanity helpers + # ------------------------------------------------------------------ + + def is_available(self) -> bool: + return self.runtime.is_available() + + def is_internode_available(self) -> bool: + return self.runtime.is_internode_available() + + def get_local_device_id(self) -> int: + return self.runtime.get_local_device_id() + + def get_num_rdma_ranks(self) -> int: + return self.runtime.get_num_rdma_ranks() + + def get_rdma_rank(self) -> int: + return self.runtime.get_rdma_rank() + + def get_root_rdma_rank(self, global_: bool) -> int: + return self.runtime.get_root_rdma_rank(global_) + + # ------------------------------------------------------------------ + # Layout / dispatch / combine (thin pass-through wrappers). + # Signatures mirror deep_ep.Buffer so existing test harnesses can reuse. + # ------------------------------------------------------------------ + + def get_dispatch_layout( + self, + topk_idx: torch.Tensor, + num_experts: int, + previous_event: Optional[EventHandle] = None, + async_finish: bool = False, + allocate_on_comm_stream: bool = False, + ): + return self.runtime.get_dispatch_layout( + topk_idx, num_experts, previous_event, async_finish, allocate_on_comm_stream + ) + + def intranode_dispatch(self, *args, **kwargs): + return self.runtime.intranode_dispatch(*args, **kwargs) + + def intranode_combine(self, *args, **kwargs): + return self.runtime.intranode_combine(*args, **kwargs) + + def internode_dispatch(self, *args, **kwargs): + return self.runtime.internode_dispatch(*args, **kwargs) + + def internode_combine(self, *args, **kwargs): + return self.runtime.internode_combine(*args, **kwargs) + + def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None: + self.runtime.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts) + + def low_latency_dispatch(self, *args, **kwargs): + return self.runtime.low_latency_dispatch(*args, **kwargs) + + def low_latency_combine(self, *args, **kwargs): + return self.runtime.low_latency_combine(*args, **kwargs) + + def get_next_low_latency_combine_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int): + return self.runtime.get_next_low_latency_combine_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts) + + def get_local_buffer_tensor(self, dtype: torch.dtype, offset: int = 0, use_rdma_buffer: bool = False) -> torch.Tensor: + return self.runtime.get_local_buffer_tensor(dtype, offset, use_rdma_buffer) + + # ------------------------------------------------------------------ + # Static helpers + # ------------------------------------------------------------------ + + @staticmethod + def get_low_latency_rdma_size_hint( + num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int + ) -> int: + return _cpp.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts) diff --git a/src/ext/CMakeLists.txt b/src/ext/CMakeLists.txt index a9ed8b54..74ea9c56 100644 --- a/src/ext/CMakeLists.txt +++ b/src/ext/CMakeLists.txt @@ -8,3 +8,7 @@ endif() if(MSCCLPP_BUILD_EXT_NCCL) add_subdirectory(nccl) endif() + +if(MSCCLPP_BUILD_EXT_EP) + add_subdirectory(ep) +endif() diff --git a/src/ext/ep/CMakeLists.txt b/src/ext/ep/CMakeLists.txt new file mode 100644 index 00000000..7f45aba1 --- /dev/null +++ b/src/ext/ep/CMakeLists.txt @@ -0,0 +1,147 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +# Builds `mscclpp_ep_cpp`, a pybind11 + PyTorch C++ extension that exposes the +# EP (Mixture-of-Experts dispatch/combine) Buffer to Python. This module is +# separate from the nanobind-based `_mscclpp` because EP carries `torch::Tensor` +# through its ABI (ported verbatim from DeepEP). +# +# Requires: PyTorch with its CMake integration available on CMAKE_PREFIX_PATH. +# Easiest invocation: +# cmake -S . -B build \ +# -DMSCCLPP_BUILD_EXT_EP=ON \ +# -DCMAKE_PREFIX_PATH="$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" + +find_package(Python 3.10 COMPONENTS Interpreter Development.Module REQUIRED) +find_package(Torch QUIET) +if(NOT Torch_FOUND) + message(WARNING + "MSCCLPP_BUILD_EXT_EP=ON but PyTorch CMake package was not found. " + "Set CMAKE_PREFIX_PATH to `$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')` " + "or disable MSCCLPP_BUILD_EXT_EP. Skipping the EP extension.") + return() +endif() + +find_package(pybind11 QUIET CONFIG) +if(NOT pybind11_FOUND) + # PyTorch ships a bundled pybind11 we can fall back on. + get_target_property(_torch_include_dirs torch INTERFACE_INCLUDE_DIRECTORIES) + foreach(d ${_torch_include_dirs}) + if(EXISTS "${d}/pybind11/pybind11.h") + set(PYBIND11_INCLUDE_DIR "${d}") + break() + endif() + endforeach() + if(NOT PYBIND11_INCLUDE_DIR) + message(FATAL_ERROR + "pybind11 not found and not bundled with Torch. Install it (pip install pybind11) " + "or provide -Dpybind11_DIR=...") + endif() +endif() + +file(GLOB_RECURSE EP_SOURCES CONFIGURE_DEPENDS + buffer.cc + internode_stub.cc + bindings.cpp + kernels/*.cu +) + +if(MSCCLPP_USE_ROCM) + # ROCm port of the EP kernels is not supported yet; see src/ext/ep/README.md. + message(WARNING "mscclpp_ep: ROCm build path not implemented, falling back to CXX compile.") + file(GLOB_RECURSE EP_CU_SOURCES kernels/*.cu) + set_source_files_properties(${EP_CU_SOURCES} PROPERTIES LANGUAGE CXX) +endif() + +# Build as a Python extension module (shared object with Python ABI suffix). +# The name `mscclpp_ep_cpp` matches the `TORCH_EXTENSION_NAME` hard-coded in +# `bindings.cpp` / `buffer.hpp`. +Python_add_library(mscclpp_ep_cpp MODULE ${EP_SOURCES}) + +target_compile_definitions(mscclpp_ep_cpp PRIVATE TORCH_EXTENSION_NAME=mscclpp_ep_cpp) +target_include_directories(mscclpp_ep_cpp PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src/core/include + ${PROJECT_SOURCE_DIR}/src/ext/include + ${GPU_INCLUDE_DIRS} +) +if(pybind11_FOUND) + target_link_libraries(mscclpp_ep_cpp PRIVATE pybind11::module) +else() + target_include_directories(mscclpp_ep_cpp SYSTEM PRIVATE ${PYBIND11_INCLUDE_DIR}) +endif() + +target_link_libraries(mscclpp_ep_cpp PRIVATE ${TORCH_LIBRARIES}) +# Torch's CUDA interop library (ATen CUDAStream helpers used in buffer.cc). +if(TARGET torch::torch) + target_link_libraries(mscclpp_ep_cpp PRIVATE torch::torch) +endif() +# libtorch_python contains pybind11 bindings for torch::Tensor / torch::dtype +# (symbols like THPDtypeType). It is not listed in TORCH_LIBRARIES. +find_library(TORCH_PYTHON_LIBRARY torch_python + HINTS "${TORCH_INSTALL_PREFIX}/lib") +if(TORCH_PYTHON_LIBRARY) + target_link_libraries(mscclpp_ep_cpp PRIVATE ${TORCH_PYTHON_LIBRARY}) +else() + message(WARNING "libtorch_python not found; `import mscclpp_ep_cpp` may fail at runtime.") +endif() +target_link_libraries(mscclpp_ep_cpp PRIVATE mscclpp ${GPU_LIBRARIES} Threads::Threads) + +set_target_properties(mscclpp_ep_cpp PROPERTIES + PREFIX "" + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + CXX_VISIBILITY_PRESET default + INSTALL_RPATH "\$ORIGIN/../lib" +) + +if(MSCCLPP_USE_CUDA) + target_compile_definitions(mscclpp_ep_cpp PRIVATE MSCCLPP_USE_CUDA) +elseif(MSCCLPP_USE_ROCM) + target_compile_definitions(mscclpp_ep_cpp PRIVATE MSCCLPP_USE_ROCM) +endif() + +install(TARGETS mscclpp_ep_cpp + LIBRARY DESTINATION ${INSTALL_PREFIX}/lib) +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc *.cpp *.cu) + +if(MSCCLPP_USE_ROCM) + file(GLOB_RECURSE CU_SOURCES *.cu) + set_source_files_properties(${CU_SOURCES} PROPERTIES LANGUAGE CXX) +endif() + +add_library(mscclpp_ep SHARED ${SOURCES}) + +target_include_directories(mscclpp_ep PRIVATE + include + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src/core/include + ${PROJECT_SOURCE_DIR}/src/ext/include + ${GPU_INCLUDE_DIRS} +) + +target_link_libraries(mscclpp_ep PUBLIC mscclpp) +target_link_libraries(mscclpp_ep PRIVATE ${GPU_LIBRARIES} Threads::Threads) + +set_target_properties(mscclpp_ep PROPERTIES + LINKER_LANGUAGE CXX + POSITION_INDEPENDENT_CODE 1 + VERSION ${MSCCLPP_VERSION} + SOVERSION ${MSCCLPP_SOVERSION}) + +if(MSCCLPP_USE_CUDA) + target_compile_definitions(mscclpp_ep PRIVATE MSCCLPP_USE_CUDA) +elseif(MSCCLPP_USE_ROCM) + target_compile_definitions(mscclpp_ep PRIVATE MSCCLPP_USE_ROCM) + foreach(arch ${MSCCLPP_GPU_ARCHS}) + target_compile_options(mscclpp_ep PRIVATE --offload-arch=${arch}) + endforeach() +endif() + +install(TARGETS mscclpp_ep + LIBRARY DESTINATION ${INSTALL_PREFIX}/lib) diff --git a/src/ext/ep/README.md b/src/ext/ep/README.md new file mode 100644 index 00000000..46d1d8df --- /dev/null +++ b/src/ext/ep/README.md @@ -0,0 +1,226 @@ +# MSCCL++ Expert-Parallel (EP) extension + +A port of DeepEP's MoE `dispatch`/`combine` primitives into MSCCL++, targeting: + +- **High-Throughput (HT) mode** from [DeepEP](https://github.com/deepseek-ai/DeepEP), + branch `chhwang/dev-atomic-add-cleanup` — which already swaps NVSHMEM for + `mscclpp::PortChannel`/`MemoryChannel`. +- **Low-Latency (LL) mode** from [`nccl/contrib/nccl_ep`](https://github.com/NVIDIA/nccl/tree/master/contrib/nccl_ep), + which implements pure-RDMA dispatch/combine on top of the NCCL Device API. + +## Status + +| Feature | Status | +|------------------------------------|---------------------------------| +| `Buffer` construction + IPC + sync | ✅ ported (NVLink + RDMA) | +| `get_dispatch_layout` | ✅ ported | +| `intranode_dispatch` (NVLink) | ✅ ported | +| `intranode_combine` (NVLink) | ✅ ported | +| `internode_dispatch` (NVLink+RDMA) | ✅ ported (pending H100 test) | +| `internode_combine` (NVLink+RDMA) | ✅ ported (pending H100 test) | +| `low_latency_dispatch` (pure RDMA) | ❌ stub | +| `low_latency_combine` (pure RDMA) | ❌ stub | +| `Connection::atomicAdd` API | ✅ cherry-picked into mscclpp | +| Python frontend `mscclpp.ext.ep` | ✅ wraps HT paths | +| pybind11 module `mscclpp_ep_cpp` | ✅ builds conditionally | + +Internode HT is code-complete but unverified on real hardware — the +`sync()` path replaces DeepEP's NVSHMEM symmetric-heap allocation with +`cudaMalloc` + `bootstrap->barrier()`, and the kernels use the new +`PortChannelDeviceHandle::atomicAdd` instead of the old raw-trigger +pattern. The low-latency path is the only remaining stub. + +## Build + +The extension is **off by default** and requires PyTorch's CMake package: + +```bash +TORCH_CMAKE=$(python -c 'import torch; print(torch.utils.cmake_prefix_path)') +cmake -S . -B build \ + -DMSCCLPP_BUILD_EXT_EP=ON \ + -DCMAKE_PREFIX_PATH="${TORCH_CMAKE}" +cmake --build build -j +``` + +This produces `mscclpp_ep_cpp.so` — a pybind11 PyTorch extension module. +The Python frontend picks it up automatically: + +```python +from mscclpp.ext import ep +buf = ep.Buffer(group, num_nvl_bytes=..., num_rdma_bytes=0) +``` + +## Layout + +``` +src/ext/ep/ +├── CMakeLists.txt — builds mscclpp_ep_cpp (Torch + pybind11) +├── buffer.hpp / buffer.cc — host-side Buffer, sync(), dispatch/combine +├── config.hpp / event.hpp — Config, EventHandle +├── bindings.cpp — PYBIND11_MODULE definition +├── internode_stub.cc — stubs for not-yet-ported LL launchers +└── kernels/ + ├── api.cuh — host-callable kernel prototypes + ├── configs.cuh — compile-time constants (GPU-only) + ├── buffer.cuh — Buffer/AsymBuffer/SymBuffer helpers + ├── exception.cuh — EP_HOST/DEVICE_ASSERT + CUDA_CHECK + ├── launch.cuh — SETUP_LAUNCH_CONFIG / SWITCH_* macros + ├── utils.cuh — device inline helpers + ├── runtime.cu — intranode::barrier launcher + ├── intranode_kernel.cu — notify_dispatch / dispatch / combine kernels + └── internode_layout.cu — get_dispatch_layout (CPU-safe subset) + +python/mscclpp/ext/ep/ +├── __init__.py — reexports Buffer / Config / EventHandle +└── buffer.py — torch.distributed-aware frontend + +test/python/ext/ep/ +└── test_ep_smoke.py — size-hint + rejection smoke test +``` + +## Migration plan + +### Phase 1 — DONE + +- [x] Copy DeepEP kernel headers (configs / buffer / utils / launch / exception). +- [x] Port intranode kernels + runtime (NVLink only). +- [x] Port `get_dispatch_layout` (host-safe subset of internode kernels). +- [x] Port host Buffer: ctor, sync, get_dispatch_layout, intranode dispatch/combine. +- [x] pybind11 `mscclpp_ep_cpp` module + Python frontend. + +### Phase 2 — internode HT (NVLink + RDMA) + +Port the rest of `DeepEP/csrc/kernels/internode.cu` (`notify_dispatch`, +`dispatch`, `cached_notify`, `combine`). Because we are starting from the +`chhwang/dev-atomic-add-cleanup` branch, the NVSHMEM -> MSCCL++ substitution +is already done upstream — just copy the kernel bodies and wire them through +`api.cuh`. The launchers need `PortChannelDeviceHandle*` / +`MemoryChannelDeviceHandle*` arguments that `Buffer::sync()` already builds +(see the `num_rdma_bytes > 0` branch — currently throws, but the code +populating `port_channel_handles_device_ptr` and +`memory_channel_handles_device_ptr` is ready). Finally replace the stubs in +`buffer.cc` (`internode_dispatch`, `internode_combine`) with the real bodies +from DeepEP. + +### Phase 3 — Low-Latency (pure RDMA) + +Port `DeepEP/csrc/kernels/internode_ll.cu` and cross-reference +`nccl/contrib/nccl_ep/device/low_latency.cu`. The nccl_ep reference is +modular (see `device_primitives.cuh`, `hybrid_ep.cuh`) and uses NCCL Device +API; the translation table is: + +| nccl_ep / DeepEP primitive | MSCCL++ replacement | +|-----------------------------------------|-------------------------------------------------| +| `nvshmemi_ibgda_put_nbi_warp` | `PortChannelDeviceHandle::put` + `signal` | +| `nvshmem_signal_wait_until` | `PortChannelDeviceHandle::wait` | +| `ncclGinPutSignal` | same as above | +| `ncclGinWaitSignal` | `PortChannelDeviceHandle::wait` | +| `ncclGetPeerPointer` / IPC | offset into `buffer_ptrs_gpu[peer]` | +| `ncclTeamLsa` locality check | per-rank `rank / NUM_MAX_NVL_PEERS` comparison | +| NVSHMEM symmetric heap | `cudaMalloc` + `ProxyService::addMemory` | +| NVSHMEM barrier | `bootstrap->barrier()` or `intranode::barrier` | + +Finally fill in `buffer.cc::low_latency_dispatch` / `low_latency_combine` +from the DeepEP bodies (already translated on the `chhwang/...` branch). + +### Phase 4 — Validation + +- Port `DeepEP/tests/test_{intranode,internode,low_latency}.py` into + `test/python/ext/ep/`. +- Run on the same H100/H800 reference rig DeepEP uses; compare throughput. +# MSCCL++ Expert-Parallel (EP) extension — migration plan + +This directory is a **scaffolding-only** port of the Mixture-of-Experts (MoE) +`dispatch` / `combine` primitives from: + +- **High-Throughput (HT) mode** — [DeepEP](https://github.com/deepseek-ai/DeepEP), + branch `chhwang/dev-atomic-add-cleanup`. That branch has already replaced + NVSHMEM / IBGDA primitives with `mscclpp::PortChannel` and + `mscclpp::MemoryChannel`, so the port is largely mechanical. +- **Low-Latency (LL) mode** — [`nccl/contrib/nccl_ep`](https://github.com/NVIDIA/nccl), + which implements a pure-RDMA dispatch/combine on top of the NCCL Device API + (GIN put/signal + LSA load/store). The kernels need to be re-expressed in + terms of MSCCL++ device handles. + +## Layout + +| Path | Purpose | +|------|---------| +| [`include/mscclpp/ext/ep/config.hpp`](../../../include/mscclpp/ext/ep/config.hpp) | Public host-side config + size hints (`EpConfig`). | +| [`include/mscclpp/ext/ep/event.hpp`](../../../include/mscclpp/ext/ep/event.hpp) | RAII wrapper around `cudaEvent_t`. | +| [`include/mscclpp/ext/ep/buffer.hpp`](../../../include/mscclpp/ext/ep/buffer.hpp) | Public `Buffer` class; dispatch/combine entry points. | +| [`include/mscclpp/ext/ep/api.hpp`](../../../include/mscclpp/ext/ep/api.hpp) | Umbrella include. | +| [`src/ext/ep/buffer.cc`](buffer.cc) | Host-side orchestration. Constructor + proxy service wired up; `sync()` / kernel stubs `TODO`. | +| [`src/ext/ep/config.cc`](config.cc) | `EpConfig` method bodies. | +| [`src/ext/ep/event.cc`](event.cc) | `EventHandle` implementation. | +| [`src/ext/ep/intranode.cu`](intranode.cu) | **STUB** — HT NVLink-only dispatch/combine. | +| [`src/ext/ep/internode.cu`](internode.cu) | **STUB** — HT NVLink+RDMA dispatch/combine. | +| [`src/ext/ep/internode_ll.cu`](internode_ll.cu) | **STUB** — LL pure-RDMA dispatch/combine. | +| [`src/ext/ep/kernels/api.cuh`](kernels/api.cuh) | Private kernel-facing API (prototypes only for now). | +| [`src/ext/ep/kernels/exception.cuh`](kernels/exception.cuh) | `EP_HOST_ASSERT` / `EP_DEVICE_ASSERT` / `EP_CUDA_CHECK`. | +| [`python/csrc/ext/ep/ep_py.cpp`](../../../python/csrc/ext/ep/ep_py.cpp) | nanobind bindings (submodule `mscclpp._mscclpp.ep`). | +| [`python/mscclpp/ext/ep/`](../../../python/mscclpp/ext/ep/) | Python frontend (`ep.Buffer`). | +| [`test/python/ext/ep/test_ep_skeleton.py`](../../../test/python/ext/ep/test_ep_skeleton.py) | Unit test placeholder. | + +## Build + +The extension is **off by default**. Enable it with: + +```bash +cmake -S . -B build -DMSCCLPP_BUILD_EXT_EP=ON +cmake --build build -j +``` + +This produces `libmscclpp_ep.so` and, when Python bindings are built, exposes +`mscclpp._mscclpp.ep` and `mscclpp.ext.ep`. + +## Migration plan (in order) + +1. **HT intranode.** Port `DeepEP/csrc/kernels/intranode.cu` into + [`intranode.cu`](intranode.cu). All communication is via peer IPC pointers, + so only `include` paths and `torch::Tensor` -> `TensorRef` marshalling need + to change. Flesh out `Buffer::sync()` so that + `nvlBufferPeers_[peer] = cudaIpcOpenMemHandle(...)` is populated and the + table is uploaded to `nvlBufferPeersDevice_`. +2. **HT internode.** Port `DeepEP/csrc/kernels/internode.cu` into + [`internode.cu`](internode.cu). Most of the heavy lifting (NVSHMEM -> + MSCCL++) is already done on the DeepEP `chhwang/dev-atomic-add-cleanup` + branch; copy the kernel bodies as-is and add the launchers. Ensure the + custom trigger type `0x0` atomicAdd path in `EpProxyService` (see + [`buffer.cc`](buffer.cc)) is in place. +3. **LL mode.** Port from `nccl/contrib/nccl_ep/device/low_latency.cu` (or + DeepEP `internode_ll.cu`) into [`internode_ll.cu`](internode_ll.cu). The + translation table lives in the file header; the critical substitution is + `ncclGinPutSignal` / `nvshmemi_ibgda_*` -> `PortChannelDeviceHandle::put` + + `signal` + `wait`, and `ncclGetPeerPointer` -> the `nvlBufferPeersDevice_` + offset table. +4. **TensorRef marshalling.** Extend [`ep_py.cpp`](../../../python/csrc/ext/ep/ep_py.cpp) + to accept DLPack / `torch.Tensor` for the dispatch/combine entry points. + The `TensorRef` type in [`buffer.hpp`](../../../include/mscclpp/ext/ep/buffer.hpp) + is intentionally Torch-free so the C++ core can be reused from + non-PyTorch callers. +5. **Tests.** Grow + [`test/python/ext/ep/`](../../../test/python/ext/ep/) by porting the + scenarios from `DeepEP/tests/test_{intranode,internode,low_latency}.py`. + +## API mapping cheatsheet + +| DeepEP / nccl_ep primitive | MSCCL++ replacement | +|----------------------------------------------------|-----------------------------------------------------------| +| `nvshmemi_ibgda_put_nbi_warp` | `PortChannelDeviceHandle::put` + `signal` | +| `nvshmem_signal_wait_until` | `PortChannelDeviceHandle::wait` | +| `ncclGinPutSignal` | `PortChannelDeviceHandle::put` + `signal` | +| `ncclGinWaitSignal` | `PortChannelDeviceHandle::wait` | +| `ncclGetPeerPointer` / `cudaIpcOpenMemHandle` | `Buffer::nvlBufferPeersDevice_[peer]` + byte offset | +| `ncclTeamLsa` locality check | `Buffer::numNvlRanks_` per-rdma-rank group | +| NVSHMEM symmetric heap allocation | `cudaMalloc` + proxy-registered memory (`ProxyService`) | +| NVSHMEM barrier | `bootstrap_->barrier()` or `intranode::barrier` kernel | + +## Status + +- Headers, CMake targets, Python bindings, and the frontend compile (build + verification has not been run in this session). +- All dispatch/combine entry points throw from C++ or raise `NotImplementedError` + from Python. Constructor, proxy-service startup, and buffer-size hints are + real; `sync()` only flips the `available_` flag and does **not** yet open + peer IPC handles or build MSCCL++ connections. diff --git a/src/ext/ep/bindings.cpp b/src/ext/ep/bindings.cpp new file mode 100644 index 00000000..0a7018c8 --- /dev/null +++ b/src/ext/ep/bindings.cpp @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// +// Portions adapted from DeepEP (https://github.com/deepseek-ai/DeepEP) +// branch `chhwang/dev-atomic-add-cleanup`. Licensed under the MIT License. +// +// pybind11 module definition for the MSCCL++ EP extension. Mirrors +// DeepEP's `PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)` so call sites port +// with minimal changes. + +#include +#include +#include +#include + +#include "buffer.hpp" +#include "config.hpp" + +namespace py = pybind11; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "MSCCL++ Expert-Parallel (MoE dispatch/combine) extension"; + + py::class_(m, "Config") + .def(py::init(), py::arg("num_sms") = 20, + py::arg("num_max_nvl_chunked_send_tokens") = 6, py::arg("num_max_nvl_chunked_recv_tokens") = 256, + py::arg("num_max_rdma_chunked_send_tokens") = 6, py::arg("num_max_rdma_chunked_recv_tokens") = 256) + .def("get_nvl_buffer_size_hint", &mscclpp::ep::Config::get_nvl_buffer_size_hint) + .def("get_rdma_buffer_size_hint", &mscclpp::ep::Config::get_rdma_buffer_size_hint); + + m.def("get_low_latency_rdma_size_hint", &mscclpp::ep::get_low_latency_rdma_size_hint); + + py::class_(m, "EventHandle") + .def(py::init<>()) + .def("current_stream_wait", &mscclpp::ep::EventHandle::current_stream_wait); + + // NOTE: `mscclpp::UniqueId` is the bootstrap id used for connecting the + // proxy service. We expose it as an opaque bytes-like object so Python can + // all-gather it across the user's process group. + py::class_(m, "UniqueId") + .def(py::init<>()) + .def("bytes", [](const mscclpp::UniqueId& self) { + return py::bytes(reinterpret_cast(self.data()), self.size()); + }) + .def_static("from_bytes", [](py::bytes data) { + auto s = std::string(data); + mscclpp::UniqueId uid; + if (s.size() != uid.size()) { + throw std::runtime_error("mscclpp.ep.UniqueId.from_bytes: size mismatch"); + } + std::memcpy(uid.data(), s.data(), s.size()); + return uid; + }); + + py::class_(m, "Buffer") + .def(py::init(), py::arg("rank"), py::arg("num_ranks"), + py::arg("num_nvl_bytes"), py::arg("num_rdma_bytes"), py::arg("low_latency_mode")) + .def("is_available", &mscclpp::ep::Buffer::is_available) + .def("is_internode_available", &mscclpp::ep::Buffer::is_internode_available) + .def("get_num_rdma_ranks", &mscclpp::ep::Buffer::get_num_rdma_ranks) + .def("get_rdma_rank", &mscclpp::ep::Buffer::get_rdma_rank) + .def("get_root_rdma_rank", &mscclpp::ep::Buffer::get_root_rdma_rank) + .def("get_local_device_id", &mscclpp::ep::Buffer::get_local_device_id) + .def("get_local_ipc_handle", &mscclpp::ep::Buffer::get_local_ipc_handle) + .def("get_local_nvshmem_unique_id", &mscclpp::ep::Buffer::get_local_nvshmem_unique_id) + .def("get_local_buffer_tensor", &mscclpp::ep::Buffer::get_local_buffer_tensor) + .def("create_unique_id", &mscclpp::ep::Buffer::create_unique_id) + .def("connect", &mscclpp::ep::Buffer::connect) + .def("sync", &mscclpp::ep::Buffer::sync) + .def("get_dispatch_layout", &mscclpp::ep::Buffer::get_dispatch_layout) + .def("intranode_dispatch", &mscclpp::ep::Buffer::intranode_dispatch) + .def("intranode_combine", &mscclpp::ep::Buffer::intranode_combine) + .def("internode_dispatch", &mscclpp::ep::Buffer::internode_dispatch) + .def("internode_combine", &mscclpp::ep::Buffer::internode_combine) + .def("clean_low_latency_buffer", &mscclpp::ep::Buffer::clean_low_latency_buffer) + .def("low_latency_dispatch", &mscclpp::ep::Buffer::low_latency_dispatch) + .def("low_latency_combine", &mscclpp::ep::Buffer::low_latency_combine) + .def("get_next_low_latency_combine_buffer", &mscclpp::ep::Buffer::get_next_low_latency_combine_buffer); +} diff --git a/src/ext/ep/buffer.cc b/src/ext/ep/buffer.cc new file mode 100644 index 00000000..c9bf5c04 --- /dev/null +++ b/src/ext/ep/buffer.cc @@ -0,0 +1,1150 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "buffer.hpp" +#include "kernels/api.cuh" +#include "kernels/configs.cuh" + +namespace mscclpp { namespace ep { + +// Upstream MSCCL++ now exposes `Connection::atomicAdd` and +// `PortChannelDeviceHandle::atomicAdd` natively (see commit "atomic add" +// on branch chhwang/new-atomic-add, merged into this tree). The stock +// `ProxyService` recognises `ChannelTrigger.type == 0` as an atomic-add +// request, so no subclass or private-member access is required anymore. +using EPProxyService = mscclpp::ProxyService; + +Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode): + rank(rank), num_ranks(num_ranks), + num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes), + low_latency_mode(low_latency_mode), + comm_stream(at::cuda::getStreamFromPool(true)), + bootstrap(std::make_shared(rank, num_ranks)), + proxy_service(std::make_shared()) { + // Task fifo memory + int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS; + int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS; + int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS; + + // Common checks + EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (num_nvl_bytes <= std::numeric_limits::max() or num_rdma_bytes == 0)); + EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (low_latency_mode or num_rdma_bytes <= std::numeric_limits::max())); + EP_HOST_ASSERT(0 <= rank and rank < num_ranks and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode)); + EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); + if (num_rdma_bytes > 0) + EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS or low_latency_mode); + + // Get ranks + CUDA_CHECK(cudaGetDevice(&device_id)); + rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); + + // Get device info + cudaDeviceProp device_prop = {}; + CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); + + if (num_nvl_bytes > 0) { + // Local IPC: alloc local memory and set local IPC handle + CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes)); + CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); + buffer_ptrs_gpu = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes); + + // Set task fifo + EP_HOST_ASSERT(NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0); + task_fifo_ptrs[nvl_rank] = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); + task_fifo_ptrs_gpu = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes + buffer_ptr_bytes); + + // No need to synchronize, will do a full device sync during `sync` + CUDA_CHECK(cudaMemsetAsync(task_fifo_ptrs[nvl_rank], 0, fifo_bytes, comm_stream)); + } + + // Create 32 MiB workspace + CUDA_CHECK(cudaMalloc(&workspace, NUM_WORKSPACE_BYTES)); + CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream)); + + // MoE counter + CUDA_CHECK(cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped)); + CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_counter_mapped, const_cast(moe_recv_counter), 0)); + *moe_recv_counter = -1; + + // MoE expert-level counter + CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, sizeof(int) * NUM_MAX_LOCAL_EXPERTS, cudaHostAllocMapped)); + CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_expert_counter_mapped, const_cast(moe_recv_expert_counter), 0)); + for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++ i) + moe_recv_expert_counter[i] = -1; + + // MoE RDMA-level counter + if (num_rdma_ranks > 0) { + CUDA_CHECK(cudaMallocHost(&moe_recv_rdma_counter, sizeof(int), cudaHostAllocMapped)); + CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast(moe_recv_rdma_counter), 0)); + *moe_recv_rdma_counter = -1; + } + + proxy_service->startProxy(); +} + +Buffer::~Buffer() noexcept(false) { + // Synchronize + CUDA_CHECK(cudaDeviceSynchronize()); + + if (num_nvl_bytes > 0) { + // Barrier + intranode::barrier(task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream); + move_fifo_slots(); + CUDA_CHECK(cudaDeviceSynchronize()); + + // Close remote IPC + if (is_available()) { + for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank) + CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i])); + } + + // Free local buffer and error flag + CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank])); + } + + // Free NVSHMEM + if (num_rdma_bytes > 0) { + // NVSHMEM support is not yet ported; if we got here with + // num_rdma_bytes > 0 the construction or sync call would already have + // failed, so there is nothing to tear down. + } + + proxy_service->stopProxy(); + + // Free cuBLAS handle, workspace and MoE counter + CUDA_CHECK(cudaFree(workspace)); + CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_counter))); + + // Free chunked mode staffs + CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); +} + +void Buffer::move_fifo_slots(int num_slots) { + head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS; +} + +bool Buffer::is_available() const { + return available; +} + +bool Buffer::is_internode_available() const { + return is_available() and num_ranks > NUM_MAX_NVL_PEERS; +} + +int Buffer::get_num_rdma_ranks() const { + return num_rdma_ranks; +} + +int Buffer::get_rdma_rank() const { + return rdma_rank; +} + +int Buffer::get_root_rdma_rank(bool global) const { + return global ? nvl_rank : 0; +} + +int Buffer::get_local_device_id() const { + return device_id; +} + +pybind11::bytearray Buffer::get_local_ipc_handle() const { + return {ipc_handles[nvl_rank].reserved, CUDA_IPC_HANDLE_SIZE}; +} + +pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { + // NVSHMEM support is not yet ported; see stub at bottom of this file. + throw std::runtime_error("mscclpp::ep::Buffer::get_local_nvshmem_unique_id: NVSHMEM support not yet ported"); +} + +torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const { + torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype); + auto element_bytes = static_cast(elementSize(casted_dtype)); + auto base_ptr = reinterpret_cast(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset; + auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes; + return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA)); +} + +mscclpp::UniqueId Buffer::create_unique_id() const { + return bootstrap->createUniqueId(); +} + +void Buffer::connect(mscclpp::UniqueId root_id) { + bootstrap->initialize(root_id); + communicator = std::make_shared(bootstrap); +} + +void Buffer::sync(const std::vector &device_ids, + const std::vector> &all_gathered_handles, + const std::optional& root_unique_id_opt) { + EP_HOST_ASSERT(not is_available()); + + const std::vector ib_transports = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, + mscclpp::Transport::IB2, mscclpp::Transport::IB3, mscclpp::Transport::IB4, + mscclpp::Transport::IB5, mscclpp::Transport::IB6, mscclpp::Transport::IB7}; + const auto ipc_transport = mscclpp::Transport::CudaIpc; + const auto ib_transport = ib_transports[device_id]; + const mscclpp::TransportFlags all_transport = ipc_transport | ib_transport; + + // Sync IPC handles + if (num_nvl_bytes > 0) { + EP_HOST_ASSERT(num_ranks == device_ids.size()); + EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size()); + for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++ i) { + EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value()); + auto handle_str = std::string(all_gathered_handles[offset + i].value()); + EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE); + if (offset + i != rank) { + std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); + CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); + task_fifo_ptrs[i] = reinterpret_cast(reinterpret_cast(buffer_ptrs[i]) + num_nvl_bytes); + } else { + EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0); + } + } + + // Copy all buffer and task pointers to GPU + CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(task_fifo_ptrs_gpu, task_fifo_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaDeviceSynchronize()); + + // create connections + std::vector connections; + { + std::vector> connection_futures; + mscclpp::EndpointConfig local_config(ipc_transport); + for (int i = 0; i < num_nvl_ranks; ++i) { + auto r = i + rdma_rank * num_nvl_ranks; + connection_futures.emplace_back(communicator->connect(local_config, r, 0)); + } + for (auto& future : connection_futures) { + connections.emplace_back(future.get()); + } + } + + auto buffer_mem = communicator->registerMemory(buffer_ptrs[nvl_rank], num_nvl_bytes, ipc_transport); + + std::vector> remote_mem_futures(num_nvl_ranks); + for (int i = 0; i < num_nvl_ranks; ++i) { + if (i == nvl_rank) continue; + auto r = i + rdma_rank * num_nvl_ranks; + communicator->sendMemory(buffer_mem, r, 0); + remote_mem_futures[i] = communicator->recvMemory(r, 0); + } + for (int i = 0; i < num_nvl_ranks; ++i) { + if (i == nvl_rank) continue; + auto r = i + rdma_rank * num_nvl_ranks; + auto sema = std::make_shared(*communicator, connections[i]); + memory_channels.emplace_back(sema, remote_mem_futures[i].get(), buffer_mem); + } + std::vector memory_channel_handles(num_nvl_ranks); + for (int i = 0; i < num_nvl_ranks; ++i) { + if (i == nvl_rank) continue; + memory_channel_handles[i] = memory_channels.rbegin()->deviceHandle(); + } + + memory_channel_handles_device_ptr = mscclpp::detail::gpuCallocShared(num_nvl_ranks); + mscclpp::gpuMemcpy( + memory_channel_handles_device_ptr.get(), memory_channel_handles.data(), num_nvl_ranks, + cudaMemcpyHostToDevice); + } + + // RDMA buffer setup (replaces DeepEP's NVSHMEM symmetric-heap allocation). + // + // Unlike DeepEP which used `nvshmem_align` to place the RDMA buffer on the + // symmetric heap, all our internode communication goes through MSCCL++ + // `PortChannel` (proxy-based RDMA), so a plain `cudaMalloc` + IB memory + // registration is sufficient. The bootstrap barrier replaces + // `nvshmem_barrier_all`. + if (num_rdma_bytes > 0) { + EP_HOST_ASSERT(communicator != nullptr); + EP_HOST_ASSERT(bootstrap != nullptr); + + // Allocate the RDMA buffer + CUDA_CHECK(cudaMalloc(&rdma_buffer_ptr, num_rdma_bytes)); + CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes)); + bootstrap->barrier(); + CUDA_CHECK(cudaDeviceSynchronize()); + + // Rank -> RDMA buffer IDs + std::map memory_ids; + + // Register local memory + auto local_rdma_buffer_mem = communicator->registerMemory(rdma_buffer_ptr, num_rdma_bytes, all_transport); + memory_ids[rank] = proxy_service->addMemory(local_rdma_buffer_mem); + + // Send local memory to other ranks. If low_latency_mode == true, only send to ranks with the same GPU ID. + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + if (low_latency_mode && ((r % NUM_MAX_NVL_PEERS) != (rank % NUM_MAX_NVL_PEERS))) continue; + communicator->sendMemory(local_rdma_buffer_mem, r, 0); + } + + // Receive remote memory from other ranks. + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + if (low_latency_mode && ((r % NUM_MAX_NVL_PEERS) != (rank % NUM_MAX_NVL_PEERS))) continue; + memory_ids[r] = proxy_service->addMemory(communicator->recvMemory(r, 0).get()); + } + + // Rank -> vector of connections + std::unordered_map> connections; + const mscclpp::EndpointConfig ipc_cfg(ipc_transport); + const mscclpp::EndpointConfig ib_cfg(ib_transport); + + // Self connection for local memory (CUDA IPC). + connections[rank].emplace_back(communicator->connect(ipc_cfg, rank, 0).get()); + + // Remote IB connections (multi-QP per peer). + const int num_ib_connections_per_rank = 12; // #QPs per rank (mirrors DeepEP). + for (auto& [r, memory_id] : memory_ids) { + if (r == rank) continue; + std::vector> futures; + futures.reserve(num_ib_connections_per_rank); + for (int i = 0; i < num_ib_connections_per_rank; ++i) { + futures.emplace_back(communicator->connect(ib_cfg, r, 0)); + } + for (auto& f : futures) connections[r].emplace_back(f.get()); + } + + // Rank -> vector of semaphore IDs + std::unordered_map> sema_ids; + const int num_semaphores_per_rank = 16; + for (int i = 0; i < num_semaphores_per_rank; ++i) { + for (auto& [r, conns] : connections) { + auto& conn = conns[i % conns.size()]; + auto sema_id = proxy_service->buildAndAddSemaphore(*communicator, conn); + sema_ids[r].emplace_back(sema_id); + } + } + + // Create port channels + device handles. + const int num_port_channels_per_rank = num_semaphores_per_rank; + std::vector port_channel_handles; + for (int i = 0; i < num_port_channels_per_rank; ++i) { + for (auto& [r, memory_id] : memory_ids) { + auto sema_id = sema_ids[r][i % sema_ids[r].size()]; + auto port_channel = proxy_service->portChannel(sema_id, memory_id, memory_ids[rank]); + port_channels.emplace_back(std::move(port_channel)); + port_channel_handles.emplace_back(port_channels.rbegin()->deviceHandle()); + } + } + + port_channel_handles_device_ptr = mscclpp::detail::gpuCallocShared( + port_channel_handles.size()); + mscclpp::gpuMemcpy( + port_channel_handles_device_ptr.get(), port_channel_handles.data(), port_channel_handles.size(), + cudaMemcpyHostToDevice); + } + + // Ready to use + available = true; +} + +std::tuple, torch::Tensor, torch::Tensor, std::optional> +Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, + std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + EP_HOST_ASSERT(topk_idx.dim() == 2); + EP_HOST_ASSERT(topk_idx.is_contiguous()); + EP_HOST_ASSERT(num_experts > 0); + + // Allocate all tensors on comm stream if set + // NOTES: do not allocate tensors upfront! + auto compute_stream = at::cuda::getCurrentCUDAStream(); + if (allocate_on_comm_stream) { + EP_HOST_ASSERT(previous_event.has_value() and async); + at::cuda::setCurrentCUDAStream(comm_stream); + } + + // Wait previous tasks to be finished + if (previous_event.has_value()) { + stream_wait(comm_stream, previous_event.value()); + } else { + stream_wait(comm_stream, compute_stream); + } + + auto num_tokens = static_cast(topk_idx.size(0)), num_topk = static_cast(topk_idx.size(1)); + auto num_tokens_per_rank = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + auto num_tokens_per_rdma_rank = std::optional(); + auto num_tokens_per_expert = torch::empty({num_experts}, dtype(torch::kInt32).device(torch::kCUDA)); + auto is_token_in_rank = torch::empty({num_tokens, num_ranks}, dtype(torch::kBool).device(torch::kCUDA)); + if (is_internode_available()) + num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + + internode::get_dispatch_layout(topk_idx.data_ptr(), + num_tokens_per_rank.data_ptr(), + num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr() : nullptr, + num_tokens_per_expert.data_ptr(), + is_token_in_rank.data_ptr(), + num_tokens, num_topk, num_ranks, num_experts, + comm_stream); + + // Wait streams + std::optional event; + if (async) { + event = EventHandle(comm_stream); + for (auto& t: {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) { + t.record_stream(comm_stream); + if (allocate_on_comm_stream) + t.record_stream(compute_stream); + } + for (auto& to: {num_tokens_per_rdma_rank}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) + to.has_value() ? to->record_stream(compute_stream) : void(); + } + } else { + stream_wait(compute_stream, comm_stream); + } + + // Switch back compute stream + if (allocate_on_comm_stream) + at::cuda::setCurrentCUDAStream(compute_stream); + + return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event}; +} + +std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> +Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, + const std::optional& topk_idx, const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, + int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + bool cached_mode = cached_rank_prefix_matrix.has_value(); + + // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. + EP_HOST_ASSERT(config.num_sms % 2 == 0); + int num_channels = config.num_sms / 2; + if (cached_mode) { + EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value()); + EP_HOST_ASSERT(cached_channel_prefix_matrix.has_value()); + } else { + EP_HOST_ASSERT(num_tokens_per_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_expert.has_value()); + } + + // Type checks + EP_HOST_ASSERT(is_token_in_rank.scalar_type() == torch::kBool); + if (cached_mode) { + EP_HOST_ASSERT(cached_rank_prefix_matrix->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(cached_channel_prefix_matrix->scalar_type() == torch::kInt32); + } else { + EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32); + } + + // Shape and contiguous checks + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); + EP_HOST_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous()); + EP_HOST_ASSERT(is_token_in_rank.size(0) == x.size(0) and is_token_in_rank.size(1) == num_ranks); + if (cached_mode) { + EP_HOST_ASSERT(cached_rank_prefix_matrix->dim() == 2 and cached_rank_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_rank_prefix_matrix->size(0) == num_ranks and cached_rank_prefix_matrix->size(1) == num_ranks); + EP_HOST_ASSERT(cached_channel_prefix_matrix->dim() == 2 and cached_channel_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_channel_prefix_matrix->size(0) == num_ranks and cached_channel_prefix_matrix->size(1) == num_channels); + } else { + EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); + EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); + EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); + } + + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); + auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks; + + // Top-k checks + int num_topk = 0; + int64_t* topk_idx_ptr = nullptr; + float* topk_weights_ptr = nullptr; + EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); + if (topk_idx.has_value()) { + num_topk = static_cast(topk_idx->size(1)); + EP_HOST_ASSERT(num_experts > 0); + EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous()); + EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); + EP_HOST_ASSERT(num_topk == topk_weights->size(1)); + EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); + topk_idx_ptr = topk_idx->data_ptr(); + topk_weights_ptr = topk_weights->data_ptr(); + } + + // FP8 scales checks + float* x_scales_ptr = nullptr; + int num_scales = 0; + if (x_scales.has_value()) { + EP_HOST_ASSERT(x.element_size() == 1); + EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32); + EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous()); + EP_HOST_ASSERT(x_scales->size(0) == num_tokens); + num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); + x_scales_ptr = x_scales->data_ptr(); + } + + // Allocate all tensors on comm stream if set + // NOTES: do not allocate tensors upfront! + auto compute_stream = at::cuda::getCurrentCUDAStream(); + if (allocate_on_comm_stream) { + EP_HOST_ASSERT(previous_event.has_value() and async); + at::cuda::setCurrentCUDAStream(comm_stream); + } + + // Wait previous tasks to be finished + if (previous_event.has_value()) { + stream_wait(comm_stream, previous_event.value()); + } else { + stream_wait(comm_stream, compute_stream); + } + + // Create handles (only return for non-cached mode) + int num_recv_tokens = -1; + auto rank_prefix_matrix = torch::Tensor(); + auto channel_prefix_matrix = torch::Tensor(); + std::vector num_recv_tokens_per_expert_list; + + // Barrier or send sizes + // To clean: channel start/end offset, head and tail + int num_memset_int = num_channels * num_ranks * 4; + if (cached_mode) { + num_recv_tokens = cached_num_recv_tokens; + rank_prefix_matrix = cached_rank_prefix_matrix.value(); + channel_prefix_matrix = cached_channel_prefix_matrix.value(); + + // Copy rank prefix matrix and clean flags + intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr(), num_memset_int, + buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, num_ranks, + comm_stream); + move_fifo_slots(2); + } else { + rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + + // Send sizes + // Meta information: + // - Size prefix by ranks, shaped as `[num_ranks, num_ranks]` + // - Size prefix by experts (not used later), shaped as `[num_ranks, num_local_experts]` + // NOTES: no more token dropping in this version + *moe_recv_counter = -1; + for (int i = 0; i < num_local_experts; ++ i) + moe_recv_expert_counter[i] = -1; + EP_HOST_ASSERT(num_ranks * (num_ranks + num_local_experts) * sizeof(int) <= num_nvl_bytes); + intranode::notify_dispatch(num_tokens_per_rank->data_ptr(), moe_recv_counter_mapped, num_ranks, + num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, + num_tokens, is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), + rank_prefix_matrix.data_ptr(), + num_memset_int, expert_alignment, + buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, + comm_stream, num_channels); + move_fifo_slots(3); + + // Synchronize total received tokens and tokens per expert + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + // Read total count + num_recv_tokens = static_cast(*moe_recv_counter); + + // Read per-expert count + bool ready = (num_recv_tokens >= 0); + for (int i = 0; i < num_local_experts and ready; ++i) + ready &= moe_recv_expert_counter[i] >= 0; + + if (ready) + break; + + // Timeout check + if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) + throw std::runtime_error("DeepEP error: CPU recv timeout"); + } + num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); + } + + // Allocate new tensors + auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); + auto recv_src_idx = torch::empty({num_recv_tokens}, dtype(torch::kInt32).device(torch::kCUDA)); + auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); + auto recv_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + auto send_head = torch::empty({num_tokens, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + + // Assign pointers + int64_t* recv_topk_idx_ptr = nullptr; + float* recv_topk_weights_ptr = nullptr; + float* recv_x_scales_ptr = nullptr; + if (topk_idx.has_value()) { + recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); + recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); + recv_topk_idx_ptr = recv_topk_idx->data_ptr(); + recv_topk_weights_ptr = recv_topk_weights->data_ptr(); + } + if (x_scales.has_value()) { + recv_x_scales = x_scales->dim() == 1 ? + torch::empty({num_recv_tokens}, x_scales->options()) : + torch::empty({num_recv_tokens, num_scales}, x_scales->options()); + recv_x_scales_ptr = recv_x_scales->data_ptr(); + } + + // Dispatch + EP_HOST_ASSERT(num_ranks * num_ranks * sizeof(int) + // Size prefix matrix + num_channels * num_ranks * sizeof(int) + // Channel start offset + num_channels * num_ranks * sizeof(int) + // Channel end offset + num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() + // Data buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(int64_t) + // Top-k index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer + <= num_nvl_bytes); + intranode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_src_idx.data_ptr(), recv_topk_idx_ptr, recv_topk_weights_ptr, recv_channel_prefix_matrix.data_ptr(), + send_head.data_ptr(), + x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, + is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), + num_tokens, static_cast(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales, + buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms, + config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); + + // Wait streams + std::optional event; + if (async) { + event = EventHandle(comm_stream); + for (auto& t: {x, is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_x, recv_src_idx, recv_channel_prefix_matrix, send_head}) { + t.record_stream(comm_stream); + if (allocate_on_comm_stream) + t.record_stream(compute_stream); + } + for (auto& to: {x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert, cached_channel_prefix_matrix, cached_rank_prefix_matrix, recv_topk_idx, recv_topk_weights, recv_x_scales}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) + to.has_value() ? to->record_stream(compute_stream) : void(); + } + } else { + stream_wait(compute_stream, comm_stream); + } + + // Switch back compute stream + if (allocate_on_comm_stream) + at::cuda::setCurrentCUDAStream(compute_stream); + + // Return values + return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event}; +} + +std::tuple, std::optional> +Buffer::intranode_combine(const torch::Tensor& x, const std::optional& topk_weights, + const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, + const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT(src_idx.dim() == 1 and src_idx.is_contiguous() and src_idx.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(send_head.dim() == 2 and send_head.is_contiguous() and send_head.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and rank_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and channel_prefix_matrix.is_contiguous() and channel_prefix_matrix.scalar_type() == torch::kInt32); + + // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. + EP_HOST_ASSERT(config.num_sms % 2 == 0); + int num_channels = config.num_sms / 2; + + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); + auto num_recv_tokens = static_cast(send_head.size(0)); + EP_HOST_ASSERT(src_idx.size(0) == num_tokens); + EP_HOST_ASSERT(send_head.size(1) == num_ranks); + EP_HOST_ASSERT(rank_prefix_matrix.size(0) == num_ranks and rank_prefix_matrix.size(1) == num_ranks); + EP_HOST_ASSERT(channel_prefix_matrix.size(0) == num_ranks and channel_prefix_matrix.size(1) == num_channels); + EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); + + // Allocate all tensors on comm stream if set + // NOTES: do not allocate tensors upfront! + auto compute_stream = at::cuda::getCurrentCUDAStream(); + if (allocate_on_comm_stream) { + EP_HOST_ASSERT(previous_event.has_value() and async); + at::cuda::setCurrentCUDAStream(comm_stream); + } + + // Wait previous tasks to be finished + if (previous_event.has_value()) { + stream_wait(comm_stream, previous_event.value()); + } else { + stream_wait(comm_stream, compute_stream); + } + + int num_topk = 0; + auto recv_topk_weights = std::optional(); + float* topk_weights_ptr = nullptr; + float* recv_topk_weights_ptr = nullptr; + if (topk_weights.has_value()) { + EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(topk_weights->size(0) == num_tokens); + EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); + num_topk = static_cast(topk_weights->size(1)); + topk_weights_ptr = topk_weights->data_ptr(); + recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); + recv_topk_weights_ptr = recv_topk_weights->data_ptr(); + } + + // Launch barrier and reset queue head and tail + EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes); + intranode::cached_notify_combine(buffer_ptrs_gpu, send_head.data_ptr(), + num_channels, num_recv_tokens, num_channels * num_ranks * 2, + task_fifo_ptrs_gpu, head, rank, num_ranks, + comm_stream); + + // NOTES: this function uses two FIFO slots (barrier before and after) + move_fifo_slots(2); + + // Combine data + auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); + EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * x.element_size() + // Data buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) // Top-k weight buffer + <= num_nvl_bytes); + intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), + recv_x.data_ptr(), recv_topk_weights_ptr, + x.data_ptr(), topk_weights_ptr, + src_idx.data_ptr(), rank_prefix_matrix.data_ptr(), channel_prefix_matrix.data_ptr(), + send_head.data_ptr(), num_tokens, num_recv_tokens, hidden, num_topk, + buffer_ptrs_gpu, rank, num_ranks, + comm_stream, config.num_sms, + config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); + + // Wait streams + std::optional event; + if (async) { + event = EventHandle(comm_stream); + for (auto& t: {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) { + t.record_stream(comm_stream); + if (allocate_on_comm_stream) + t.record_stream(compute_stream); + } + for (auto& to: {topk_weights, recv_topk_weights}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) + to.has_value() ? to->record_stream(compute_stream) : void(); + } + } else { + stream_wait(compute_stream, comm_stream); + } + + // Switch back compute stream + if (allocate_on_comm_stream) + at::cuda::setCurrentCUDAStream(compute_stream); + + return {recv_x, recv_topk_weights, event}; +} + +// ----------------------------------------------------------------------------- +// Internode (NVLink + RDMA) high-throughput path. Ported verbatim from +// DeepEP `csrc/deep_ep.cpp`; the kernels it drives are in +// `src/ext/ep/kernels/internode.cu`. Low-latency (pure RDMA) paths below +// are still stubbed. +// ----------------------------------------------------------------------------- + +std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> +Buffer::internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, + const std::optional& topk_idx, const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, + const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, + const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, + const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, + int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + // In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks, which can be quite long. + pybind11::gil_scoped_release release; + + const int num_channels = config.num_sms / 2; + EP_HOST_ASSERT(config.num_sms % 2 == 0); + EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); + + bool cached_mode = cached_rdma_channel_prefix_matrix.has_value(); + if (cached_mode) { + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix.has_value()); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum.has_value()); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix.has_value()); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value()); + } else { + EP_HOST_ASSERT(num_tokens_per_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_expert.has_value()); + } + + // Type checks + if (cached_mode) { + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->scalar_type() == torch::kInt32); + } else { + EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32); + } + + // Shape and contiguous checks + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); + if (cached_mode) { + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->dim() == 2 and cached_rdma_channel_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and cached_rdma_channel_prefix_matrix->size(1) == num_channels); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dim() == 1 and cached_recv_rdma_rank_prefix_sum->is_contiguous()); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->dim() == 2 and cached_gbl_channel_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and cached_gbl_channel_prefix_matrix->size(1) == num_channels); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 and cached_recv_gbl_rank_prefix_sum->is_contiguous()); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks); + } else { + EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 and num_tokens_per_rdma_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_rdma_ranks); + EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); + EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); + } + + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); + auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks; + + // Top-k checks + int num_topk = 0; + int64_t* topk_idx_ptr = nullptr; + float* topk_weights_ptr = nullptr; + EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); + if (topk_idx.has_value()) { + num_topk = static_cast(topk_idx->size(1)); + EP_HOST_ASSERT(num_experts > 0); + EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous()); + EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); + EP_HOST_ASSERT(num_topk == topk_weights->size(1)); + EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); + topk_idx_ptr = topk_idx->data_ptr(); + topk_weights_ptr = topk_weights->data_ptr(); + } + + // FP8 scales checks + float* x_scales_ptr = nullptr; + int num_scales = 0; + if (x_scales.has_value()) { + EP_HOST_ASSERT(x.element_size() == 1); + EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32); + EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous()); + EP_HOST_ASSERT(x_scales->size(0) == num_tokens); + num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); + x_scales_ptr = x_scales->data_ptr(); + } + + // Allocate all tensors on comm stream if set + auto compute_stream = at::cuda::getCurrentCUDAStream(); + if (allocate_on_comm_stream) { + EP_HOST_ASSERT(previous_event.has_value() and async); + at::cuda::setCurrentCUDAStream(comm_stream); + } + + // Wait previous tasks to be finished + if (previous_event.has_value()) { + stream_wait(comm_stream, previous_event.value()); + } else { + stream_wait(comm_stream, compute_stream); + } + + // Create handles (only return for non-cached mode) + int num_recv_tokens = -1, num_rdma_recv_tokens = -1; + auto rdma_channel_prefix_matrix = torch::Tensor(); + auto recv_rdma_rank_prefix_sum = torch::Tensor(); + auto gbl_channel_prefix_matrix = torch::Tensor(); + auto recv_gbl_rank_prefix_sum = torch::Tensor(); + std::vector num_recv_tokens_per_expert_list; + + // Barrier or send sizes + if (cached_mode) { + num_recv_tokens = cached_num_recv_tokens; + num_rdma_recv_tokens = cached_num_rdma_recv_tokens; + rdma_channel_prefix_matrix = cached_rdma_channel_prefix_matrix.value(); + recv_rdma_rank_prefix_sum = cached_recv_rdma_rank_prefix_sum.value(); + gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value(); + recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value(); + + // Just a barrier and clean flags + internode::cached_notify(hidden_int4, num_scales, num_topk, num_topk, + num_ranks, num_channels, 0, nullptr, + nullptr, nullptr, nullptr, + rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, + task_fifo_ptrs_gpu, head, rank, comm_stream, + config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + num_nvl_bytes, true, low_latency_mode, + port_channel_handles_device_ptr.get(), + memory_channel_handles_device_ptr.get()); + move_fifo_slots(2); + } else { + rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + recv_rdma_rank_prefix_sum = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + recv_gbl_rank_prefix_sum = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + + // Send sizes + *moe_recv_counter = -1, *moe_recv_rdma_counter = -1; + for (int i = 0; i < num_local_experts; ++i) + moe_recv_expert_counter[i] = -1; + internode::notify_dispatch(num_tokens_per_rank->data_ptr(), moe_recv_counter_mapped, num_ranks, + num_tokens_per_rdma_rank->data_ptr(), moe_recv_rdma_counter_mapped, + num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, + is_token_in_rank.data_ptr(), num_tokens, num_channels, + hidden_int4, num_scales, num_topk, expert_alignment, + rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), + rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, + task_fifo_ptrs_gpu, head, rank, comm_stream, + config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + num_nvl_bytes, low_latency_mode, port_channel_handles_device_ptr.get(), + memory_channel_handles_device_ptr.get()); + move_fifo_slots(3); + + // Synchronize total received tokens and tokens per expert + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + num_recv_tokens = static_cast(*moe_recv_counter); + num_rdma_recv_tokens = static_cast(*moe_recv_rdma_counter); + + bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0); + for (int i = 0; i < num_local_experts and ready; ++i) + ready &= moe_recv_expert_counter[i] >= 0; + + if (ready) break; + + if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) { + printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens); + for (int i = 0; i < num_local_experts; ++i) + printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]); + throw std::runtime_error("mscclpp::ep error: timeout (internode_dispatch CPU)"); + } + } + num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); + } + + // Allocate new tensors + auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); + auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); + auto recv_src_meta = std::optional(); + auto recv_rdma_channel_prefix_matrix = std::optional(); + auto recv_gbl_channel_prefix_matrix = std::optional(); + auto send_rdma_head = std::optional(); + auto send_nvl_head = std::optional(); + if (not cached_mode) { + recv_src_meta = torch::empty({num_recv_tokens, internode::get_source_meta_bytes()}, dtype(torch::kByte).device(torch::kCUDA)); + recv_rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + recv_gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + send_rdma_head = torch::empty({num_tokens, num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + send_nvl_head = torch::empty({num_rdma_recv_tokens, NUM_MAX_NVL_PEERS}, dtype(torch::kInt32).device(torch::kCUDA)); + } + + int64_t* recv_topk_idx_ptr = nullptr; + float* recv_topk_weights_ptr = nullptr; + float* recv_x_scales_ptr = nullptr; + if (topk_idx.has_value()) { + recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); + recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); + recv_topk_idx_ptr = recv_topk_idx->data_ptr(); + recv_topk_weights_ptr = recv_topk_weights->data_ptr(); + } + if (x_scales.has_value()) { + recv_x_scales = x_scales->dim() == 1 ? + torch::empty({num_recv_tokens}, x_scales->options()) : + torch::empty({num_recv_tokens, num_scales}, x_scales->options()); + recv_x_scales_ptr = recv_x_scales->data_ptr(); + } + + // Launch data dispatch + internode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr, + cached_mode ? nullptr : recv_src_meta->data_ptr(), + x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, + cached_mode ? nullptr : send_rdma_head->data_ptr(), cached_mode ? nullptr : send_nvl_head->data_ptr(), + cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr(), + cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr(), + rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), + num_tokens, hidden_int4, num_scales, num_topk, num_experts, + is_token_in_rank.data_ptr(), + rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, + rank, num_ranks, cached_mode, + comm_stream, num_channels, low_latency_mode, + port_channel_handles_device_ptr.get(), + memory_channel_handles_device_ptr.get()); + + // Wait streams + std::optional event; + if (async) { + event = EventHandle(comm_stream); + for (auto& t: {x, is_token_in_rank, recv_x, + rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) { + t.record_stream(comm_stream); + if (allocate_on_comm_stream) + t.record_stream(compute_stream); + } + for (auto& to: {x_scales, topk_idx, topk_weights, + num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, + cached_rdma_channel_prefix_matrix, cached_recv_rdma_rank_prefix_sum, + cached_gbl_channel_prefix_matrix, cached_recv_gbl_rank_prefix_sum, + recv_topk_idx, recv_topk_weights, recv_x_scales, + recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, send_rdma_head, send_nvl_head, + recv_src_meta}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) + to.has_value() ? to->record_stream(compute_stream) : void(); + } + } else { + stream_wait(compute_stream, comm_stream); + } + + if (allocate_on_comm_stream) + at::cuda::setCurrentCUDAStream(compute_stream); + + return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, + rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, + recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, + recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, + recv_src_meta, send_rdma_head, send_nvl_head, event}; +} + +std::tuple, std::optional> +Buffer::internode_combine(const torch::Tensor& x, const std::optional& topk_weights, + const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, + const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, + const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, + const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + const int num_channels = config.num_sms / 2; + EP_HOST_ASSERT(config.num_sms % 2 == 0); + + // Shape and contiguous checks + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT(src_meta.dim() == 2 and src_meta.is_contiguous() and src_meta.scalar_type() == torch::kByte); + EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and is_combined_token_in_rank.is_contiguous() and is_combined_token_in_rank.scalar_type() == torch::kBool); + EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and rdma_channel_prefix_matrix.is_contiguous() and rdma_channel_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and rdma_rank_prefix_sum.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and gbl_channel_prefix_matrix.is_contiguous() and gbl_channel_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and combined_rdma_head.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.is_contiguous() and combined_nvl_head.scalar_type() == torch::kInt32); + + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); + auto num_combined_tokens = static_cast(is_combined_token_in_rank.size(0)); + EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); + EP_HOST_ASSERT(src_meta.size(1) == internode::get_source_meta_bytes()); + EP_HOST_ASSERT(is_combined_token_in_rank.size(1) == num_ranks); + EP_HOST_ASSERT(rdma_channel_prefix_matrix.size(0) == num_rdma_ranks and rdma_channel_prefix_matrix.size(1) == num_channels); + EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks); + EP_HOST_ASSERT(gbl_channel_prefix_matrix.size(0) == num_ranks and gbl_channel_prefix_matrix.size(1) == num_channels); + EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.size(0) == num_combined_tokens and combined_rdma_head.size(1) == num_rdma_ranks); + EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.size(1) == NUM_MAX_NVL_PEERS); + + auto compute_stream = at::cuda::getCurrentCUDAStream(); + if (allocate_on_comm_stream) { + EP_HOST_ASSERT(previous_event.has_value() and async); + at::cuda::setCurrentCUDAStream(comm_stream); + } + + if (previous_event.has_value()) { + stream_wait(comm_stream, previous_event.value()); + } else { + stream_wait(comm_stream, compute_stream); + } + + int num_topk = 0; + auto combined_topk_weights = std::optional(); + float* topk_weights_ptr = nullptr; + float* combined_topk_weights_ptr = nullptr; + if (topk_weights.has_value()) { + EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(topk_weights->size(0) == num_tokens); + EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); + num_topk = static_cast(topk_weights->size(1)); + topk_weights_ptr = topk_weights->data_ptr(); + combined_topk_weights = torch::empty({num_combined_tokens, num_topk}, topk_weights->options()); + combined_topk_weights_ptr = combined_topk_weights->data_ptr(); + } + + EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); + EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <= config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks); + + internode::cached_notify(hidden_int4, 0, 0, num_topk, + num_ranks, num_channels, + num_combined_tokens, combined_rdma_head.data_ptr(), + rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), combined_nvl_head.data_ptr(), + rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, + task_fifo_ptrs_gpu, head, rank, comm_stream, + config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + num_nvl_bytes, false, low_latency_mode, + port_channel_handles_device_ptr.get(), + memory_channel_handles_device_ptr.get()); + move_fifo_slots(2); + + auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); + internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), + combined_x.data_ptr(), combined_topk_weights_ptr, + is_combined_token_in_rank.data_ptr(), + x.data_ptr(), topk_weights_ptr, + combined_rdma_head.data_ptr(), combined_nvl_head.data_ptr(), + src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), + num_tokens, num_combined_tokens, hidden, num_topk, + rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, + rank, num_ranks, comm_stream, num_channels, low_latency_mode, + port_channel_handles_device_ptr.get(), + memory_channel_handles_device_ptr.get()); + + std::optional event; + if (async) { + event = EventHandle(comm_stream); + for (auto& t: {x, src_meta, + is_combined_token_in_rank, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, + combined_x, combined_rdma_head, combined_nvl_head}) { + t.record_stream(comm_stream); + if (allocate_on_comm_stream) + t.record_stream(compute_stream); + } + for (auto& to: {topk_weights, combined_topk_weights}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) + to.has_value() ? to->record_stream(compute_stream) : void(); + } + } else { + stream_wait(compute_stream, comm_stream); + } + + if (allocate_on_comm_stream) + at::cuda::setCurrentCUDAStream(compute_stream); + + return {combined_x, combined_topk_weights, event}; +} + +void Buffer::clean_low_latency_buffer(int, int, int) { + throw std::runtime_error("mscclpp::ep::Buffer::clean_low_latency_buffer: not yet ported"); +} + +std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> +Buffer::low_latency_dispatch(const torch::Tensor&, const torch::Tensor&, int, int, bool, bool, bool) { + throw std::runtime_error("mscclpp::ep::Buffer::low_latency_dispatch: not yet ported (needs NVSHMEM/IBGDA -> MSCCL++ migration)"); +} + +std::tuple, std::optional>> +Buffer::low_latency_combine(const torch::Tensor&, const torch::Tensor&, const torch::Tensor&, + const torch::Tensor&, const torch::Tensor&, + int, int, bool, bool, bool, const std::optional&) { + throw std::runtime_error("mscclpp::ep::Buffer::low_latency_combine: not yet ported"); +} + +torch::Tensor Buffer::get_next_low_latency_combine_buffer(int, int, int) { + throw std::runtime_error("mscclpp::ep::Buffer::get_next_low_latency_combine_buffer: not yet ported"); +} + +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/buffer.hpp b/src/ext/ep/buffer.hpp new file mode 100644 index 00000000..1b0399ef --- /dev/null +++ b/src/ext/ep/buffer.hpp @@ -0,0 +1,169 @@ +#pragma once + +// Forcibly disable NDEBUG +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "config.hpp" +#include "event.hpp" +#include "kernels/configs.cuh" +#include "kernels/exception.cuh" + +#ifndef TORCH_EXTENSION_NAME +#define TORCH_EXTENSION_NAME mscclpp_ep_cpp +#endif + +namespace mscclpp { namespace ep { + +struct Buffer { + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8"); + +private: + // Low-latency mode buffer + int low_latency_buffer_idx = 0; + bool low_latency_mode = false; + + // NVLink Buffer + int64_t num_nvl_bytes; + void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; + void** buffer_ptrs_gpu = nullptr; + + // NVSHMEM Buffer + int64_t num_rdma_bytes; + void* rdma_buffer_ptr = nullptr; + + // Device info and communication + int device_id; + int rank, rdma_rank, nvl_rank; + int num_ranks, num_rdma_ranks, num_nvl_ranks; + cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS]; + + // Stream for communication + at::cuda::CUDAStream comm_stream; + + // After IPC/NVSHMEM synchronization, this flag will be true + bool available = false; + + // Task fifo + int head = 0; + int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; + int** task_fifo_ptrs_gpu = nullptr; + + // Workspace + void* workspace = nullptr; + + // Host-side MoE info + volatile int* moe_recv_counter = nullptr; + int* moe_recv_counter_mapped = nullptr; + + // Host-side expert-level MoE info + volatile int* moe_recv_expert_counter = nullptr; + int* moe_recv_expert_counter_mapped = nullptr; + + // Host-side RDMA-level MoE info + volatile int* moe_recv_rdma_counter = nullptr; + int* moe_recv_rdma_counter_mapped = nullptr; + + std::shared_ptr bootstrap; + std::shared_ptr proxy_service; + std::shared_ptr communicator; + std::vector port_channels; + std::vector memory_channels; + std::shared_ptr port_channel_handles_device_ptr; + std::shared_ptr memory_channel_handles_device_ptr; + +private: + void move_fifo_slots(int num_slots = 1); + +public: + Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode); + + ~Buffer() noexcept(false); + + bool is_available() const; + + bool is_internode_available() const; + + int get_num_rdma_ranks() const; + + int get_rdma_rank() const; + + int get_root_rdma_rank(bool global) const; + + int get_local_device_id() const; + + pybind11::bytearray get_local_ipc_handle() const; + + pybind11::bytearray get_local_nvshmem_unique_id() const; + + torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const; + + mscclpp::UniqueId create_unique_id() const; + + void connect(mscclpp::UniqueId root_id); + + void sync(const std::vector& device_ids, const std::vector>& all_gathered_handles, const std::optional& root_unique_id_opt); + + std::tuple, torch::Tensor, torch::Tensor, std::optional> + get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional& previous_event, + bool async, bool allocate_on_comm_stream); + + std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> + intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, + const std::optional& topk_idx, const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, + int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + + std::tuple, std::optional> + intranode_combine(const torch::Tensor& x, const std::optional& topk_weights, + const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, + const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + + std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> + internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, + const std::optional& topk_idx, const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, + const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, + const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, + const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, + int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + + std::tuple, std::optional> + internode_combine(const torch::Tensor& x, const std::optional& topk_weights, + const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, + const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, + const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, + const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + + void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); + + std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> + low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, + int num_max_dispatch_tokens_per_rank, int num_experts, + bool use_fp8, bool async, bool return_recv_hook); + + std::tuple, std::optional>> + low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, + const torch::Tensor& src_info, const torch::Tensor& layout_range, + int num_max_dispatch_tokens_per_rank, int num_experts, + bool zero_copy, bool async, bool return_recv_hook, + const std::optional& out = std::nullopt); + + torch::Tensor + get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); +}; + +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/config.hpp b/src/ext/ep/config.hpp new file mode 100644 index 00000000..8acdf0c8 --- /dev/null +++ b/src/ext/ep/config.hpp @@ -0,0 +1,183 @@ +#pragma once + +#include "kernels/api.cuh" +#include "kernels/exception.cuh" + +namespace mscclpp { namespace ep { + +template +dtype_t cell_div(dtype_t a, dtype_t b) { + return (a + b - 1) / b; +} + +template +dtype_t align(dtype_t a, dtype_t b) { + return cell_div(a, b) * b; +} + +struct Config { + int num_sms; + int num_max_nvl_chunked_send_tokens; + int num_max_nvl_chunked_recv_tokens; + int num_max_rdma_chunked_send_tokens; + int num_max_rdma_chunked_recv_tokens; + + Config(int num_sms, + int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens) : + num_sms(num_sms), + num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens), + num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens), + num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens), + num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) { + EP_HOST_ASSERT(num_sms >= 0); + EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0); + EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens); + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0); + + // Ceil up RDMA buffer size + this->num_max_rdma_chunked_recv_tokens = align(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens); + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens); + // NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always have space to push + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2); + } + + size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const { + // Below are some assumptions + // TODO: add assertions + constexpr int kNumMaxTopK = 128; + constexpr int kNumMaxScales = 128; + EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); + EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0); + const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); + const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); + const int num_channels = num_sms / 2; + + size_t num_bytes = 0; + num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes; + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes(); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float); + num_bytes = ((num_bytes + 127) / 128) * 128; + return num_bytes; + } + + size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const { + // Legacy mode + if (num_ranks <= NUM_MAX_NVL_PEERS) + return 0; + + // Below are some assumptions + // TODO: add assertions + constexpr int kNumMaxTopK = 128; + constexpr int kNumMaxScales = 128; + EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); + EP_HOST_ASSERT(num_sms % 2 == 0); + const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + const int num_channels = num_sms / 2; + + size_t num_bytes = 0; + num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int); + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t) * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2; + num_bytes = ((num_bytes + 127) / 128) * 128; + return num_bytes; + } +}; + +struct LowLatencyBuffer { + int num_clean_int = 0; + + void* dispatch_rdma_send_buffer = nullptr; + void* dispatch_rdma_recv_data_buffer = nullptr; + int* dispatch_rdma_recv_count_buffer = nullptr; + + void* combine_rdma_send_buffer = nullptr; + void* combine_rdma_recv_data_buffer = nullptr; + int* combine_rdma_recv_flag_buffer = nullptr; + + void* combine_rdma_send_buffer_data_start = nullptr; + size_t num_bytes_per_combine_msg = 0; + + std::pair clean_meta() { + EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer); + return {dispatch_rdma_recv_count_buffer, num_clean_int}; + } +}; + +struct LowLatencyLayout { + size_t total_bytes = 0; + LowLatencyBuffer buffers[2]; + + template + out_ptr_t advance(const in_ptr_t& ptr, size_t count) { + return reinterpret_cast(reinterpret_cast(ptr) + count); + } + + LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { + const int num_scales = hidden / 128; + + // Dispatch and combine layout: + // - 2 symmetric odd/even send buffer + // - 2 symmetric odd/even receive buffers + // - 2 symmetric odd/even signaling buffers + + // Message sizes + // NOTES: you should add a control `int4` for combine messages if you want to do data transformation + EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); + size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); + size_t num_bytes_per_combine_msg = hidden * sizeof(nv_bfloat16); + + // Send buffer + size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; + size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; + size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes); + EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0); + total_bytes += send_buffer_bytes * 2; + + // Symmetric receive buffers + // TODO: optimize memory usages + size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; + size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; + size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes); + EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0); + total_bytes += recv_buffer_bytes * 2; + + // Symmetric signaling buffers + size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); + size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; + size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); + total_bytes += signaling_buffer_bytes * 2; + + // Assign pointers + // NOTES: we still leave some space for distinguishing dispatch/combine buffer, + // so you may see some parameters are duplicated + for (int i = 0; i < 2; ++ i) { + buffers[i] = { + static_cast(signaling_buffer_bytes / sizeof(int)), + advance(rdma_buffer, send_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * i), + num_bytes_per_combine_msg + }; + } + } +}; + +inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { + auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes; + return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; +} + +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/event.hpp b/src/ext/ep/event.hpp new file mode 100644 index 00000000..947ec549 --- /dev/null +++ b/src/ext/ep/event.hpp @@ -0,0 +1,44 @@ +#include +#include + +#include "kernels/exception.cuh" + +namespace mscclpp { namespace ep { + +struct EventHandle { + std::shared_ptr event; + + EventHandle() { + event = std::make_shared(torch::kCUDA); + event->record(at::cuda::getCurrentCUDAStream()); + } + + explicit EventHandle(const at::cuda::CUDAStream& stream) { + event = std::make_shared(torch::kCUDA); + event->record(stream); + } + + EventHandle(const EventHandle& other) = default; + + void current_stream_wait() const { + at::cuda::getCurrentCUDAStream().unwrap().wait(*event); + } +}; + +inline torch::Event create_event(const at::cuda::CUDAStream &s) { + auto event = torch::Event(torch::kCUDA); + event.record(s); + return event; +} + +inline void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) { + EP_HOST_ASSERT(s_0.id() != s_1.id()); + s_0.unwrap().wait(create_event(s_1)); +} + +inline void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) { + s.unwrap().wait(*event.event); +} + +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/internode_stub.cc b/src/ext/ep/internode_stub.cc new file mode 100644 index 00000000..329e31ee --- /dev/null +++ b/src/ext/ep/internode_stub.cc @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// +// Placeholder launchers for the not-yet-ported internode HT and low-latency +// kernels. `get_dispatch_layout` and `get_source_meta_bytes` ARE ported in +// `kernels/internode_layout.cu`. + +#include + +#include "kernels/api.cuh" + +namespace mscclpp { +namespace ep { + +namespace internode_ll { + +void clean_low_latency_buffer(int* /*clean_0*/, int /*n0*/, int* /*clean_1*/, int /*n1*/, cudaStream_t /*stream*/) { + throw std::runtime_error( + "mscclpp::ep::internode_ll::clean_low_latency_buffer: not yet ported. " + "See nccl/contrib/nccl_ep/device/low_latency.cu and DeepEP " + "csrc/kernels/internode_ll.cu for the reference implementation."); +} + +} // namespace internode_ll + +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/kernels/api.cuh b/src/ext/ep/kernels/api.cuh new file mode 100644 index 00000000..479b6a17 --- /dev/null +++ b/src/ext/ep/kernels/api.cuh @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// +// Portions adapted from DeepEP (https://github.com/deepseek-ai/DeepEP), +// branch `chhwang/dev-atomic-add-cleanup`. Licensed under the MIT License. +// +// Private host-callable API exposed by the EP CUDA kernels. One-to-one port of +// DeepEP `csrc/kernels/api.cuh` minus the NVSHMEM-only internode entrypoints, +// which are still to be migrated. + +#pragma once + +#include +#include +#include + +#include +#include + +namespace mscclpp { +namespace ep { + +// =========================================================================== +// Intranode (NVLink) runtime barrier. +// =========================================================================== +namespace intranode { + +void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream); + +void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, + const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, + int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix, + int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, + void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, + cudaStream_t stream, int num_sms); + +void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, + void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, int num_ranks, + cudaStream_t stream); + +void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, + int* recv_channel_offset, int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, + const float* topk_weights, const bool* is_token_in_rank, const int* channel_prefix_matrix, + int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, + void** buffer_ptrs, int rank, int num_ranks, cudaStream_t stream, int num_sms, + int num_max_send_tokens, int num_recv_buffer_tokens); + +void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, + int num_memset_int, int** task_fifo_ptrs, int head, int rank, int num_ranks, + cudaStream_t stream); + +void combine(cudaDataType_t type, void* recv_x, float* recv_topk_weights, const void* x, const float* topk_weights, + const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, int* send_head, + int num_tokens, int num_recv_tokens, int hidden, int num_topk, + void** buffer_ptrs, int rank, int num_ranks, cudaStream_t stream, int num_sms, + int num_max_send_tokens, int num_recv_buffer_tokens); + +} // namespace intranode + +// =========================================================================== +// Internode (NVLink + RDMA) high-throughput kernels. Ported from DeepEP +// `csrc/kernels/internode.cu` on branch `chhwang/dev-atomic-add-cleanup`. +// NVSHMEM dependencies in the kernel were replaced with MSCCL++ port-channel +// atomic adds (see src/ext/ep/README.md for the translation table). +// =========================================================================== +namespace internode { + +int get_source_meta_bytes(); + +void get_dispatch_layout(const int64_t* topk_idx, int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, bool* is_token_in_rank, int num_tokens, int num_topk, + int num_ranks, int num_experts, cudaStream_t stream); + +void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, + const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, + const bool* is_token_in_rank, int num_tokens, int num_channels, + int hidden_int4, int num_scales, int num_topk, int expert_alignment, + int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, + int** task_fifo_ptrs, int head, int rank, + cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, + bool low_latency_mode, + mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles); + +void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, + const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, + int* send_rdma_head, int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, + int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, + const bool* is_token_in_rank, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks, bool is_cached_dispatch, + cudaStream_t stream, int num_channels, bool low_latency_mode, + mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles); + +void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, + int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head, + const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, + void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, + int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, + int64_t num_rdma_bytes, int64_t num_nvl_bytes, + bool is_cached_dispatch, bool low_latency_mode, + mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles); + +void combine(cudaDataType_t type, + void* combined_x, float* combined_topk_weights, + const bool* is_combined_token_in_rank, + const void* x, const float* topk_weights, + const int* combined_rdma_head, const int* combined_nvl_head, + const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + int num_tokens, int num_combined_tokens, int hidden, int num_topk, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode, + mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles); + +} // namespace internode + +// =========================================================================== +// Internode low-latency (pure RDMA) kernels. Not ported yet. +// =========================================================================== +namespace internode_ll { + +void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, int* clean_1, int num_clean_int_1, + cudaStream_t stream); + +} // namespace internode_ll + +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/kernels/buffer.cuh b/src/ext/ep/kernels/buffer.cuh new file mode 100644 index 00000000..de1fb4ee --- /dev/null +++ b/src/ext/ep/kernels/buffer.cuh @@ -0,0 +1,139 @@ +#pragma once + +#include "configs.cuh" +#include "exception.cuh" + +namespace mscclpp { namespace ep { + +template +struct Buffer { +private: + uint8_t* ptr; + +public: + int total_bytes; + + __device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {} + + __device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) { + total_bytes = num_elems * sizeof(dtype_t); + ptr = reinterpret_cast(gbl_ptr) + offset * sizeof(dtype_t); + gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + } + + __device__ __forceinline__ Buffer advance_also(void* &gbl_ptr) { + gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + return *this; + } + + __device__ __forceinline__ dtype_t* buffer() { + return reinterpret_cast(ptr); + } + + __device__ __forceinline__ dtype_t& operator[](int idx) { + return buffer()[idx]; + } +}; + +template +struct AsymBuffer { +private: + uint8_t* ptrs[kNumRanks]; + int num_bytes; + +public: + int total_bytes; + + __device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks, + int sm_id = 0, int num_sms = 1, int offset = 0) { + EP_STATIC_ASSERT(kNumRanks == 1, ""); + num_bytes = num_elems * sizeof(dtype_t); + + int per_channel_bytes = num_bytes * num_ranks; + total_bytes = per_channel_bytes * num_sms; + ptrs[0] = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; + gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + } + + __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, + int sm_id = 0, int num_sms = 1, int offset = 0) { + EP_STATIC_ASSERT(kNumRanks > 1, ""); + num_bytes = num_elems * sizeof(dtype_t); + + int per_channel_bytes = num_bytes * num_ranks; + total_bytes = per_channel_bytes * num_sms; + for (int i = 0; i < kNumRanks; ++ i) { + ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; + gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; + } + } + + __device__ __forceinline__ void advance(int shift) { + #pragma unroll + for (int i = 0; i < kNumRanks; ++ i) + ptrs[i] = ptrs[i] + shift * sizeof(dtype_t); + } + + __device__ __forceinline__ AsymBuffer advance_also(void* &gbl_ptr) { + gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + return *this; + } + + template + __device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) { + for (int i = 0; i < kNumAlsoRanks; ++ i) + gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; + return *this; + } + + __device__ __forceinline__ dtype_t* buffer(int idx = 0) { + EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case"); + return reinterpret_cast(ptrs[0] + num_bytes * idx); + } + + __device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) { + EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case"); + return reinterpret_cast(ptrs[rank_idx] + num_bytes * idx); + } +}; + +template +struct SymBuffer { +private: + // NOTES: for non-decoupled case, `recv_ptr` is not used + uint8_t* send_ptr; + uint8_t* recv_ptr; + int num_bytes; + +public: + int total_bytes; + + __device__ __forceinline__ SymBuffer(void* &gbl_ptr, int num_elems, int num_ranks, + int sm_id = 0, int num_sms = 1) { + num_bytes = num_elems * sizeof(dtype_t); + + int per_channel_bytes = num_bytes * num_ranks; + total_bytes = per_channel_bytes * num_sms * (static_cast(kDecoupled) + 1); + send_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id; + recv_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); + gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + } + + __device__ __forceinline__ dtype_t* send_buffer(int idx = 0) { + EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case"); + return reinterpret_cast(send_ptr + num_bytes * idx); + } + + __device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) { + EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case"); + return reinterpret_cast(recv_ptr + num_bytes * idx); + } + + __device__ __forceinline__ dtype_t* buffer(int idx = 0) { + EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case"); + return reinterpret_cast(send_ptr + num_bytes * idx); + } +}; + +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/kernels/configs.cuh b/src/ext/ep/kernels/configs.cuh new file mode 100644 index 00000000..7f413a6b --- /dev/null +++ b/src/ext/ep/kernels/configs.cuh @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// +// Portions adapted from DeepEP (https://github.com/deepseek-ai/DeepEP), +// branch `chhwang/dev-atomic-add-cleanup`. Licensed under the MIT License. +// +// Kernel-side configuration. This is the MSCCL++ version of +// `DeepEP/csrc/kernels/configs.cuh` with NVSHMEM / IBGDA / mlx5dv includes +// removed so the intranode (NVLink-only) kernels can be built standalone. +// Include this file **only** from `.cu` files. + +#pragma once + +#define NUM_MAX_NVL_PEERS 8 +#define NUM_MAX_RDMA_PEERS 20 +#define NUM_MAX_FIFO_SLOTS 32768 +#define NUM_WORKSPACE_BYTES (32 * 1024 * 1024) +#define NUM_MAX_LOCAL_EXPERTS 1024 +#define NUM_BUFFER_ALIGNMENT_BYTES 128 + +#define FINISHED_SUM_TAG 1024 +#define NUM_CPU_TIMEOUT_SECS 100 +#define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s +#define NUM_WAIT_NANOSECONDS 500 + +#define LOW_LATENCY_SEND_PHASE 1 +#define LOW_LATENCY_RECV_PHASE 2 + +// Make CLion CUDA indexing work. +#ifdef __CLION_IDE__ +#define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier) +#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier) +__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); } +#define printf host_device_printf +#endif + +// Remove Torch restrictions. +#ifdef __CUDA_NO_HALF_CONVERSIONS__ +#undef __CUDA_NO_HALF_CONVERSIONS__ +#endif +#ifdef __CUDA_NO_HALF_OPERATORS__ +#undef __CUDA_NO_HALF_OPERATORS__ +#endif +#ifdef __CUDA_NO_HALF2_OPERATORS__ +#undef __CUDA_NO_HALF2_OPERATORS__ +#endif +#ifdef __CUDA_NO_BFLOAT16_CONVERSIONS__ +#undef __CUDA_NO_BFLOAT16_CONVERSIONS__ +#endif +#ifdef __CUDA_NO_BFLOAT162_OPERATORS__ +#undef __CUDA_NO_BFLOAT162_OPERATORS__ +#endif + +#include +#include +#include + +// NVSHMEM / IBGDA / mlx5dv are only required for the RDMA internode paths and +// are not included here. The internode/low-latency kernels that need them +// will include them directly under `#ifdef MSCCLPP_EP_HAVE_NVSHMEM`. diff --git a/src/ext/ep/kernels/exception.cuh b/src/ext/ep/kernels/exception.cuh new file mode 100644 index 00000000..5f5a651a --- /dev/null +++ b/src/ext/ep/kernels/exception.cuh @@ -0,0 +1,46 @@ +#pragma once + +#include +#include + +#include "configs.cuh" + +#ifndef EP_STATIC_ASSERT +#define EP_STATIC_ASSERT(cond, reason) static_assert(cond, reason) +#endif + +class EPException: public std::exception { +private: + std::string message = {}; + +public: + explicit EPException(const char *name, const char* file, const int line, const std::string& error) { + message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'"; + } + + const char *what() const noexcept override { return message.c_str(); } +}; + +#ifndef CUDA_CHECK +#define CUDA_CHECK(cmd) \ +do { \ + cudaError_t e = (cmd); \ + if (e != cudaSuccess) { \ + throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \ + } \ +} while (0) +#endif + +#ifndef EP_HOST_ASSERT +#define EP_HOST_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + throw EPException("Assertion", __FILE__, __LINE__, #cond); \ + } \ +} while (0) +#endif + +#ifndef EP_DEVICE_ASSERT +// #define EP_DEVICE_ASSERT(cond) do { if (not (cond)) { printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); asm("trap;"); } } while (0) +#define EP_DEVICE_ASSERT(cond) +#endif diff --git a/src/ext/ep/kernels/internode.cu b/src/ext/ep/kernels/internode.cu new file mode 100644 index 00000000..5c704513 --- /dev/null +++ b/src/ext/ep/kernels/internode.cu @@ -0,0 +1,1814 @@ +#include "configs.cuh" +#include "buffer.cuh" +#include "exception.cuh" +#include "launch.cuh" +#include "utils.cuh" +#include +#include + +namespace mscclpp { namespace ep { + +namespace internode { + +template +__global__ void __launch_bounds__(kNumThreads, 1) +get_dispatch_layout(const int64_t* topk_idx, + int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, bool* is_token_in_rank, + int num_tokens, int num_topk, int num_ranks, int num_experts) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x); + + // Count expert statistics + __shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM]; + int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts); + if (expert_begin_idx < expert_end_idx) { + // Per-thread count + #pragma unroll + for (int i = 0; i < kNumExpertsPerSM; ++ i) + num_tokens_per_expert_per_thread[thread_id][i] = 0; + #pragma unroll + for (int i = thread_id; i < num_tokens; i += kNumThreads) { + auto shifted_topk_idx = topk_idx + i * num_topk; + #pragma unroll + for (int j = 0, expert_idx; j < num_topk; ++ j) { + expert_idx = static_cast(shifted_topk_idx[j]); + if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx) + ++ num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx]; + } + } + __syncthreads(); + + // Sum up + EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM"); + if (expert_begin_idx + thread_id < expert_end_idx) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumThreads; ++ i) + sum += num_tokens_per_expert_per_thread[i][thread_id]; + num_tokens_per_expert[expert_begin_idx + thread_id] = sum; + } + return; + } + + if (num_tokens_per_rdma_rank != nullptr) + EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS); + + // Count rank statistics + constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS; + __shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM]; + __shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM]; + auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM; + int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks); + int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS; + if (rank_begin_idx < rank_end_idx) { + const auto num_expert_per_rank = num_experts / num_ranks; + auto expert_begin = rank_begin_idx * num_expert_per_rank; + auto expert_end = rank_end_idx * num_expert_per_rank; + + // Per-thread count + #pragma unroll + for (int i = 0; i < kNumRanksPerSM; ++ i) + num_tokens_per_rank_per_thread[thread_id][i] = 0; + #pragma unroll + for (int i = 0; i < kNumRDMARanksPerSM; ++ i) + num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0; + #pragma unroll + for (int i = thread_id; i < num_tokens; i += kNumThreads) { + auto shifted_topk_idx = topk_idx + i * num_topk; + int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0}; + #pragma unroll + for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) { + expert_idx = static_cast(shifted_topk_idx[j]); + if (expert_begin <= expert_idx and expert_idx < expert_end) { + // Count single rank + rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx; + is_in_rank[rank_idx] ++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS] ++; + } + } + + auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks; + #pragma unroll + for (int j = 0; j + rank_begin_idx < rank_end_idx; ++ j) { + shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0); + num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0); + } + + #pragma unroll + for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++ j) + num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0); + } + __syncthreads(); + + // Sum up + EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM"); + if (rank_begin_idx + thread_id < rank_end_idx) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumThreads; ++ i) + sum += num_tokens_per_rank_per_thread[i][thread_id]; + num_tokens_per_rank[rank_begin_idx + thread_id] = sum; + } + + if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumThreads; ++ i) + sum += num_tokens_per_rdma_rank_per_thread[i][thread_id]; + num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum; + } + } +} + +void get_dispatch_layout(const int64_t* topk_idx, + int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, bool* is_token_in_rank, + int num_tokens, int num_topk, int num_ranks, int num_experts, + cudaStream_t stream) { + constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8; + int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM; + EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of experts per SM"); + + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + LAUNCH_KERNEL(&cfg, (get_dispatch_layout), + topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, + num_tokens, num_topk, num_ranks, num_experts); +} + +struct SourceMeta { + int src_rdma_rank, is_token_in_nvl_rank_bits; + + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers"); + + __forceinline__ SourceMeta() = default; + + // TODO: faster encoding + __device__ __forceinline__ SourceMeta(int rdma_rank, const bool* is_token_in_nvl_ranks) { + src_rdma_rank = rdma_rank; + is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0]; + #pragma unroll + for (int i = 1; i < NUM_MAX_NVL_PEERS; ++ i) + is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i; + } + + __device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const { + return (is_token_in_nvl_rank_bits >> nvl_rank) & 1; + } +}; + +EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); + +int get_source_meta_bytes() { + return sizeof(SourceMeta); +} + +__host__ __device__ __forceinline__ +int get_num_bytes_per_rdma_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) { + return static_cast(align(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4))); +} + +__host__ __device__ __forceinline__ +std::pair get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) { + // Return `int32_t` offset and count to clean + return { + (get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / sizeof(int), + (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms + }; +} + +__host__ __device__ __forceinline__ +std::pair get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens, int num_sms) { + // Return `int32_t` offset and to clean + EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); + return { + (num_nvl_recv_buffer_tokens * (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float) + sizeof(SourceMeta)) * num_nvl_ranks * num_sms) / sizeof(int), + num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms, + }; +} + +template +__global__ void +notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, + const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, + const bool* is_token_in_rank, int num_tokens, int num_channels, int expert_alignment, + const int rdma_clean_offset, const int rdma_num_int_clean, + const int nvl_clean_offset, const int nvl_num_int_clean, + int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, + mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); + auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; + + auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + auto num_rdma_experts = num_experts / kNumRDMARanks, num_nvl_experts = num_rdma_experts / NUM_MAX_NVL_PEERS; + + if (sm_id == 0) { + // Communication with others + // Global barrier: the first warp do intra-node sync, the second warp do internode sync + EP_DEVICE_ASSERT(num_warps > 1); + EP_DEVICE_ASSERT(kNumRDMARanks + 32 <= num_threads); + const auto barrier_thread_id = thread_id - 32; + const bool run_barrier = (barrier_thread_id >= 0) && (barrier_thread_id < kNumRDMARanks) && (barrier_thread_id != rdma_rank); + const auto barrier_channel_idx = kLowLatencyMode ? barrier_thread_id : (barrier_thread_id * NUM_MAX_NVL_PEERS + nvl_rank); + if (run_barrier) { + port_channel_handles[barrier_channel_idx].signal(); + port_channel_handles[barrier_channel_idx].wait(); + } + if constexpr (!kLowLatencyMode) { + // kLowLatencyMode==false requires sync of all ranks, which can be done by running intra-node sync + // after the inter-node sync is done. + __syncthreads(); + } +#if 1 + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); +#else + // TODO(chhwang): make memory channels work + if (thread_id < NUM_MAX_NVL_PEERS && thread_id != nvl_rank) { + memory_channel_handles[thread_id].relaxedSignal(); + memory_channel_handles[thread_id].relaxedWait(); + } +#endif + __syncthreads(); + + // Send numbers of tokens per rank/expert to RDMA ranks + auto rdma_buffer_ptr_int = reinterpret_cast(rdma_buffer_ptr); + auto num_elems = NUM_MAX_NVL_PEERS + num_rdma_experts + 1; + auto num_bytes = num_elems * sizeof(int); + auto per_channel_bytes = num_bytes * kNumRDMARanks; + auto rdma_recv_num_tokens_mixed = SymBuffer(rdma_buffer_ptr, num_elems, kNumRDMARanks); + + // Clean up for later data dispatch + EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= rdma_clean_offset * sizeof(int)); + #pragma unroll + for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) + rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; + + // Copy to send buffer + #pragma unroll + for (int i = thread_id; i < num_ranks; i += num_threads) + rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] = num_tokens_per_rank[i]; + #pragma unroll + for (int i = thread_id; i < num_experts; i += num_threads) + rdma_recv_num_tokens_mixed.send_buffer(i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] = num_tokens_per_expert[i]; + if (thread_id < kNumRDMARanks) + rdma_recv_num_tokens_mixed.send_buffer(thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] = num_tokens_per_rdma_rank[thread_id]; + __syncthreads(); + + // Issue send + // TODO: more light fence or barrier or signaling + // TODO: overlap EP barrier and NVL cleaning + if (thread_id < kNumRDMARanks) { + auto dst_offset = rdma_rank * num_bytes + per_channel_bytes; + auto src_offset = thread_id * num_bytes; + auto peer_rank = kLowLatencyMode ? thread_id : (thread_id * NUM_MAX_NVL_PEERS + nvl_rank); + port_channel_handles[peer_rank].putWithSignal(dst_offset, src_offset, num_bytes); + port_channel_handles[peer_rank].wait(); + } + __syncthreads(); + + // NVL buffers + auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr; + auto nvl_recv_buffer = buffer_ptrs[nvl_rank]; + auto nvl_reduced_num_tokens_per_expert = Buffer(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer); + auto nvl_send_num_tokens_per_rank = AsymBuffer(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS); + auto nvl_send_num_tokens_per_expert = AsymBuffer(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); + auto nvl_recv_num_tokens_per_rank = AsymBuffer(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS); + auto nvl_recv_num_tokens_per_expert = AsymBuffer(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); + + // Clean up for later data dispatch + auto nvl_buffer_ptr_int = reinterpret_cast(buffer_ptrs[nvl_rank]); + EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes + + nvl_send_num_tokens_per_expert.total_bytes <= nvl_clean_offset * sizeof(int)); + #pragma unroll + for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) + nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; + + // Reduce number of tokens per expert into the NVL send buffer + // TODO: may use NVSHMEM reduction + EP_DEVICE_ASSERT(num_rdma_experts <= num_threads); + if (thread_id < num_rdma_experts) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumRDMARanks; ++ i) + sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id]; + nvl_reduced_num_tokens_per_expert[thread_id] = sum; + } + __syncthreads(); + + // Reduce RDMA received tokens + if (thread_id == 0) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumRDMARanks; ++ i) { + sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts]; + recv_rdma_rank_prefix_sum[i] = sum; + } + while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1); + *moe_recv_rdma_counter_mapped = sum; + } + + // Send numbers of tokens per rank/expert to NVL ranks + EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads); + if (thread_id < NUM_MAX_NVL_PEERS) { + #pragma unroll + for (int i = 0; i < kNumRDMARanks; ++ i) + nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] = rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id]; + #pragma unroll + for (int i = 0; i < num_nvl_experts; ++ i) + nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i]; + } + memory_fence(); + __syncthreads(); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + // Reduce number of tokens per rank/expert + EP_DEVICE_ASSERT(num_nvl_experts <= num_threads); + if (thread_id == 0) { + int sum = 0; + #pragma unroll + for (int i = 0; i < num_ranks; ++ i) { + int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS; + sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank]; + recv_gbl_rank_prefix_sum[i] = sum; + } + while (ld_volatile_global(moe_recv_counter_mapped) != -1); + *moe_recv_counter_mapped = sum; + } + if (thread_id < num_nvl_experts) { + int sum = 0; + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) + sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id]; + sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; + while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1); + moe_recv_expert_counter_mapped[thread_id] = sum; + } + + // Finally barrier + __syncthreads(); + + if (run_barrier) { + port_channel_handles[barrier_channel_idx].signal(); + port_channel_handles[barrier_channel_idx].wait(); + } + if constexpr (!kLowLatencyMode) { + // kLowLatencyMode==false requires sync of all ranks, which can be done by running intra-node sync + // after the inter-node sync is done. + __syncthreads(); + } + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + } else { + // Calculate meta data + int dst_rdma_rank = sm_id - 1; + for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { + int token_start_idx, token_end_idx; + get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Iterate over tokens + int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0}; + for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) { + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); + auto is_token_in_rank_uint64 = *reinterpret_cast(is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS); + auto is_token_in_rank_values = reinterpret_cast(&is_token_in_rank_uint64); + #pragma unroll + for (int j = 0; j < NUM_MAX_NVL_PEERS; ++ j) + per_nvl_rank_count[j] += is_token_in_rank_values[j]; + total_count += (is_token_in_rank_uint64 != 0); + } + + // Warp reduce + total_count = warp_reduce_sum(total_count); + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) + per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]); + + // Write into channel matrix + if (lane_id == 0) { + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) + gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * num_channels + channel_id] = per_nvl_rank_count[i]; + rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count; + } + } + + // Calculate prefix sum + __syncthreads(); + if (thread_id == 0) { + auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels; + #pragma unroll + for (int i = 1; i < num_channels; ++ i) + prefix_row[i] += prefix_row[i - 1]; + } + + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); + if (thread_id < NUM_MAX_NVL_PEERS) { + auto prefix_row = gbl_channel_prefix_matrix + (dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels; + #pragma unroll + for (int i = 1; i < num_channels; ++ i) + prefix_row[i] += prefix_row[i - 1]; + } + } +} + +void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, + const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, + const bool* is_token_in_rank, int num_tokens, int num_channels, + int hidden_int4, int num_scales, int num_topk, int expert_alignment, + int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, + int** task_fifo_ptrs, int head, int rank, + cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, + bool low_latency_mode, + mscclpp::PortChannelDeviceHandle *port_channel_handles, + mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { +#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \ + auto notify_dispatch_func = low_latency_mode ? \ + notify_dispatch : notify_dispatch; \ + LAUNCH_KERNEL(&cfg, notify_dispatch_func, \ + num_tokens_per_rank, moe_recv_counter_mapped, num_ranks, \ + num_tokens_per_rdma_rank, moe_recv_rdma_counter_mapped, \ + num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \ + is_token_in_rank, num_tokens, num_channels, expert_alignment, \ + rdma_clean_meta.first, rdma_clean_meta.second, \ + nvl_clean_meta.first, nvl_clean_meta.second, \ + rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ + gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \ + rdma_buffer_ptr, \ + buffer_ptrs, task_fifo_ptrs, head, rank, \ + port_channel_handles, memory_channel_handles); } break + + constexpr int kNumThreads = 512; + const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + + // Get clean meta + auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); + auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels); + EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); + EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); + EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); + + // Launch kernel + SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream); + SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); +#undef NOTIFY_DISPATCH_LAUNCH_CASE +} + +// At most 8 RDMA ranks to be sent +constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) { + return num_rdma_ranks < 8 ? num_rdma_ranks : 8; +} + +template +__global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32), 1) +dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, SourceMeta* recv_src_meta, + const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, + int* send_rdma_head, int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, + int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, + const bool* is_token_in_rank, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks, + mscclpp::PortChannelDeviceHandle *port_channel_handles, + mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { + enum class WarpRole { + kRDMASender, + kRDMASenderCoordinator, + kRDMAAndNVLForwarder, + kForwarderCoordinator, + kNVLReceivers + }; + + const auto sm_id = static_cast(blockIdx.x); + const auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); + const auto num_channels = static_cast(gridDim.x) / 2, channel_id = sm_id / 2; + const bool is_forwarder = sm_id % 2 == 0; + const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + + const auto role_meta = [=]() -> std::pair { + if (is_forwarder) { + if (warp_id < NUM_MAX_NVL_PEERS) { + return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; + } else { + return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS}; + } + } else if (warp_id < kNumDispatchRDMASenderWarps) { + return {WarpRole::kRDMASender, -1}; + } else if (warp_id == kNumDispatchRDMASenderWarps) { + return {WarpRole::kRDMASenderCoordinator, -1}; + } else { + return {WarpRole::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS}; + } + }(); + auto warp_role = role_meta.first; + auto target_rank = role_meta.second; // Not applicable for RDMA senders + EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS); + + // Data checks + EP_DEVICE_ASSERT(num_topk <= 32); + + // RDMA symmetric layout + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); + auto hidden_bytes = hidden_int4 * sizeof(int4); + auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk); + auto rdma_channel_data = SymBuffer(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_meta = SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + + auto data_send_offset = sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * channel_id; + auto data_recv_offset = sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * (channel_id + num_channels); + auto meta_offset = sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * num_channels * 2; + auto meta_send_offset = meta_offset + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2) * kNumRDMARanks * channel_id; + auto meta_recv_offset = meta_offset + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2) * kNumRDMARanks * (channel_id + num_channels); + auto head_offset = meta_offset + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2) * kNumRDMARanks * num_channels * 2; + auto head_send_offset = head_offset + sizeof(uint64_t) * kNumRDMARanks * channel_id; + auto tail_offset = head_offset + sizeof(uint64_t) * kNumRDMARanks * num_channels; + auto tail_send_offset = tail_offset + sizeof(uint64_t) * kNumRDMARanks * channel_id; + + // NVL buffer layouts + // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers" + void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr; + int rs_wr_rank = 0, ws_rr_rank = 0; + if (warp_role == WarpRole::kRDMAAndNVLForwarder) + rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank; + if (warp_role == WarpRole::kNVLReceivers) + rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank; + + // Allocate buffers + auto nvl_channel_x = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_src_meta = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_x_scales = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_topk_idx = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_topk_weights = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_prefix_start = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_prefix_end = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_head = AsymBuffer(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr); + auto nvl_channel_tail = AsymBuffer(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + + // RDMA sender warp synchronization + __shared__ volatile int rdma_send_next_token_idx; + __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks]; + __shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks]; + auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; + + // Forward warp synchronization + __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks]; + __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS]; + auto sync_forwarder_smem = []() { asm volatile("bar.sync 1, %0;" :: "r"((NUM_MAX_NVL_PEERS + 1) * 32)); }; + + if (warp_role == WarpRole::kRDMASender) { + // Get tasks + int token_start_idx, token_end_idx; + get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Clean shared memory + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); + (warp_id == 0 and lane_id == 0) ? (rdma_send_next_token_idx = token_start_idx) : 0; + (warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0; + (warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_next_tail[lane_id] = 0) : 0; + + // Send number of tokens in this channel by `-value - 1` + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); + for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { + auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank); + if (lane_id < NUM_MAX_NVL_PEERS) { + dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1; + } else if (lane_id < NUM_MAX_NVL_PEERS * 2) { + dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1; + } else if (lane_id == NUM_MAX_NVL_PEERS * 2) { + dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1; + } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) { + dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; + } + __syncwarp(); + + if (dst_rdma_rank == rdma_rank) continue; + + // Issue RDMA for non-local ranks + if (lane_id == 0) { + auto num_bytes = sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2); + auto dst_offset = rdma_rank * num_bytes + meta_recv_offset; + auto src_offset = dst_rdma_rank * num_bytes + meta_send_offset; + auto port_channel_idx = kLowLatencyMode ? (channel_id * kNumRDMARanks + dst_rdma_rank) : (channel_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); + port_channel_handles[port_channel_idx].put(dst_offset, src_offset, num_bytes); + // port_channel_handles[port_channel_idx].flush(); + } + __syncwarp(); + } + sync_rdma_sender_smem(); + + // Iterate over tokens and copy into buffer + int64_t token_idx; + int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1; + auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id); + for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) { + // Read RDMA rank existence + uint64_t is_token_in_rank_uint64 = 0; + if (lane_id < kNumRDMARanks) + is_token_in_rank_uint64 = *reinterpret_cast(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS); + + // Acquire sequential lock + while (lane_id == 0 and rdma_send_next_token_idx != token_idx); + __syncwarp(); + + // Acquire next tail + int rdma_tail_idx = -1; + if (is_token_in_rank_uint64 != 0) { + rdma_tail_idx = rdma_send_channel_next_tail[lane_id] ++; + while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) + cached_rdma_channel_head = static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); + } + __syncwarp(); + + // Store RDMA head for combine + if (lane_id < kNumRDMARanks and not kCachedMode) + send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx; + + // Update last token tail + if (last_rdma_tail_idx >= 0) + st_release_cta(const_cast(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1); + last_rdma_tail_idx = rdma_tail_idx; + + // Release sequential lock + lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; + + // Broadcast tails + SourceMeta src_meta; + int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks]; + void* dst_send_buffers[kNumTopkRDMARanks]; + #pragma unroll + for (int i = 0, slot_idx; i < kNumRDMARanks; ++ i) if ((slot_idx = __shfl_sync(0xffffffff, rdma_tail_idx, i)) >= 0) { + slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens; + topk_ranks[num_topk_ranks] = i; + auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i); + auto recv_is_token_in_rank_values = reinterpret_cast(&recv_is_token_in_rank_uint64); + if (lane_id == num_topk_ranks) + src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values); + dst_send_buffers[num_topk_ranks ++] = reinterpret_cast(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token; + } + EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks); + + // Copy `x` into symmetric send buffer + auto st_broadcast = [=](const int key, const int4& value) { + #pragma unroll + for (int j = 0; j < num_topk_ranks; ++ j) + st_na_global(reinterpret_cast(dst_send_buffers[j]) + key, value); + }; + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast); + #pragma unroll + for (int i = 0; i < num_topk_ranks; ++ i) + dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + hidden_int4; + + // Copy source metadata into symmetric send buffer + if (lane_id < num_topk_ranks) + st_na_global(reinterpret_cast(dst_send_buffers[lane_id]), src_meta); + #pragma unroll + for (int i = 0; i < num_topk_ranks; ++ i) + dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + 1; + + // Copy `x_scales` into symmetric send buffer + #pragma unroll + for (int i = lane_id; i < num_scales; i += 32) { + auto value = ld_nc_global(x_scales + token_idx * num_scales + i); + #pragma unroll + for (int j = 0; j < num_topk_ranks; ++ j) + st_na_global(reinterpret_cast(dst_send_buffers[j]) + i, value); + } + #pragma unroll + for (int i = 0; i < num_topk_ranks; ++ i) + dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + num_scales; + + // Copy `topk_idx` and `topk_weights` into symmetric send buffer + #pragma unroll + for (int i = lane_id; i < num_topk * num_topk_ranks; i += 32) { + auto rank_idx = i / num_topk, copy_idx = i % num_topk; + auto idx_value = static_cast(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx)); + auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx); + st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + copy_idx, idx_value); + st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value); + } + } + + // Epilogue + // Acquire sequential lock + while (lane_id == 0 and rdma_send_next_token_idx != token_idx); + __syncwarp(); + + // Update last token tail + if (last_rdma_tail_idx >= 0) + st_release_cta(const_cast(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1); + + // Release sequential lock + lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; + } else if (warp_role == WarpRole::kRDMASenderCoordinator) { + // NOTES: in case of splitting the issued put at the end of the buffer + EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); + + // Synchronize shared memory + sync_rdma_sender_smem(); + + // Get number of tokens to send for each RDMA rank + int num_tokens_to_send = 0; + if (lane_id < kNumRDMARanks) { + num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id]; + if (channel_id > 0) + num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1]; + } + + // Iterate all RDMA ranks + int last_issued_tail = 0; + while (__any_sync(0xffffffff, num_tokens_to_send > 0)) { + #pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i, __syncwarp()) { + // To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels + const int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks; + if (lane_id != dst_rdma_rank) continue; + if (num_tokens_to_send == 0) continue; + + // Read progress + auto processed_tail = ld_acquire_cta(const_cast(rdma_send_channel_tail + dst_rdma_rank)); + auto num_tokens_processed = processed_tail - last_issued_tail; + if (num_tokens_processed != num_tokens_to_send && num_tokens_processed < num_max_rdma_chunked_send_tokens) + continue; + + // Issue RDMA send + int num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens); + EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 && num_tokens_to_issue <= num_tokens_to_send); + if (num_tokens_to_issue == 0) continue; + + if (dst_rdma_rank == rdma_rank) { + // Update tails + mscclpp::atomicFetchAdd(reinterpret_cast(rdma_channel_tail.buffer(rdma_rank)), (uint64_t)num_tokens_to_issue, mscclpp::memoryOrderRelease); + } else { + const auto dst_slot_idx = last_issued_tail % num_max_rdma_chunked_recv_tokens; + const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue; + const auto dst_offset = rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) + dst_slot_idx * num_bytes_per_rdma_token + data_recv_offset; + const auto src_offset = dst_rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) + dst_slot_idx * num_bytes_per_rdma_token + data_send_offset; + const auto port_channel_idx = kLowLatencyMode ? (channel_id * kNumRDMARanks + dst_rdma_rank) : (channel_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); + auto& handle = port_channel_handles[port_channel_idx]; + handle.put(dst_offset, src_offset, num_bytes_per_msg); + + // Remote atomic add on the peer's tail counter: +num_tokens_to_issue. + handle.atomicAdd(rdma_rank * sizeof(uint64_t) + tail_send_offset, (int64_t)num_tokens_to_issue); + // handle.flush(); + } + last_issued_tail += num_tokens_to_issue; + num_tokens_to_send -= num_tokens_to_issue; + } + } + } else if (warp_role == WarpRole::kRDMAAndNVLForwarder) { + // RDMA consumers and NVL producers + const auto dst_nvl_rank = target_rank; + const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank; + const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks); + const auto dst_rank_expert_end = dst_rank_expert_begin + (num_experts / num_ranks); + + // Wait counters to arrive + int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0; + EP_DEVICE_ASSERT(kNumRDMARanks <= 32); + auto start_time = clock64(); + if (lane_id < kNumRDMARanks) { + while (true) { + auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank); + auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); + auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2); + auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1); + if (meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0) { + // Notify NVL ranks + int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1; + EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum); + st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1); + st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1); + + // Save RDMA channel received token count + src_rdma_channel_prefix = -meta_2 - 1; + auto src_rdma_channel_prefix_1 = -meta_3 - 1; + num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; + if (not kCachedMode) + recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1; + src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; + EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0); + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3); + trap(); + } + } + } + __syncwarp(); + + // Shift cached head + send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank; + + // Wait shared memory to be cleaned + sync_forwarder_smem(); + + // Forward tokens from RDMA buffer + // NOTES: always start from the local rank + int src_rdma_rank = sm_id % kNumRDMARanks; + int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0; + int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0; + while (__any_sync(0xffffffff, num_tokens_to_recv_from_rdma > 0)) { + // Check destination queue emptiness, or wait a buffer to be released + start_time = clock64(); + while (lane_id == 0) { + int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head; + if (num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens) + break; + cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer()); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n", + channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail); + trap(); + } + } + __syncwarp(); + + // Find next source RDMA rank (round-robin) + start_time = clock64(); + while (true) { + src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks; + if (__shfl_sync(0xffffffff, num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) { + if (lane_id == src_rdma_rank and cached_rdma_channel_head == cached_rdma_channel_tail) + cached_rdma_channel_tail = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank))); + if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { + printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n", + channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma); + trap(); + } + } + auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank); + auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank); + + // Iterate over every token from the RDMA buffer + for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) { + auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; + void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token; + auto src_meta = ld_nc_global(reinterpret_cast(reinterpret_cast(shifted) + hidden_bytes)); + lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0; + bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank); + if (lane_id == src_rdma_rank) { + auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1; + rdma_nvl_token_idx += is_in_dst_nvl_rank; + if (not kCachedMode) + send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head; + } + if (not is_in_dst_nvl_rank) + continue; + + // Get an empty slot + int dst_slot_idx = (cached_nvl_channel_tail ++) % num_max_nvl_chunked_recv_tokens; + + // Copy data + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, + nvl_channel_x.buffer() + dst_slot_idx * hidden_int4, + reinterpret_cast(shifted), + ld_nc_global, st_na_global); + shifted = reinterpret_cast(shifted) + hidden_int4; + + // Copy source meta + if (lane_id == 0) + st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta); + shifted = reinterpret_cast(shifted) + 1; + + // Copy `x_scales` + UNROLLED_WARP_COPY(1, lane_id, num_scales, + nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales, + reinterpret_cast(shifted), + ld_nc_global, st_na_global); + shifted = reinterpret_cast(shifted) + num_scales; + + // Copy `topk_idx` and `topk_weights` + // NOTES: do not use `shifted` after this `if`, because only several lanes are shifted + if (lane_id < num_topk) { + // Read + auto idx_value = ld_nc_global(reinterpret_cast(shifted) + lane_id); + shifted = reinterpret_cast(shifted) + num_topk; + auto weight_value = ld_nc_global(reinterpret_cast(shifted) + lane_id); + + // Transform and write + idx_value = (idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1; + st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value); + weight_value = idx_value >= 0 ? weight_value : 0.0f; + st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value); + } + + // In case of insufficient NVL buffers, early stopping + if ((++ num_tokens_sent) == num_max_nvl_chunked_send_tokens) + src_rdma_tail = i + 1; + } + + // Sync head index + if (lane_id == src_rdma_rank) + forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail); + + // Move tail index + __syncwarp(); + if (lane_id == 0) + st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail); + } + + // Retired + __syncwarp(); + if (lane_id == 0) + forward_channel_retired[dst_nvl_rank] = true; + } else if (warp_role == WarpRole::kForwarderCoordinator) { + // Extra warps for forwarder coordinator should exit directly + if (target_rank > 0) + return; + + // Forward warp coordinator + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + + // Clean shared memory + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); + #pragma unroll + for (int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += 32) + forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0; + if (lane_id < NUM_MAX_NVL_PEERS) + forward_channel_retired[lane_id] = false; + sync_forwarder_smem(); + + int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0; + while (true) { + // Find minimum head + int min_head = std::numeric_limits::max(); + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) if (not forward_channel_retired[i]) + min_head = min(min_head, forward_channel_head[i][target_rdma]); + if (__all_sync(0xffffffff, min_head == std::numeric_limits::max())) + break; + + // Update remote head + if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { + if (lane_id == rdma_rank) { + mscclpp::atomicFetchAdd(static_cast(rdma_channel_head.buffer(rdma_rank)), (uint64_t)(min_head - last_head), mscclpp::memoryOrderRelease); + } else { + auto dst_offset = rdma_rank * sizeof(uint64_t) + head_send_offset; + auto port_channel_idx = kLowLatencyMode ? (channel_id * kNumRDMARanks + lane_id) : (channel_id * num_ranks + lane_id * NUM_MAX_NVL_PEERS + nvl_rank); + auto& handle = port_channel_handles[port_channel_idx]; + // Remote atomic add on the peer's head counter. + handle.atomicAdd(dst_offset, (int64_t)(min_head - last_head)); + } + last_head = min_head; + } + + // Nanosleep and let other warps work + __nanosleep(NUM_WAIT_NANOSECONDS); + } + } else { + // NVL consumers + // Retrieve rank offset from barrier results (each lane's register stores an RDMA rank) + int src_nvl_rank = target_rank, total_offset = 0; + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0) + total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1]; + + // Receive channel offsets + int start_offset = 0, end_offset = 0, num_tokens_to_recv; + auto start_time = clock64(); + while (lane_id < kNumRDMARanks) { + start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id); + end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id); + if (start_offset < 0 and end_offset < 0) { + start_offset = -start_offset - 1, end_offset = -end_offset - 1; + total_offset += start_offset; + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset); + trap(); + } + } + num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset); + + // Save for combine usage + if (lane_id < kNumRDMARanks and not kCachedMode) + recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset; + __syncwarp(); + + int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; + while (num_tokens_to_recv > 0) { + // Check channel status by lane 0 + start_time = clock64(); + while (lane_id == 0) { + // Ready to copy + if (cached_channel_head_idx != cached_channel_tail_idx) + break; + cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer()); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n", + channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx); + trap(); + } + } + + // Sync queue tail + cached_channel_tail_idx = __shfl_sync(0xffffffff, cached_channel_tail_idx, 0); + + // Copy data + int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; + for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx, -- num_tokens_to_recv) { + int token_idx_in_buffer = (cached_channel_head_idx ++) % num_max_nvl_chunked_recv_tokens; + auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer); + int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank); + (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; + + // Copy data + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, + recv_x + recv_token_idx * hidden_int4, + nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4, + ld_nc_global, st_na_global); + + // Copy source meta + if (lane_id == 0 and not kCachedMode) + st_na_global(recv_src_meta + recv_token_idx, meta); + + // Copy scales + UNROLLED_WARP_COPY(1, lane_id, num_scales, + recv_x_scales + recv_token_idx * num_scales, + nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales, + ld_nc_global, st_na_global); + + // Copy `topk_idx` and `topk_weights` + if (lane_id < num_topk) { + auto recv_idx = recv_token_idx * num_topk + lane_id; + auto buffer_idx = token_idx_in_buffer * num_topk + lane_id; + st_na_global(recv_topk_idx + recv_idx, static_cast(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx))); + st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx)); + } + } + + // Move queue + __syncwarp(); + if (lane_id == 0) + st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx); + } + } +} + +void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, + const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, + int* send_rdma_head, int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, + int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, + const bool* is_token_in_rank, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks, bool is_cached_dispatch, + cudaStream_t stream, int num_channels, bool low_latency_mode, + mscclpp::PortChannelDeviceHandle *port_channel_handles, + mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { + constexpr int kNumDispatchRDMASenderWarps = 7; + +#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \ + auto dispatch_func = low_latency_mode ? \ + (is_cached_dispatch ? dispatch : dispatch) : \ + (is_cached_dispatch ? dispatch : dispatch); \ + LAUNCH_KERNEL(&cfg, dispatch_func, \ + reinterpret_cast(recv_x), recv_x_scales, recv_topk_idx, recv_topk_weights, reinterpret_cast(recv_src_meta), \ + reinterpret_cast(x), x_scales, topk_idx, topk_weights, \ + send_rdma_head, send_nvl_head, \ + recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \ + rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ + gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \ + num_tokens, hidden_int4, num_scales, num_topk, num_experts, \ + is_token_in_rank, \ + rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \ + buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \ + rank, num_ranks, \ + port_channel_handles, memory_channel_handles); } break + + EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr)); + EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr)); + + SETUP_LAUNCH_CONFIG(num_channels * 2, (kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32, stream); + SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE); +#undef DISPATCH_LAUNCH_CASE +} + +template +__global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, + const int nvl_clean_offset, const int nvl_num_int_clean, + int* combined_rdma_head, int num_combined_tokens, int num_channels, + const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, + void* rdma_buffer_ptr, + void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, int num_ranks, + bool is_cached_dispatch, + mscclpp::PortChannelDeviceHandle *port_channel_handles, + mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x); + auto num_threads = static_cast(blockDim.x); + auto warp_id = thread_id / 32; + auto lane_id = get_lane_id(); + + auto rdma_rank = rank / NUM_MAX_NVL_PEERS; + auto nvl_rank = rank % NUM_MAX_NVL_PEERS; + auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + + // Using two SMs, which clean the RDMA/NVL buffer respectively + if (sm_id == 0) { + // Barrier for RDMA + + // TODO(chhwang): it should be a global barrier when kLowLatencyMode is false + const bool run_barrier = (threadIdx.x < num_rdma_ranks) && (threadIdx.x != rdma_rank); + const auto barrier_channel_idx = kLowLatencyMode ? threadIdx.x : (threadIdx.x * NUM_MAX_NVL_PEERS + nvl_rank); + if (run_barrier) { + port_channel_handles[barrier_channel_idx].signal(); + port_channel_handles[barrier_channel_idx].wait(); + } + __syncthreads(); + + // Clean + auto rdma_buffer_ptr_int = reinterpret_cast(rdma_buffer_ptr); + #pragma unroll + for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) + rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; + // Make the cleanup visible to the proxy + remote peers before the barrier. + // DeepEP used `nvshmem_fence()` here; we fall back to a system-scope + // threadfence because the actual remote visibility is provided by the + // subsequent port-channel barrier (signal + flush + wait). + __threadfence_system(); + __syncthreads(); + + // Barrier again + if (run_barrier) { + port_channel_handles[barrier_channel_idx].signal(); + port_channel_handles[barrier_channel_idx].flush(); + port_channel_handles[barrier_channel_idx].wait(); + } + } else if (sm_id == 1) { + // Barrier for NVL + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + // Clean + auto nvl_buffer_ptr_int = reinterpret_cast(buffer_ptrs[nvl_rank]); + #pragma unroll + for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) + nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; + memory_fence(); + __syncthreads(); + + // Barrier again + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + } else if (sm_id == 2) { + if (is_cached_dispatch) + return; + + EP_DEVICE_ASSERT(num_warps >= num_channels); + EP_DEVICE_ASSERT(num_rdma_ranks <= 32); + + // Iterate in reverse order + if (lane_id < num_rdma_ranks and warp_id < num_channels) { + int token_start_idx, token_end_idx; + get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, token_end_idx); + + // NOTES: `1 << 25` is a heuristic large number + int last_head = 1 << 25; + for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; -- token_idx) { + auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); + if (current_head < 0) { + combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; + } else { + last_head = current_head; + } + } + } + } else { + if (is_cached_dispatch) + return; + + EP_DEVICE_ASSERT(num_warps >= num_channels); + EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr); + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers"); + + if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) { + for (int dst_rdma_rank = sm_id - 3; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 3) { + // Iterate in reverse order + int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1]; + int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id]; + int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; + token_start_idx += shift, token_end_idx += shift; + + // NOTES: `1 << 25` is a heuristic large number + int last_head = 1 << 25; + #pragma unroll + for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; -- token_idx) { + auto current_head = __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); + if (current_head < 0) { + combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1; + } else { + last_head = current_head; + } + } + } + } + } +} + +void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, + int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head, + const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, + void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, + int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, + int64_t num_rdma_bytes, int64_t num_nvl_bytes, + bool is_cached_dispatch, bool low_latency_mode, + mscclpp::PortChannelDeviceHandle *port_channel_handles, + mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { + const int num_threads = std::max(128, 32 * num_channels); + const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + + // Get clean meta + auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); + auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels); + EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); + EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); + EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_channels * 2 > 3); + + // Launch kernel + auto cached_notify_func = low_latency_mode ? cached_notify : cached_notify; + SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream); + LAUNCH_KERNEL(&cfg, cached_notify_func, + rdma_clean_meta.first, rdma_clean_meta.second, + nvl_clean_meta.first, nvl_clean_meta.second, + combined_rdma_head, num_combined_tokens, num_channels, + rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head, + rdma_buffer_ptr, + buffer_ptrs, task_fifo_ptrs, head, rank, num_ranks, + is_cached_dispatch, + port_channel_handles, memory_channel_handles); +} + +template +__device__ int combine_token(bool is_token_in_rank, int head_idx, + int lane_id, int hidden_int4, int num_topk, + int4* combined_row, float* combined_topk_weights, + int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) { + constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); + + // Broadcast current heads + // Lane `i` holds the head of rank `i` and `is_token_in_rank` + EP_STATIC_ASSERT(kMaxNumRanks <= 32, "Too many ranks"); + int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks]; + #pragma unroll + for (int i = 0; i < kNumRanks; ++ i) if (__shfl_sync(0xffffffff, is_token_in_rank, i)) { + slot_indices[num_topk_ranks] = __shfl_sync(0xffffffff, head_idx, i) % num_max_recv_tokens; + topk_ranks[num_topk_ranks ++] = i; + } + EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks); + + // Reduce data + #pragma unroll + for (int i = lane_id; i < hidden_int4; i += 32) { + // Read buffers + // TODO: maybe too many registers here + int4 recv_value_int4[kMaxNumRanks]; + #pragma unroll + for (int j = 0; j < num_topk_ranks; ++ j) + recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i); + + // Reduce all-to-all results + float values[kDtypePerInt4] = {0}; + #pragma unroll + for (int j = 0; j < num_topk_ranks; ++ j) { + auto recv_value_dtypes = reinterpret_cast(&recv_value_int4[j]); + #pragma unroll + for (int k = 0; k < kDtypePerInt4; ++ k) + values[k] += static_cast(recv_value_dtypes[k]); + } + + // Cast back to `dtype_t` and write + int4 out_int4; + auto out_dtypes = reinterpret_cast(&out_int4); + #pragma unroll + for (int j = 0; j < kDtypePerInt4; ++ j) + out_dtypes[j] = static_cast(values[j]); + st_na_global(combined_row + i, out_int4); + } + + // Reduce `topk_weights` + if (lane_id < num_topk) { + float value = 0; + #pragma unroll + for (int i = 0; i < num_topk_ranks; ++ i) + value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id); + st_na_global(combined_topk_weights + lane_id, value); + } + + // Return the minimum top-k rank + return topk_ranks[0]; +} + +template 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1, + int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder, + int kNumRDMAReceivers = kNumForwarders + NUM_MAX_NVL_PEERS> +__global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, 1) +combine(int4* combined_x, float* combined_topk_weights, + const bool* is_combined_token_in_rank, + const int4* x, const float* topk_weights, + const int* combined_rdma_head, const int* combined_nvl_head, + const SourceMeta* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, + int num_tokens, int num_combined_tokens, int hidden, int num_topk, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks, + mscclpp::PortChannelDeviceHandle *port_channel_handles, + mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { + enum class WarpRole { + kNVLSender, + kNVLAndRDMAForwarder, + kRDMAReceiver, + kCoordinator + }; + + const auto sm_id = static_cast(blockIdx.x); + const auto thread_id = static_cast(threadIdx.x), lane_id = get_lane_id(); + const auto num_channels = static_cast(gridDim.x) / 2, channel_id = sm_id / 2; + const bool is_rdma_receiver_sm = sm_id % 2 == 1; + + EP_DEVICE_ASSERT(num_topk <= 32); + EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0); + const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t)); + + // NOTES: we decouple a channel into 2 SMs + const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + auto role_meta = [=]() -> std::pair { + auto warp_id = thread_id / 32; + if (not is_rdma_receiver_sm) { + if (warp_id < NUM_MAX_NVL_PEERS) { + auto shuffled_warp_id = warp_id; + shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; + return {WarpRole::kNVLSender, shuffled_warp_id}; + } else if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { + auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS; + shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders; + return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id}; + } else { + return {WarpRole::kCoordinator, 0}; + } + } else { + if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { + return {WarpRole::kRDMAReceiver, warp_id}; + } else { + return {WarpRole::kCoordinator, 0}; + } + } + }(); + auto warp_role = role_meta.first; + auto warp_id = role_meta.second; + + EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + kNumForwarders + 1); + auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks; + + if (warp_role == WarpRole::kNVLSender) { + // NVL producers + const auto dst_nvl_rank = warp_id; + + // NVL layouts + // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources + auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank]; + auto nvl_channel_x = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); + auto nvl_channel_src_meta = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); + auto nvl_channel_topk_weights = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); + auto nvl_channel_head = AsymBuffer(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr); + auto nvl_channel_tail = AsymBuffer(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); + + // Get tasks for each RDMA lane + int token_start_idx = 0, token_end_idx = 0; + if (lane_id < kNumRDMARanks) { + int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id; + token_start_idx = gbl_channel_prefix_matrix[prefix_idx]; + token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1]; + } + __syncwarp(); + + // NOTES: here the cached value of each lane is only responsible for a single RDMA buffer + int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + + // Iterate over all tokens and send by chunks + while (true) { + // Exit if possible + if (__all_sync(0xffffffff, token_start_idx >= token_end_idx)) + break; + + // Decide next RDMA buffer to send + bool is_lane_ready = false; + auto start_time = clock64(); + while (true) { + int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx; + is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens; + if (__any_sync(0xffffffff, is_lane_ready)) + break; + + // Retry + if (lane_id < kNumRDMARanks and token_start_idx < token_end_idx) + cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { + printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n", + channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, ld_volatile_global(nvl_channel_head.buffer() + lane_id), cached_channel_tail_idx, + token_start_idx, token_end_idx); + trap(); + } + } + + // Sync token start index and count + for (int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++ current_rdma_idx) { + if (__shfl_sync(0xffffffff, (token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx)) + continue; + + // Sync token start index + auto token_idx = static_cast(__shfl_sync(0xffffffff, token_start_idx, current_rdma_idx)); + int num_tokens_in_chunk = __shfl_sync(0xffffffff, min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx); + + // Send by chunk + for (int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++ chunk_idx, ++ token_idx) { + // Get an empty slot + int dst_slot_idx = 0; + if (lane_id == current_rdma_idx) { + dst_slot_idx = (cached_channel_tail_idx ++) % num_max_nvl_chunked_recv_tokens_per_rdma; + dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx; + } + dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, current_rdma_idx); + + // Copy data + auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4; + auto shifted_x = x + token_idx * hidden_int4; + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global); + + // Copy source meta + if (lane_id == 0) + st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx)); + + // Copy `topk_weights` + if (lane_id < num_topk) + st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, ld_nc_global(topk_weights + token_idx * num_topk + lane_id)); + } + lane_id == current_rdma_idx ? (token_start_idx = static_cast(token_idx)) : 0; + } + + // Move queue tail + __syncwarp(); + if (lane_id < kNumRDMARanks and is_lane_ready) + st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx); + } + } else { + // Combiners and coordinators + // RDMA symmetric layout + auto hidden_bytes = hidden_int4 * sizeof(int4); + auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk); + auto rdma_channel_data = SymBuffer(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + + auto data_send_offset = sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * channel_id; + auto data_recv_offset = sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * (channel_id + num_channels); + auto head_offset = sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * num_channels * 2; + auto head_send_offset = head_offset + sizeof(uint64_t) * kNumRDMARanks * channel_id; + auto tail_offset = head_offset + sizeof(uint64_t) * kNumRDMARanks * num_channels; + auto tail_send_offset = tail_offset + sizeof(uint64_t) * kNumRDMARanks * channel_id; + + // NVL layouts + void* local_nvl_buffer = buffer_ptrs[nvl_rank]; + void* nvl_buffers[NUM_MAX_NVL_PEERS]; + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) + nvl_buffers[i] = buffer_ptrs[i]; + auto nvl_channel_x = AsymBuffer(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); + auto nvl_channel_src_meta = AsymBuffer(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); + auto nvl_channel_topk_weights = AsymBuffer(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); + auto nvl_channel_head = AsymBuffer(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer); + auto nvl_channel_tail = AsymBuffer(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); + + // Combiner warp synchronization + __shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS]; + __shared__ volatile bool forwarder_retired[kNumForwarders]; + __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks]; + __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers]; + auto sync_forwarder_smem = [=]() { asm volatile("bar.sync 0, %0;" :: "r"((kNumForwarders + 1) * 32)); }; + auto sync_rdma_receiver_smem = [=]() { asm volatile("bar.sync 1, %0;" :: "r"((kNumRDMAReceivers + 1) * 32)); }; + + if (warp_role == WarpRole::kNVLAndRDMAForwarder) { + // Receive from NVL ranks and forward to RDMA ranks + // NOTES: this part is using "large warps" for each RDMA ranks + const auto dst_rdma_rank = warp_id / kNumWarpsPerForwarder; + const auto sub_warp_id = warp_id % kNumWarpsPerForwarder; + auto send_buffer = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank); + auto sync_large_warp = [=]() { + if (kNumWarpsPerForwarder == 1) { + __syncwarp(); + } else { + asm volatile("bar.sync %0, %1;" :: "r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * 32)); + } + }; + EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, "Barriers are not enough"); + + // Advance to the corresponding NVL buffer + nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4); + nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma); + nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk); + nvl_channel_head.advance(dst_rdma_rank); + nvl_channel_tail.advance(dst_rdma_rank); + + // Clean shared memory and sync + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); + lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[warp_id][lane_id] = 0) : 0; + lane_id == 0 ? (forwarder_retired[warp_id] = false) : false; + sync_forwarder_smem(); + + // Get count and cached head + int cached_nvl_channel_tail_idx = 0; + int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id]; + int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]; + num_tokens_to_combine -= num_tokens_prefix; + num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; + combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS; + + // Iterate over all tokens and combine by chunks + for (int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) { + // Check destination queue emptiness, or wait a buffer to be released + auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine); + auto num_chunked_tokens = token_end_idx - token_start_idx; + auto start_time = clock64(); + while (sub_warp_id == 0 and lane_id == 0) { + // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens` + // Here, `token_start_idx` is the actual tail + int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)); + if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens) + break; + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n", + channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens); + trap(); + } + } + sync_large_warp(); + + // Combine and write to the RDMA buffer + for (int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) { + // Read expected head + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + int expected_head = -1; + if (lane_id < NUM_MAX_NVL_PEERS) + expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); + + // Wait lanes to be ready + start_time = clock64(); + while (cached_nvl_channel_tail_idx <= expected_head) { + cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id)); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) { + printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head); + trap(); + } + } + + // Combine current token + auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens; + void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token; + auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); }; + auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); }; + combine_token(expected_head >= 0, + expected_head, lane_id, + hidden_int4, num_topk, + reinterpret_cast(shifted), + reinterpret_cast(reinterpret_cast(shifted) + hidden_bytes + sizeof(SourceMeta)), + num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn); + + // Update head + if (lane_id < NUM_MAX_NVL_PEERS) + expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1) : (forwarder_nvl_head[warp_id][lane_id] = expected_head + 1); + } + sync_large_warp(); + + // Issue RDMA send + if (sub_warp_id == kNumWarpsPerForwarder - 1) { + if (lane_id == 0) { + if (dst_rdma_rank == rdma_rank) { + mscclpp::atomicFetchAdd(reinterpret_cast(rdma_channel_tail.buffer(rdma_rank)), (uint64_t)num_chunked_tokens, mscclpp::memoryOrderRelease); + } else { + auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; + const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token; + auto dst_offset = rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) + rdma_slot_idx * num_bytes_per_rdma_token + data_recv_offset; + auto src_offset = dst_rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) + rdma_slot_idx * num_bytes_per_rdma_token + data_send_offset; + auto port_channel_idx = kLowLatencyMode ? (channel_id * kNumRDMARanks + dst_rdma_rank) : (channel_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); + auto& handle = port_channel_handles[port_channel_idx]; + handle.put(dst_offset, src_offset, num_bytes_per_msg); + + // Remote atomic add on the peer's tail counter: +num_chunked_tokens. + handle.atomicAdd(rdma_rank * sizeof(uint64_t) + tail_send_offset, (int64_t)num_chunked_tokens); + } + } + __syncwarp(); + } + } + + // Retired + __syncwarp(); + if (lane_id == 0) + forwarder_retired[warp_id] = true; + } else if (warp_role == WarpRole::kRDMAReceiver) { + // Receive from RDMA ranks and write to the output tensor + // Clean shared memory and sync + EP_DEVICE_ASSERT(kNumRDMARanks <= 32); + lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[warp_id][lane_id] = 0) : 0; + lane_id == 0 ? (rdma_receiver_retired[warp_id] = false) : 0; + sync_rdma_receiver_smem(); + + // The same tokens as the dispatch process + int token_start_idx, token_end_idx; + get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Iterate over all tokens and combine + int cached_channel_tail_idx = 0; + for (int64_t token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) { + // Read expected head + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + int expected_head = -1; + if (lane_id < kNumRDMARanks) { + expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id); + (expected_head < 0) ? (rdma_receiver_rdma_head[warp_id][lane_id] = -expected_head - 1) : (rdma_receiver_rdma_head[warp_id][lane_id] = expected_head); + } + + // Wait lanes to be ready + auto start_time = clock64(); + while (cached_channel_tail_idx <= expected_head) { + cached_channel_tail_idx = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id))); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head); + trap(); + } + } + __syncwarp(); + + // Combine current token + auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);}; + auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);}; + combine_token(expected_head >= 0, + expected_head, lane_id, + hidden_int4, num_topk, + combined_x + token_idx * hidden_int4, + combined_topk_weights + token_idx * num_topk, + num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn); + } + + // Retired + __syncwarp(); + if (lane_id == 0) + rdma_receiver_retired[warp_id] = true; + } else { + // Coordinator + // Sync shared memory status + is_rdma_receiver_sm ? sync_rdma_receiver_smem() : sync_forwarder_smem(); + const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks; + + int last_rdma_head = 0; + int last_nvl_head[kNumRDMARanks] = {0}; + int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0; + int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0; + EP_STATIC_ASSERT(kNumCombineForwarderWarps <= 32, "Invalid number of forwarder warps"); + while (true) { + // Retired + if (is_rdma_receiver_sm and __all_sync(0xffffffff, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id])) + break; + if (not is_rdma_receiver_sm and __all_sync(0xffffffff, lane_id >= kNumForwarders or forwarder_retired[lane_id])) + break; + + // Find minimum head for RDMA ranks + if (is_rdma_receiver_sm) { + int min_head = std::numeric_limits::max(); + #pragma unroll + for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i]) + min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); + if (min_head != std::numeric_limits::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { + if (dst_rdma_rank == rdma_rank) { + mscclpp::atomicFetchAdd(static_cast(rdma_channel_head.buffer(rdma_rank)), (uint64_t)(min_head - last_rdma_head), mscclpp::memoryOrderRelease); + } else { + auto dst_offset = rdma_rank * sizeof(uint64_t) + head_send_offset; + auto port_channel_idx = kLowLatencyMode ? (channel_id * kNumRDMARanks + dst_rdma_rank) : (channel_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); + auto& handle = port_channel_handles[port_channel_idx]; + // Remote atomic add on the peer's head counter. + handle.atomicAdd(dst_offset, (int64_t)(min_head - last_rdma_head)); + } + last_rdma_head = min_head; + } + } else { + // Find minimum head for NVL ranks + #pragma unroll + for (int i = 0; i < kNumRDMARanks; ++ i) { + int min_head = std::numeric_limits::max(); + #pragma unroll + for (int j = 0; j < num_warps_per_rdma_rank; ++ j) if (not forwarder_retired[i * num_warps_per_rdma_rank + j]) + min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]); + if (min_head != std::numeric_limits::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) + st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head); + } + } + + // Nanosleep and let other warps work + __nanosleep(NUM_WAIT_NANOSECONDS); + } + } + } +} + +void combine(cudaDataType_t type, + void* combined_x, float* combined_topk_weights, + const bool* is_combined_token_in_rank, + const void* x, const float* topk_weights, + const int* combined_rdma_head, const int* combined_nvl_head, + const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, + int num_tokens, int num_combined_tokens, int hidden, int num_topk, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode, + mscclpp::PortChannelDeviceHandle *port_channel_handles, + mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { + constexpr int kNumCombineForwarderWarps = 16; + +#define COMBINE_LAUNCH_CASE(num_rdma_ranks) { \ + auto combine_func = low_latency_mode ? \ + combine : combine; \ + LAUNCH_KERNEL(&cfg, combine_func, \ + reinterpret_cast(combined_x), combined_topk_weights, is_combined_token_in_rank, \ + reinterpret_cast(x), topk_weights, \ + combined_rdma_head, combined_nvl_head, \ + reinterpret_cast(src_meta), rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \ + num_tokens, num_combined_tokens, hidden, num_topk, \ + rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \ + buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \ + rank, num_ranks, \ + port_channel_handles, memory_channel_handles); } break + + int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1); + int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder; + EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0); + EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); + EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens)); + EP_HOST_ASSERT(type == CUDA_R_16BF); + + SETUP_LAUNCH_CONFIG(num_channels * 2, (NUM_MAX_NVL_PEERS + num_forwarder_warps + 1) * 32, stream); + SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE); +#undef COMBINE_LAUNCH_CASE +} + +} // namespace internode + +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/kernels/intranode_kernel.cu b/src/ext/ep/kernels/intranode_kernel.cu new file mode 100644 index 00000000..b3f23442 --- /dev/null +++ b/src/ext/ep/kernels/intranode_kernel.cu @@ -0,0 +1,826 @@ +#include "configs.cuh" +#include "buffer.cuh" +#include "exception.cuh" +#include "launch.cuh" +#include "utils.cuh" + +namespace mscclpp { namespace ep { + +namespace intranode { + +template +__global__ void +notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, + const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, + int num_tokens, int num_channels, const bool* is_token_in_rank, int* channel_prefix_matrix, + int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, + void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); + auto lane_id = thread_id % 32, warp_id = thread_id / 32, num_warps = num_threads / 32; + + if (sm_id == 0) { + // Barrier first + barrier_device(task_fifo_ptrs, head, rank); + move_fifo_slots(head); + __syncthreads(); + + int *per_rank_buffer, *per_expert_buffer; + if (thread_id < kNumRanks) { + per_rank_buffer = reinterpret_cast(buffer_ptrs[thread_id]); + per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks; + } + + // After this loop: + // - `per_rank_buffer[rank][i, j]` means the number of tokens from rank i to rank j + // - `per_expert_buffer[rank][i, j]` means the number of tokens from rank i to local expert j + int num_experts_per_rank = num_experts / kNumRanks; + if (thread_id < kNumRanks) { + #pragma unroll + for (int i = 0; i < kNumRanks; ++ i) + per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i]; + #pragma unroll + for (int i = 0; i < num_experts_per_rank; ++ i) + per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i]; + } + __syncthreads(); + + // Wait for all ranks to be finished + barrier_device(task_fifo_ptrs, head, rank); + move_fifo_slots(head); + __syncthreads(); + + // Sum per-rank counts and return to CPU + // Also pre-compute the prefix sum for data sending + auto local_per_rank_buffer = reinterpret_cast(buffer_ptrs[rank]); + if (thread_id < kNumRanks) { + #pragma unroll + for (int i = 1; i < kNumRanks; ++ i) + local_per_rank_buffer[i * kNumRanks + thread_id] += local_per_rank_buffer[(i - 1) * kNumRanks + thread_id]; + if (thread_id == rank) + *moe_recv_counter_mapped = local_per_rank_buffer[(kNumRanks - 1) * kNumRanks + rank]; + } + + // Sum per-experts counts and return to CPU + auto local_per_expert_buffer = local_per_rank_buffer + kNumRanks * kNumRanks; + if (thread_id < num_experts_per_rank) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumRanks; ++ i) + sum += local_per_expert_buffer[i * num_experts_per_rank + thread_id]; + sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; + moe_recv_expert_counter_mapped[thread_id] = sum; + } + __syncthreads(); + + // Copy rank size prefix matrix to another tensor + #pragma unroll + for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads) + rank_prefix_matrix_copy[i] = local_per_rank_buffer[i]; + + // Extra memset for later communication queue + #pragma unroll + for (int i = thread_id; i < num_memset_int; i += num_threads) + local_per_expert_buffer[i] = 0; + + // Barrier + memory_fence(); + __syncthreads(); + barrier_device(task_fifo_ptrs, head, rank); + } else { + int dst_rank = sm_id - 1; + for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { + int token_start_idx, token_end_idx; + get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Iterate over tokens + int count = 0; + for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) + count += is_token_in_rank[i * kNumRanks + dst_rank]; + count = warp_reduce_sum(count); + if (lane_id == 0) + channel_prefix_matrix[dst_rank * num_channels + channel_id] = count; + } + __syncthreads(); + + // Pre-compute prefix sum for all channels + if (thread_id == 0) { + #pragma unroll + for (int i = 1; i < num_channels; ++ i) + channel_prefix_matrix[dst_rank * num_channels + i] += channel_prefix_matrix[dst_rank * num_channels + i - 1]; + } + } +} + +void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, + const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, + int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix, + int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, + void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, + cudaStream_t stream, int num_channels) { +#define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \ + LAUNCH_KERNEL(&cfg, notify_dispatch, \ + num_tokens_per_rank, moe_recv_counter_mapped, \ + num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \ + num_tokens, num_channels, is_token_in_rank, channel_prefix_matrix, \ + rank_prefix_matrix_copy, num_memset_int, expert_alignment, \ + buffer_ptrs, task_fifo_ptrs, head, rank); \ + break + + constexpr int kNumThreads = 128; + EP_HOST_ASSERT(num_experts % num_ranks == 0); + EP_HOST_ASSERT(num_experts / num_ranks <= kNumThreads and num_ranks <= kNumThreads); + + SETUP_LAUNCH_CONFIG(1 + num_ranks, kNumThreads, stream); + SWITCH_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); +#undef NOTIFY_DISPATCH_LAUNCH_CASE +} + +template +__global__ void +cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, + void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank) { + // A simplified version for cached handles + barrier_device(task_fifo_ptrs, head, rank); + move_fifo_slots(head); + __syncthreads(); + + // Copy and clean + auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); + auto ptr = reinterpret_cast(buffer_ptrs[rank]); + #pragma unroll + for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads) + ptr[i] = rank_prefix_matrix[i]; + #pragma unroll + for (int i = thread_id; i < num_memset_int; i += num_threads) + ptr[kNumRanks * kNumRanks + i] = 0; + memory_fence(); + __syncthreads(); + + // Barrier after cleaning + barrier_device(task_fifo_ptrs, head, rank); +} + +void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, + void** buffer_ptrs, int** task_fifo_ptrs, + int head, int rank, int num_ranks, cudaStream_t stream) { +#define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \ + LAUNCH_KERNEL(&cfg, cached_notify_dispatch, \ + rank_prefix_matrix, num_memset_int, buffer_ptrs, task_fifo_ptrs, head, rank); \ + break + + SETUP_LAUNCH_CONFIG(1, 128, stream); + SWITCH_RANKS(CACHED_NOTIFY_DISPATCH_LAUNCH_CASE); +#undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE +} + +template +__global__ void __launch_bounds__(kNumThreads, 1) +dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, + int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, + const bool* is_token_in_rank, const int* channel_prefix_matrix, + int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, + void **buffer_ptrs, int rank, + int num_max_send_tokens, int num_recv_buffer_tokens) { + const auto num_sms = static_cast(gridDim.x), sm_id = static_cast(blockIdx.x); + const auto thread_id = static_cast(threadIdx.x); + const bool is_sender = sm_id % 2 == 0; + EP_DEVICE_ASSERT(num_sms % 2 == 0); + + // Several warps are response for a single rank + const auto num_threads_per_rank = kNumThreads / kNumRanks; + const auto num_channels = num_sms / 2; + const auto responsible_rank = (static_cast(thread_id)) / num_threads_per_rank; + // Even-numbered blocks for sending, odd-numbered blocks for receiving. + const auto responsible_channel = sm_id / 2; + + int num_experts_per_rank = num_experts / kNumRanks; + EP_DEVICE_ASSERT(num_experts_per_rank > 0 or num_topk == 0); + EP_DEVICE_ASSERT(num_topk <= 32); + EP_DEVICE_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr)); + EP_DEVICE_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr)); + + // Calculate pointers by the specific layout + // `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int) + auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int)); + int target_rank = is_sender ? rank : responsible_rank; + auto num_channels_total = num_channels * kNumRanks; + auto channel_rank_offset = responsible_channel * kNumRanks + target_rank; + + // Channel buffer metadata + // Senders are responsible for tails, and receivers are responsible for heads + // Stored on the receiver side + // The retired signals are actually boolean flags, but to align with 16 bytes, we make it `int64_t` + // `start_offset`: kNumChannels * kNumRanks * sizeof(int) + // `end_offset`: kNumChannels * kNumRanks * sizeof(int) + // `head_idx`: kNumChannels * kNumRanks * sizeof(int) + // `tail_idx`: kNumChannels * kNumRanks * sizeof(int) + auto channel_start_offset = Buffer(ptr, num_channels_total, channel_rank_offset); + auto channel_end_offset = Buffer(ptr, num_channels_total, channel_rank_offset); + auto channel_head_idx = Buffer(ptr, num_channels_total, channel_rank_offset); + auto channel_tail_idx = Buffer(ptr, num_channels_total, channel_rank_offset); + + // Channel data buffers, stored on the receiver side + // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4) + // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int) + // `topk_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(int64_t) + // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float) + // `x_scales_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_scales * sizeof(float) + auto channel_x_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4); + auto channel_src_idx_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens); + auto channel_topk_idx_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); + auto channel_topk_weights_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); + auto channel_x_scales_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales); + + if (is_sender) { + // Workers for sending + constexpr int num_send_warps = kNumThreads / 32; + constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks; + const auto send_thread_id = thread_id; + const auto send_lane_id = send_thread_id % 32; + const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32; + EP_DEVICE_ASSERT(kNumRanks <= 32); + EP_DEVICE_ASSERT(num_send_warps % kNumRanks == 0); + + // Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2 + // NOTES: this is for distinguishing zero tokens + if (send_lane_id == 0 and send_warp_id_in_rank == 0) { + int value = responsible_channel > 0 ? channel_prefix_matrix[responsible_rank * num_channels + responsible_channel - 1] : 0; + st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1); + value = channel_prefix_matrix[responsible_rank * num_channels + responsible_channel]; + st_relaxed_sys_global(channel_end_offset.buffer(), -value - 1); + } + __syncwarp(); + + // Get tasks + int token_start_idx, token_end_idx; + get_channel_task_range(num_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx); + + // Iterate over all tokens and send by chunks + int cached_channel_tail_idx = 0; + for (int64_t token_idx = token_start_idx; token_idx < token_end_idx; ) { + // Check destination queue emptiness, or wait a buffer to be released (rare cases) + // NOTES: the head index received by different warps may not be the same + auto start_time = clock64(); + while (send_lane_id == 0) { + // NOTES: we only consider the worst case, because counting the real numbers are time-consuming + int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer()); + if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens) + break; + + // Rare cases to loop again + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP timeout for dispatch senders, rank %d, responsible_channel = %d\n", rank, responsible_channel); + trap(); + } + } + __syncwarp(); + + int chunk_token_idx = 0; + while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) { + // NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data + if (send_lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank) + send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1; + + // Skip if not selected + if (not is_token_in_rank[token_idx * kNumRanks + responsible_rank]) { + token_idx ++; + continue; + } + + // Get an empty slot + int dst_slot_idx = (cached_channel_tail_idx ++) % num_recv_buffer_tokens; + if (cached_channel_tail_idx % num_send_warps_per_rank == send_warp_id_in_rank) { + // Copy data + auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4; + auto shifted_x = x + token_idx * hidden_int4; + UNROLLED_WARP_COPY(5, send_lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x, + __ldg, st_na_global); + + // Copy source index + if (send_lane_id == 0) + channel_src_idx_buffers[dst_slot_idx] = static_cast(token_idx); + + // Copy `topk_idx` and `topk_weights` with transformed index + if (send_lane_id < num_topk) { + // Top-k index + int recv_expert_begin = responsible_rank * num_experts_per_rank, recv_expert_end = (responsible_rank + 1) * num_experts_per_rank; + auto idx_value = __ldg(topk_idx + token_idx * num_topk + send_lane_id); + idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1; + channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = idx_value; + + // Top-k weights + auto weight_value = __ldg(topk_weights + token_idx * num_topk + send_lane_id); + weight_value = (idx_value >= 0) ? weight_value : 0.0f; + channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = weight_value; + } + + // Copy `x_scales` + #pragma unroll + for (int i = send_lane_id; i < num_scales; i += 32) + channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i); + } + + // Move token index + chunk_token_idx ++, token_idx ++; + } + + // Move tail index + // NOTES: here all warps should share the same new tail + asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank)); + if (send_warp_id_in_rank == 0 and send_lane_id == 0) + st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx); + } + } else { + // Workers for receiving and copying into buffer + constexpr int num_recv_warps = kNumThreads / 32; + constexpr int num_recv_warps_per_rank = num_recv_warps / kNumRanks; + const auto recv_thread_id = thread_id; + const auto recv_lane_id = recv_thread_id % 32; + const auto recv_thread_id_in_rank = recv_thread_id % num_threads_per_rank; + const auto recv_warp_id_in_rank = recv_thread_id_in_rank / 32; + EP_DEVICE_ASSERT(kNumRanks <= 32); + EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps % kNumRanks == 0); + + // Calculate offset first + auto rank_prefix_matrix = reinterpret_cast(buffer_ptrs[rank]); + int rank_offset = responsible_rank > 0 ? rank_prefix_matrix[(responsible_rank - 1) * kNumRanks + rank] : 0; + + // Receive channel offset + int total_offset, num_tokens_to_recv; + while (recv_lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0); + while (recv_lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0); + if (recv_lane_id == 0) { + total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1; + if (recv_warp_id_in_rank == 0) + recv_channel_offset[responsible_rank * num_channels + responsible_channel] = total_offset; + num_tokens_to_recv -= total_offset; + } + total_offset = __shfl_sync(0xffffffff, total_offset, 0); + total_offset += rank_offset; + num_tokens_to_recv = __shfl_sync(0xffffffff, num_tokens_to_recv, 0); + + // Shared tail indices for different warps + __shared__ volatile int shared_channel_tail_idx[kNumRanks]; + + auto start_time = clock64(); + int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; + while (num_tokens_to_recv > 0) { + // NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are same + while (recv_thread_id_in_rank == 0) { + cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer());; + + // Ready to copy + if (cached_channel_head_idx != cached_channel_tail_idx) { + shared_channel_tail_idx[responsible_rank] = cached_channel_tail_idx; + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP timeout for dispatch receivers, rank %d, responsible_channel = %d, tokens remained: %d\n", rank, responsible_channel, num_tokens_to_recv); + trap(); + } + } + + // Synchronize queue tail + asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank)); + cached_channel_tail_idx = shared_channel_tail_idx[responsible_rank]; + + // Copy data + int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; + for (int chunk_idx = recv_warp_id_in_rank; chunk_idx < num_recv_tokens; chunk_idx += num_recv_warps_per_rank) { + int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; + auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4; + auto shifted_recv_x_int4 = recv_x + static_cast(total_offset + chunk_idx) * hidden_int4; + UNROLLED_WARP_COPY(5, recv_lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4, + ld_nc_global, st_na_global); + } + + // Copy `src_idx` + #pragma unroll 4 + for (int chunk_idx = cached_channel_head_idx + recv_thread_id_in_rank; chunk_idx < cached_channel_tail_idx; chunk_idx += 32 * num_recv_warps_per_rank) + recv_src_idx[total_offset + chunk_idx - cached_channel_head_idx] = ld_nc_global(channel_src_idx_buffers.buffer() + chunk_idx % num_recv_buffer_tokens); + + // Copy `topk_idx` and `topk_weights` + #pragma unroll 4 + for (int idx = recv_thread_id_in_rank; idx < num_recv_tokens * num_topk; idx += 32 * num_recv_warps_per_rank) { + int chunk_idx = idx / num_topk, token_topk_idx = idx % num_topk; + int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; + auto recv_idx = static_cast(total_offset + chunk_idx) * num_topk + token_topk_idx; + auto buffer_idx = token_idx_in_buffer * num_topk + token_topk_idx; + recv_topk_idx[recv_idx] = ld_nc_global(channel_topk_idx_buffers.buffer() + buffer_idx); + recv_topk_weights[recv_idx] = ld_nc_global(channel_topk_weights_buffers.buffer() + buffer_idx); + } + + // Copy `x_scales` + #pragma unroll 4 + for (int i = recv_thread_id_in_rank; i < num_recv_tokens * num_scales; i += 32 * num_recv_warps_per_rank) { + int chunk_idx = i / num_scales, scales_idx = i % num_scales; + int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; + recv_x_scales[static_cast(total_offset + chunk_idx) * num_scales + scales_idx] = + ld_nc_global(channel_x_scales_buffers.buffer() + token_idx_in_buffer * num_scales + scales_idx); + } + + // Move queue + cached_channel_head_idx += num_recv_tokens; + total_offset += num_recv_tokens; + asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank)); + if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and recv_lane_id == 0) + st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx); + + // Exit + num_tokens_to_recv -= num_recv_tokens; + } + } +} + +void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, + int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, + const bool* is_token_in_rank, const int* channel_prefix_matrix, + int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, + void** buffer_ptrs, int rank, int num_ranks, + cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) { + constexpr int kNumThreads = 512; + +#define DISPATCH_LAUNCH_CASE(ranks) \ +LAUNCH_KERNEL(&cfg, dispatch, \ + reinterpret_cast(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \ + send_head, reinterpret_cast(x), x_scales, topk_idx, topk_weights, \ + is_token_in_rank, channel_prefix_matrix, \ + num_tokens, hidden_int4, num_topk, num_experts, num_scales, \ + buffer_ptrs, rank, \ + num_max_send_tokens, num_recv_buffer_tokens); \ +break + + // Even-numbered blocks for sending, odd-numbered blocks for receiving. + EP_HOST_ASSERT(num_sms % 2 == 0); + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + SWITCH_RANKS(DISPATCH_LAUNCH_CASE); +#undef DISPATCH_LAUNCH_CASE +} + +template +__global__ void +cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, + int** task_fifo_ptrs, int head, int rank) { + const auto sm_id = static_cast(blockIdx.x); + if (sm_id == 0) { + // Barrier before cleaning + barrier_device(task_fifo_ptrs, head, rank); + move_fifo_slots(head); + __syncthreads(); + + // Clean + auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); + auto ptr = reinterpret_cast(buffer_ptrs[rank]); + #pragma unroll + for (int i = thread_id; i < num_memset_int; i += num_threads) + ptr[i] = 0; + memory_fence(); + __syncthreads(); + + // Barrier after cleaning + barrier_device(task_fifo_ptrs, head, rank); + } else { + const auto channel_id = sm_id - 1; + const auto thread_id = static_cast(threadIdx.x); + const auto rank_id = thread_id / 32; + const auto lane_id = thread_id % 32; + if (rank_id >= kNumRanks) + return; + + int token_start_idx, token_end_idx; + get_channel_task_range(num_recv_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // NOTES: `1 << 25` is a heuristic large number + int last_head = 1 << 25; + #pragma unroll + for (int token_idx_tail = token_end_idx - 1; token_idx_tail >= token_start_idx; token_idx_tail -= 32) { + int token_idx = token_idx_tail - lane_id, expected_head = 0; + auto current_head = (token_idx >= token_start_idx) ? __ldg(send_head + token_idx * kNumRanks + rank_id) : -1; + for (int i = 0; i < min(32, token_idx_tail - token_start_idx + 1); ++ i) { + head = __shfl_sync(0xffffffff, current_head, i); + if (head < 0) { + if (lane_id == i) + expected_head = -last_head - 1; + } else { + last_head = head; + } + } + if (current_head < 0 and token_idx >= token_start_idx) + send_head[token_idx * kNumRanks + rank_id] = expected_head; + } + } +} + +void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, + int num_recv_tokens, int num_memset_int, + int** task_fifo_ptrs, int head, int rank, int num_ranks, + cudaStream_t stream) { +#define CACHED_NOTIFY_COMBINE(ranks) \ + LAUNCH_KERNEL(&cfg, cached_notify_combine, \ + buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, task_fifo_ptrs, head, rank); \ + break + + const int num_threads = std::max(128, 32 * num_ranks); + EP_HOST_ASSERT(num_ranks <= num_threads); + EP_HOST_ASSERT(num_threads <= 1024); + EP_HOST_ASSERT(1 + num_channels <= num_channels * 2); + SETUP_LAUNCH_CONFIG(1 + num_channels, num_threads, stream); + SWITCH_RANKS(CACHED_NOTIFY_COMBINE); +#undef CACHED_NOTIFY_COMBINE +} + +template +__global__ void __launch_bounds__(kNumThreads, 1) +combine(dtype_t* recv_x, float* recv_topk_weights, + const dtype_t* x, const float* topk_weights, + const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, + int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, + void **buffer_ptrs, int rank, + int num_max_send_tokens, int num_recv_buffer_tokens) { + const auto num_sms = static_cast(gridDim.x); + const auto thread_id = static_cast(threadIdx.x); + const auto sm_id = static_cast(blockIdx.x); + const auto num_channels = num_sms / 2; + const bool is_sender = sm_id % 2 == 0; + const int responsible_channel = sm_id / 2; + EP_DEVICE_ASSERT(num_topk <= 32); + + constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); + int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4); + auto x_int4 = reinterpret_cast(x); + auto recv_int4 = reinterpret_cast(recv_x); + + if (is_sender) { + // Workers for sending + // Several warps are responsible for a single rank + constexpr int num_send_warps = kNumThreads / 32; + constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks; + const auto num_threads_per_rank = num_send_warps_per_rank * 32; + const auto send_thread_id = thread_id; + const auto send_lane_id = send_thread_id % 32; + const auto send_rank_id = thread_id / num_threads_per_rank; + const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32; + + // Calculate pointers by the specific layout + auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[send_rank_id])); + auto num_channels_total = num_channels * kNumRanks; + auto channel_rank_offset = responsible_channel * kNumRanks + rank; + + // Channel meta data + // `head_idx`: kNumChannels * kNumRanks * sizeof(int) + // `tail_idx`: kNumChannels * kNumRanks * sizeof(int) + // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4) + // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int) + // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float) + auto channel_head_idx = Buffer(ptr, num_channels_total, channel_rank_offset); + auto channel_tail_idx = Buffer(ptr, num_channels_total, channel_rank_offset); + auto channel_x_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4); + auto channel_src_idx_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens); + auto channel_topk_weights_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); + + // Get tasks + // NOTES: `channel_offset` is already shifted + int rank_offset = send_rank_id > 0 ? rank_prefix_matrix[(send_rank_id - 1) * kNumRanks + rank] : 0; + int num_rank_tokens = rank_prefix_matrix[send_rank_id * kNumRanks + rank] - rank_offset; + int channel_offset = channel_prefix_matrix[send_rank_id * num_channels + responsible_channel]; + int num_channel_tokens = (responsible_channel == num_channels - 1 ? num_rank_tokens : channel_prefix_matrix[send_rank_id * num_channels + responsible_channel + 1]) - channel_offset; + int token_start_idx = rank_offset + channel_offset, token_end_idx = rank_offset + channel_offset + num_channel_tokens; + + // Iterate over all tokens and send by chunks + int current_channel_tail_idx = 0; + for (int64_t token_idx = token_start_idx; token_idx < token_end_idx; ) { + // Check destination queue emptiness, or wait a buffer to be released (rare cases) + auto start_time = clock64(); + int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast(token_idx)); + while (send_lane_id == 0) { + // NOTES: we only consider the worst case, because counting the real numbers are time-consuming + int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer()); + if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens) + break; + + // Rare cases to loop again + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP timeout for combine senders, rank %d, responsible_channel = %d\n", rank, responsible_channel); + trap(); + } + } + __syncwarp(); + + // Send by chunk + #pragma unroll + for (int i = send_warp_id_in_rank; i < num_round_tokens; i += num_send_warps_per_rank) { + // Get an empty slot + int dst_slot_idx = (current_channel_tail_idx + i) % num_recv_buffer_tokens; + + // Copy data + auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4; + auto shifted_x = x_int4 + (token_idx + i) * hidden_int4; + UNROLLED_WARP_COPY(4, send_lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global); + + // Send source index + if (send_lane_id == 0) + channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i); + + // Send `topk_weights` + if (num_topk > 0 and send_lane_id < num_topk) + channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + send_lane_id); + } + token_idx += num_round_tokens; + current_channel_tail_idx += num_round_tokens; + + // Move tail index + asm volatile("bar.sync %0, %1;" :: "r"(send_rank_id), "r"(num_threads_per_rank)); + if (send_lane_id == 0 and send_warp_id_in_rank == 0) + st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx); + } + } else { + // Workers for receiving + // One warp for moving the queue head, others for reduction + constexpr int num_recv_warps = kNumThreads / 32; + const auto recv_warp_id = thread_id / 32; + const auto recv_lane_id = thread_id % 32; + EP_DEVICE_ASSERT(kNumRanks <= 32 and kNumThreads > 32); + EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % 32 == 0); + + // Shared head, tail and retired flags for receiver warps + __shared__ volatile int warp_channel_head_idx[num_recv_warps][kNumRanks]; + __shared__ volatile int channel_tail_idx[kNumRanks]; + __shared__ volatile bool warp_retired[num_recv_warps]; + if (thread_id < num_recv_warps) + warp_retired[thread_id] = false; + if (recv_lane_id < kNumRanks) + warp_channel_head_idx[recv_warp_id][recv_lane_id] = 0; + if (thread_id < kNumRanks) + channel_tail_idx[thread_id] = 0; + asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads)); + + if (thread_id < 32) { + int* channel_head_idx_ptr = reinterpret_cast(buffer_ptrs[rank]) + responsible_channel * kNumRanks + recv_lane_id; + int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks; + + // Queue head updater + int last_head = 0; + while (recv_lane_id < kNumRanks) { + // Check retired + bool retired = true; + #pragma unroll + for (int i = 1; i < num_recv_warps; ++ i) + retired = retired and warp_retired[i]; + if (retired) + break; + + // Update queue tail + channel_tail_idx[recv_lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr); + + // Update minimum head + int min_head = std::numeric_limits::max(); + #pragma unroll + for (int i = 1; i < num_recv_warps; ++ i) if (not warp_retired[i]) + min_head = min(min_head, warp_channel_head_idx[i][recv_lane_id]); + if (min_head != std::numeric_limits::max() and min_head > last_head) + st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head); + } + } else { + // Receivers + // Channel metadata + // All lanes will use data buffer, but only rank lane will use `head/tail/src_idx` + Buffer channel_x_buffers[kNumRanks]; + Buffer channel_topk_weights_buffers[kNumRanks]; + + // Calculate pointers by the specific layout + #pragma unroll + for (int i = 0; i < kNumRanks; ++ i) { + auto channel_rank_offset = responsible_channel * kNumRanks + i; + auto num_channels_total = num_channels * kNumRanks; + // `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int) + auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int)); + + // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4) + channel_x_buffers[i] = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4); + + // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int) + ptr = reinterpret_cast(reinterpret_cast(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int)); + + // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float) + channel_topk_weights_buffers[i] = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); + } + + // The same tokens as the dispatch process + int token_start_idx, token_end_idx; + get_channel_task_range(num_recv_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx); + + // Iterate over all tokens and combine + for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) { + // Read expected head + int expected_head = -1; + if (recv_lane_id < kNumRanks) + expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id); + + auto start_time = clock64(); + while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) { + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head); + trap(); + } + } + __syncwarp(); + + // Broadcast current heads + int num_topk_ranks = 0, topk_ranks[kNumRanks], slot_indices[kNumRanks]; + #pragma unroll + for (int i = 0; i < kNumRanks; ++ i) { + auto expected_head_i = __shfl_sync(0xffffffff, expected_head, i); + if (expected_head_i >= 0) { + slot_indices[num_topk_ranks] = expected_head_i % num_recv_buffer_tokens; + topk_ranks[num_topk_ranks ++] = i; + } + } + + // Reduce data + #pragma unroll + for (int i = recv_lane_id; i < hidden_int4; i += 32) { + // Read buffers + int4 recv_value_int4[kNumRanks]; + #pragma unroll + for (int j = 0; j < num_topk_ranks; ++ j) + recv_value_int4[j] = ld_nc_global(channel_x_buffers[topk_ranks[j]].buffer() + slot_indices[j] * hidden_int4 + i); + + // Reduce all-to-all results + float values[kDtypePerInt4] = {0}; + #pragma unroll + for (int j = 0; j < num_topk_ranks; ++ j) { + auto recv_value_dtypes = reinterpret_cast(&recv_value_int4[j]); + #pragma unroll + for (int k = 0; k < kDtypePerInt4; ++ k) + values[k] += static_cast(recv_value_dtypes[k]); + } + + // Cast back to `dtype_t` and write + int4 out_int4; + auto out_dtypes = reinterpret_cast(&out_int4); + #pragma unroll + for (int j = 0; j < kDtypePerInt4; ++ j) + out_dtypes[j] = static_cast(values[j]); + recv_int4[token_idx * hidden_int4 + i] = out_int4; + } + + // Reduce `topk_weights` + if (recv_lane_id < num_topk) { + float value = 0; + #pragma unroll + for (int i = 0; i < num_topk_ranks; ++ i) + value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id); + recv_topk_weights[token_idx * num_topk + recv_lane_id] = value; + } + + // Update head + if (recv_lane_id < kNumRanks) + warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1; + } + + // Retired + __syncwarp(); + if (recv_lane_id == 0) + warp_retired[recv_warp_id] = true; + } + } +} + +void combine(cudaDataType_t type, + void* recv_x, float* recv_topk_weights, + const void* x, const float* topk_weights, + const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, + int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, + void** buffer_ptrs, int rank, int num_ranks, + cudaStream_t stream, int num_sms, + int num_max_send_tokens, int num_recv_buffer_tokens) { + constexpr int kNumThreads = 768; + +#define COMBINE_LAUNCH_CASE(dtype, ranks) \ + LAUNCH_KERNEL(&cfg, (combine), \ + reinterpret_cast(recv_x), recv_topk_weights, \ + reinterpret_cast(x), topk_weights, \ + src_idx, rank_prefix_matrix, channel_prefix_matrix, \ + send_head, num_tokens, num_recv_tokens, hidden, num_topk, \ + buffer_ptrs, rank, \ + num_max_send_tokens, num_recv_buffer_tokens); \ + break +#define COMBINE_DTYPE_LAUNCH_CASE(dtype) SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); break + + // Even-numbered blocks for sending, odd-numbered blocks for receiving + EP_HOST_ASSERT(num_sms % 2 == 0); + EP_HOST_ASSERT(kNumThreads >= num_ranks * 32); + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + SWITCH_TYPES(COMBINE_DTYPE_LAUNCH_CASE); +#undef COMBINE_DTYPE_LAUNCH_CASE +#undef COMBINE_LAUNCH_CASE +} + +} // namespace intranode + +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/kernels/launch.cuh b/src/ext/ep/kernels/launch.cuh new file mode 100644 index 00000000..3d7574ce --- /dev/null +++ b/src/ext/ep/kernels/launch.cuh @@ -0,0 +1,61 @@ +#pragma once + +#include "configs.cuh" + +#ifndef SETUP_LAUNCH_CONFIG +#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \ + cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \ + cudaLaunchAttribute attr[1]; \ + attr[0].id = cudaLaunchAttributeCooperative; \ + attr[0].val.cooperative = 1; \ + cfg.attrs = attr; \ + cfg.numAttrs = 1 +#endif + +#ifndef LAUNCH_KERNEL +#define LAUNCH_KERNEL(config, kernel, ...) CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__)) +#endif + +#define SWITCH_RANKS(case_macro) \ + switch (num_ranks) { \ + case 2: case_macro(2); \ + case 4: case_macro(4); \ + case 8: case_macro(8); \ + default: EP_HOST_ASSERT(false and "Unsupported ranks"); \ + } while (false) + +#define SWITCH_RDMA_RANKS(case_macro) \ + switch (num_ranks / NUM_MAX_NVL_PEERS) { \ + case 2: case_macro(2); \ + case 3: case_macro(3); \ + case 4: case_macro(4); \ + case 8: case_macro(8); \ + case 16: case_macro(16); \ + case 18: case_macro(18); \ + case 20: case_macro(20); \ + default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \ + } while (false) + +#define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro) \ + switch (num_ranks) { \ + case 2: case_macro(dtype, 2); \ + case 4: case_macro(dtype, 4); \ + case 8: case_macro(dtype, 8); \ + default: EP_HOST_ASSERT(false && "Unsupported ranks"); \ + } while (false) + +#define SWITCH_TYPES(case_macro) \ + switch (type) { \ + case CUDA_R_16BF: case_macro(nv_bfloat16); \ + case CUDA_R_32F: case_macro(float); \ + default: EP_HOST_ASSERT(false && "Unsupported type"); \ + } while (false) + +#define SWITCH_HIDDEN(case_macro) \ + switch (hidden) { \ + case 2560: case_macro(2560); \ + case 4096: case_macro(4096); \ + case 5120: case_macro(5120); \ + case 7168: case_macro(7168); \ + default: EP_HOST_ASSERT(false && "Unsupported hidden"); \ + } while (false) diff --git a/src/ext/ep/kernels/runtime.cu b/src/ext/ep/kernels/runtime.cu new file mode 100644 index 00000000..4526fac1 --- /dev/null +++ b/src/ext/ep/kernels/runtime.cu @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// +// Portions adapted from DeepEP (https://github.com/deepseek-ai/DeepEP), +// branch `chhwang/dev-atomic-add-cleanup`. Licensed under the MIT License. +// +// Intranode runtime helpers. Only the NVLink barrier launcher is ported here +// (see DeepEP `csrc/kernels/runtime.cu::intranode::barrier`). The +// internode/NVSHMEM init helpers are deliberately omitted; the MSCCL++ port +// uses `mscclpp::Bootstrap`/`ProxyService` instead of NVSHMEM. + +#include "configs.cuh" +#include "exception.cuh" +#include "launch.cuh" +#include "utils.cuh" + +namespace mscclpp { +namespace ep { +namespace intranode { + +template +__global__ void barrier(int** task_fifo_ptrs, int head, int rank) { + barrier_device(task_fifo_ptrs, head, rank); +} + +void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) { +#define BARRIER_LAUNCH_CASE(ranks) \ + LAUNCH_KERNEL(&cfg, barrier, task_fifo_ptrs, head, rank); \ + break + + SETUP_LAUNCH_CONFIG(1, 32, stream); + SWITCH_RANKS(BARRIER_LAUNCH_CASE); +#undef BARRIER_LAUNCH_CASE +} + +} // namespace intranode +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/kernels/utils.cuh b/src/ext/ep/kernels/utils.cuh new file mode 100644 index 00000000..1a2512a0 --- /dev/null +++ b/src/ext/ep/kernels/utils.cuh @@ -0,0 +1,379 @@ +#pragma once + +#include "exception.cuh" + +#define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \ +{ \ + constexpr int kLoopStride = 32 * (UNROLL_FACTOR); \ + typename std::remove_reference::type unrolled_values[(UNROLL_FACTOR)]; \ + auto __src = (SRC); \ + auto __dst = (DST); \ + for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \ + _Pragma("unroll") \ + for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \ + unrolled_values[__j] = LD_FUNC(__src + __i + __j * 32); \ + _Pragma("unroll") \ + for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \ + ST_FUNC(__dst + __i + __j * 32, unrolled_values[__j]); \ + } \ + for (int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); __i += 32) \ + ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \ +} + +namespace mscclpp { namespace ep { + +template +struct VecInt {}; +template<> struct VecInt<1> { using vec_t = int8_t; }; +template<> struct VecInt<2> { using vec_t = int16_t; }; +template<> struct VecInt<4> { using vec_t = int; }; +template<> struct VecInt<8> { using vec_t = int64_t; }; +template<> struct VecInt<16> { using vec_t = int4; }; + +__device__ __forceinline__ void trap() { + asm("trap;"); +} + +__device__ __forceinline__ void memory_fence() { + asm volatile("fence.acq_rel.sys;":: : "memory"); +} + +__device__ __forceinline__ void memory_fence_gpu() { + asm volatile("fence.acq_rel.gpu;":: : "memory"); +} + +__device__ __forceinline__ void memory_fence_cta() { + asm volatile("fence.acq_rel.cta;":: : "memory"); +} + +__device__ __forceinline__ void st_relaxed_sys_global(const int *ptr, int val) { + asm volatile("st.relaxed.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory"); +} + +__device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) { + asm volatile("st.release.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory"); +} + +__device__ __forceinline__ void st_release_cta(const int *ptr, int val) { + asm volatile("st.release.cta.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory"); +} + +__device__ __forceinline__ int ld_acquire_sys_global(const int *ptr) { + int ret; + asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ uint64_t ld_acquire_sys_global(const uint64_t *ptr) { + uint64_t ret; + asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ int ld_acquire_global(const int *ptr) { + int ret; + asm volatile("ld.acquire.gpu.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ int atomic_add_release_sys_global(const int* ptr, int value) { + int ret; + asm volatile("atom.add.release.sys.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); + return ret; +} + +__device__ __forceinline__ int atomic_add_release_global(const int* ptr, int value) { + int ret; + asm volatile("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); + return ret; +} + +__device__ __forceinline__ int ld_acquire_cta(const int *ptr) { + int ret; + asm volatile("ld.acquire.cta.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ uint8_t ld_na_relaxed(const uint8_t *ptr) { + uint16_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr)); + return static_cast(ret); +} + +__device__ __forceinline__ uint16_t ld_na_relaxed(const uint16_t *ptr) { + uint16_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ uint32_t ld_na_relaxed(const uint32_t *ptr) { + uint32_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ uint64_t ld_na_relaxed(const uint64_t *ptr) { + uint64_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ int ld_volatile_global(const int *ptr) { + int ret; + asm volatile("ld.volatile.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ float ld_volatile_global(const float *ptr) { + float ret; + asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ int64_t ld_volatile_global(const int64_t *ptr) { + int64_t ret; + asm volatile("ld.volatile.global.s64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ int64_t ld_volatile_global(const uint64_t *ptr) { + int64_t ret; + asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS +#define LD_NC_FUNC "ld.global.nc.L1::no_allocate.L2::256B" +#else +#define LD_NC_FUNC "ld.volatile.global" +#endif + +// `ld.global.nc.L1::no_allocate` will be translated into `LDG.E.NA.[width].CONSTANT` in SASS +template +__device__ __forceinline__ dtype_t ld_nc_global(const dtype_t *ptr) { + auto ret = ld_nc_global(reinterpret_cast::vec_t*>(ptr)); + return *reinterpret_cast(&ret); +} + +template <> +__device__ __forceinline__ uint8_t ld_nc_global(const uint8_t *ptr) { + uint16_t ret; + // NOTES: we must use `uint16_t` as inline ASM does not support 8-bit constraint letter (`h` below means unsigned 16-bit) + asm volatile(LD_NC_FUNC ".u8 %0, [%1];" : "=h"(ret) : "l"(ptr)); + return static_cast(ret); +} + +template <> +__device__ __forceinline__ int ld_nc_global(const int *ptr) { + int ret; + asm volatile(LD_NC_FUNC ".s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +template <> +__device__ __forceinline__ int64_t ld_nc_global(const int64_t *ptr) { + int64_t ret; + asm volatile(LD_NC_FUNC ".s64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +template <> +__device__ __forceinline__ float ld_nc_global(const float *ptr) { + float ret; + asm volatile(LD_NC_FUNC ".f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); + return ret; +} + +template <> +__device__ __forceinline__ int2 ld_nc_global(const int2 *ptr) { + int2 ret; + asm volatile(LD_NC_FUNC ".v2.s32 {%0, %1}, [%2];" : "=r"(ret.x), "=r"(ret.y) : "l"(ptr)); + return ret; +} + +template <> +__device__ __forceinline__ int4 ld_nc_global(const int4 *ptr) { + int4 ret; + asm volatile(LD_NC_FUNC ".v4.s32 {%0, %1, %2, %3}, [%4];" + : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ void st_na_relaxed(const uint8_t *ptr, uint8_t val) { + asm volatile("st.relaxed.gpu.global.L1::no_allocate.b8 [%0], %1;" : : "l"(ptr), "h"(static_cast(val))); +} + +__device__ __forceinline__ void st_na_relaxed(const uint16_t *ptr, uint16_t val) { + asm volatile("st.relaxed.gpu.global.L1::no_allocate.b16 [%0], %1;" : : "l"(ptr), "h"(val)); +} + +__device__ __forceinline__ void st_na_relaxed(const uint32_t *ptr, uint32_t val) { + asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); +} + +__device__ __forceinline__ void st_na_relaxed(const int *ptr, int val) { + asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); +} + +__device__ __forceinline__ void st_na_relaxed(const int4 *ptr, int4 val) { + asm volatile("st.relaxed.gpu.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};" + : : "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); +} + +__device__ __forceinline__ void st_na_release(const int *ptr, int val) { + asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); +} + +__device__ __forceinline__ void st_na_release(const uint32_t *ptr, uint32_t val) { + asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); +} + +__device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val) { + asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val)); +} + +// `st.global.L1::no_allocate` will be translated into `ST.E.NA.[width]` in SASS +#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS +#define ST_NA_FUNC "st.global.L1::no_allocate" +#else +#define ST_NA_FUNC "st.global" +#endif + +template +__device__ __forceinline__ void st_na_global(const dtype_t *ptr, const dtype_t& value) { + st_na_global(reinterpret_cast::vec_t*>(ptr), + *reinterpret_cast::vec_t*>(&value)); +} + +template <> +__device__ __forceinline__ void st_na_global(const int *ptr, const int& value) { + asm volatile(ST_NA_FUNC ".s32 [%0], %1;" ::"l"(ptr), "r"(value)); +} + +template <> +__device__ __forceinline__ void st_na_global(const int64_t *ptr, const int64_t& value) { + asm volatile(ST_NA_FUNC ".s64 [%0], %1;" ::"l"(ptr), "l"(value)); +} + +template <> +__device__ __forceinline__ void st_na_global(const float *ptr, const float& value) { + asm volatile(ST_NA_FUNC ".f32 [%0], %1;" ::"l"(ptr), "f"(value)); +} + +template <> +__device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value) { + asm volatile(ST_NA_FUNC ".v4.s32 [%0], {%1, %2, %3, %4};" + ::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w)); +} + +template +__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) { + return (a + b - 1) / b; +} + +template +__host__ __device__ dtype_t align(dtype_t a, dtype_t b) { + return cell_div(a, b) * b; +} + +__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id, + int& token_start_idx, int& token_end_idx) { + int num_tokens_per_sm = cell_div(num_tokens, num_sms); + token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens); + token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens); +} + +template +__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) { + EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes"); + dtype_b_t packed; + auto unpacked_ptr = reinterpret_cast(&packed); + unpacked_ptr[0] = x, unpacked_ptr[1] = y; + return packed; +} + +template +__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) { + EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes"); + auto unpacked_ptr = reinterpret_cast(&packed); + x = unpacked_ptr[0], y = unpacked_ptr[1]; +} + +template +__device__ __forceinline__ dtype_t broadcast(dtype_t& ptr, int src_lane_idx) { + EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, ""); + auto send_int_values = reinterpret_cast(&ptr); + int recv_int_values[sizeof(dtype_t) / sizeof(int)]; + #pragma unroll + for (int i = 0; i < sizeof(dtype_t) / sizeof(int); ++ i) + recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], src_lane_idx); + return *reinterpret_cast(recv_int_values); +} + +__forceinline__ __device__ int warp_reduce_sum(int value) { + value += __shfl_xor_sync(0xffffffff, value, 16); + value += __shfl_xor_sync(0xffffffff, value, 8); + value += __shfl_xor_sync(0xffffffff, value, 4); + value += __shfl_xor_sync(0xffffffff, value, 2); + value += __shfl_xor_sync(0xffffffff, value, 1); + return value; +} + +__forceinline__ __device__ float half_warp_reduce_max(float value) { + auto mask = __activemask(); + // The mask be in `{0xffffffff, 0xffff}` + value = max(value, __shfl_xor_sync(mask, value, 8)); + value = max(value, __shfl_xor_sync(mask, value, 4)); + value = max(value, __shfl_xor_sync(mask, value, 2)); + value = max(value, __shfl_xor_sync(mask, value, 1)); + return value; +} + +__forceinline__ __device__ int get_lane_id() { + int lane_id; + asm("mov.s32 %0, %laneid;" : "=r"(lane_id)); + return lane_id; +} + +template +__forceinline__ __device__ void move_fifo_slots(int &head) { + head = (head + kNumRanks) % NUM_MAX_FIFO_SLOTS; +} + +template +__device__ __forceinline__ bool not_finished(int *task, int expected) { + auto result = false; + auto lane_id = threadIdx.x % 32; + if (lane_id < kNumRanks) + result = ld_volatile_global(task + lane_id) != expected; + return __any_sync(0xffffffff, result); +} + +template +__forceinline__ __device__ void +timeout_check(int **task_fifo_ptrs, int head, int rank, int expected, int tag = 0) { + auto start_time = clock64(); + while (not_finished(task_fifo_ptrs[rank] + head, expected)) { + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and threadIdx.x == 0) { + printf("DeepEP timeout check failed: %d (rank = %d)\n", tag, rank); + trap(); + } + } +} + +template +__forceinline__ __device__ void +barrier_device(int **task_fifo_ptrs, int head, int rank, int tag = 0) { + auto thread_id = static_cast(threadIdx.x); + EP_DEVICE_ASSERT(kNumRanks <= 32); + + if (thread_id < kNumRanks) { + atomicAdd_system(task_fifo_ptrs[rank] + head + thread_id, FINISHED_SUM_TAG); + memory_fence(); + atomicSub_system(task_fifo_ptrs[thread_id] + head + rank, FINISHED_SUM_TAG); + } + timeout_check(task_fifo_ptrs, head, rank, 0, tag); +} + +} // namespace ep +} // namespace mscclpp diff --git a/test/python/ext/ep/test_ep_smoke.py b/test/python/ext/ep/test_ep_smoke.py new file mode 100644 index 00000000..b300f041 --- /dev/null +++ b/test/python/ext/ep/test_ep_smoke.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Smoke tests for the EP extension. + +These tests only exercise single-rank / pure-Python code paths so they can +run in CI without multi-GPU resources. Multi-rank dispatch/combine tests +belong in ``test/python/ext/ep/test_intranode.py`` and are left as TODO +until the Python frontend is validated on H100. + +Run with:: + + pytest -xvs test/python/ext/ep/test_ep_smoke.py +""" + +from __future__ import annotations + +import pytest + +try: + import mscclpp_ep_cpp as _cpp # type: ignore[import-not-found] +except ImportError: # pragma: no cover + pytest.skip("mscclpp_ep_cpp is not built (set -DMSCCLPP_BUILD_EXT_EP=ON)", allow_module_level=True) + + +def test_config_roundtrip(): + cfg = _cpp.Config(num_sms=20, num_max_nvl_chunked_send_tokens=6, num_max_nvl_chunked_recv_tokens=256, + num_max_rdma_chunked_send_tokens=6, num_max_rdma_chunked_recv_tokens=256) + hint = cfg.get_nvl_buffer_size_hint(7168 * 2, 8) + assert hint > 0 + + +def test_low_latency_size_hint(): + assert _cpp.get_low_latency_rdma_size_hint(128, 7168, 8, 256) > 0 + + +def test_low_latency_rejected(): + # Low-latency (pure RDMA) path is not ported yet; Python frontend must + # refuse to construct a Buffer with low_latency_mode=True. We test the + # underlying C++ constructor directly so this does not depend on the + # full `mscclpp` Python package being installed. + import torch + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # The C++ Buffer allows low_latency_mode at construction; the enforcement + # lives in the Python frontend (`mscclpp.ext.ep.buffer.Buffer.__init__`). + # Verify the C++ side does NOT reject it, so the guarantee sits at the + # Python layer where it belongs. + buf = _cpp.Buffer(rank=0, num_ranks=1, num_nvl_bytes=0, num_rdma_bytes=0, low_latency_mode=True) + assert not buf.is_available()