ext/ep: LL intra-node fast path via CUDA IPC + MemoryChannel

When all ranks live on the same host (num_rdma_ranks == 1), the LL
kernels now bypass PortChannel/IB-loopback entirely. In Buffer::sync()
we additionally:
  - allGather IPC handles for each rank's rdma_buffer_ptr and
    cudaIpcOpenMemHandle them into peer_rdma_bases[]
  - build per-peer MemoryChannels over CUDA IPC connections (tag=2)
    used only for the LL barrier ring

The three LL kernels (clean / dispatch / combine) gain a kIpcPath
template parameter and two extra args (peer_rdma_bases,
memory_channel_handles). At each peer op:
  - put -> peer-mapped warp copy over NVLink
  - atomicAdd-like flag store -> single-writer st_na_release on peer ptr
  - signal/wait barrier -> MemoryChannel signal/wait

Cross-node LL (num_rdma_ranks > 1) is untouched; the IPC setup block is
a no-op. The host launch wrappers select the variant via use_ipc_path.
This commit is contained in:
Qinghua Zhou
2026-04-23 21:10:39 +00:00
parent 906fa3c48f
commit b0eb5da53d
4 changed files with 306 additions and 66 deletions

View File

@@ -118,6 +118,18 @@ Buffer::~Buffer() noexcept(false) {
// failed, so there is nothing to tear down.
}
// Intra-node LL IPC fast-path teardown.
if (ll_ipc_ready) {
for (int i = 0; i < num_ranks; ++i) {
if (i == rank or peer_rdma_bases[i] == nullptr) continue;
CUDA_CHECK(cudaIpcCloseMemHandle(peer_rdma_bases[i]));
}
if (peer_rdma_bases_gpu != nullptr) {
CUDA_CHECK(cudaFree(peer_rdma_bases_gpu));
peer_rdma_bases_gpu = nullptr;
}
}
proxy_service->stopProxy();
// Free cuBLAS handle, workspace and MoE counter
@@ -372,6 +384,80 @@ void Buffer::sync(const std::vector<int> &device_ids,
mscclpp::gpuMemcpy<mscclpp::PortChannelDeviceHandle>(
port_channel_handles_device_ptr.get(), port_channel_handles.data(), port_channel_handles.size(),
cudaMemcpyHostToDevice);
// ------------------------------------------------------------------
// Intra-node LL fast path setup.
//
// When all ranks sit on the same host (num_rdma_ranks == 1), LL dispatch
// and combine still go through `PortChannel` above — which internally
// uses the proxy service over IB loopback between different HCAs on
// this platform. That path is correct but slow (caps at ~170 GB/s vs.
// NVLink's multi-TB/s). We additionally set up CUDA-IPC peer pointers
// to each peer's `rdma_buffer_ptr` plus a set of per-peer MemoryChannels
// for a barrier ring. The LL kernels select this path at launch time.
// Cross-node LL is unaffected: this block is a no-op there.
// ------------------------------------------------------------------
if (low_latency_mode and num_rdma_ranks == 1) {
EP_HOST_ASSERT(num_ranks == num_nvl_ranks);
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS);
// 1. Exchange CUDA IPC handles for rdma_buffer_ptr via bootstrap.
CUDA_CHECK(cudaIpcGetMemHandle(&rdma_ipc_handles[rank], rdma_buffer_ptr));
std::vector<cudaIpcMemHandle_t> all_rdma_handles(num_ranks);
all_rdma_handles[rank] = rdma_ipc_handles[rank];
bootstrap->allGather(all_rdma_handles.data(), sizeof(cudaIpcMemHandle_t));
peer_rdma_bases[rank] = rdma_buffer_ptr;
for (int r = 0; r < num_ranks; ++r) {
if (r == rank) continue;
rdma_ipc_handles[r] = all_rdma_handles[r];
CUDA_CHECK(cudaIpcOpenMemHandle(&peer_rdma_bases[r], rdma_ipc_handles[r],
cudaIpcMemLazyEnablePeerAccess));
}
CUDA_CHECK(cudaMalloc(&peer_rdma_bases_gpu, sizeof(void*) * NUM_MAX_NVL_PEERS));
CUDA_CHECK(cudaMemcpy(peer_rdma_bases_gpu, peer_rdma_bases,
sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));
// 2. Build MemoryChannels for the per-peer barrier ring. These use
// CUDA IPC connections (distinct tag from the existing port-channel
// machinery) so setup does not interfere with cross-node fallback.
constexpr int kLlIpcTag = 2;
auto rdma_mem_ipc = communicator->registerMemory(rdma_buffer_ptr, num_rdma_bytes, ipc_transport);
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remote_futures(num_ranks);
for (int r = 0; r < num_ranks; ++r) {
if (r == rank) continue;
communicator->sendMemory(rdma_mem_ipc, r, kLlIpcTag);
remote_futures[r] = communicator->recvMemory(r, kLlIpcTag);
}
std::vector<mscclpp::Connection> ll_ipc_conns(num_ranks);
{
std::vector<std::shared_future<mscclpp::Connection>> conn_futures(num_ranks);
mscclpp::EndpointConfig cfg(ipc_transport);
for (int r = 0; r < num_ranks; ++r) {
if (r == rank) continue;
conn_futures[r] = communicator->connect(cfg, r, kLlIpcTag);
}
for (int r = 0; r < num_ranks; ++r) {
if (r == rank) continue;
ll_ipc_conns[r] = conn_futures[r].get();
}
}
std::vector<mscclpp::MemoryChannelDeviceHandle> ll_handles(num_ranks);
for (int r = 0; r < num_ranks; ++r) {
if (r == rank) continue;
auto sema = std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*communicator, ll_ipc_conns[r]);
ll_memory_channels.emplace_back(sema, remote_futures[r].get(), rdma_mem_ipc);
ll_handles[r] = ll_memory_channels.rbegin()->deviceHandle();
}
ll_memory_channel_handles_device_ptr =
mscclpp::detail::gpuCallocShared<mscclpp::MemoryChannelDeviceHandle>(num_ranks);
mscclpp::gpuMemcpy<mscclpp::MemoryChannelDeviceHandle>(
ll_memory_channel_handles_device_ptr.get(), ll_handles.data(), num_ranks,
cudaMemcpyHostToDevice);
ll_ipc_ready = true;
}
}
// Ready to use
@@ -1175,6 +1261,9 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
clean_meta_1.first, clean_meta_1.second,
rank, num_ranks,
port_channel_handles_device_ptr.get(),
ll_memory_channel_handles_device_ptr ?
ll_memory_channel_handles_device_ptr.get() : nullptr,
ll_ipc_ready,
at::cuda::getCurrentCUDAStream());
}
@@ -1223,6 +1312,10 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto next_clean_meta = next_buffer.clean_meta();
auto port_handles = port_channel_handles_device_ptr.get();
auto mem_handles = ll_memory_channel_handles_device_ptr ?
ll_memory_channel_handles_device_ptr.get() : nullptr;
auto peer_bases = peer_rdma_bases_gpu;
const bool use_ipc = ll_ipc_ready;
auto rdma_base = rdma_buffer_ptr;
auto launcher = [=](int phases) {
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
@@ -1235,7 +1328,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, use_fp8,
workspace, launch_stream, phases,
rdma_base, port_handles);
rdma_base, port_handles,
peer_bases, mem_handles, use_ipc);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
@@ -1303,6 +1397,10 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
auto next_clean_meta = next_buffer.clean_meta();
auto port_handles = port_channel_handles_device_ptr.get();
auto mem_handles = ll_memory_channel_handles_device_ptr ?
ll_memory_channel_handles_device_ptr.get() : nullptr;
auto peer_bases = peer_rdma_bases_gpu;
const bool use_ipc = ll_ipc_ready;
auto rdma_base = rdma_buffer_ptr;
auto launcher = [=](int phases) {
internode_ll::combine(combined_x.data_ptr(),
@@ -1315,7 +1413,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
num_topk, num_experts, rank, num_ranks,
workspace, launch_stream,
phases, zero_copy,
rdma_base, port_handles);
rdma_base, port_handles,
peer_bases, mem_handles, use_ipc);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));

View File

@@ -82,6 +82,18 @@ private:
std::shared_ptr<mscclpp::PortChannelDeviceHandle> port_channel_handles_device_ptr;
std::shared_ptr<mscclpp::MemoryChannelDeviceHandle> memory_channel_handles_device_ptr;
// Intra-node LL only: peer-mapped RDMA buffer pointers (CUDA IPC).
// ``peer_rdma_bases[r]`` aliases rank ``r``'s ``rdma_buffer_ptr`` via
// ``cudaIpcOpenMemHandle`` (lazy peer access). Populated in ``sync()`` when
// ``low_latency_mode && num_rdma_ranks == 1``; null otherwise.
cudaIpcMemHandle_t rdma_ipc_handles[NUM_MAX_NVL_PEERS];
void* peer_rdma_bases[NUM_MAX_NVL_PEERS] = {nullptr};
void** peer_rdma_bases_gpu = nullptr;
// MemoryChannels over CUDA IPC used only for the LL barrier ring.
std::vector<mscclpp::MemoryChannel> ll_memory_channels;
std::shared_ptr<mscclpp::MemoryChannelDeviceHandle> ll_memory_channel_handles_device_ptr;
bool ll_ipc_ready = false;
private:
void move_fifo_slots(int num_slots = 1);

View File

@@ -139,6 +139,8 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1,
int rank, int num_ranks,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
bool use_ipc_path,
cudaStream_t stream);
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
@@ -151,7 +153,10 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
void* workspace, cudaStream_t stream, int phases,
void* rdma_buffer_ptr,
mscclpp::PortChannelDeviceHandle* port_channel_handles);
mscclpp::PortChannelDeviceHandle* port_channel_handles,
void* const* peer_rdma_bases,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
bool use_ipc_path);
void combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
@@ -163,7 +168,10 @@ void combine(void* combined_x,
void* workspace, cudaStream_t stream,
int phases, bool zero_copy,
void* rdma_buffer_ptr,
mscclpp::PortChannelDeviceHandle* port_channel_handles);
mscclpp::PortChannelDeviceHandle* port_channel_handles,
void* const* peer_rdma_bases,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
bool use_ipc_path);
} // namespace internode_ll

View File

@@ -33,6 +33,7 @@
#include <cooperative_groups.h>
#include <mscclpp/port_channel_device.hpp>
#include <mscclpp/memory_channel_device.hpp>
namespace cg = cooperative_groups;
@@ -52,6 +53,15 @@ __device__ __forceinline__ uint64_t rdma_offset_of(uint64_t ptr, void* rdma_buff
return ptr - reinterpret_cast<uint64_t>(rdma_buffer_ptr);
}
// Translate a local pointer aliased into our own rdma_buffer_ptr into the
// peer-mapped pointer for rank `peer_rank`. Only valid when the IPC fast path
// is active (`peer_rdma_bases != nullptr`).
__device__ __forceinline__ uint64_t peer_ptr_of(uint64_t local_ptr, void* const* peer_rdma_bases,
void* rdma_buffer_ptr, int peer_rank) {
const auto off = local_ptr - reinterpret_cast<uint64_t>(rdma_buffer_ptr);
return reinterpret_cast<uint64_t>(peer_rdma_bases[peer_rank]) + off;
}
// Cross-rank barrier via port-channel signal/wait ring.
// Uses port channel `qp=0` across all connected peers.
__device__ __forceinline__ void port_channel_barrier_block(
@@ -66,17 +76,43 @@ __device__ __forceinline__ void port_channel_barrier_block(
__syncthreads();
}
// Same barrier but using MemoryChannels (CUDA IPC, NVLink). Used on the
// intra-node LL fast path.
__device__ __forceinline__ void memory_channel_barrier_block(
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
int rank, int num_ranks) {
const int tid = threadIdx.x;
if (tid < num_ranks && tid != rank) {
memory_channel_handles[tid].signal();
memory_channel_handles[tid].wait();
}
__syncthreads();
}
template <bool kIpcPath>
__device__ __forceinline__ void ll_barrier_block(
mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
int rank, int num_ranks) {
if constexpr (kIpcPath) {
memory_channel_barrier_block(memory_channel_handles, rank, num_ranks);
} else {
port_channel_barrier_block(port_channel_handles, rank, num_ranks);
}
}
// ---------------------------------------------------------------------------
// clean_low_latency_buffer
// ---------------------------------------------------------------------------
template <int kNumThreads> __launch_bounds__(kNumThreads, 1)
template <int kNumThreads, bool kIpcPath> __launch_bounds__(kNumThreads, 1)
__global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
int rank, int num_ranks) {
// Barrier before cleaning (in case of unfinished chunked EP)
port_channel_barrier_block(port_channel_handles, rank, num_ranks);
ll_barrier_block<kIpcPath>(port_channel_handles, memory_channel_handles, rank, num_ranks);
// Clean
auto thread_id = static_cast<int>(threadIdx.x);
@@ -88,27 +124,35 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
clean_1[i] = 0;
// Barrier after cleaning (make sure low-latency mode work fine)
port_channel_barrier_block(port_channel_handles, rank, num_ranks);
ll_barrier_block<kIpcPath>(port_channel_handles, memory_channel_handles, rank, num_ranks);
}
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1,
int rank, int num_ranks,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
bool use_ipc_path,
cudaStream_t stream) {
constexpr int kNumThreads = 256;
SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
LAUNCH_KERNEL(&cfg, clean_low_latency_buffer<kNumThreads>,
clean_0, num_clean_int_0, clean_1, num_clean_int_1,
port_channel_handles, rank, num_ranks);
if (use_ipc_path) {
LAUNCH_KERNEL(&cfg, (clean_low_latency_buffer<kNumThreads, true>),
clean_0, num_clean_int_0, clean_1, num_clean_int_1,
port_channel_handles, memory_channel_handles, rank, num_ranks);
} else {
LAUNCH_KERNEL(&cfg, (clean_low_latency_buffer<kNumThreads, false>),
clean_0, num_clean_int_0, clean_1, num_clean_int_1,
port_channel_handles, memory_channel_handles, rank, num_ranks);
}
}
// ---------------------------------------------------------------------------
// dispatch
// ---------------------------------------------------------------------------
template <bool kUseFP8, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
template <bool kUseFP8, bool kIpcPath, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
@@ -121,7 +165,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int num_topk, int num_experts, int rank, int num_ranks,
int phases,
void* rdma_buffer_ptr,
mscclpp::PortChannelDeviceHandle* port_channel_handles) {
mscclpp::PortChannelDeviceHandle* port_channel_handles,
void* const* peer_rdma_bases,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
@@ -214,14 +260,22 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg;
if (dst_rank != rank) {
// MSCCL++ port-channel PUT (lane 0 issues one request).
if (lane_id == 0) {
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr);
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank]
.put(dst_off, src_off, num_bytes_per_msg);
if constexpr (kIpcPath) {
// Peer-mapped warp copy over NVLink (CUDA IPC).
const auto peer_dst = peer_ptr_of(dst_ptr, peer_rdma_bases, rdma_buffer_ptr, dst_rank);
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(peer_dst);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
} else {
// MSCCL++ port-channel PUT (lane 0 issues one request).
if (lane_id == 0) {
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr);
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank]
.put(dst_off, src_off, num_bytes_per_msg);
}
__syncwarp();
}
__syncwarp();
} else {
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
@@ -279,10 +333,19 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) {
auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(counter_ptr), rdma_buffer_ptr);
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank]
.atomicAdd(off, static_cast<int64_t>(-num_tokens_sent - 1));
if constexpr (kIpcPath) {
// Single writer per (dst_expert_local_idx, rank) slot, so a
// release-ordered store on the peer-mapped counter is enough.
auto peer_counter = reinterpret_cast<int64_t*>(peer_ptr_of(
reinterpret_cast<uint64_t>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank),
peer_rdma_bases, rdma_buffer_ptr, dst_rank));
st_na_release(peer_counter, static_cast<int64_t>(-num_tokens_sent - 1));
} else {
auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(counter_ptr), rdma_buffer_ptr);
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank]
.atomicAdd(off, static_cast<int64_t>(-num_tokens_sent - 1));
}
} else {
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
static_cast<int64_t>(-num_tokens_sent - 1));
@@ -367,7 +430,10 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
void* workspace, cudaStream_t stream, int phases,
void* rdma_buffer_ptr,
mscclpp::PortChannelDeviceHandle* port_channel_handles) {
mscclpp::PortChannelDeviceHandle* port_channel_handles,
void* const* peer_rdma_bases,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
bool use_ipc_path) {
constexpr int kNumMaxTopK = 9;
constexpr int kNumWarpsPerGroup = 10;
constexpr int kNumWarpGroups = 3;
@@ -382,19 +448,37 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
#define DISPATCH_LAUNCH_CASE(hidden_case) { \
auto dispatch_func = use_fp8 ? dispatch<true, kNumWarpGroups, kNumWarpsPerGroup, hidden_case> : \
dispatch<false, kNumWarpGroups, kNumWarpsPerGroup, hidden_case>; \
LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, phases, \
rdma_buffer_ptr, port_channel_handles); } break
if (use_ipc_path) { \
auto dispatch_func = use_fp8 ? dispatch<true, true, kNumWarpGroups, kNumWarpsPerGroup, hidden_case> \
: dispatch<false, true, kNumWarpGroups, kNumWarpsPerGroup, hidden_case>; \
LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, phases, \
rdma_buffer_ptr, port_channel_handles, \
peer_rdma_bases, memory_channel_handles); \
} else { \
auto dispatch_func = use_fp8 ? dispatch<true, false, kNumWarpGroups, kNumWarpsPerGroup, hidden_case> \
: dispatch<false, false, kNumWarpGroups, kNumWarpsPerGroup, hidden_case>; \
LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, phases, \
rdma_buffer_ptr, port_channel_handles, \
peer_rdma_bases, memory_channel_handles); \
} } break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
@@ -405,7 +489,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
// combine
// ---------------------------------------------------------------------------
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
template <bool kIpcPath, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
@@ -418,7 +502,9 @@ combine(void* combined_x,
int num_experts, int rank, int num_ranks,
int phases, bool zero_copy,
void* rdma_buffer_ptr,
mscclpp::PortChannelDeviceHandle* port_channel_handles) {
mscclpp::PortChannelDeviceHandle* port_channel_handles,
void* const* peer_rdma_bases,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);
@@ -474,17 +560,25 @@ combine(void* combined_x,
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
} else {
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
// MSCCL++ port-channel PUT.
if (lane_id == 0) {
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
const auto src_off = rdma_offset_of(static_cast<uint64_t>(buf_ptr), rdma_buffer_ptr);
port_channel_handles[local_expert_idx * num_ranks + dst_rank]
.put(dst_off, src_off, hidden * sizeof(nv_bfloat16));
if constexpr (kIpcPath) {
// Peer-mapped warp copy over NVLink. `zero_copy` is irrelevant
// on this path because we skip the rdma_send staging buffer.
const auto peer_dst = peer_ptr_of(dst_ptr, peer_rdma_bases, rdma_buffer_ptr, dst_rank);
const auto peer_dst_int4 = reinterpret_cast<int4*>(peer_dst);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, peer_dst_int4, x_int4, ld_nc_global, st_na_global);
} else {
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
// MSCCL++ port-channel PUT.
if (lane_id == 0) {
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
const auto src_off = rdma_offset_of(static_cast<uint64_t>(buf_ptr), rdma_buffer_ptr);
port_channel_handles[local_expert_idx * num_ranks + dst_rank]
.put(dst_off, src_off, hidden * sizeof(nv_bfloat16));
}
__syncwarp();
}
__syncwarp();
}
}
@@ -493,10 +587,17 @@ combine(void* combined_x,
if (sub_warp_id == 1 and lane_id == 0) {
while (ld_acquire_global(atomic_clean_flag) == 0);
if (dst_rank != rank) {
auto* flag_ptr = rdma_recv_flag + global_expert_idx;
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(flag_ptr), rdma_buffer_ptr);
port_channel_handles[local_expert_idx * num_ranks + dst_rank]
.atomicAdd(off, static_cast<int64_t>(1));
if constexpr (kIpcPath) {
auto peer_flag = reinterpret_cast<int64_t*>(peer_ptr_of(
reinterpret_cast<uint64_t>(rdma_recv_flag + global_expert_idx),
peer_rdma_bases, rdma_buffer_ptr, dst_rank));
st_na_release(peer_flag, static_cast<int64_t>(1));
} else {
auto* flag_ptr = rdma_recv_flag + global_expert_idx;
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(flag_ptr), rdma_buffer_ptr);
port_channel_handles[local_expert_idx * num_ranks + dst_rank]
.atomicAdd(off, static_cast<int64_t>(1));
}
} else {
st_na_release(rdma_recv_flag + global_expert_idx, static_cast<int64_t>(1));
}
@@ -561,7 +662,10 @@ void combine(void* combined_x,
void* workspace, cudaStream_t stream,
int phases, bool zero_copy,
void* rdma_buffer_ptr,
mscclpp::PortChannelDeviceHandle* port_channel_handles) {
mscclpp::PortChannelDeviceHandle* port_channel_handles,
void* const* peer_rdma_bases,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
bool use_ipc_path) {
constexpr int kNumWarpsPerGroup = 10;
constexpr int kNumWarpGroups = 3;
constexpr int kNumMaxTopk = 9;
@@ -574,18 +678,35 @@ void combine(void* combined_x,
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
#define COMBINE_LAUNCH_CASE(hidden_case) { \
auto combine_func = combine<kNumWarpGroups, kNumWarpsPerGroup, hidden_case, kNumMaxTopk>; \
LAUNCH_KERNEL(&cfg, combine_func, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
next_clean, num_next_clean_int, \
atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
phases, zero_copy, \
rdma_buffer_ptr, port_channel_handles); } break
if (use_ipc_path) { \
auto combine_func = combine<true, kNumWarpGroups, kNumWarpsPerGroup, hidden_case, kNumMaxTopk>; \
LAUNCH_KERNEL(&cfg, combine_func, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
next_clean, num_next_clean_int, \
atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
phases, zero_copy, \
rdma_buffer_ptr, port_channel_handles, \
peer_rdma_bases, memory_channel_handles); \
} else { \
auto combine_func = combine<false, kNumWarpGroups, kNumWarpsPerGroup, hidden_case, kNumMaxTopk>; \
LAUNCH_KERNEL(&cfg, combine_func, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
next_clean, num_next_clean_int, \
atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
phases, zero_copy, \
rdma_buffer_ptr, port_channel_handles, \
peer_rdma_bases, memory_channel_handles); \
} } break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);