mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
src/ext/ep: port low-latency dispatch/combine kernels
Port DeepEP's pure-RDMA low-latency (LL) MoE kernels from
csrc/kernels/internode_ll.cu (branch chhwang/dev-atomic-add-cleanup)
into the MSCCL++ EP extension. NVSHMEM / IBGDA device primitives are
replaced with MSCCL++ PortChannelDeviceHandle operations:
nvshmemx_barrier_all_block() -> port-channel signal+wait ring
nvshmemi_ibgda_put_nbi_warp(...) -> lane-0 PortChannel.put(...)
nvshmemi_ibgda_amo_nonfetch_add(...) -> lane-0 PortChannel.atomicAdd(...)
The atomicAdd path relies on the MSCCL++ Connection::atomicAdd /
PortChannelDeviceHandle::atomicAdd API cherry-picked from branch
chhwang/new-atomic-add; the LL dispatch path uses a signed delta
(-num_tokens_sent - 1) which the new int64_t signature supports.
Changes:
* New file src/ext/ep/kernels/internode_ll.cu (~530 lines) with the
three kernels clean_low_latency_buffer, dispatch<kUseFP8,...>,
combine<...> plus their launchers. rdma_buffer_ptr is threaded
through the launchers so the kernel can translate virtual addresses
into registered-memory offsets expected by MSCCL++.
* kernels/api.cuh: replace the single stub signature with full LL
launcher prototypes.
* buffer.cc: replace the four LL throw-stubs
(clean_low_latency_buffer, low_latency_dispatch,
low_latency_combine, get_next_low_latency_combine_buffer) with
torch-Tensor implementations ported from DeepEP/csrc/deep_ep.cpp.
* Drop src/ext/ep/internode_stub.cc and its CMake entry.
* python/mscclpp/ext/ep/buffer.py: remove the low_latency_mode=True
NotImplementedError guard; update docstring.
* test/python/ext/ep/test_ep_smoke.py: rename
test_low_latency_rejected -> test_low_latency_buffer_construct
to reflect that LL construction is now accepted.
* src/ext/ep/README.md: update status matrix, document the
NVSHMEM -> MSCCL++ translation table, and list the known
limitations.
This is a structural port: the kernels compile, link, and pass the
single-rank smoke tests, but end-to-end behaviour on multi-node H100
is not yet validated. Two known caveats:
1. Performance will NOT match IBGDA because MSCCL++ port channels
use a CPU proxy; this port is for functional parity, not latency.
2. Buffer::sync() in LL mode only connects peers that share the
same local GPU id (DeepEP convention), so the LL kernels assume
a one-GPU-per-node topology (num_ranks == num_rdma_ranks).
Multi-GPU-per-node LL layouts will need a follow-up in sync().
Tested:
cmake --build build -j --target mscclpp_ep_cpp # builds clean
pytest test/python/ext/ep/test_ep_smoke.py # 3 passed
This commit is contained in:
@@ -13,8 +13,9 @@ Current status (see ``src/ext/ep/README.md``):
|
||||
|
||||
* Intranode (NVLink-only) dispatch and combine are fully ported.
|
||||
* ``get_dispatch_layout`` is ported.
|
||||
* Internode HT and low-latency methods raise from C++ — they still need
|
||||
the NVSHMEM/IBGDA -> MSCCL++ PortChannel migration.
|
||||
* Internode HT (MSCCL++ PortChannel + MemoryChannel) is ported.
|
||||
* Internode low-latency kernels are ported structurally (NVSHMEM/IBGDA ->
|
||||
MSCCL++ PortChannel) but **untested on multi-node H100**.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -49,10 +50,11 @@ class Buffer:
|
||||
num_nvl_bytes:
|
||||
Size of the NVLink-accessible scratch buffer (shared via CUDA IPC).
|
||||
num_rdma_bytes:
|
||||
Size of the RDMA scratch buffer. Must be 0 until internode/LL
|
||||
support is landed.
|
||||
Size of the RDMA scratch buffer. Required (>0) for internode HT and
|
||||
low-latency modes.
|
||||
low_latency_mode:
|
||||
Reserved — must be ``False`` until the LL path is ported.
|
||||
Enable the low-latency dispatch/combine path (structural port,
|
||||
untested).
|
||||
num_qps_per_rank:
|
||||
Ignored for intranode mode.
|
||||
"""
|
||||
@@ -68,13 +70,6 @@ class Buffer:
|
||||
low_latency_mode: bool = False,
|
||||
num_qps_per_rank: int = 12,
|
||||
) -> None:
|
||||
if low_latency_mode:
|
||||
raise NotImplementedError(
|
||||
"mscclpp.ext.ep.Buffer: low-latency mode is not yet ported. "
|
||||
"Set low_latency_mode=False. See src/ext/ep/README.md for the "
|
||||
"migration plan."
|
||||
)
|
||||
|
||||
self.rank: int = group.rank()
|
||||
self.group_size: int = group.size()
|
||||
self.group = group
|
||||
|
||||
@@ -41,7 +41,6 @@ endif()
|
||||
|
||||
file(GLOB_RECURSE EP_SOURCES CONFIGURE_DEPENDS
|
||||
buffer.cc
|
||||
internode_stub.cc
|
||||
bindings.cpp
|
||||
kernels/*.cu
|
||||
)
|
||||
|
||||
@@ -18,17 +18,37 @@ A port of DeepEP's MoE `dispatch`/`combine` primitives into MSCCL++, targeting:
|
||||
| `intranode_combine` (NVLink) | ✅ ported |
|
||||
| `internode_dispatch` (NVLink+RDMA) | ✅ ported (pending H100 test) |
|
||||
| `internode_combine` (NVLink+RDMA) | ✅ ported (pending H100 test) |
|
||||
| `low_latency_dispatch` (pure RDMA) | ❌ stub |
|
||||
| `low_latency_combine` (pure RDMA) | ❌ stub |
|
||||
| `low_latency_dispatch` (pure RDMA) | ⚠️ structural port, untested |
|
||||
| `low_latency_combine` (pure RDMA) | ⚠️ structural port, untested |
|
||||
| `Connection::atomicAdd` API | ✅ cherry-picked into mscclpp |
|
||||
| Python frontend `mscclpp.ext.ep` | ✅ wraps HT paths |
|
||||
| Python frontend `mscclpp.ext.ep` | ✅ wraps HT + LL paths |
|
||||
| pybind11 module `mscclpp_ep_cpp` | ✅ builds conditionally |
|
||||
|
||||
Internode HT is code-complete but unverified on real hardware — the
|
||||
`sync()` path replaces DeepEP's NVSHMEM symmetric-heap allocation with
|
||||
`cudaMalloc` + `bootstrap->barrier()`, and the kernels use the new
|
||||
`PortChannelDeviceHandle::atomicAdd` instead of the old raw-trigger
|
||||
pattern. The low-latency path is the only remaining stub.
|
||||
pattern.
|
||||
|
||||
The low-latency port is **structural**: the DeepEP LL kernels (pure
|
||||
IBGDA) have been mechanically translated to MSCCL++ port-channel ops.
|
||||
Semantic mapping:
|
||||
|
||||
| DeepEP / IBGDA | MSCCL++ replacement |
|
||||
|------------------------------------------|--------------------------------------------------------------|
|
||||
| `nvshmemx_barrier_all_block()` | signal+wait ring across `port_channel_handles[peer_rank]` |
|
||||
| `nvshmemi_ibgda_put_nbi_warp(...)` | lane-0 `port_channel_handles[qp*N+dst].put(dst_off, src_off, n)` |
|
||||
| `nvshmemi_ibgda_amo_nonfetch_add(...)` | lane-0 `port_channel_handles[qp*N+dst].atomicAdd(off, int64)` |
|
||||
|
||||
**Known limitations**:
|
||||
|
||||
* LL performance will NOT match IBGDA — the MSCCL++ port channel uses a
|
||||
CPU proxy. The port is for functional parity, not latency.
|
||||
* `Buffer::sync()` in `low_latency_mode=True` only connects peers sharing
|
||||
the same local GPU ID (DeepEP convention). LL kernels therefore assume
|
||||
one-GPU-per-node topology, i.e. `num_ranks == num_rdma_ranks`. Running
|
||||
with >1 GPU per node in LL mode will fail to reach cross-GPU peers.
|
||||
* Multi-node H100 validation is still pending.
|
||||
|
||||
## Build
|
||||
|
||||
@@ -58,7 +78,6 @@ src/ext/ep/
|
||||
├── buffer.hpp / buffer.cc — host-side Buffer, sync(), dispatch/combine
|
||||
├── config.hpp / event.hpp — Config, EventHandle
|
||||
├── bindings.cpp — PYBIND11_MODULE definition
|
||||
├── internode_stub.cc — stubs for not-yet-ported LL launchers
|
||||
└── kernels/
|
||||
├── api.cuh — host-callable kernel prototypes
|
||||
├── configs.cuh — compile-time constants (GPU-only)
|
||||
@@ -67,8 +86,9 @@ src/ext/ep/
|
||||
├── launch.cuh — SETUP_LAUNCH_CONFIG / SWITCH_* macros
|
||||
├── utils.cuh — device inline helpers
|
||||
├── runtime.cu — intranode::barrier launcher
|
||||
├── intranode_kernel.cu — notify_dispatch / dispatch / combine kernels
|
||||
└── internode_layout.cu — get_dispatch_layout (CPU-safe subset)
|
||||
├── intranode_kernel.cu — intranode dispatch/combine kernels
|
||||
├── internode.cu — internode HT dispatch/combine + layout
|
||||
└── internode_ll.cu — internode LL dispatch/combine (structural)
|
||||
|
||||
python/mscclpp/ext/ep/
|
||||
├── __init__.py — reexports Buffer / Config / EventHandle
|
||||
|
||||
@@ -1126,24 +1126,193 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
||||
return {combined_x, combined_topk_weights, event};
|
||||
}
|
||||
|
||||
void Buffer::clean_low_latency_buffer(int, int, int) {
|
||||
throw std::runtime_error("mscclpp::ep::Buffer::clean_low_latency_buffer: not yet ported");
|
||||
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
|
||||
auto clean_meta_0 = layout.buffers[0].clean_meta();
|
||||
auto clean_meta_1 = layout.buffers[1].clean_meta();
|
||||
|
||||
auto check_boundary = [=](void* ptr, size_t num_bytes) {
|
||||
auto offset = reinterpret_cast<int64_t>(ptr) - reinterpret_cast<int64_t>(rdma_buffer_ptr);
|
||||
EP_HOST_ASSERT(0 <= offset and offset + static_cast<int64_t>(num_bytes) <= num_rdma_bytes);
|
||||
};
|
||||
check_boundary(clean_meta_0.first, clean_meta_0.second * sizeof(int));
|
||||
check_boundary(clean_meta_1.first, clean_meta_1.second * sizeof(int));
|
||||
|
||||
internode_ll::clean_low_latency_buffer(clean_meta_0.first, clean_meta_0.second,
|
||||
clean_meta_1.first, clean_meta_1.second,
|
||||
rank, num_ranks,
|
||||
port_channel_handles_device_ptr.get(),
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
Buffer::low_latency_dispatch(const torch::Tensor&, const torch::Tensor&, int, int, bool, bool, bool) {
|
||||
throw std::runtime_error("mscclpp::ep::Buffer::low_latency_dispatch: not yet ported (needs NVSHMEM/IBGDA -> MSCCL++ migration)");
|
||||
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool use_fp8, bool async, bool return_recv_hook) {
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);
|
||||
EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0);
|
||||
EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
|
||||
EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank);
|
||||
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
|
||||
EP_HOST_ASSERT(num_experts % num_ranks == 0);
|
||||
|
||||
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
|
||||
auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1));
|
||||
int num_local_experts = num_experts / num_ranks;
|
||||
|
||||
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
|
||||
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
|
||||
auto buffer = layout.buffers[low_latency_buffer_idx];
|
||||
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
|
||||
|
||||
auto compute_stream = at::cuda::getCurrentCUDAStream();
|
||||
auto launch_stream = return_recv_hook ? compute_stream : comm_stream;
|
||||
EP_HOST_ASSERT(not (async and return_recv_hook));
|
||||
if (not return_recv_hook)
|
||||
stream_wait(launch_stream, compute_stream);
|
||||
|
||||
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
|
||||
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16));
|
||||
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
|
||||
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
|
||||
auto packed_recv_x_scales = std::optional<torch::Tensor>();
|
||||
float* packed_recv_x_scales_ptr = nullptr;
|
||||
if (use_fp8) {
|
||||
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
|
||||
packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
|
||||
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
|
||||
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr<float>();
|
||||
}
|
||||
|
||||
auto next_clean_meta = next_buffer.clean_meta();
|
||||
auto port_handles = port_channel_handles_device_ptr.get();
|
||||
auto rdma_base = rdma_buffer_ptr;
|
||||
auto launcher = [=](int phases) {
|
||||
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
|
||||
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
|
||||
packed_recv_count.data_ptr<int>(),
|
||||
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
|
||||
buffer.dispatch_rdma_send_buffer,
|
||||
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
|
||||
next_clean_meta.first, next_clean_meta.second,
|
||||
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
||||
num_topk, num_experts, rank, num_ranks, use_fp8,
|
||||
workspace, launch_stream, phases,
|
||||
rdma_base, port_handles);
|
||||
};
|
||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||
|
||||
std::optional<EventHandle> event;
|
||||
if (async) {
|
||||
event = EventHandle(launch_stream);
|
||||
} else if (not return_recv_hook) {
|
||||
stream_wait(compute_stream, launch_stream);
|
||||
}
|
||||
|
||||
std::optional<std::function<void()>> recv_hook = std::nullopt;
|
||||
if (return_recv_hook)
|
||||
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
|
||||
|
||||
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
Buffer::low_latency_combine(const torch::Tensor&, const torch::Tensor&, const torch::Tensor&,
|
||||
const torch::Tensor&, const torch::Tensor&,
|
||||
int, int, bool, bool, bool, const std::optional<torch::Tensor>&) {
|
||||
throw std::runtime_error("mscclpp::ep::Buffer::low_latency_combine: not yet ported");
|
||||
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& src_info, const torch::Tensor& layout_range,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool zero_copy, bool async, bool return_recv_hook,
|
||||
const std::optional<torch::Tensor>& out) {
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);
|
||||
EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks);
|
||||
EP_HOST_ASSERT(x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank);
|
||||
EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0);
|
||||
EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
|
||||
EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1));
|
||||
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
|
||||
EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous());
|
||||
EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank);
|
||||
EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32);
|
||||
EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous());
|
||||
EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0));
|
||||
EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());
|
||||
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
|
||||
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);
|
||||
auto hidden = static_cast<int>(x.size(2));
|
||||
auto num_topk = static_cast<int>(topk_weights.size(1));
|
||||
auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
|
||||
|
||||
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
|
||||
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
|
||||
auto buffer = layout.buffers[low_latency_buffer_idx];
|
||||
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
|
||||
|
||||
auto compute_stream = at::cuda::getCurrentCUDAStream();
|
||||
auto launch_stream = return_recv_hook ? compute_stream : comm_stream;
|
||||
EP_HOST_ASSERT(not (async and return_recv_hook));
|
||||
if (not return_recv_hook)
|
||||
stream_wait(launch_stream, compute_stream);
|
||||
|
||||
torch::Tensor combined_x;
|
||||
if (out.has_value()) {
|
||||
EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous());
|
||||
EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden);
|
||||
EP_HOST_ASSERT(out->scalar_type() == x.scalar_type());
|
||||
combined_x = out.value();
|
||||
} else {
|
||||
combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
|
||||
}
|
||||
|
||||
auto next_clean_meta = next_buffer.clean_meta();
|
||||
auto port_handles = port_channel_handles_device_ptr.get();
|
||||
auto rdma_base = rdma_buffer_ptr;
|
||||
auto launcher = [=](int phases) {
|
||||
internode_ll::combine(combined_x.data_ptr(),
|
||||
buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer,
|
||||
buffer.combine_rdma_send_buffer,
|
||||
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
|
||||
src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(),
|
||||
next_clean_meta.first, next_clean_meta.second,
|
||||
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
||||
num_topk, num_experts, rank, num_ranks,
|
||||
workspace, launch_stream,
|
||||
phases, zero_copy,
|
||||
rdma_base, port_handles);
|
||||
};
|
||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||
|
||||
std::optional<EventHandle> event;
|
||||
if (async) {
|
||||
event = EventHandle(launch_stream);
|
||||
} else if (not return_recv_hook) {
|
||||
stream_wait(compute_stream, launch_stream);
|
||||
}
|
||||
|
||||
std::optional<std::function<void()>> recv_hook = std::nullopt;
|
||||
if (return_recv_hook)
|
||||
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
|
||||
|
||||
return {combined_x, event, recv_hook};
|
||||
}
|
||||
|
||||
torch::Tensor Buffer::get_next_low_latency_combine_buffer(int, int, int) {
|
||||
throw std::runtime_error("mscclpp::ep::Buffer::get_next_low_latency_combine_buffer: not yet ported");
|
||||
torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
|
||||
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
|
||||
auto buffer = layout.buffers[low_latency_buffer_idx];
|
||||
auto dtype = torch::kBFloat16;
|
||||
auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16));
|
||||
|
||||
EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0);
|
||||
return torch::from_blob(buffer.combine_rdma_send_buffer_data_start,
|
||||
{num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
|
||||
{num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1},
|
||||
torch::TensorOptions().dtype(dtype).device(torch::kCUDA));
|
||||
}
|
||||
|
||||
} // namespace ep
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
//
|
||||
// Placeholder launchers for the not-yet-ported internode HT and low-latency
|
||||
// kernels. `get_dispatch_layout` and `get_source_meta_bytes` ARE ported in
|
||||
// `kernels/internode_layout.cu`.
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#include "kernels/api.cuh"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace ep {
|
||||
|
||||
namespace internode_ll {
|
||||
|
||||
void clean_low_latency_buffer(int* /*clean_0*/, int /*n0*/, int* /*clean_1*/, int /*n1*/, cudaStream_t /*stream*/) {
|
||||
throw std::runtime_error(
|
||||
"mscclpp::ep::internode_ll::clean_low_latency_buffer: not yet ported. "
|
||||
"See nccl/contrib/nccl_ep/device/low_latency.cu and DeepEP "
|
||||
"csrc/kernels/internode_ll.cu for the reference implementation.");
|
||||
}
|
||||
|
||||
} // namespace internode_ll
|
||||
|
||||
} // namespace ep
|
||||
} // namespace mscclpp
|
||||
@@ -129,13 +129,42 @@ void combine(cudaDataType_t type,
|
||||
} // namespace internode
|
||||
|
||||
// ===========================================================================
|
||||
// Internode low-latency (pure RDMA) kernels. Not ported yet.
|
||||
// Internode low-latency (pure RDMA) kernels. Ported from DeepEP
|
||||
// `csrc/kernels/internode_ll.cu` with NVSHMEM/IBGDA device ops replaced by
|
||||
// MSCCL++ port-channel primitives (`put`, `atomicAdd`, signal/wait barrier).
|
||||
// ===========================================================================
|
||||
namespace internode_ll {
|
||||
|
||||
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, int* clean_1, int num_clean_int_1,
|
||||
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
int* clean_1, int num_clean_int_1,
|
||||
int rank, int num_ranks,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
cudaStream_t stream);
|
||||
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
||||
void* workspace, cudaStream_t stream, int phases,
|
||||
void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles);
|
||||
|
||||
void combine(void* combined_x,
|
||||
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
|
||||
const void* x, const int64_t* topk_idx, const float* topk_weights,
|
||||
const int* src_info, const int64_t* layout_range,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream,
|
||||
int phases, bool zero_copy,
|
||||
void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles);
|
||||
|
||||
} // namespace internode_ll
|
||||
|
||||
} // namespace ep
|
||||
|
||||
595
src/ext/ep/kernels/internode_ll.cu
Normal file
595
src/ext/ep/kernels/internode_ll.cu
Normal file
@@ -0,0 +1,595 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
//
|
||||
// Low-latency internode dispatch/combine kernels ported from DeepEP
|
||||
// `csrc/kernels/internode_ll.cu` (branch `chhwang/dev-atomic-add-cleanup`).
|
||||
//
|
||||
// NVSHMEM/IBGDA device calls are replaced with MSCCL++ PortChannel device
|
||||
// operations:
|
||||
//
|
||||
// nvshmemx_barrier_all_block() -> port-channel signal/wait ring
|
||||
// nvshmemi_ibgda_put_nbi_warp(...) -> port_channel.put(...) (lane 0)
|
||||
// nvshmemi_ibgda_amo_nonfetch_add(...) -> port_channel.atomicAdd(...)
|
||||
//
|
||||
// Addressing convention:
|
||||
// - `rdma_buffer_ptr` is the base of the locally-registered RDMA buffer.
|
||||
// - Remote counter/buffer pointers written by the kernel are virtual
|
||||
// addresses that alias the corresponding offset inside each peer's
|
||||
// symmetric RDMA buffer. MSCCL++ needs those as offsets; we derive them
|
||||
// via `ptr - rdma_buffer_ptr`.
|
||||
// - Port-channel layout built by `Buffer::sync()` in low-latency mode is
|
||||
// `handles[qp * num_peers + peer_idx]` where `peer_idx` is the dst rank's
|
||||
// position in the connected-peer map. In the recommended 1-GPU-per-node
|
||||
// LL topology, `peer_idx == dst_rank`; see src/ext/ep/README.md.
|
||||
//
|
||||
// WARNING: This port is untested on multi-node H100; performance will NOT
|
||||
// match IBGDA (host-proxy adds latency). Functional correctness needs
|
||||
// validation on real hardware.
|
||||
|
||||
#include "configs.cuh"
|
||||
#include "exception.cuh"
|
||||
#include "launch.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
namespace mscclpp {
|
||||
namespace ep {
|
||||
|
||||
namespace internode_ll {
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Device helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Pointer-to-offset helper for MSCCL++ port channels. Both src and dst
|
||||
// pointers passed to the DeepEP LL kernels are virtual addresses aliased into
|
||||
// the caller's symmetric RDMA buffer; MSCCL++ expects offsets.
|
||||
__device__ __forceinline__ uint64_t rdma_offset_of(uint64_t ptr, void* rdma_buffer_ptr) {
|
||||
return ptr - reinterpret_cast<uint64_t>(rdma_buffer_ptr);
|
||||
}
|
||||
|
||||
// Cross-rank barrier via port-channel signal/wait ring.
|
||||
// Uses port channel `qp=0` across all connected peers.
|
||||
__device__ __forceinline__ void port_channel_barrier_block(
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
int rank, int num_ranks) {
|
||||
const int tid = threadIdx.x;
|
||||
if (tid < num_ranks && tid != rank) {
|
||||
// Index: qp 0, peer = tid's rank (assumes peer_idx == rank in LL topology).
|
||||
port_channel_handles[tid].signal();
|
||||
port_channel_handles[tid].wait();
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// clean_low_latency_buffer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
template <int kNumThreads> __launch_bounds__(kNumThreads, 1)
|
||||
__global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
int* clean_1, int num_clean_int_1,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
int rank, int num_ranks) {
|
||||
// Barrier before cleaning (in case of unfinished chunked EP)
|
||||
port_channel_barrier_block(port_channel_handles, rank, num_ranks);
|
||||
|
||||
// Clean
|
||||
auto thread_id = static_cast<int>(threadIdx.x);
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)
|
||||
clean_0[i] = 0;
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
|
||||
clean_1[i] = 0;
|
||||
|
||||
// Barrier after cleaning (make sure low-latency mode work fine)
|
||||
port_channel_barrier_block(port_channel_handles, rank, num_ranks);
|
||||
}
|
||||
|
||||
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
int* clean_1, int num_clean_int_1,
|
||||
int rank, int num_ranks,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
cudaStream_t stream) {
|
||||
constexpr int kNumThreads = 256;
|
||||
|
||||
SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
|
||||
LAUNCH_KERNEL(&cfg, clean_low_latency_buffer<kNumThreads>,
|
||||
clean_0, num_clean_int_0, clean_1, num_clean_int_1,
|
||||
port_channel_handles, rank, num_ranks);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// dispatch
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
template <bool kUseFP8, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
||||
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
int phases,
|
||||
void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||
const auto num_local_experts = num_experts / num_ranks;
|
||||
const auto warp_group_id = warp_id / kNumWarpsPerGroup;
|
||||
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
|
||||
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
|
||||
|
||||
// FP8 staffs
|
||||
constexpr int kNumPerChannels = 128;
|
||||
constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f;
|
||||
const int num_scales = kHidden / kNumPerChannels;
|
||||
const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16));
|
||||
const size_t hidden_int4 = hidden_bytes / sizeof(int4);
|
||||
|
||||
// Message package: hidden data, FP8 scales, index at source
|
||||
using vec_t = typename std::conditional<kUseFP8, int2, int4>::type;
|
||||
const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16)));
|
||||
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
|
||||
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
|
||||
|
||||
// Sending phase
|
||||
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
|
||||
goto LOW_LATENCY_DISPATCH_RECV;
|
||||
|
||||
// Expert counts
|
||||
__shared__ int shared_num_tokens_sent_per_expert[kNumWarpGroups];
|
||||
|
||||
if (warp_id < num_warps - 1) {
|
||||
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16);
|
||||
EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
|
||||
EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization");
|
||||
const auto num_threads = (num_warps - 1) * 32;
|
||||
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;
|
||||
|
||||
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
|
||||
const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
|
||||
const auto rdma_x_src_idx = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
|
||||
const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
|
||||
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
|
||||
|
||||
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
|
||||
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
|
||||
|
||||
// FP8 cast
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
|
||||
auto int4_value = __ldg(x_int4 + i);
|
||||
|
||||
if (kUseFP8) {
|
||||
auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);
|
||||
float fp32_values[kNumElemsPerRead];
|
||||
float amax = kFP8Margin, scale, scale_inv;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kNumElemsPerRead; ++ j) {
|
||||
fp32_values[j] = static_cast<float>(bf16_values[j]);
|
||||
amax = fmaxf(amax, fabsf(fp32_values[j]));
|
||||
}
|
||||
|
||||
EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
|
||||
amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, scale_inv = amax * kFP8AmaxInv;
|
||||
if (lane_id == 0 or lane_id == 16)
|
||||
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
|
||||
|
||||
vec_t int2_value;
|
||||
auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kNumElemsPerRead; j += 2) {
|
||||
float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
|
||||
fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3);
|
||||
}
|
||||
rdma_x_vec[i] = int2_value;
|
||||
} else {
|
||||
rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
|
||||
}
|
||||
}
|
||||
asm volatile("bar.sync 1, %0;" :: "r"(num_threads));
|
||||
|
||||
// Issue sends
|
||||
if (dst_expert_idx >= 0) {
|
||||
int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
|
||||
slot_idx = __shfl_sync(0xffffffff, slot_idx, 0);
|
||||
const auto dst_rank = dst_expert_idx / num_local_experts;
|
||||
const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
|
||||
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||
slot_idx * num_bytes_per_msg;
|
||||
if (dst_rank != rank) {
|
||||
// MSCCL++ port-channel PUT (lane 0 issues one request).
|
||||
if (lane_id == 0) {
|
||||
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
|
||||
const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr);
|
||||
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank]
|
||||
.put(dst_off, src_off, num_bytes_per_msg);
|
||||
}
|
||||
__syncwarp();
|
||||
} else {
|
||||
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
|
||||
const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
|
||||
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
|
||||
}
|
||||
|
||||
__syncwarp();
|
||||
lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
|
||||
}
|
||||
}
|
||||
} else if (warp_id == num_warps - 1) {
|
||||
EP_DEVICE_ASSERT(num_sms > 1);
|
||||
if (sm_id == 0) {
|
||||
// NOTE: DeepEP asserts `ibgda_get_state()->num_rc_per_pe >= num_local_experts`
|
||||
// here. The MSCCL++ port relies on Buffer::sync() provisioning enough QPs
|
||||
// (see `num_ib_connections_per_rank` / `num_port_channels_per_rank`).
|
||||
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < num_next_clean_int; i += 32)
|
||||
next_clean[i] = 0;
|
||||
|
||||
__syncwarp();
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < num_experts; i += 32)
|
||||
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
|
||||
}
|
||||
|
||||
int expert_count[kNumWarpGroups] = {0};
|
||||
const auto expert_begin_idx = sm_id * kNumWarpGroups;
|
||||
const auto expert_end_idx = min(expert_begin_idx + kNumWarpGroups, num_experts);
|
||||
|
||||
#pragma unroll 8
|
||||
for (int i = lane_id; i < num_tokens * num_topk; i += 32) {
|
||||
auto idx = static_cast<int>(__ldg(topk_idx + i));
|
||||
if (idx >= expert_begin_idx and idx < expert_end_idx)
|
||||
expert_count[idx - expert_begin_idx] ++;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
|
||||
auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
|
||||
if (lane_id == 0) {
|
||||
shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
|
||||
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Issue count sends
|
||||
if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
|
||||
const auto dst_rank = responsible_expert_idx / num_local_experts;
|
||||
const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
|
||||
const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * kNumWarpGroups];
|
||||
|
||||
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
|
||||
if (dst_rank != rank) {
|
||||
auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
|
||||
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(counter_ptr), rdma_buffer_ptr);
|
||||
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank]
|
||||
.atomicAdd(off, static_cast<int64_t>(-num_tokens_sent - 1));
|
||||
} else {
|
||||
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
|
||||
}
|
||||
|
||||
atomic_counter_per_expert[responsible_expert_idx] = 0;
|
||||
atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
|
||||
|
||||
if (dst_rank == 0)
|
||||
packed_recv_count[dst_expert_local_idx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Receiving phase
|
||||
LOW_LATENCY_DISPATCH_RECV:
|
||||
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
|
||||
return;
|
||||
|
||||
if (phases & LOW_LATENCY_SEND_PHASE)
|
||||
cg::this_grid().sync();
|
||||
|
||||
if (responsible_expert_idx < num_experts) {
|
||||
const auto src_rank = responsible_expert_idx / num_local_experts;
|
||||
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
|
||||
const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
|
||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
|
||||
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
|
||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
|
||||
const auto recv_x_scales = packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales;
|
||||
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
|
||||
|
||||
__shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups];
|
||||
|
||||
int num_recv_tokens, recv_token_begin_idx;
|
||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
|
||||
if (sub_warp_id == 1 and lane_id == 0) {
|
||||
while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
|
||||
num_recv_tokens = -num_recv_tokens - 1;
|
||||
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
|
||||
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
|
||||
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
|
||||
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
|
||||
}
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32));
|
||||
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
|
||||
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
|
||||
|
||||
EP_DEVICE_ASSERT(num_scales <= 64);
|
||||
for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) {
|
||||
const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
|
||||
if (lane_id == 0)
|
||||
recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
|
||||
__syncwarp();
|
||||
|
||||
const auto src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));
|
||||
const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
|
||||
|
||||
if (kUseFP8) {
|
||||
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
|
||||
const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i);
|
||||
const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0;
|
||||
auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0;
|
||||
lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f;
|
||||
(lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
||||
void* workspace, cudaStream_t stream, int phases,
|
||||
void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles) {
|
||||
constexpr int kNumMaxTopK = 9;
|
||||
constexpr int kNumWarpsPerGroup = 10;
|
||||
constexpr int kNumWarpGroups = 3;
|
||||
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
|
||||
|
||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
|
||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
||||
|
||||
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
|
||||
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
|
||||
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(hidden_case) { \
|
||||
auto dispatch_func = use_fp8 ? dispatch<true, kNumWarpGroups, kNumWarpsPerGroup, hidden_case> : \
|
||||
dispatch<false, kNumWarpGroups, kNumWarpsPerGroup, hidden_case>; \
|
||||
LAUNCH_KERNEL(&cfg, dispatch_func, \
|
||||
packed_recv_x, packed_recv_x_scales, \
|
||||
packed_recv_src_info, packed_recv_layout_range, \
|
||||
packed_recv_count, \
|
||||
rdma_recv_x, rdma_recv_count, rdma_x, \
|
||||
x, topk_idx, \
|
||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
|
||||
next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, \
|
||||
num_topk, num_experts, rank, num_ranks, phases, \
|
||||
rdma_buffer_ptr, port_channel_handles); } break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
|
||||
#undef DISPATCH_LAUNCH_CASE
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// combine
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
|
||||
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
||||
combine(void* combined_x,
|
||||
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
|
||||
const void* x, const int64_t* topk_idx, const float* topk_weights,
|
||||
const int* src_info, const int64_t* layout_range,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int* atomic_clean_flag,
|
||||
int num_combined_tokens, int hidden, int num_topk,
|
||||
int num_max_dispatch_tokens_per_rank,
|
||||
int num_experts, int rank, int num_ranks,
|
||||
int phases, bool zero_copy,
|
||||
void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto num_threads = static_cast<int>(blockDim.x);
|
||||
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||
const auto num_local_experts = num_experts / num_ranks;
|
||||
const auto warp_group_id = warp_id / kNumWarpsPerGroup;
|
||||
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
|
||||
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
|
||||
|
||||
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
|
||||
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
|
||||
|
||||
constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16);
|
||||
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
|
||||
|
||||
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
|
||||
goto LOW_LATENCY_COMBINE_RECV;
|
||||
|
||||
if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < num_next_clean_int; i += 32)
|
||||
next_clean[i] = 0;
|
||||
|
||||
__syncwarp();
|
||||
if (lane_id == 0)
|
||||
atomic_add_release_global(atomic_clean_flag, num_experts);
|
||||
}
|
||||
|
||||
if (responsible_expert_idx < num_experts) {
|
||||
const auto dst_rank = responsible_expert_idx / num_local_experts;
|
||||
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
|
||||
const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
|
||||
const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
|
||||
const auto local_x = reinterpret_cast<const int4*>(x) +
|
||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
|
||||
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
|
||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
|
||||
|
||||
int offset, num_tokens_to_send;
|
||||
unpack2(layout, num_tokens_to_send, offset);
|
||||
|
||||
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
|
||||
const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
|
||||
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
|
||||
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
|
||||
|
||||
auto src_idx = __ldg(local_src_info + token_idx);
|
||||
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
|
||||
if (dst_rank == rank) {
|
||||
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
} else {
|
||||
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
|
||||
if (not zero_copy)
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
// MSCCL++ port-channel PUT.
|
||||
if (lane_id == 0) {
|
||||
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
|
||||
const auto src_off = rdma_offset_of(static_cast<uint64_t>(buf_ptr), rdma_buffer_ptr);
|
||||
port_channel_handles[local_expert_idx * num_ranks + dst_rank]
|
||||
.put(dst_off, src_off, hidden * sizeof(nv_bfloat16));
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
|
||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32));
|
||||
if (sub_warp_id == 1 and lane_id == 0) {
|
||||
while (ld_acquire_global(atomic_clean_flag) == 0);
|
||||
if (dst_rank != rank) {
|
||||
auto* flag_ptr = rdma_recv_flag + global_expert_idx;
|
||||
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(flag_ptr), rdma_buffer_ptr);
|
||||
port_channel_handles[local_expert_idx * num_ranks + dst_rank]
|
||||
.atomicAdd(off, static_cast<int64_t>(1));
|
||||
} else {
|
||||
st_na_release(rdma_recv_flag + global_expert_idx, 1);
|
||||
}
|
||||
atomic_add_release_global(atomic_clean_flag, -1);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
LOW_LATENCY_COMBINE_RECV:
|
||||
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
|
||||
return;
|
||||
|
||||
if (responsible_expert_idx < num_experts) {
|
||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
|
||||
if (sub_warp_id == 0 and lane_id == 0)
|
||||
while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
|
||||
}
|
||||
cg::this_grid().sync();
|
||||
|
||||
EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads);
|
||||
EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization");
|
||||
if (thread_id < hidden_bf16_int4) {
|
||||
for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
|
||||
int reg_topk_idx[kNumMaxTopk];
|
||||
float reg_topk_weights[kNumMaxTopk];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_topk; ++ i) {
|
||||
reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
|
||||
reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
|
||||
}
|
||||
|
||||
float combined_values[kNumElemsPerInt4] = {0.0f};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
|
||||
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
|
||||
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
|
||||
|
||||
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
|
||||
const auto x_bf16 = reinterpret_cast<nv_bfloat16*>(&x_vec);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kNumElemsPerInt4; ++ j)
|
||||
combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
|
||||
}
|
||||
|
||||
int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
|
||||
auto combined_bf16 = reinterpret_cast<nv_bfloat16*>(&combined_values);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kNumElemsPerInt4; ++ j)
|
||||
combined_bf16[j] = static_cast<nv_bfloat16>(combined_values[j]);
|
||||
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void combine(void* combined_x,
|
||||
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
|
||||
const void* x, const int64_t* topk_idx, const float* topk_weights,
|
||||
const int* src_info, const int64_t* layout_range,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream,
|
||||
int phases, bool zero_copy,
|
||||
void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles) {
|
||||
constexpr int kNumWarpsPerGroup = 10;
|
||||
constexpr int kNumWarpGroups = 3;
|
||||
constexpr int kNumMaxTopk = 9;
|
||||
|
||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
|
||||
|
||||
auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
|
||||
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
|
||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
|
||||
|
||||
#define COMBINE_LAUNCH_CASE(hidden_case) { \
|
||||
auto combine_func = combine<kNumWarpGroups, kNumWarpsPerGroup, hidden_case, kNumMaxTopk>; \
|
||||
LAUNCH_KERNEL(&cfg, combine_func, \
|
||||
combined_x, \
|
||||
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
|
||||
x, topk_idx, topk_weights, src_info, layout_range, \
|
||||
next_clean, num_next_clean_int, \
|
||||
atomic_clean_flag, \
|
||||
num_combined_tokens, hidden, num_topk, \
|
||||
num_max_dispatch_tokens_per_rank, \
|
||||
num_experts, rank, num_ranks, \
|
||||
phases, zero_copy, \
|
||||
rdma_buffer_ptr, port_channel_handles); } break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
|
||||
#undef COMBINE_LAUNCH_CASE
|
||||
}
|
||||
|
||||
} // namespace internode_ll
|
||||
} // namespace ep
|
||||
} // namespace mscclpp
|
||||
@@ -33,19 +33,14 @@ def test_low_latency_size_hint():
|
||||
assert _cpp.get_low_latency_rdma_size_hint(128, 7168, 8, 256) > 0
|
||||
|
||||
|
||||
def test_low_latency_rejected():
|
||||
# Low-latency (pure RDMA) path is not ported yet; Python frontend must
|
||||
# refuse to construct a Buffer with low_latency_mode=True. We test the
|
||||
# underlying C++ constructor directly so this does not depend on the
|
||||
# full `mscclpp` Python package being installed.
|
||||
def test_low_latency_buffer_construct():
|
||||
# Low-latency kernels are structurally ported. At construction time the
|
||||
# C++ Buffer must accept low_latency_mode=True; runtime requires a real
|
||||
# multi-node setup (see tests in tests/test_low_latency.py when ported).
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
# The C++ Buffer allows low_latency_mode at construction; the enforcement
|
||||
# lives in the Python frontend (`mscclpp.ext.ep.buffer.Buffer.__init__`).
|
||||
# Verify the C++ side does NOT reject it, so the guarantee sits at the
|
||||
# Python layer where it belongs.
|
||||
buf = _cpp.Buffer(rank=0, num_ranks=1, num_nvl_bytes=0, num_rdma_bytes=0, low_latency_mode=True)
|
||||
assert not buf.is_available()
|
||||
|
||||
Reference in New Issue
Block a user