diff --git a/python/mscclpp/ext/ep/buffer.py b/python/mscclpp/ext/ep/buffer.py index b50538a1..4a8ec011 100644 --- a/python/mscclpp/ext/ep/buffer.py +++ b/python/mscclpp/ext/ep/buffer.py @@ -13,8 +13,9 @@ Current status (see ``src/ext/ep/README.md``): * Intranode (NVLink-only) dispatch and combine are fully ported. * ``get_dispatch_layout`` is ported. -* Internode HT and low-latency methods raise from C++ — they still need - the NVSHMEM/IBGDA -> MSCCL++ PortChannel migration. +* Internode HT (MSCCL++ PortChannel + MemoryChannel) is ported. +* Internode low-latency kernels are ported structurally (NVSHMEM/IBGDA -> + MSCCL++ PortChannel) but **untested on multi-node H100**. """ from __future__ import annotations @@ -49,10 +50,11 @@ class Buffer: num_nvl_bytes: Size of the NVLink-accessible scratch buffer (shared via CUDA IPC). num_rdma_bytes: - Size of the RDMA scratch buffer. Must be 0 until internode/LL - support is landed. + Size of the RDMA scratch buffer. Required (>0) for internode HT and + low-latency modes. low_latency_mode: - Reserved — must be ``False`` until the LL path is ported. + Enable the low-latency dispatch/combine path (structural port, + untested). num_qps_per_rank: Ignored for intranode mode. """ @@ -68,13 +70,6 @@ class Buffer: low_latency_mode: bool = False, num_qps_per_rank: int = 12, ) -> None: - if low_latency_mode: - raise NotImplementedError( - "mscclpp.ext.ep.Buffer: low-latency mode is not yet ported. " - "Set low_latency_mode=False. See src/ext/ep/README.md for the " - "migration plan." - ) - self.rank: int = group.rank() self.group_size: int = group.size() self.group = group diff --git a/src/ext/ep/CMakeLists.txt b/src/ext/ep/CMakeLists.txt index 7f45aba1..bdd493df 100644 --- a/src/ext/ep/CMakeLists.txt +++ b/src/ext/ep/CMakeLists.txt @@ -41,7 +41,6 @@ endif() file(GLOB_RECURSE EP_SOURCES CONFIGURE_DEPENDS buffer.cc - internode_stub.cc bindings.cpp kernels/*.cu ) diff --git a/src/ext/ep/README.md b/src/ext/ep/README.md index 46d1d8df..bb6fad0a 100644 --- a/src/ext/ep/README.md +++ b/src/ext/ep/README.md @@ -18,17 +18,37 @@ A port of DeepEP's MoE `dispatch`/`combine` primitives into MSCCL++, targeting: | `intranode_combine` (NVLink) | ✅ ported | | `internode_dispatch` (NVLink+RDMA) | ✅ ported (pending H100 test) | | `internode_combine` (NVLink+RDMA) | ✅ ported (pending H100 test) | -| `low_latency_dispatch` (pure RDMA) | ❌ stub | -| `low_latency_combine` (pure RDMA) | ❌ stub | +| `low_latency_dispatch` (pure RDMA) | ⚠️ structural port, untested | +| `low_latency_combine` (pure RDMA) | ⚠️ structural port, untested | | `Connection::atomicAdd` API | ✅ cherry-picked into mscclpp | -| Python frontend `mscclpp.ext.ep` | ✅ wraps HT paths | +| Python frontend `mscclpp.ext.ep` | ✅ wraps HT + LL paths | | pybind11 module `mscclpp_ep_cpp` | ✅ builds conditionally | Internode HT is code-complete but unverified on real hardware — the `sync()` path replaces DeepEP's NVSHMEM symmetric-heap allocation with `cudaMalloc` + `bootstrap->barrier()`, and the kernels use the new `PortChannelDeviceHandle::atomicAdd` instead of the old raw-trigger -pattern. The low-latency path is the only remaining stub. +pattern. + +The low-latency port is **structural**: the DeepEP LL kernels (pure +IBGDA) have been mechanically translated to MSCCL++ port-channel ops. +Semantic mapping: + +| DeepEP / IBGDA | MSCCL++ replacement | +|------------------------------------------|--------------------------------------------------------------| +| `nvshmemx_barrier_all_block()` | signal+wait ring across `port_channel_handles[peer_rank]` | +| `nvshmemi_ibgda_put_nbi_warp(...)` | lane-0 `port_channel_handles[qp*N+dst].put(dst_off, src_off, n)` | +| `nvshmemi_ibgda_amo_nonfetch_add(...)` | lane-0 `port_channel_handles[qp*N+dst].atomicAdd(off, int64)` | + +**Known limitations**: + +* LL performance will NOT match IBGDA — the MSCCL++ port channel uses a + CPU proxy. The port is for functional parity, not latency. +* `Buffer::sync()` in `low_latency_mode=True` only connects peers sharing + the same local GPU ID (DeepEP convention). LL kernels therefore assume + one-GPU-per-node topology, i.e. `num_ranks == num_rdma_ranks`. Running + with >1 GPU per node in LL mode will fail to reach cross-GPU peers. +* Multi-node H100 validation is still pending. ## Build @@ -58,7 +78,6 @@ src/ext/ep/ ├── buffer.hpp / buffer.cc — host-side Buffer, sync(), dispatch/combine ├── config.hpp / event.hpp — Config, EventHandle ├── bindings.cpp — PYBIND11_MODULE definition -├── internode_stub.cc — stubs for not-yet-ported LL launchers └── kernels/ ├── api.cuh — host-callable kernel prototypes ├── configs.cuh — compile-time constants (GPU-only) @@ -67,8 +86,9 @@ src/ext/ep/ ├── launch.cuh — SETUP_LAUNCH_CONFIG / SWITCH_* macros ├── utils.cuh — device inline helpers ├── runtime.cu — intranode::barrier launcher - ├── intranode_kernel.cu — notify_dispatch / dispatch / combine kernels - └── internode_layout.cu — get_dispatch_layout (CPU-safe subset) + ├── intranode_kernel.cu — intranode dispatch/combine kernels + ├── internode.cu — internode HT dispatch/combine + layout + └── internode_ll.cu — internode LL dispatch/combine (structural) python/mscclpp/ext/ep/ ├── __init__.py — reexports Buffer / Config / EventHandle diff --git a/src/ext/ep/buffer.cc b/src/ext/ep/buffer.cc index c9bf5c04..1544c93b 100644 --- a/src/ext/ep/buffer.cc +++ b/src/ext/ep/buffer.cc @@ -1126,24 +1126,193 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional(ptr) - reinterpret_cast(rdma_buffer_ptr); + EP_HOST_ASSERT(0 <= offset and offset + static_cast(num_bytes) <= num_rdma_bytes); + }; + check_boundary(clean_meta_0.first, clean_meta_0.second * sizeof(int)); + check_boundary(clean_meta_1.first, clean_meta_1.second * sizeof(int)); + + internode_ll::clean_low_latency_buffer(clean_meta_0.first, clean_meta_0.second, + clean_meta_1.first, clean_meta_1.second, + rank, num_ranks, + port_channel_handles_device_ptr.get(), + at::cuda::getCurrentCUDAStream()); } std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> -Buffer::low_latency_dispatch(const torch::Tensor&, const torch::Tensor&, int, int, bool, bool, bool) { - throw std::runtime_error("mscclpp::ep::Buffer::low_latency_dispatch: not yet ported (needs NVSHMEM/IBGDA -> MSCCL++ migration)"); +Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, + int num_max_dispatch_tokens_per_rank, int num_experts, + bool use_fp8, bool async, bool return_recv_hook) { + EP_HOST_ASSERT(low_latency_mode); + + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); + EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0); + EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); + EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); + EP_HOST_ASSERT(num_experts % num_ranks == 0); + + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); + auto num_scales = hidden / 128, num_topk = static_cast(topk_idx.size(1)); + int num_local_experts = num_experts / num_ranks; + + LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); + auto buffer = layout.buffers[low_latency_buffer_idx]; + auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; + + auto compute_stream = at::cuda::getCurrentCUDAStream(); + auto launch_stream = return_recv_hook ? compute_stream : comm_stream; + EP_HOST_ASSERT(not (async and return_recv_hook)); + if (not return_recv_hook) + stream_wait(launch_stream, compute_stream); + + auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, + x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16)); + auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); + auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + + auto packed_recv_x_scales = std::optional(); + float* packed_recv_x_scales_ptr = nullptr; + if (use_fp8) { + EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); + packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); + packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); + packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); + } + + auto next_clean_meta = next_buffer.clean_meta(); + auto port_handles = port_channel_handles_device_ptr.get(); + auto rdma_base = rdma_buffer_ptr; + auto launcher = [=](int phases) { + internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, + packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), + packed_recv_count.data_ptr(), + buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, + buffer.dispatch_rdma_send_buffer, + x.data_ptr(), topk_idx.data_ptr(), + next_clean_meta.first, next_clean_meta.second, + num_tokens, hidden, num_max_dispatch_tokens_per_rank, + num_topk, num_experts, rank, num_ranks, use_fp8, + workspace, launch_stream, phases, + rdma_base, port_handles); + }; + launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); + + std::optional event; + if (async) { + event = EventHandle(launch_stream); + } else if (not return_recv_hook) { + stream_wait(compute_stream, launch_stream); + } + + std::optional> recv_hook = std::nullopt; + if (return_recv_hook) + recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; + + return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook}; } std::tuple, std::optional>> -Buffer::low_latency_combine(const torch::Tensor&, const torch::Tensor&, const torch::Tensor&, - const torch::Tensor&, const torch::Tensor&, - int, int, bool, bool, bool, const std::optional&) { - throw std::runtime_error("mscclpp::ep::Buffer::low_latency_combine: not yet ported"); +Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, + const torch::Tensor& src_info, const torch::Tensor& layout_range, + int num_max_dispatch_tokens_per_rank, int num_experts, + bool zero_copy, bool async, bool return_recv_hook, + const std::optional& out) { + EP_HOST_ASSERT(low_latency_mode); + + EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); + EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks); + EP_HOST_ASSERT(x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0); + EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); + EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1)); + EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); + EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous()); + EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32); + EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous()); + EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0)); + EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous()); + EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64); + EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks); + auto hidden = static_cast(x.size(2)); + auto num_topk = static_cast(topk_weights.size(1)); + auto num_combined_tokens = static_cast(topk_weights.size(0)); + + LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); + auto buffer = layout.buffers[low_latency_buffer_idx]; + auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; + + auto compute_stream = at::cuda::getCurrentCUDAStream(); + auto launch_stream = return_recv_hook ? compute_stream : comm_stream; + EP_HOST_ASSERT(not (async and return_recv_hook)); + if (not return_recv_hook) + stream_wait(launch_stream, compute_stream); + + torch::Tensor combined_x; + if (out.has_value()) { + EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous()); + EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden); + EP_HOST_ASSERT(out->scalar_type() == x.scalar_type()); + combined_x = out.value(); + } else { + combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); + } + + auto next_clean_meta = next_buffer.clean_meta(); + auto port_handles = port_channel_handles_device_ptr.get(); + auto rdma_base = rdma_buffer_ptr; + auto launcher = [=](int phases) { + internode_ll::combine(combined_x.data_ptr(), + buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, + buffer.combine_rdma_send_buffer, + x.data_ptr(), topk_idx.data_ptr(), topk_weights.data_ptr(), + src_info.data_ptr(), layout_range.data_ptr(), + next_clean_meta.first, next_clean_meta.second, + num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, + num_topk, num_experts, rank, num_ranks, + workspace, launch_stream, + phases, zero_copy, + rdma_base, port_handles); + }; + launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); + + std::optional event; + if (async) { + event = EventHandle(launch_stream); + } else if (not return_recv_hook) { + stream_wait(compute_stream, launch_stream); + } + + std::optional> recv_hook = std::nullopt; + if (return_recv_hook) + recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; + + return {combined_x, event, recv_hook}; } -torch::Tensor Buffer::get_next_low_latency_combine_buffer(int, int, int) { - throw std::runtime_error("mscclpp::ep::Buffer::get_next_low_latency_combine_buffer: not yet ported"); +torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { + LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + auto buffer = layout.buffers[low_latency_buffer_idx]; + auto dtype = torch::kBFloat16; + auto num_msg_elems = static_cast(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16)); + + EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0); + return torch::from_blob(buffer.combine_rdma_send_buffer_data_start, + {num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, + {num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1}, + torch::TensorOptions().dtype(dtype).device(torch::kCUDA)); } } // namespace ep diff --git a/src/ext/ep/internode_stub.cc b/src/ext/ep/internode_stub.cc deleted file mode 100644 index 329e31ee..00000000 --- a/src/ext/ep/internode_stub.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. -// -// Placeholder launchers for the not-yet-ported internode HT and low-latency -// kernels. `get_dispatch_layout` and `get_source_meta_bytes` ARE ported in -// `kernels/internode_layout.cu`. - -#include - -#include "kernels/api.cuh" - -namespace mscclpp { -namespace ep { - -namespace internode_ll { - -void clean_low_latency_buffer(int* /*clean_0*/, int /*n0*/, int* /*clean_1*/, int /*n1*/, cudaStream_t /*stream*/) { - throw std::runtime_error( - "mscclpp::ep::internode_ll::clean_low_latency_buffer: not yet ported. " - "See nccl/contrib/nccl_ep/device/low_latency.cu and DeepEP " - "csrc/kernels/internode_ll.cu for the reference implementation."); -} - -} // namespace internode_ll - -} // namespace ep -} // namespace mscclpp diff --git a/src/ext/ep/kernels/api.cuh b/src/ext/ep/kernels/api.cuh index 479b6a17..294e1043 100644 --- a/src/ext/ep/kernels/api.cuh +++ b/src/ext/ep/kernels/api.cuh @@ -129,13 +129,42 @@ void combine(cudaDataType_t type, } // namespace internode // =========================================================================== -// Internode low-latency (pure RDMA) kernels. Not ported yet. +// Internode low-latency (pure RDMA) kernels. Ported from DeepEP +// `csrc/kernels/internode_ll.cu` with NVSHMEM/IBGDA device ops replaced by +// MSCCL++ port-channel primitives (`put`, `atomicAdd`, signal/wait barrier). // =========================================================================== namespace internode_ll { -void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, int* clean_1, int num_clean_int_1, +void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, + int* clean_1, int num_clean_int_1, + int rank, int num_ranks, + mscclpp::PortChannelDeviceHandle* port_channel_handles, cudaStream_t stream); +void dispatch(void* packed_recv_x, float* packed_recv_x_scales, + int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int* packed_recv_count, + void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, + const void* x, const int64_t* topk_idx, + int* next_clean, int num_next_clean_int, + int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, + void* workspace, cudaStream_t stream, int phases, + void* rdma_buffer_ptr, + mscclpp::PortChannelDeviceHandle* port_channel_handles); + +void combine(void* combined_x, + void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, + const void* x, const int64_t* topk_idx, const float* topk_weights, + const int* src_info, const int64_t* layout_range, + int* next_clean, int num_next_clean_int, + int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, + void* workspace, cudaStream_t stream, + int phases, bool zero_copy, + void* rdma_buffer_ptr, + mscclpp::PortChannelDeviceHandle* port_channel_handles); + } // namespace internode_ll } // namespace ep diff --git a/src/ext/ep/kernels/internode_ll.cu b/src/ext/ep/kernels/internode_ll.cu new file mode 100644 index 00000000..b75c18c7 --- /dev/null +++ b/src/ext/ep/kernels/internode_ll.cu @@ -0,0 +1,595 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// +// Low-latency internode dispatch/combine kernels ported from DeepEP +// `csrc/kernels/internode_ll.cu` (branch `chhwang/dev-atomic-add-cleanup`). +// +// NVSHMEM/IBGDA device calls are replaced with MSCCL++ PortChannel device +// operations: +// +// nvshmemx_barrier_all_block() -> port-channel signal/wait ring +// nvshmemi_ibgda_put_nbi_warp(...) -> port_channel.put(...) (lane 0) +// nvshmemi_ibgda_amo_nonfetch_add(...) -> port_channel.atomicAdd(...) +// +// Addressing convention: +// - `rdma_buffer_ptr` is the base of the locally-registered RDMA buffer. +// - Remote counter/buffer pointers written by the kernel are virtual +// addresses that alias the corresponding offset inside each peer's +// symmetric RDMA buffer. MSCCL++ needs those as offsets; we derive them +// via `ptr - rdma_buffer_ptr`. +// - Port-channel layout built by `Buffer::sync()` in low-latency mode is +// `handles[qp * num_peers + peer_idx]` where `peer_idx` is the dst rank's +// position in the connected-peer map. In the recommended 1-GPU-per-node +// LL topology, `peer_idx == dst_rank`; see src/ext/ep/README.md. +// +// WARNING: This port is untested on multi-node H100; performance will NOT +// match IBGDA (host-proxy adds latency). Functional correctness needs +// validation on real hardware. + +#include "configs.cuh" +#include "exception.cuh" +#include "launch.cuh" +#include "utils.cuh" + +#include +#include + +namespace cg = cooperative_groups; + +namespace mscclpp { +namespace ep { + +namespace internode_ll { + +// --------------------------------------------------------------------------- +// Device helpers +// --------------------------------------------------------------------------- + +// Pointer-to-offset helper for MSCCL++ port channels. Both src and dst +// pointers passed to the DeepEP LL kernels are virtual addresses aliased into +// the caller's symmetric RDMA buffer; MSCCL++ expects offsets. +__device__ __forceinline__ uint64_t rdma_offset_of(uint64_t ptr, void* rdma_buffer_ptr) { + return ptr - reinterpret_cast(rdma_buffer_ptr); +} + +// Cross-rank barrier via port-channel signal/wait ring. +// Uses port channel `qp=0` across all connected peers. +__device__ __forceinline__ void port_channel_barrier_block( + mscclpp::PortChannelDeviceHandle* port_channel_handles, + int rank, int num_ranks) { + const int tid = threadIdx.x; + if (tid < num_ranks && tid != rank) { + // Index: qp 0, peer = tid's rank (assumes peer_idx == rank in LL topology). + port_channel_handles[tid].signal(); + port_channel_handles[tid].wait(); + } + __syncthreads(); +} + +// --------------------------------------------------------------------------- +// clean_low_latency_buffer +// --------------------------------------------------------------------------- + +template __launch_bounds__(kNumThreads, 1) +__global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, + int* clean_1, int num_clean_int_1, + mscclpp::PortChannelDeviceHandle* port_channel_handles, + int rank, int num_ranks) { + // Barrier before cleaning (in case of unfinished chunked EP) + port_channel_barrier_block(port_channel_handles, rank, num_ranks); + + // Clean + auto thread_id = static_cast(threadIdx.x); + #pragma unroll + for (int i = thread_id; i < num_clean_int_0; i += kNumThreads) + clean_0[i] = 0; + #pragma unroll + for (int i = thread_id; i < num_clean_int_1; i += kNumThreads) + clean_1[i] = 0; + + // Barrier after cleaning (make sure low-latency mode work fine) + port_channel_barrier_block(port_channel_handles, rank, num_ranks); +} + +void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, + int* clean_1, int num_clean_int_1, + int rank, int num_ranks, + mscclpp::PortChannelDeviceHandle* port_channel_handles, + cudaStream_t stream) { + constexpr int kNumThreads = 256; + + SETUP_LAUNCH_CONFIG(1, kNumThreads, stream); + LAUNCH_KERNEL(&cfg, clean_low_latency_buffer, + clean_0, num_clean_int_0, clean_1, num_clean_int_1, + port_channel_handles, rank, num_ranks); +} + +// --------------------------------------------------------------------------- +// dispatch +// --------------------------------------------------------------------------- + +template +__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void +dispatch(void* packed_recv_x, float* packed_recv_x_scales, + int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int* packed_recv_count, + void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, + const void* x, const int64_t* topk_idx, + int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, + int* next_clean, int num_next_clean_int, + int num_tokens, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, + int phases, + void* rdma_buffer_ptr, + mscclpp::PortChannelDeviceHandle* port_channel_handles) { + const auto sm_id = static_cast(blockIdx.x); + const auto thread_id = static_cast(threadIdx.x); + const auto warp_id = thread_id / 32, lane_id = get_lane_id(); + const auto num_sms = static_cast(gridDim.x); + const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; + const auto num_local_experts = num_experts / num_ranks; + const auto warp_group_id = warp_id / kNumWarpsPerGroup; + const auto sub_warp_id = warp_id % kNumWarpsPerGroup; + const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id; + + // FP8 staffs + constexpr int kNumPerChannels = 128; + constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f; + const int num_scales = kHidden / kNumPerChannels; + const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); + const size_t hidden_int4 = hidden_bytes / sizeof(int4); + + // Message package: hidden data, FP8 scales, index at source + using vec_t = typename std::conditional::type; + const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); + const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); + EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); + + // Sending phase + if ((phases & LOW_LATENCY_SEND_PHASE) == 0) + goto LOW_LATENCY_DISPATCH_RECV; + + // Expert counts + __shared__ int shared_num_tokens_sent_per_expert[kNumWarpGroups]; + + if (warp_id < num_warps - 1) { + constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16); + EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); + EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization"); + const auto num_threads = (num_warps - 1) * 32; + const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; + + for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { + const auto x_int4 = reinterpret_cast(x) + token_idx * hidden_bf16_int4; + const auto rdma_x_src_idx = reinterpret_cast(reinterpret_cast(rdma_x) + token_idx * num_bytes_per_msg); + const auto rdma_x_vec = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); + const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); + + auto dst_expert_idx = warp_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; + thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; + + // FP8 cast + #pragma unroll + for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { + auto int4_value = __ldg(x_int4 + i); + + if (kUseFP8) { + auto bf16_values = reinterpret_cast(&int4_value); + float fp32_values[kNumElemsPerRead]; + float amax = kFP8Margin, scale, scale_inv; + #pragma unroll + for (int j = 0; j < kNumElemsPerRead; ++ j) { + fp32_values[j] = static_cast(bf16_values[j]); + amax = fmaxf(amax, fabsf(fp32_values[j])); + } + + EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); + amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, scale_inv = amax * kFP8AmaxInv; + if (lane_id == 0 or lane_id == 16) + rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; + + vec_t int2_value; + auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); + #pragma unroll + for (int j = 0; j < kNumElemsPerRead; j += 2) { + float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; + fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); + } + rdma_x_vec[i] = int2_value; + } else { + rdma_x_vec[i] = *reinterpret_cast(&int4_value); + } + } + asm volatile("bar.sync 1, %0;" :: "r"(num_threads)); + + // Issue sends + if (dst_expert_idx >= 0) { + int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; + slot_idx = __shfl_sync(0xffffffff, slot_idx, 0); + const auto dst_rank = dst_expert_idx / num_local_experts; + const auto dst_expert_local_idx = dst_expert_idx % num_local_experts; + const auto src_ptr = reinterpret_cast(rdma_x_src_idx); + const auto dst_ptr = reinterpret_cast(rdma_recv_x) + + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + slot_idx * num_bytes_per_msg; + if (dst_rank != rank) { + // MSCCL++ port-channel PUT (lane 0 issues one request). + if (lane_id == 0) { + const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); + const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr); + port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank] + .put(dst_off, src_off, num_bytes_per_msg); + } + __syncwarp(); + } else { + const auto* src_int4_ptr = reinterpret_cast(src_ptr); + const auto* dst_int4_ptr = reinterpret_cast(dst_ptr); + UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + } + + __syncwarp(); + lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0; + } + } + } else if (warp_id == num_warps - 1) { + EP_DEVICE_ASSERT(num_sms > 1); + if (sm_id == 0) { + // NOTE: DeepEP asserts `ibgda_get_state()->num_rc_per_pe >= num_local_experts` + // here. The MSCCL++ port relies on Buffer::sync() provisioning enough QPs + // (see `num_ib_connections_per_rank` / `num_port_channels_per_rank`). + + #pragma unroll + for (int i = lane_id; i < num_next_clean_int; i += 32) + next_clean[i] = 0; + + __syncwarp(); + #pragma unroll + for (int i = lane_id; i < num_experts; i += 32) + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); + } + + int expert_count[kNumWarpGroups] = {0}; + const auto expert_begin_idx = sm_id * kNumWarpGroups; + const auto expert_end_idx = min(expert_begin_idx + kNumWarpGroups, num_experts); + + #pragma unroll 8 + for (int i = lane_id; i < num_tokens * num_topk; i += 32) { + auto idx = static_cast(__ldg(topk_idx + i)); + if (idx >= expert_begin_idx and idx < expert_end_idx) + expert_count[idx - expert_begin_idx] ++; + } + + #pragma unroll + for (int i = expert_begin_idx; i < expert_end_idx; ++ i) { + auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); + if (lane_id == 0) { + shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); + } + } + } + __syncthreads(); + + // Issue count sends + if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) { + const auto dst_rank = responsible_expert_idx / num_local_experts; + const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts; + const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * kNumWarpGroups]; + + while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); + if (dst_rank != rank) { + auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank; + const auto off = rdma_offset_of(reinterpret_cast(counter_ptr), rdma_buffer_ptr); + port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank] + .atomicAdd(off, static_cast(-num_tokens_sent - 1)); + } else { + st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1); + } + + atomic_counter_per_expert[responsible_expert_idx] = 0; + atomic_finish_counter_per_expert[responsible_expert_idx] = 0; + + if (dst_rank == 0) + packed_recv_count[dst_expert_local_idx] = 0; + } + __syncwarp(); + + // Receiving phase + LOW_LATENCY_DISPATCH_RECV: + if ((phases & LOW_LATENCY_RECV_PHASE) == 0) + return; + + if (phases & LOW_LATENCY_SEND_PHASE) + cg::this_grid().sync(); + + if (responsible_expert_idx < num_experts) { + const auto src_rank = responsible_expert_idx / num_local_experts; + const auto local_expert_idx = responsible_expert_idx % num_local_experts; + const auto rdma_recv_x_uint8 = reinterpret_cast(rdma_recv_x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; + const auto recv_x_int4 = reinterpret_cast(packed_recv_x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; + const auto recv_x_scales = packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales; + const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; + const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; + + __shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups]; + + int num_recv_tokens, recv_token_begin_idx; + EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); + if (sub_warp_id == 1 and lane_id == 0) { + while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0); + num_recv_tokens = -num_recv_tokens - 1; + recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); + shared_num_recv_tokens[warp_group_id] = num_recv_tokens; + shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx; + recv_range[src_rank] = pack2(num_recv_tokens, recv_token_begin_idx); + } + asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32)); + num_recv_tokens = shared_num_recv_tokens[warp_group_id]; + recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; + + EP_DEVICE_ASSERT(num_scales <= 64); + for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) { + const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); + if (lane_id == 0) + recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); + __syncwarp(); + + const auto src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); + const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; + UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); + + if (kUseFP8) { + const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); + const auto dst_scales = reinterpret_cast(recv_x_scales + recv_token_begin_idx + i); + const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank; + auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0; + auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0; + lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f; + (lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f; + } + } + } +} + +void dispatch(void* packed_recv_x, float* packed_recv_x_scales, + int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int* packed_recv_count, + void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, + const void* x, const int64_t* topk_idx, + int* next_clean, int num_next_clean_int, + int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, + void* workspace, cudaStream_t stream, int phases, + void* rdma_buffer_ptr, + mscclpp::PortChannelDeviceHandle* port_channel_handles) { + constexpr int kNumMaxTopK = 9; + constexpr int kNumWarpsPerGroup = 10; + constexpr int kNumWarpGroups = 3; + EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections"); + + const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; + const auto num_sms = cell_div(num_experts, kNumWarpGroups); + EP_HOST_ASSERT(num_topk <= kNumMaxTopK); + + auto atomic_counter_per_expert = reinterpret_cast(workspace); + auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; + EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); + +#define DISPATCH_LAUNCH_CASE(hidden_case) { \ +auto dispatch_func = use_fp8 ? dispatch : \ + dispatch; \ +LAUNCH_KERNEL(&cfg, dispatch_func, \ + packed_recv_x, packed_recv_x_scales, \ + packed_recv_src_info, packed_recv_layout_range, \ + packed_recv_count, \ + rdma_recv_x, rdma_recv_count, rdma_x, \ + x, topk_idx, \ + atomic_counter_per_expert, atomic_finish_counter_per_expert, \ + next_clean, num_next_clean_int, \ + num_tokens, num_max_dispatch_tokens_per_rank, \ + num_topk, num_experts, rank, num_ranks, phases, \ + rdma_buffer_ptr, port_channel_handles); } break + + SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream); + SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE); +#undef DISPATCH_LAUNCH_CASE +} + +// --------------------------------------------------------------------------- +// combine +// --------------------------------------------------------------------------- + +template +__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void +combine(void* combined_x, + void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, + const void* x, const int64_t* topk_idx, const float* topk_weights, + const int* src_info, const int64_t* layout_range, + int* next_clean, int num_next_clean_int, + int* atomic_clean_flag, + int num_combined_tokens, int hidden, int num_topk, + int num_max_dispatch_tokens_per_rank, + int num_experts, int rank, int num_ranks, + int phases, bool zero_copy, + void* rdma_buffer_ptr, + mscclpp::PortChannelDeviceHandle* port_channel_handles) { + const auto sm_id = static_cast(blockIdx.x); + const auto num_sms = static_cast(gridDim.x); + const auto thread_id = static_cast(threadIdx.x); + const auto num_threads = static_cast(blockDim.x); + const auto warp_id = thread_id / 32, lane_id = get_lane_id(); + const auto num_local_experts = num_experts / num_ranks; + const auto warp_group_id = warp_id / kNumWarpsPerGroup; + const auto sub_warp_id = warp_id % kNumWarpsPerGroup; + const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id; + + constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16); + const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4; + + constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16); + EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); + + if ((phases & LOW_LATENCY_SEND_PHASE) == 0) + goto LOW_LATENCY_COMBINE_RECV; + + if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) { + #pragma unroll + for (int i = lane_id; i < num_next_clean_int; i += 32) + next_clean[i] = 0; + + __syncwarp(); + if (lane_id == 0) + atomic_add_release_global(atomic_clean_flag, num_experts); + } + + if (responsible_expert_idx < num_experts) { + const auto dst_rank = responsible_expert_idx / num_local_experts; + const auto local_expert_idx = responsible_expert_idx % num_local_experts; + const auto global_expert_idx = rank * num_local_experts + local_expert_idx; + const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank); + const auto local_x = reinterpret_cast(x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4; + const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; + const auto rdma_send_x_vec = reinterpret_cast(rdma_send_x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot; + + int offset, num_tokens_to_send; + unpack2(layout, num_tokens_to_send, offset); + + for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) { + const auto x_int4 = local_x + token_idx * hidden_bf16_int4; + const auto rdma_send_type_row = reinterpret_cast(rdma_send_x_vec + token_idx * num_bytes_per_slot); + const auto rdma_send_x_vec_row = reinterpret_cast(rdma_send_type_row); + + auto src_idx = __ldg(local_src_info + token_idx); + const auto buf_ptr = reinterpret_cast(rdma_send_x_vec_row); + const auto dst_ptr = reinterpret_cast(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot; + if (dst_rank == rank) { + const auto dst_int4_ptr = reinterpret_cast(dst_ptr); + UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); + } else { + const auto buf_int4_ptr = reinterpret_cast(buf_ptr); + if (not zero_copy) + UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); + // MSCCL++ port-channel PUT. + if (lane_id == 0) { + const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); + const auto src_off = rdma_offset_of(static_cast(buf_ptr), rdma_buffer_ptr); + port_channel_handles[local_expert_idx * num_ranks + dst_rank] + .put(dst_off, src_off, hidden * sizeof(nv_bfloat16)); + } + __syncwarp(); + } + } + + EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); + asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32)); + if (sub_warp_id == 1 and lane_id == 0) { + while (ld_acquire_global(atomic_clean_flag) == 0); + if (dst_rank != rank) { + auto* flag_ptr = rdma_recv_flag + global_expert_idx; + const auto off = rdma_offset_of(reinterpret_cast(flag_ptr), rdma_buffer_ptr); + port_channel_handles[local_expert_idx * num_ranks + dst_rank] + .atomicAdd(off, static_cast(1)); + } else { + st_na_release(rdma_recv_flag + global_expert_idx, 1); + } + atomic_add_release_global(atomic_clean_flag, -1); + } + __syncwarp(); + } + + LOW_LATENCY_COMBINE_RECV: + if ((phases & LOW_LATENCY_RECV_PHASE) == 0) + return; + + if (responsible_expert_idx < num_experts) { + EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group"); + if (sub_warp_id == 0 and lane_id == 0) + while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0); + } + cg::this_grid().sync(); + + EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads); + EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization"); + if (thread_id < hidden_bf16_int4) { + for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) { + int reg_topk_idx[kNumMaxTopk]; + float reg_topk_weights[kNumMaxTopk]; + #pragma unroll + for (int i = 0; i < num_topk; ++ i) { + reg_topk_idx[i] = static_cast(__ldg(topk_idx + token_idx * num_topk + i)); + reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i); + } + + float combined_values[kNumElemsPerInt4] = {0.0f}; + #pragma unroll + for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) { + auto rdma_buffer_type = reinterpret_cast(reinterpret_cast(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot); + auto rdma_buffer_row = reinterpret_cast(rdma_buffer_type); + + auto x_vec = ld_nc_global(reinterpret_cast(rdma_buffer_row) + thread_id); + const auto x_bf16 = reinterpret_cast(&x_vec); + #pragma unroll + for (int j = 0; j < kNumElemsPerInt4; ++ j) + combined_values[j] += static_cast(x_bf16[j]) * reg_topk_weights[i]; + } + + int4& combined_int4 = *reinterpret_cast(combined_values); + auto combined_bf16 = reinterpret_cast(&combined_values); + #pragma unroll + for (int j = 0; j < kNumElemsPerInt4; ++ j) + combined_bf16[j] = static_cast(combined_values[j]); + (reinterpret_cast(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4; + } + } +} + +void combine(void* combined_x, + void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, + const void* x, const int64_t* topk_idx, const float* topk_weights, + const int* src_info, const int64_t* layout_range, + int* next_clean, int num_next_clean_int, + int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, + void* workspace, cudaStream_t stream, + int phases, bool zero_copy, + void* rdma_buffer_ptr, + mscclpp::PortChannelDeviceHandle* port_channel_handles) { + constexpr int kNumWarpsPerGroup = 10; + constexpr int kNumWarpGroups = 3; + constexpr int kNumMaxTopk = 9; + + const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; + const auto num_sms = cell_div(num_experts, kNumWarpGroups); + + auto atomic_clean_flag = reinterpret_cast(workspace); + EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES); + EP_HOST_ASSERT(num_topk <= kNumMaxTopk); + +#define COMBINE_LAUNCH_CASE(hidden_case) { \ +auto combine_func = combine; \ +LAUNCH_KERNEL(&cfg, combine_func, \ + combined_x, \ + rdma_recv_x, rdma_recv_flag, rdma_send_x, \ + x, topk_idx, topk_weights, src_info, layout_range, \ + next_clean, num_next_clean_int, \ + atomic_clean_flag, \ + num_combined_tokens, hidden, num_topk, \ + num_max_dispatch_tokens_per_rank, \ + num_experts, rank, num_ranks, \ + phases, zero_copy, \ + rdma_buffer_ptr, port_channel_handles); } break + + SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream); + SWITCH_HIDDEN(COMBINE_LAUNCH_CASE); +#undef COMBINE_LAUNCH_CASE +} + +} // namespace internode_ll +} // namespace ep +} // namespace mscclpp diff --git a/test/python/ext/ep/test_ep_smoke.py b/test/python/ext/ep/test_ep_smoke.py index b300f041..3002e20f 100644 --- a/test/python/ext/ep/test_ep_smoke.py +++ b/test/python/ext/ep/test_ep_smoke.py @@ -33,19 +33,14 @@ def test_low_latency_size_hint(): assert _cpp.get_low_latency_rdma_size_hint(128, 7168, 8, 256) > 0 -def test_low_latency_rejected(): - # Low-latency (pure RDMA) path is not ported yet; Python frontend must - # refuse to construct a Buffer with low_latency_mode=True. We test the - # underlying C++ constructor directly so this does not depend on the - # full `mscclpp` Python package being installed. +def test_low_latency_buffer_construct(): + # Low-latency kernels are structurally ported. At construction time the + # C++ Buffer must accept low_latency_mode=True; runtime requires a real + # multi-node setup (see tests in tests/test_low_latency.py when ported). import torch if not torch.cuda.is_available(): pytest.skip("CUDA not available") - # The C++ Buffer allows low_latency_mode at construction; the enforcement - # lives in the Python frontend (`mscclpp.ext.ep.buffer.Buffer.__init__`). - # Verify the C++ side does NOT reject it, so the guarantee sits at the - # Python layer where it belongs. buf = _cpp.Buffer(rank=0, num_ranks=1, num_nvl_bytes=0, num_rdma_bytes=0, low_latency_mode=True) assert not buf.is_available()