Phase 11: hybrid NVLink + RDMA LL dispatch (+70% throughput)

Inside the IBGDA template branch, runtime-check whether the host has
opened a CUDA IPC peer pointer for the destination rank. If yes, do
the send via NVLink (warp copy / st_na_release on the peer-mapped
pointer); else fall through to the existing port_put / rdma_write_inl8.

Host: in sync(), when low_latency_mode && num_rdma_ranks > 1 && IBGDA
is up, allgather rdma_buffer_ptr IPC handles and cudaIpcOpenMemHandle
only for same-node peers. Sparse pointer table is mirrored to GPU and
threaded into the launchers as peer_bases.

Kernel: per-peer branch added at all four RDMA send sites (dispatch
send-data, dispatch send-count, combine send-data, combine send-flag).
Recv-side polling is transport-agnostic and unchanged.

Result on 16-rank/2-node LL bench:
  baseline (IBGDA only):   38.7 / 39.4 GB/s
  Phase 11 hybrid:         65.9 / 67.0 GB/s   (+70%)
Now matches nccl-ep default-mode numbers (63-71 / 62-72 GB/s).
Validation max diff = 0.

Gated by MSCCLPP_EP_HYBRID_LL env (default on). Single-node LL is
untouched (num_rdma_ranks>1 gate).
This commit is contained in:
Qinghua Zhou
2026-05-09 23:04:15 +00:00
parent 5f219b5cda
commit 4569c4e751
4 changed files with 275 additions and 34 deletions

View File

@@ -749,3 +749,81 @@ math becomes a weighted average dominated by the cheaper hops.
- Future work flagged: hybrid NVLink + RDMA LL dispatch (would close
the remaining gap to nccl-ep's default-mode numbers and is
architecturally well-defined).
---
## Phase 11 — Hybrid NVLink + RDMA LL Dispatch [POSITIVE, +70%]
**Hypothesis (from Phase 10):** the remaining gap to nccl-ep's default-mode
numbers is *architectural*: our LL kernel selects a single transport for
the entire kernel via the `kIpcPath` / `kIbgdaPath` template parameters,
so a 16-rank/2-node run forces all 15 destination peers through IBGDA
even though 7 of them sit on the same host. nccl-ep mixes per peer:
NVLink for intranode, RDMA for cross-node.
### Implementation
Goal: keep the kernel's IBGDA template branch (because the barrier and
recv-side logic are already wired for it) but, *inside* every per-peer
RDMA send site, runtime-check whether the host has supplied a
peer-mapped pointer for that destination rank. If yes, use a warp copy
over NVLink; if no, fall through to the existing `port_put` /
`rdma_write_inl8` code.
**Host (`buffer.cc` / `buffer.hpp`):**
- New buffer fields: `hybrid_ipc_handles`, `hybrid_peer_bases`
(`std::vector<void*>`, size `num_ranks`, sparse), `hybrid_peer_bases_gpu`
(GPU mirror), `hybrid_ll_ready`.
- After IBGDA setup succeeds in `sync()`, when
`low_latency_mode && num_rdma_ranks > 1`:
1. `cudaIpcGetMemHandle(rdma_buffer_ptr)` and `bootstrap->allGather`
across all ranks (cross-node entries are ignored).
2. `cudaIpcOpenMemHandle` only for peers `r` where
`r / NUM_MAX_NVL_PEERS == rank / NUM_MAX_NVL_PEERS && r != rank`.
3. Mirror the sparse pointer table to GPU.
- Launchers (`low_latency_dispatch` and `low_latency_combine`):
when `use_ibgda && hybrid_ll_ready`, override `peer_bases` from the
null pointer to `hybrid_peer_bases_gpu` so the kernel can see it.
- Gated by env `MSCCLPP_EP_HYBRID_LL` (default ON; set to `0` to disable).
- IPC handles cleaned up in destructor.
**Kernel (`internode_ll.cu`):** at all four RDMA send sites
(dispatch send-data, dispatch send-count, combine send-data,
combine send-flag), inside the `if constexpr (kIbgdaPath)` block,
prepended:
```
const bool use_ipc_for_peer =
(peer_rdma_bases != nullptr) && (peer_rdma_bases[dst_rank] != nullptr);
if (use_ipc_for_peer) {
<IPC warp copy / st_na_release on peer-mapped pointer>
} else {
<existing port_put / rdma_write_inl8>
}
```
The recv-side spin-wait (`ld_acquire_sys_global`) is transport-agnostic
and unchanged. Branch is uniform across the warp — `dst_rank` is
broadcast-derived per warp group.
### Result (LL bench, 16 ranks across 2 nodes)
| config | dispatch GB/s | combine GB/s | notes |
| --- | --- | --- | --- |
| Phase 10 baseline (IBGDA only) | 38.739.4 | 39.4 | All 15 peers via RDMA |
| nccl-ep `--disable-nvlink` (RDMA only) | 42.3 | 42.5 | Apples-to-apples ceiling |
| nccl-ep default (mixed) | 6371 | 6272 | Mixed transports |
| **mscclpp_ep Phase 11 hybrid** | **65.8565.97** | **66.7867.06** | 7/15 peers NVLink |
3 consecutive runs all in this band. **+~70% over baseline**, now
matching nccl-ep's default-mode numbers within run-to-run noise.
Validation `max|got-expected| = 0.0000e+00` on every run.
### Notes
- The `[mscclpp_ep] Hybrid LL ready: 7 intranode peers per rank …` line
is printed once on rank 0 during `sync()` — useful as a confidence
marker in benchmark logs.
- Single-node LL is unaffected: `num_rdma_ranks > 1` gate keeps the
existing `ll_ipc_ready` fast path in charge there.
- C8 (intranode LL audit) and C9 (health metric in bench output) remain
open but are no longer urgent — the architectural gap that motivated
them is closed.

View File

@@ -194,6 +194,18 @@ Buffer::~Buffer() noexcept(false) {
}
}
// Hybrid LL IPC fast-path teardown (Phase 11).
if (hybrid_ll_ready) {
for (int i = 0; i < (int)hybrid_peer_bases.size(); ++i) {
if (i == rank || hybrid_peer_bases[i] == nullptr) continue;
CUDA_CHECK(cudaIpcCloseMemHandle(hybrid_peer_bases[i]));
}
if (hybrid_peer_bases_gpu != nullptr) {
CUDA_CHECK(cudaFree(hybrid_peer_bases_gpu));
hybrid_peer_bases_gpu = nullptr;
}
}
for (auto& ps : proxy_services) ps->stopProxy();
// Free cuBLAS handle, workspace and MoE counter
@@ -579,6 +591,81 @@ void Buffer::sync(const std::vector<int>& device_ids,
}
#endif // MSCCLPP_EP_HAVE_IBGDA
// ------------------------------------------------------------------
// Phase 11 — Hybrid LL fast path setup (multi-node, IBGDA + IPC).
//
// When LL is running cross-node with IBGDA enabled, also open CUDA IPC
// peer pointers for each same-node neighbor's rdma_buffer_ptr. The LL
// kernel will then prefer NVLink for intranode peers and IBGDA for
// internode peers, matching nccl-ep's per-peer transport selection.
//
// The IPC ring is restricted to same-node peers (peer / NUM_MAX_NVL_PEERS
// == rdma_rank); cross-node entries stay nullptr in `hybrid_peer_bases`.
// The kernel checks for nullptr to decide IPC vs IBGDA per peer.
// ------------------------------------------------------------------
if (low_latency_mode && num_rdma_ranks > 1 && num_rdma_bytes > 0) {
bool hybrid_enabled = true;
if (const char* e = std::getenv("MSCCLPP_EP_HYBRID_LL")) {
if (e[0] != '\0' && std::atoi(e) == 0) hybrid_enabled = false;
}
#ifdef MSCCLPP_EP_HAVE_IBGDA
const bool have_ibgda = (ibgda_setup_ != nullptr);
#else
const bool have_ibgda = false;
#endif
if (hybrid_enabled && have_ibgda) {
try {
hybrid_ipc_handles.assign(num_ranks, cudaIpcMemHandle_t{});
hybrid_peer_bases.assign(num_ranks, nullptr);
// 1. Allgather rdma_buffer_ptr IPC handles to all ranks. Cross-node
// handles are useless to us but the bootstrap allgather is
// symmetric and easier than building a per-node sub-bootstrap.
CUDA_CHECK(cudaIpcGetMemHandle(&hybrid_ipc_handles[rank], rdma_buffer_ptr));
bootstrap->allGather(hybrid_ipc_handles.data(), sizeof(cudaIpcMemHandle_t));
// 2. Open IPC handles only for same-node peers.
const int my_node = rank / NUM_MAX_NVL_PEERS;
int opened = 0;
for (int r = 0; r < num_ranks; ++r) {
if (r == rank) {
hybrid_peer_bases[r] = rdma_buffer_ptr;
continue;
}
if (r / NUM_MAX_NVL_PEERS != my_node) continue; // cross-node
CUDA_CHECK(cudaIpcOpenMemHandle(&hybrid_peer_bases[r], hybrid_ipc_handles[r],
cudaIpcMemLazyEnablePeerAccess));
++opened;
}
// 3. Mirror to GPU.
CUDA_CHECK(cudaMalloc(&hybrid_peer_bases_gpu, sizeof(void*) * num_ranks));
CUDA_CHECK(cudaMemcpy(hybrid_peer_bases_gpu, hybrid_peer_bases.data(),
sizeof(void*) * num_ranks, cudaMemcpyHostToDevice));
hybrid_ll_ready = true;
if (rank == 0) {
printf("[mscclpp_ep] Hybrid LL ready: %d intranode peers per rank "
"(NVLink for same-node, IBGDA for cross-node)\n", opened);
fflush(stdout);
}
} catch (const std::exception& e) {
fprintf(stderr, "[mscclpp_ep][rank=%d] Hybrid LL setup failed, kernel will use IBGDA only: %s\n",
rank, e.what());
for (int i = 0; i < (int)hybrid_peer_bases.size(); ++i) {
if (i != rank && hybrid_peer_bases[i] != nullptr) {
(void)cudaIpcCloseMemHandle(hybrid_peer_bases[i]);
}
}
hybrid_peer_bases.clear();
hybrid_ipc_handles.clear();
if (hybrid_peer_bases_gpu) { cudaFree(hybrid_peer_bases_gpu); hybrid_peer_bases_gpu = nullptr; }
hybrid_ll_ready = false;
(void)cudaGetLastError();
}
}
}
// Ready to use
available = true;
}
@@ -1520,6 +1607,12 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_dev = nullptr;
const bool use_ibgda = false;
#endif
// Phase 11 hybrid: when IBGDA is active and same-node IPC pointers are
// available, hand them to the kernel so it can prefer NVLink for
// intranode peers. Single-node (use_ipc) keeps its existing path.
if (use_ibgda && hybrid_ll_ready && hybrid_peer_bases_gpu != nullptr) {
peer_bases = hybrid_peer_bases_gpu;
}
auto launcher = [=](int phases) {
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, packed_recv_src_info.data_ptr<int>(),
packed_recv_layout_range.data_ptr<int64_t>(), packed_recv_count.data_ptr<int>(),
@@ -1685,6 +1778,10 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_dev = nullptr;
const bool use_ibgda = false;
#endif
// Phase 11 hybrid: see low_latency_dispatch().
if (use_ibgda && hybrid_ll_ready && hybrid_peer_bases_gpu != nullptr) {
peer_bases = hybrid_peer_bases_gpu;
}
auto launcher = [=](int phases) {
internode_ll::combine(
combined_x.data_ptr(), buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer,

View File

@@ -106,6 +106,28 @@ struct Buffer {
std::shared_ptr<mscclpp::MemoryChannelDeviceHandle> ll_memory_channel_handles_device_ptr;
bool ll_ipc_ready = false;
// ------------------------------------------------------------------
// Phase 11 — Hybrid LL fast path.
//
// In multi-node LL with IBGDA, also open CUDA IPC peer pointers for
// same-node neighbors so the kernel can prefer NVLink for intranode
// peers and IBGDA for internode peers (matching nccl-ep's behavior).
//
// `hybrid_peer_bases` is sparse: indexed by global rank, populated
// only for same-node peers (rank' / NUM_MAX_NVL_PEERS == rdma_rank
// && rank' != rank). Cross-node and self entries are nullptr; the
// kernel checks for nullptr to decide IPC vs IBGDA per peer.
//
// Built lazily in `sync()` when:
// - low_latency_mode && num_rdma_ranks > 1
// - env MSCCLPP_EP_USE_IBGDA=1 && IBGDA setup succeeds
// - env MSCCLPP_EP_HYBRID_LL is not set to "0"
// ------------------------------------------------------------------
std::vector<cudaIpcMemHandle_t> hybrid_ipc_handles;
std::vector<void*> hybrid_peer_bases; // size num_ranks; same-node entries non-null
void** hybrid_peer_bases_gpu = nullptr; // GPU array of size num_ranks
bool hybrid_ll_ready = false;
// ------------------------------------------------------------------
// Native IBGDA path (Stage 4b). Built lazily in `sync()` when env
// `MSCCLPP_EP_USE_IBGDA=1` is set AND the run is cross-node.

View File

@@ -291,15 +291,28 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
} else {
#ifdef MSCCLPP_EP_LL_HAS_IBGDA
if constexpr (kIbgdaPath) {
// Native IBGDA: lane 0 issues one RDMA WRITE WQE directly.
// UNSIGNALED + no DB ring — the trailing count write below
// (signaled, rings DB) flushes the per-QP queue.
if (lane_id == 0) {
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr);
mscclpp::ibgda::port_put(ibgda_handles[dst_expert_local_idx * num_ranks + dst_rank], dst_off, src_off,
num_bytes_per_msg,
/*signal_cqe=*/false, /*ring_db=*/false);
// Phase 11 hybrid: prefer NVLink for same-node peers when
// host has supplied per-peer IPC bases. Cross-node peers fall
// through to the IBGDA RDMA WRITE path below.
const bool use_ipc_for_peer =
(peer_rdma_bases != nullptr) && (peer_rdma_bases[dst_rank] != nullptr);
if (use_ipc_for_peer) {
const auto peer_dst = peer_ptr_of(dst_ptr, peer_rdma_bases, rdma_buffer_ptr, dst_rank);
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(peer_dst);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global,
st_na_global);
} else {
// Native IBGDA: lane 0 issues one RDMA WRITE WQE directly.
// UNSIGNALED + no DB ring — the trailing count write below
// (signaled, rings DB) flushes the per-QP queue.
if (lane_id == 0) {
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr);
mscclpp::ibgda::port_put(ibgda_handles[dst_expert_local_idx * num_ranks + dst_rank], dst_off,
src_off, num_bytes_per_msg,
/*signal_cqe=*/false, /*ring_db=*/false);
}
}
__syncwarp();
} else
@@ -381,16 +394,27 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
} else {
#ifdef MSCCLPP_EP_LL_HAS_IBGDA
if constexpr (kIbgdaPath) {
// Single writer per (dst_expert_local_idx, rank) slot, so an
// 8-byte inline RDMA WRITE delivering the encoded count is
// semantically equivalent to atomicAdd from a zero-initialised
// remote slot. The receiver polls with ld_acquire_sys_global.
auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(counter_ptr), rdma_buffer_ptr);
const auto& ch = ibgda_handles[dst_expert_local_idx * num_ranks + dst_rank];
mscclpp::IbgdaRemoteMr r = ch.remote_mrs[ch.dst];
mscclpp::ibgda::rdma_write_inl8(ch.qp, static_cast<uint64_t>(static_cast<int64_t>(-num_tokens_sent - 1)),
r.addr + off, r.rkey_be);
// Phase 11 hybrid: same-node peer → store the count via the
// peer-mapped pointer (NVLink); cross-node peer → inline RDMA WRITE.
const bool use_ipc_for_peer =
(peer_rdma_bases != nullptr) && (peer_rdma_bases[dst_rank] != nullptr);
if (use_ipc_for_peer) {
auto peer_counter = reinterpret_cast<int64_t*>(
peer_ptr_of(reinterpret_cast<uint64_t>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank),
peer_rdma_bases, rdma_buffer_ptr, dst_rank));
st_na_release(peer_counter, static_cast<int64_t>(-num_tokens_sent - 1));
} else {
// Single writer per (dst_expert_local_idx, rank) slot, so an
// 8-byte inline RDMA WRITE delivering the encoded count is
// semantically equivalent to atomicAdd from a zero-initialised
// remote slot. The receiver polls with ld_acquire_sys_global.
auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(counter_ptr), rdma_buffer_ptr);
const auto& ch = ibgda_handles[dst_expert_local_idx * num_ranks + dst_rank];
mscclpp::IbgdaRemoteMr r = ch.remote_mrs[ch.dst];
mscclpp::ibgda::rdma_write_inl8(ch.qp, static_cast<uint64_t>(static_cast<int64_t>(-num_tokens_sent - 1)),
r.addr + off, r.rkey_be);
}
} else
#endif
{
@@ -686,16 +710,25 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com
} else {
#ifdef MSCCLPP_EP_LL_HAS_IBGDA
if constexpr (kIbgdaPath) {
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
if (lane_id == 0) {
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
const auto src_off = rdma_offset_of(static_cast<uint64_t>(buf_ptr), rdma_buffer_ptr);
// UNSIGNALED + no DB ring — trailing flag write drives the doorbell.
mscclpp::ibgda::port_put(ibgda_handles[local_expert_idx * num_ranks + dst_rank], dst_off, src_off,
hidden * sizeof(nv_bfloat16),
/*signal_cqe=*/false, /*ring_db=*/false);
// Phase 11 hybrid: same-node peer → NVLink warp copy.
const bool use_ipc_for_peer =
(peer_rdma_bases != nullptr) && (peer_rdma_bases[dst_rank] != nullptr);
if (use_ipc_for_peer) {
const auto peer_dst = peer_ptr_of(dst_ptr, peer_rdma_bases, rdma_buffer_ptr, dst_rank);
const auto peer_dst_int4 = reinterpret_cast<int4*>(peer_dst);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, peer_dst_int4, x_int4, ld_nc_global, st_na_global);
} else {
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
if (lane_id == 0) {
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
const auto src_off = rdma_offset_of(static_cast<uint64_t>(buf_ptr), rdma_buffer_ptr);
// UNSIGNALED + no DB ring — trailing flag write drives the doorbell.
mscclpp::ibgda::port_put(ibgda_handles[local_expert_idx * num_ranks + dst_rank], dst_off, src_off,
hidden * sizeof(nv_bfloat16),
/*signal_cqe=*/false, /*ring_db=*/false);
}
}
__syncwarp();
} else
@@ -731,11 +764,22 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com
} else {
#ifdef MSCCLPP_EP_LL_HAS_IBGDA
if constexpr (kIbgdaPath) {
auto* flag_ptr = rdma_recv_flag + global_expert_idx;
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(flag_ptr), rdma_buffer_ptr);
const auto& ch = ibgda_handles[local_expert_idx * num_ranks + dst_rank];
mscclpp::IbgdaRemoteMr r = ch.remote_mrs[ch.dst];
mscclpp::ibgda::rdma_write_inl8(ch.qp, static_cast<uint64_t>(1), r.addr + off, r.rkey_be);
// Phase 11 hybrid: same-node peer → store the flag via the
// peer-mapped pointer (NVLink); cross-node peer → inline RDMA WRITE.
const bool use_ipc_for_peer =
(peer_rdma_bases != nullptr) && (peer_rdma_bases[dst_rank] != nullptr);
if (use_ipc_for_peer) {
auto peer_flag = reinterpret_cast<int64_t*>(peer_ptr_of(
reinterpret_cast<uint64_t>(rdma_recv_flag + global_expert_idx),
peer_rdma_bases, rdma_buffer_ptr, dst_rank));
st_na_release(peer_flag, static_cast<int64_t>(1));
} else {
auto* flag_ptr = rdma_recv_flag + global_expert_idx;
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(flag_ptr), rdma_buffer_ptr);
const auto& ch = ibgda_handles[local_expert_idx * num_ranks + dst_rank];
mscclpp::IbgdaRemoteMr r = ch.remote_mrs[ch.dst];
mscclpp::ibgda::rdma_write_inl8(ch.qp, static_cast<uint64_t>(1), r.addr + off, r.rkey_be);
}
} else
#endif
{