Phase 9: multi-NIC striping (NEGATIVE RESULT)

- Refactor IbgdaSetup to support N NICs per rank: each rank picks
  nic[i] = numa_base + (effective_base + i) % 4 (NUMA-aware).
- Channel c routes via ib_ctxs[c % num_nics]. d_local_mrs becomes a
  vector of size num_nics; d_remote_mrs becomes num_nics * num_ranks.
- New env MSCCLPP_EP_NUM_NICS (default 1 = identical to Phase 8 baseline).
- Debug env MSCCLPP_EP_NIC_DUP forces all ctxs onto the same NIC
  to isolate multi-IbCtx overhead from real multi-NIC routing.

Empirical findings on 2x8xH100 NDv5 with TOKENS=128 / TOPK=8:
- N=1: 38.7/39.4 GB/s (baseline preserved).
- N=2: 15.4/17.0 GB/s; N=4: 9.0/9.6 GB/s. Strict monotonic regression.
- NIC_DUP=1, N=2: 38.6/39.0 GB/s — multi-PD overhead is zero.

Conclusion: regression is purely from PCIe topology — H100 has a single
efficient P2P path to its NUMA-affine NIC. Posting WRs to any other NIC
forces cross-PCIe-switch hops that dominate any bandwidth gain.

Multi-NIC plumbing is left in place behind opt-in env so the path can
be re-evaluated on different hardware. Single-NIC ceiling ~41 GB/s
stands.
This commit is contained in:
Qinghua Zhou
2026-05-09 20:22:00 +00:00
parent 9d729d795e
commit 04f047fc5a
4 changed files with 367 additions and 105 deletions

View File

@@ -513,3 +513,166 @@ Remaining categories of attack, all architectural / non-trivial:
For the current 2×8×H100 NDR(50) setup with TOKENS=128 / TOPK=8 / BF16
hidden=7168, **the LL benchmark is closed**. Future commits should
either pivot to multi-NIC or to a different problem regime.
---
## Phase 9 — Multi-NIC striping (IN PROGRESS)
### 9.0 Latest finding: NIC-saturation diagnostic (May 9, 2026)
Forced all 8 master-node ranks onto a single NIC via
`MSCCLPP_EP_IB_DEVICE_OVERRIDE=0` (worker node ranks kept their NUMA
NICs because mpirun did not propagate the override). Result:
```
PASS
dispatch: avg=4832us per_rank_bw=2.85 GB/s agg_bw=45.6 GB/s
combine : avg=4708us per_rank_bw=2.92 GB/s agg_bw=46.8 GB/s
```
Per-rank BW collapsed 38.9 → 2.85 GB/s when 8 master ranks share 1 NIC.
That is a 13.6× drop and confirms unambiguously that **the per-rank
NIC line-rate IS the binding constraint** in the current 1-NIC-per-rank
configuration. The 8-NIC fabric is fully utilized at ~38.9 GB/s × 8 =
311 GB/s of node-aggregate payload BW, ≈ 78% of the 8 × 50 = 400 GB/s
node ceiling.
Implication: each NIC is at ~82% line-rate. Remaining headroom at the
node level is ≈ 89 GB/s = 11 GB/s/rank IFF we can perfectly redistribute
WRs across NICs. In a fully symmetric N-NIC stripe the per-NIC
aggregate demand is unchanged, so the gain comes from
**(a) breaking per-rank NIC affinity** (a single rank's last-WR
completion time can drop because its WRs scatter across N NICs and the
NIC scheduler interleaves them), and
**(b) recv-side multi-NIC ingress** (the destination GPU now drains
across N NICs instead of one).
### 9.1 Architectural inventory
- 8 mlx5_ib NICs/node, all PORT_ACTIVE, NDR (50 GB/s each).
- Current setup (`src/ext/ep/ibgda_setup.cc`): single `IbCtx` per rank.
All `num_channels × num_ranks` QPs created on that one NIC. Layout
`qps[c * num_ranks + peer]`, indexed in the kernel by
`(local_expert_idx, dst_rank)`.
- Kernel uses `ibgda_handles[le * num_ranks + dst_rank]` (4 distinct
channels actually exercised: `le ∈ [0..num_local_experts=4)`); the
default `num_ibgda_channels=16` is over-provisioned by 4×.
- `MSCCLPP_EP_IB_DEVICE_OVERRIDE` env exists for forcing a single NIC.
### 9.2 Plan
Stripe each rank's QPs across N NICs by channel index:
- New env `MSCCLPP_EP_NUM_NICS` (default 1 = current behavior).
- Each rank picks `nic[i] = (device_id + i) % 8` for i ∈ [0, N).
- Channel `c`'s QP lives on `nic[c % N]`.
- Refactor `build_ibgda_setup`:
- Allocate `N` `IbCtx` instances, one per NIC.
- Register the rdma buffer MR on each ctx (N lkey/rkey values).
- Register the sig MR on each ctx.
- `d_local_mrs` length becomes `N` (was 1); each handle's `src` =
nic_index_for_channel(c).
- `d_remote_mrs` length becomes `N × num_ranks` (was num_ranks);
handle's `dst` = nic_index_for_channel(c) * num_ranks + peer_rank.
- Allgather rdma/sig MR records of size `(addr, rkey[N])` instead of
`(addr, rkey)`.
- One CQ-poller thread polling all NICs round-robin.
- Kernel changes: NONE expected — the device-handle layout (channel ×
num_ranks) is unchanged, and each handle already references its own
local/remote MR entries via `src`/`dst` indices that the kernel
blindly threads through `port_put`.
### 9.3 Risks
- Cross-NUMA penalty: NIC `(device_id + 1) % 8` is on a different NUMA
domain than the GPU. PCIe traversal cost is real but small for IBGDA
(kernel writes the BF doorbell directly).
- Multiple PDs across IbCtx instances: signal MR's PD must match the
QP's PD per NIC. The current code grabs the PD off any QP — easy
to fix by making sig MR per-NIC.
- Doubled QP count on the host: 2 NICs × current QP count is still
small (16 ch × 16 ranks × 2 = 512 QPs/rank), well under HCA limits.
- Allgather record size grows; bootstrap is one-shot, no bench impact.
### 9.4 Execution
Step-by-step:
1. Refactor `IbgdaSetup` struct + `build_ibgda_setup` to support N NICs.
2. Wire `MSCCLPP_EP_NUM_NICS` env in `buffer.cc`.
3. Forward env in `nccl-tests/run_ll_mpirun.sh`.
4. Build + deploy. Sweep N ∈ {1, 2, 4, 8}. Compare to baseline 38.9/39.6.
5. If PASS, commit and update this doc with results.
### 9.5 Results — multi-NIC striping is NOT viable on NDv5 (REVERTED)
Implementation done as planned (channel `c` → NIC `(base + c) % 8`).
First sweep (no NUMA awareness):
| N | dispatch GB/s | combine GB/s | notes |
|--:|--:|--:|---|
| 1 | 38.61 | 39.59 | baseline preserved |
| 2 | 13.80 | 15.06 | -64% |
| 4 | (timeout) | (timeout) | hung at 60s |
| 8 | 3.32 | 3.43 | -91% |
Strict monotonic regression with N. Hypothesis: cross-NUMA PCIe penalty
(NICs 0-3 belong to NUMA 0, 4-7 to NUMA 1, GPUs 0-3 NUMA 0, 4-7 NUMA 1).
Constrained the stripe to within-NUMA (4 NICs/NUMA, modular indexing
inside the NUMA group). Sweep:
| N (NUMA-aware) | dispatch GB/s | combine GB/s |
|---:|--:|--:|
| 1 | 38.73 | 39.41 |
| 2 | 15.39 | 17.06 |
| 3 | 12.22 | 13.13 |
| 4 | 9.06 | 9.56 |
Still bad. Within-NUMA is not enough. Even rank 0 using only ib0+ib1
(both NUMA 0) drops 2.5×.
Isolation test (`MSCCLPP_EP_NIC_DUP=1`, N=2 with both ctxs pointing at
the SAME mlx5_ib0):
```
PASS
dispatch: avg=356us per_rank_bw=38.61 GB/s
combine : avg=353us per_rank_bw=38.96 GB/s
```
**Multi-IbCtx / multi-PD overhead is zero.** Two `IbCtx` instances on
the same NIC perform identically to one. The regression therefore comes
exclusively from the GPU↔NIC PCIe routing when hitting a non-affine
NIC.
**Conclusion: PCIe topology limits multi-NIC striping on this hardware.**
The `nvidia-smi topo` matrix shows every GPU has a single "NODE" path
to its NUMA-affine NIC subset, but in practice the GPU's PCIe stack is
optimized only for its primary NIC. Posting RDMA WRs to any other NIC
forces cross-switch PCIe hops that dominate over the bandwidth saving.
This rules out multi-NIC striping on NDv5/H100 with mlx5 + IBGDA. The
single-NIC ceiling of ~41 GB/s/rank stands.
### 9.6 Status
- The multi-NIC plumbing is left in place (`MSCCLPP_EP_NUM_NICS` env,
default 1) so the path can be re-evaluated on different hardware
(NDR2, NVL switches with multi-port HCAs, etc.) without re-implementing.
- Default behavior (N=1) is identical to pre-Phase-9.
- `MSCCLPP_EP_NIC_DUP=1` is left as a debug knob (forces all ctxs onto
the same NIC; useful to isolate multi-PD overhead from multi-NIC PCIe
cost).
### 9.7 Synthesis (post Phase 9)
The single-NIC line-rate IS the binding constraint (proved in 9.0:
forcing all 8 master ranks onto ib0 drops to 2.85 GB/s = 1/14× of
baseline). Per-NIC NDR(50) line rate = ~41 GB/s practical payload.
Multi-NIC striping is the ONLY remaining software lever, and it does
not work on NDv5 due to GPU↔NIC PCIe affinity.
**The LL benchmark on this hardware/problem is closed at ~38.7/39.4 GB/s
(94-97% of the ~41 GB/s practical single-NIC ceiling).** Further wins
require hardware uplift (NDR2 / 100 GB/s NIC, or platforms with multi-NIC
P2P like Grace-Hopper SuperChip).

View File

@@ -551,14 +551,19 @@ void Buffer::sync(const std::vector<int>& device_ids,
int v = std::atoi(e);
if (v > 0) num_ibgda_channels = v;
}
int num_nics = 1;
if (const char* e = std::getenv("MSCCLPP_EP_NUM_NICS")) {
int v = std::atoi(e);
if (v >= 1 && v <= 8) num_nics = v;
}
const int kNumIbgdaChannels = num_ibgda_channels;
try {
ibgda_setup_ = mscclpp::ep::build_ibgda_setup(rank, num_ranks, /*ib_transport_index=*/device_id,
kNumIbgdaChannels, rdma_buffer_ptr,
static_cast<std::size_t>(num_rdma_bytes), bootstrap);
static_cast<std::size_t>(num_rdma_bytes), bootstrap, num_nics);
if (rank == 0) {
printf("[mscclpp_ep] IBGDA setup built: channels=%d num_ranks=%d (per-rank QPs=%d)\n",
kNumIbgdaChannels, num_ranks, kNumIbgdaChannels * (num_ranks - 1));
printf("[mscclpp_ep] IBGDA setup built: channels=%d num_ranks=%d num_nics=%d (per-rank QPs=%d)\n",
kNumIbgdaChannels, num_ranks, num_nics, kNumIbgdaChannels * (num_ranks - 1));
fflush(stdout);
}
// Clear any benign CUDA sticky error left by overlapping host/UAR

View File

@@ -39,12 +39,12 @@ IbgdaSetup::~IbgdaSetup() {
// Tear down in reverse order of construction:
// - device_handles, d_local_mrs, d_remote_mrs: shared_ptr from
// gpuCallocShared, auto-freed via custom deleter.
// - resources / qps / rdma_mr: smart ptrs, auto-freed.
// - sig_mr (raw ibv_mr) + sig_slots / sig_seq (raw cudaMalloc): explicit.
if (sig_mr != nullptr) {
ibv_dereg_mr(sig_mr);
sig_mr = nullptr;
// - resources / qps / rdma_mrs: smart ptrs, auto-freed.
// - sig_mrs (raw ibv_mr) + sig_slots / sig_seq (raw cudaMalloc): explicit.
for (ibv_mr* m : sig_mrs) {
if (m) ibv_dereg_mr(m);
}
sig_mrs.clear();
if (sig_slots != nullptr) {
cudaFree(sig_slots);
sig_slots = nullptr;
@@ -76,52 +76,78 @@ constexpr int kIbgdaMaxSendWr = 8192;
std::unique_ptr<IbgdaSetup> build_ibgda_setup(int rank, int num_ranks, int ib_transport_index, int num_channels,
void* rdma_buffer_ptr, std::size_t num_rdma_bytes,
std::shared_ptr<TcpBootstrap> bootstrap) {
std::shared_ptr<TcpBootstrap> bootstrap, int num_nics) {
EP_HOST_ASSERT(rdma_buffer_ptr != nullptr);
EP_HOST_ASSERT(num_rdma_bytes > 0);
EP_HOST_ASSERT(num_ranks > 1);
EP_HOST_ASSERT(num_channels > 0);
EP_HOST_ASSERT(ib_transport_index >= 0 && ib_transport_index < 8);
EP_HOST_ASSERT(num_nics >= 1 && num_nics <= 8);
auto setup = std::make_unique<IbgdaSetup>();
setup->rank = rank;
setup->num_ranks = num_ranks;
setup->num_channels = num_channels;
setup->num_nics = num_nics;
// 1. Resolve IB device name and build the IbCtx.
// 1. Resolve IB device(s) and build IbCtx per NIC.
// `MSCCLPP_EP_IB_DEVICE_OVERRIDE` may force a specific IB transport
// index (0..7) for diagnostic NIC-affinity sweeps. Default = use the
// NUMA-affine NIC selected by the caller (== local rank on NDv5).
int effective_ib_index = ib_transport_index;
// With num_nics > 1, the additional NICs are picked starting from the
// base index and wrapping mod 8: nic[i] = (base + i) % 8.
int effective_base = ib_transport_index;
if (const char* e = std::getenv("MSCCLPP_EP_IB_DEVICE_OVERRIDE")) {
int v = std::atoi(e);
if (v >= 0 && v < 8) effective_ib_index = v;
if (e[0] != '\0') {
int v = std::atoi(e);
if (v >= 0 && v < 8) effective_base = v;
}
}
setup->ib_ctxs.resize(num_nics);
std::vector<int> nic_idx(num_nics);
// NUMA-aware striping: on NDv5, NICs 0-3 belong to NUMA 0 and 4-7 to NUMA 1.
// Crossing NUMA for IBGDA doorbells/PCIe DMA is ~3× slower (verified
// empirically). Constrain each rank's NIC stripe to its own 4-NIC NUMA
// group: nic[i] = numa_base + (effective_base + i) % 4.
// DEBUG: MSCCLPP_EP_NIC_DUP=1 forces all stripe slots to the SAME NIC
// (same as effective_base) — used to isolate multi-IbCtx overhead from
// actual multi-NIC routing cost.
const int numa_base = (effective_base / 4) * 4;
const int local_off = effective_base % 4;
bool nic_dup = false;
if (const char* e = std::getenv("MSCCLPP_EP_NIC_DUP"); e && e[0] != '\0' && std::atoi(e) > 0) {
nic_dup = true;
}
for (int n = 0; n < num_nics; ++n) {
nic_idx[n] = nic_dup ? effective_base : (numa_base + (local_off + n) % 4);
auto ib_transport = static_cast<Transport>(static_cast<int>(Transport::IB0) + nic_idx[n]);
std::string dev_name = getIBDeviceName(ib_transport);
setup->ib_ctxs[n] = std::make_unique<IbCtx>(dev_name);
EP_HOST_ASSERT(setup->ib_ctxs[n]->isMlx5() && "IBGDA requires an mlx5 NIC");
fprintf(stderr, "[mscclpp_ep] rank %d -> IB device[%d/%d] %s (transport_index=%d, numa_base=%d, dup=%d)\n",
rank, n, num_nics, dev_name.c_str(), nic_idx[n], numa_base, (int)nic_dup);
}
auto ib_transport = static_cast<Transport>(static_cast<int>(Transport::IB0) + effective_ib_index);
std::string dev_name = getIBDeviceName(ib_transport);
setup->ib_ctx = std::make_unique<IbCtx>(dev_name);
EP_HOST_ASSERT(setup->ib_ctx->isMlx5() && "IBGDA requires an mlx5 NIC");
fprintf(stderr, "[mscclpp_ep] rank %d -> IB device %s (transport_index=%d, override=%s)\n",
rank, dev_name.c_str(), effective_ib_index,
effective_ib_index == ib_transport_index ? "no" : "yes");
fflush(stderr);
// 2. Create QPs. Layout: qps[channel * num_ranks + peer].
// Self entries are nullptr.
auto nic_for_channel = [num_nics](int c) { return c % num_nics; };
// 2. Create QPs. Layout: qps[channel * num_ranks + peer]. Each QP lives on
// ib_ctxs[nic_for_channel(channel)]. Self entries are nullptr.
const int total_slots = num_channels * num_ranks;
setup->qps.resize(total_slots);
setup->resources.resize(total_slots);
for (int c = 0; c < num_channels; ++c) {
auto& ctx = setup->ib_ctxs[nic_for_channel(c)];
for (int r = 0; r < num_ranks; ++r) {
if (r == rank) continue;
auto qp = setup->ib_ctx->createQp(/*port=*/kIbgdaPort, /*gidIndex=*/kIbgdaGid,
/*maxSendCqSize=*/kIbgdaMaxSendWr * 2,
/*maxSendCqPollNum=*/64,
/*maxSendWr=*/kIbgdaMaxSendWr,
/*maxRecvWr=*/1,
/*maxWrPerSend=*/1,
/*noAtomic=*/true);
auto qp = ctx->createQp(/*port=*/kIbgdaPort, /*gidIndex=*/kIbgdaGid,
/*maxSendCqSize=*/kIbgdaMaxSendWr * 2,
/*maxSendCqPollNum=*/64,
/*maxSendWr=*/kIbgdaMaxSendWr,
/*maxRecvWr=*/1,
/*maxWrPerSend=*/1,
/*noAtomic=*/true);
setup->qps[c * num_ranks + r] = qp;
}
}
@@ -147,7 +173,8 @@ std::unique_ptr<IbgdaSetup> build_ibgda_setup(int rank, int num_ranks, int ib_tr
// 4. RTR/RTS each local QP using the matching peer-side QP info. The peer
// info we want is "rank r's QP for talking to us" == r's record indexed at
// [c * num_ranks + rank].
// [c * num_ranks + rank]. Both sides agree on the channel index and
// therefore on which NIC pair carries the connection.
for (int c = 0; c < num_channels; ++c) {
for (int r = 0; r < num_ranks; ++r) {
if (r == rank) continue;
@@ -168,73 +195,128 @@ std::unique_ptr<IbgdaSetup> build_ibgda_setup(int rank, int num_ranks, int ib_tr
}
}
// 6. Register the existing rdma_buffer_ptr as an MR on this IbCtx, then
// allgather (addr, rkey) so we can build per-peer remote_mrs entries.
setup->rdma_mr = setup->ib_ctx->registerMr(rdma_buffer_ptr, num_rdma_bytes);
IbgdaSetup::PeerMr my_rdma_mr{};
{
auto info = setup->rdma_mr->getInfo();
my_rdma_mr.addr = info.addr;
my_rdma_mr.rkey = info.rkey;
// 6. Register the rdma buffer as an MR on EACH NIC, then allgather
// (addr, rkey[N]) so we can build per-(nic, peer) remote_mrs entries.
setup->rdma_mrs.resize(num_nics);
std::vector<uint32_t> my_rdma_rkeys(num_nics);
uint64_t my_rdma_addr = 0;
for (int n = 0; n < num_nics; ++n) {
setup->rdma_mrs[n] = setup->ib_ctxs[n]->registerMr(rdma_buffer_ptr, num_rdma_bytes);
auto info = setup->rdma_mrs[n]->getInfo();
my_rdma_addr = info.addr; // same across NICs (single allocation)
my_rdma_rkeys[n] = info.rkey;
}
// Allgather record per rank: addr (8B) + N rkeys (4B each), packed.
const std::size_t rdma_rec_bytes = sizeof(uint64_t) + num_nics * sizeof(uint32_t);
std::vector<uint8_t> rdma_all(num_ranks * rdma_rec_bytes, 0);
{
uint8_t* p = rdma_all.data() + rank * rdma_rec_bytes;
std::memcpy(p, &my_rdma_addr, sizeof(uint64_t));
std::memcpy(p + sizeof(uint64_t), my_rdma_rkeys.data(), num_nics * sizeof(uint32_t));
}
bootstrap->allGather(rdma_all.data(), rdma_rec_bytes);
setup->peer_rdma.assign(num_nics * num_ranks, IbgdaSetup::PeerMr{});
for (int r = 0; r < num_ranks; ++r) {
const uint8_t* p = rdma_all.data() + r * rdma_rec_bytes;
uint64_t addr;
std::memcpy(&addr, p, sizeof(uint64_t));
for (int n = 0; n < num_nics; ++n) {
uint32_t rk;
std::memcpy(&rk, p + sizeof(uint64_t) + n * sizeof(uint32_t), sizeof(uint32_t));
auto& pr = setup->peer_rdma[n * num_ranks + r];
pr.addr = addr;
pr.rkey = rk;
}
}
setup->peer_rdma.assign(num_ranks, IbgdaSetup::PeerMr{});
setup->peer_rdma[rank] = my_rdma_mr;
bootstrap->allGather(setup->peer_rdma.data(), sizeof(IbgdaSetup::PeerMr));
// 7. Allocate signal slots: total_slots * 4 bytes on GPU. Layout:
// sig_slots[c * num_ranks + sender_peer] — when peer P sends signal() to
// us through channel c, it RDMA-WRITEs to *our* sig_slots[c * num_ranks + P].
// Each peer therefore needs to know our sig MR (addr+rkey) AND the offset
// within it that corresponds to (channel, sender=P) == "we are receiving
// from P". We allgather the base addr+rkey only; the offset is derivable.
// 7. Allocate signal slots and register the same buffer as MR on each NIC.
// Layout: sig_slots[c * num_ranks + sender_peer] — when peer P sends signal()
// to us through channel c, it RDMA-WRITEs to *our* sig_slots[c * num_ranks
// + P]. The base address is identical across NICs; rkey differs.
const std::size_t sig_bytes = std::size_t(total_slots) * sizeof(uint32_t);
CUDA_CHECK(cudaMalloc(&setup->sig_slots, sig_bytes));
CUDA_CHECK(cudaMemset(setup->sig_slots, 0, sig_bytes));
CUDA_CHECK(cudaMalloc(&setup->sig_seq, sig_bytes));
CUDA_CHECK(cudaMemset(setup->sig_seq, 0, sig_bytes));
// Use raw verbs for the signal MR (we need the rkey/lkey directly).
ibv_pd* pd = setup->qps[(rank == 0 ? 1 : 0)]->getRawQp()->pd; // any non-null QP shares the same pd
for (int c = 0; c < num_channels && pd == nullptr; ++c)
for (int r = 0; r < num_ranks && pd == nullptr; ++r)
if (auto& q = setup->qps[c * num_ranks + r]; q) pd = q->getRawQp()->pd;
EP_HOST_ASSERT(pd != nullptr);
setup->sig_mr = ibv_reg_mr(pd, setup->sig_slots, sig_bytes,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
if (!setup->sig_mr) {
throw std::runtime_error("ibv_reg_mr(sig_slots) failed errno=" + std::to_string(errno));
setup->sig_mrs.assign(num_nics, nullptr);
std::vector<uint32_t> my_sig_lkeys(num_nics);
std::vector<uint32_t> my_sig_rkeys(num_nics);
for (int n = 0; n < num_nics; ++n) {
// Pick any QP on this NIC to grab its PD.
ibv_pd* pd = nullptr;
for (int c = 0; c < num_channels && pd == nullptr; ++c) {
if (nic_for_channel(c) != n) continue;
for (int r = 0; r < num_ranks && pd == nullptr; ++r) {
if (auto& q = setup->qps[c * num_ranks + r]; q) pd = q->getRawQp()->pd;
}
}
EP_HOST_ASSERT(pd != nullptr);
ibv_mr* m = ibv_reg_mr(pd, setup->sig_slots, sig_bytes,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
if (!m) {
throw std::runtime_error("ibv_reg_mr(sig_slots) failed errno=" + std::to_string(errno));
}
setup->sig_mrs[n] = m;
my_sig_lkeys[n] = m->lkey;
my_sig_rkeys[n] = m->rkey;
}
IbgdaSetup::PeerMr my_sig_mr{};
my_sig_mr.addr = reinterpret_cast<uint64_t>(setup->sig_slots);
my_sig_mr.rkey = setup->sig_mr->rkey;
setup->peer_sig.assign(num_ranks, IbgdaSetup::PeerMr{});
setup->peer_sig[rank] = my_sig_mr;
bootstrap->allGather(setup->peer_sig.data(), sizeof(IbgdaSetup::PeerMr));
// 8. Build the GPU-resident MR tables. We have a single local MR (the
// rdma_buffer_ptr); remote_mrs has one entry per peer rank.
std::vector<IbgdaLocalMr> h_local(1);
h_local[0].addr = reinterpret_cast<uint64_t>(rdma_buffer_ptr);
h_local[0].lkey_be = htonl(setup->rdma_mr->getLkey());
h_local[0].pad = 0;
std::vector<IbgdaRemoteMr> h_remote(num_ranks);
uint64_t my_sig_addr = reinterpret_cast<uint64_t>(setup->sig_slots);
const std::size_t sig_rec_bytes = sizeof(uint64_t) + num_nics * sizeof(uint32_t);
std::vector<uint8_t> sig_all(num_ranks * sig_rec_bytes, 0);
{
uint8_t* p = sig_all.data() + rank * sig_rec_bytes;
std::memcpy(p, &my_sig_addr, sizeof(uint64_t));
std::memcpy(p + sizeof(uint64_t), my_sig_rkeys.data(), num_nics * sizeof(uint32_t));
}
bootstrap->allGather(sig_all.data(), sig_rec_bytes);
setup->peer_sig.assign(num_nics * num_ranks, IbgdaSetup::PeerMr{});
for (int r = 0; r < num_ranks; ++r) {
h_remote[r].addr = setup->peer_rdma[r].addr;
h_remote[r].rkey_be = htonl(setup->peer_rdma[r].rkey);
h_remote[r].pad = 0;
const uint8_t* p = sig_all.data() + r * sig_rec_bytes;
uint64_t addr;
std::memcpy(&addr, p, sizeof(uint64_t));
for (int n = 0; n < num_nics; ++n) {
uint32_t rk;
std::memcpy(&rk, p + sizeof(uint64_t) + n * sizeof(uint32_t), sizeof(uint32_t));
auto& ps = setup->peer_sig[n * num_ranks + r];
ps.addr = addr;
ps.rkey = rk;
}
}
setup->d_local_mrs = mscclpp::detail::gpuCallocShared<IbgdaLocalMr>(1);
setup->d_remote_mrs = mscclpp::detail::gpuCallocShared<IbgdaRemoteMr>(num_ranks);
mscclpp::gpuMemcpy<IbgdaLocalMr>(setup->d_local_mrs.get(), h_local.data(), 1, cudaMemcpyHostToDevice);
mscclpp::gpuMemcpy<IbgdaRemoteMr>(setup->d_remote_mrs.get(), h_remote.data(), num_ranks, cudaMemcpyHostToDevice);
// 8. Build the GPU-resident MR tables.
// d_local_mrs[n] = (rdma_buffer_ptr, lkey on NIC n) — one per NIC.
// d_remote_mrs[n*num_ranks + r] = (peer r's addr, peer r's rkey on NIC n).
std::vector<IbgdaLocalMr> h_local(num_nics);
for (int n = 0; n < num_nics; ++n) {
h_local[n].addr = reinterpret_cast<uint64_t>(rdma_buffer_ptr);
h_local[n].lkey_be = htonl(setup->rdma_mrs[n]->getLkey());
h_local[n].pad = 0;
}
std::vector<IbgdaRemoteMr> h_remote(num_nics * num_ranks);
for (int n = 0; n < num_nics; ++n) {
for (int r = 0; r < num_ranks; ++r) {
auto& pr = setup->peer_rdma[n * num_ranks + r];
auto& hr = h_remote[n * num_ranks + r];
hr.addr = pr.addr;
hr.rkey_be = htonl(pr.rkey);
hr.pad = 0;
}
}
setup->d_local_mrs = mscclpp::detail::gpuCallocShared<IbgdaLocalMr>(num_nics);
setup->d_remote_mrs = mscclpp::detail::gpuCallocShared<IbgdaRemoteMr>(num_nics * num_ranks);
mscclpp::gpuMemcpy<IbgdaLocalMr>(setup->d_local_mrs.get(), h_local.data(), num_nics, cudaMemcpyHostToDevice);
mscclpp::gpuMemcpy<IbgdaRemoteMr>(setup->d_remote_mrs.get(), h_remote.data(),
num_nics * num_ranks, cudaMemcpyHostToDevice);
// 9. Build the device handle array (channel × num_ranks).
std::vector<IbgdaPortChannelDeviceHandle> h_handles(total_slots);
std::memset(h_handles.data(), 0, h_handles.size() * sizeof(IbgdaPortChannelDeviceHandle));
for (int c = 0; c < num_channels; ++c) {
const int n = nic_for_channel(c);
for (int r = 0; r < num_ranks; ++r) {
auto& h = h_handles[c * num_ranks + r];
if (r == rank) continue; // self-slot left zeroed
@@ -244,23 +326,23 @@ std::unique_ptr<IbgdaSetup> build_ibgda_setup(int rank, int num_ranks, int ib_tr
// Local sig slot WE poll for messages from peer r through channel c.
h.sig_local_addr = &setup->sig_slots[c * num_ranks + r];
h.sig_local_lkey = htonl(setup->sig_mr->lkey);
h.sig_local_lkey = htonl(my_sig_lkeys[n]);
// Remote sig slot peer r polls for messages FROM US through channel c.
// Peer r's slot for "rank == us" is at offset (c * num_ranks + rank) * 4
// inside their signal buffer.
h.sig_remote_addr = setup->peer_sig[r].addr +
// inside their signal buffer. The rkey must be the one peer r registered
// on its own NIC n (matching channel-to-NIC pairing).
h.sig_remote_addr = setup->peer_sig[n * num_ranks + r].addr +
static_cast<uint64_t>(c * num_ranks + rank) * sizeof(uint32_t);
h.sig_rkey_be = htonl(setup->peer_sig[r].rkey);
h.sig_rkey_be = htonl(setup->peer_sig[n * num_ranks + r].rkey);
// Outbound seq counter is per-handle; carve out one u32 from sig_seq
// (sig_seq is num_channels × num_ranks u32s so we can reuse the same
// index function — it is unrelated to the inbound sig_slots buffer
// beyond reusing the size).
// Outbound seq counter is per-handle.
h.sig_seq = &setup->sig_seq[c * num_ranks + r];
h.dst = static_cast<uint32_t>(r); // index into remote_mrs[] (peer-rank-based)
h.src = 0; // single-entry local_mrs table
// src/dst index into the local/remote MR tables. With multi-NIC,
// src = nic, dst = nic * num_ranks + peer_rank.
h.src = static_cast<uint32_t>(n);
h.dst = static_cast<uint32_t>(n * num_ranks + r);
h.peer_rank = static_cast<uint32_t>(r);
h._pad = 0;
}
@@ -273,7 +355,8 @@ std::unique_ptr<IbgdaSetup> build_ibgda_setup(int rank, int num_ranks, int ib_tr
// 10. Spawn CQ drain thread. Kernel-issued rdma_write paths set CQ_UPDATE
// on every WR; without this drain, the send CQ (sized 2 × kIbgdaMaxSendWr)
// would fill within a few iterations of LL dispatch+combine and the QP
// would error out. We collect raw send_cq pointers from each QP up front.
// would error out. We collect raw send_cq pointers from each QP up front
// (across ALL NICs).
{
std::vector<ibv_cq*> send_cqs;
send_cqs.reserve(total_slots);

View File

@@ -48,16 +48,20 @@ struct IbgdaSetup {
int num_ranks = 0;
int rank = 0;
// IB context + QPs + Stage-1 GPU mappings.
std::unique_ptr<IbCtx> ib_ctx;
// Phase 9 multi-NIC striping: each rank may use up to N NICs in parallel.
// Channel `c` uses NIC `c % num_nics`. With num_nics=1 (default) the layout
// and behavior are identical to pre-Phase-9.
int num_nics = 1;
// IB contexts (one per NIC) + QPs + Stage-1 GPU mappings.
std::vector<std::unique_ptr<IbCtx>> ib_ctxs; // size = num_nics
// Indexed [channel * num_ranks + peer]; entries with peer == rank are null.
std::vector<std::shared_ptr<IbQp>> qps;
std::vector<std::unique_ptr<IbgdaResources>> resources;
// RDMA buffer MR. We register the *same* `rdma_buffer_ptr` that the
// existing PortChannel path uses, so dst/src offsets in the kernel work
// unchanged.
std::unique_ptr<const IbMr> rdma_mr;
// RDMA buffer MR registered on each NIC. We register the *same*
// `rdma_buffer_ptr` on every IbCtx so that any QP can DMA from/to it.
std::vector<std::unique_ptr<const IbMr>> rdma_mrs; // size = num_nics
// Signal slots (GPU-resident).
// Layout: 4 bytes per (channel, peer) on the receiving side. `sig_slots`
@@ -69,18 +73,19 @@ struct IbgdaSetup {
// rank=A's POV indexed by channel × peer"). See the per-handle wiring
// in `IbgdaSetup::populate_handles`.
uint32_t* sig_slots = nullptr; // GPU mem; size = num_channels * num_ranks * 4
ibv_mr* sig_mr = nullptr; // raw verbs (we keep the IbMr only for rdma_mr)
std::vector<ibv_mr*> sig_mrs; // size = num_nics; raw verbs (one per IbCtx)
uint32_t* sig_seq = nullptr; // GPU mem; per-handle outbound counters
// Per-peer (addr, rkey) for the RDMA buffer and the signal buffer. Filled
// by an allgather over the bootstrap.
// Per-(NIC, peer) (addr, rkey) for the RDMA buffer and the signal buffer.
// Layout: peer_rdma[nic * num_ranks + peer]. The addr is the same across
// NICs (single allocation) but the rkey differs per ctx.
struct PeerMr {
uint64_t addr = 0;
uint32_t rkey = 0;
uint32_t pad = 0;
};
std::vector<PeerMr> peer_rdma; // size = num_ranks
std::vector<PeerMr> peer_sig; // size = num_ranks
std::vector<PeerMr> peer_rdma; // size = num_nics * num_ranks
std::vector<PeerMr> peer_sig; // size = num_nics * num_ranks
// Flat device-side handle array: num_channels * num_ranks entries.
// Self entries (peer == rank) are zeroed and unused.
@@ -88,8 +93,11 @@ struct IbgdaSetup {
// Underlying GPU-side MR-table arrays referenced by every device handle
// (see IbgdaPortChannelDeviceHandle::local_mrs / remote_mrs).
std::shared_ptr<IbgdaLocalMr> d_local_mrs; // length 1 (we have a single MR)
std::shared_ptr<IbgdaRemoteMr> d_remote_mrs; // length num_ranks
// d_local_mrs has one entry per NIC: handle.src = nic_index_for_channel(c).
// d_remote_mrs has num_nics * num_ranks entries:
// handle.dst = nic_index_for_channel(c) * num_ranks + peer_rank.
std::shared_ptr<IbgdaLocalMr> d_local_mrs; // length num_nics
std::shared_ptr<IbgdaRemoteMr> d_remote_mrs; // length num_nics * num_ranks
// CQ drain thread. The kernel-side rdma_write paths set CQ_UPDATE on
// every WR, so without periodic ibv_poll_cq() the send CQ would fill up
@@ -102,13 +110,16 @@ struct IbgdaSetup {
// Build the full IBGDA setup. Bootstrap is used for cross-rank exchange of
// QP info and MR keys; rdma_buffer_ptr/num_rdma_bytes is the same buffer the
// existing PortChannel path uses. `ib_transport_index` is the IB device
// index this rank will use (== `device_id` on NDv5).
// index this rank will use (== `device_id` on NDv5) — interpreted as the
// BASE NIC. With `num_nics > 1`, additional NICs are picked starting from
// `ib_transport_index` and wrapping mod 8: nic[i] = (ib_transport_index + i)
// % 8 for i ∈ [0, num_nics). Channel `c`'s QP is placed on `nic[c % num_nics]`.
//
// Throws on any irrecoverable error. On success returns a fully-initialised
// IbgdaSetup with all QPs in RTS and the device-side handle array populated.
std::unique_ptr<IbgdaSetup> build_ibgda_setup(int rank, int num_ranks, int ib_transport_index, int num_channels,
void* rdma_buffer_ptr, std::size_t num_rdma_bytes,
std::shared_ptr<TcpBootstrap> bootstrap);
std::shared_ptr<TcpBootstrap> bootstrap, int num_nics = 1);
} // namespace ep
} // namespace mscclpp