ext/ep: GPU-initiated IBGDA path for low-latency dispatch/combine

Add a GPU-initiated RDMA WRITE path for the LL dispatch/combine kernels
based on mlx5dv direct verbs, alongside the existing IPC and host-FIFO
PortChannel paths. Selected at runtime via MSCCLPP_EP_USE_IBGDA when
num_rdma_ranks > 1.

Core (src/core, include/mscclpp):
  - New ibgda module (ibgda.{hpp,cc}, ibgda_device.cuh): per-peer mlx5
    QP/MR/CQ setup, device-side WQE writers (write_rdma_wqe,
    write_rdma_write_inl_wqe for 4B/8B), submit_requests / submit_no_db
    ring helpers, and a poller thread for send CQs.
  - ibgda_port_channel_device.{hpp,cuh}: thin port_put() wrapper over
    rdma_write with signal_cqe / ring_db flags so callers can issue
    UNSIGNALED batched WRs and ring the doorbell once at the tail.
  - mlx5dv_wrapper: expose extra symbols needed for direct WQE
    construction; minor connection.cc / proxy.cc / port_channel.cc
    plumbing to surface QP / MR handles and rkeys to the EP layer.

EP layer (src/ext/ep):
  - ibgda_setup.{hpp,cc}: build per-(local_expert, peer_rank) GpuQp
    handles, exchange remote MR addr/rkey via the bootstrap, own the
    CQ poller. h.dst is set to the per-peer remote_mrs index.
  - buffer.{hpp,cc}: gate IBGDA path with use_ibgda_path_ &&
    ibgda_setup_ != nullptr && !use_ipc; pass device_handles to the
    kernel launchers.
  - kernels/internode_ll.cu: 3-way DISPATCH_LAUNCH_CASE /
    COMBINE_LAUNCH_CASE (IPC / IBGDA / port-FIFO), templated on
    kIbgdaPath. Data PUTs are issued UNSIGNALED with ring_db=false;
    the trailing per-QP count write (dispatch) and flag write
    (combine) keep the defaults so each QP gets a single signaled
    WR that advances prod_idx past all queued data WRs and rings
    the doorbell once.

Test (test/python/ext/ep): extend test_low_latency_multirank.py with
env-driven config knobs (MSCCLPP_EP_LL_TOKENS / _HIDDEN / _TOPK /
_EXPERTS_PER_RANK) for sweeping the new path.
This commit is contained in:
Qinghua Zhou
2026-05-07 05:14:15 +00:00
parent e87c66a85d
commit 04ebba7563
22 changed files with 1851 additions and 80 deletions

View File

@@ -0,0 +1,101 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
// Stage 3: device-side IBGDA "port channel" handle.
//
// Mirrors the device-facing API of mscclpp::PortChannelDeviceHandle (see
// include/mscclpp/port_channel_device.hpp) but issues RDMA traffic directly
// from the GPU via mscclpp::ibgda::rdma_write — no host proxy / FIFO.
//
// put(dstOffset, srcOffset, size) — issue 1 RDMA WRITE WQE
// signal() — issue 1 inline 4-byte RDMA WRITE to a
// remote 4B counter slot
// wait() — spin on the local-side mirror of that
// counter
//
// Memory model (each rank constructs):
// - For every registered local memory region M, a flat row
// local_mrs[M] = { base_addr, lkey_be }
// - For every (peer rank R, registered region M) pair, a flat row
// remote_mrs[M] = { base_addr_on_R, rkey_be_on_R }
// (a channel is bound to a single peer R, so the rank dimension is
// flattened away inside the handle).
//
// The signal counter:
// - Each rank cudaMallocs a 4-byte "signal slot" registered with this NIC,
// and exchanges its address+rkey with peers. The handle's
// {sig_remote_addr, sig_rkey_be} points at the *peer's* slot, while
// sig_local_addr is the local mirror used by wait() to compare against.
//
// - signal() atomically bumps sig_seq (CTA-shared-by-channel u32 in GPU
// memory), writes that value as a 4-byte inline RDMA WRITE to peer's
// slot. Receiver polls its local slot.
#ifndef MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_HPP_
#define MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_HPP_
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
#include <cstdint>
#include "ibgda.hpp"
namespace mscclpp {
using IbgdaMemoryId = uint32_t;
struct IbgdaLocalMr {
uint64_t addr; // host-order base virtual address (GPU-mapped)
uint32_t lkey_be; // BE-encoded local key
uint32_t pad;
};
struct IbgdaRemoteMr {
uint64_t addr; // host-order base virtual address on the peer
uint32_t rkey_be; // BE-encoded remote key
uint32_t pad;
};
// POD copied to GPU. Fields kept dense and 8-byte aligned.
struct IbgdaPortChannelDeviceHandle {
IbgdaQpHandle qp;
// Tables: indexed by IbgdaMemoryId. local_mrs[id] is for THIS rank;
// remote_mrs[id] is for the peer rank this channel is bound to.
// Both pointers must dereference from device code.
const IbgdaLocalMr* local_mrs;
const IbgdaRemoteMr* remote_mrs;
// Signal slot:
// sig_local_addr — pointer to a 4-byte slot in this rank's GPU memory
// used by wait() to poll. NIC writes here from the peer.
// sig_local_lkey — BE-encoded lkey of the local 4B inline staging
// buffer (same MR as sig_local_addr; we use it as both
// the receive slot for wait() and as a backing MR for
// completeness — but inline WRs don't actually need
// a local MR. Kept for symmetry / future non-inline
// fallback).
// sig_remote_addr — peer's 4-byte slot
// sig_rkey_be — peer's rkey for sig_remote_addr
// sig_seq — pointer to a 4-byte GPU-resident counter incremented
// by signal(); the value sent is the post-increment.
// Distinct from sig_local_addr.
uint32_t* sig_local_addr;
uint32_t sig_local_lkey; // unused for inline; kept for layout stability
uint64_t sig_remote_addr;
uint32_t sig_rkey_be;
uint32_t* sig_seq;
// Bound peer (informational); MemoryIds default to (0,0) in the simple
// case but we keep them for parity with PortChannelDeviceHandle.
IbgdaMemoryId dst;
IbgdaMemoryId src;
uint32_t peer_rank;
uint32_t _pad;
};
} // namespace mscclpp
#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM
#endif // MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_HPP_

View File

@@ -4,6 +4,9 @@
#ifndef MSCCLPP_PORT_CHANNEL_HPP_
#define MSCCLPP_PORT_CHANNEL_HPP_
#include <unordered_set>
#include <unordered_map>
#include "core.hpp"
#include "port_channel_device.hpp"
#include "proxy.hpp"
@@ -84,6 +87,14 @@ class ProxyService : public BaseProxyService {
std::vector<RegisteredMemory> memories_;
std::shared_ptr<Proxy> proxy_;
std::unordered_map<std::shared_ptr<BaseConnection>, int> inflightRequests_;
// Connections with WRs staged but not yet posted. Mapped to the count of
// WRs staged since the last postPending() call (used to trip the batch
// threshold). Distinct from `inflightRequests_` which counts QP-inflight
// WRs since the last flush() (used for QP overflow detection).
std::unordered_map<std::shared_ptr<BaseConnection>, int> stagedConns_;
std::unordered_set<std::shared_ptr<BaseConnection>> dirtyConns_;
void postPendingAll();
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw);
};

View File

@@ -53,6 +53,13 @@ class Proxy {
/// @return Shared pointer to FIFO.
std::shared_ptr<Fifo> fifo();
/// Set a callback invoked by the proxy thread when the FIFO transitions
/// from busy to idle (i.e., a poll returns no trigger after at least one
/// trigger was processed). Useful for batching: handlers can defer
/// expensive system calls (e.g., ibv_post_send) until the FIFO drains.
/// Must be called before start().
void setOnIdle(std::function<void()> onIdle);
private:
struct Impl;
std::unique_ptr<Impl> pimpl_;

View File

@@ -396,7 +396,10 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem
qp_.lock()->stageSendWrite(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset,
/*dstOffset=*/dstOffset, /*signaled=*/true);
qp_.lock()->postSend();
// NOTE: postSend() is intentionally deferred. The proxy service batches
// many triggers and calls postPending() once at idle. External callers
// that pair write() with flush() still observe correct semantics because
// flush() now drains staged WRs before polling the CQ.
INFO(CONN, "IBConnection write: from ", (uint8_t*)srcMr->getBuff() + srcOffset, " to ",
(uint8_t*)dstMrInfo.addr + dstOffset, ", size ", size);
@@ -438,12 +441,12 @@ void IBConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint6
/*size=*/0, /*wrId=*/0,
/*srcOffset=*/0, /*dstOffset=*/0,
/*signaled=*/true, /*immData=*/immData);
qp_.lock()->postSend();
// postSend deferred; see IBConnection::write.
INFO(CONN, "IBConnection signal forwarding: value ", oldValue, " -> ", newValue);
} else {
qp_.lock()->stageSendAtomicAdd(atomicSrcTransportInfo_.ibMr, dstMrInfo, /*wrId=*/0, dstOffset, newValue - oldValue,
/*signaled=*/true);
qp_.lock()->postSend();
// postSend deferred; see IBConnection::write.
INFO(CONN, "IBConnection atomic write: from ", src, " to ", (uint8_t*)dstMrInfo.addr + dstOffset, ", ", oldValue,
" -> ", newValue);
}
@@ -458,6 +461,9 @@ void IBConnection::flush(int64_t timeoutUsec) {
NpKit::CollectCpuEvent(NPKIT_EVENT_CONN_IB_FLUSH_ENTRY, 0, 0, *NpKit::GetCpuTimestamp(), 0);
#endif
// Drain any staged WRs that were deferred by the staging-only write path.
qp_.lock()->postSend();
// Check if the recv thread has already reported an error (e.g., QP entered error state).
if (recvThreadError_.load(std::memory_order_acquire)) {
THROW(CONN, Error, ErrorCode::SystemError, "IBConnection recv thread failed: ", recvThreadErrorMsg_);
@@ -503,10 +509,16 @@ void IBConnection::atomicAdd(RegisteredMemory dst, uint64_t dstOffset, int64_t v
qp_.lock()->stageSendAtomicAdd(atomicSrcTransportInfo_.ibMr, dstMrInfo, /*wrId=*/0, dstOffset,
static_cast<uint64_t>(value), /*signaled=*/true);
qp_.lock()->postSend();
// postSend deferred; see IBConnection::write.
INFO(CONN, "IBConnection atomicAdd: dst ", (uint8_t*)dstMrInfo.addr + dstOffset, ", value ", value);
}
void IBConnection::postPending() {
// Drain all staged WRs to the NIC in a single ibv_post_send call.
// Cheap no-op when nothing is staged (IbQp::postSend early-returns).
qp_.lock()->postSend();
}
// EthernetConnection
EthernetConnection::EthernetConnection(std::shared_ptr<Context> context, const Endpoint& localEndpoint,

177
src/core/ibgda.cc Normal file
View File

@@ -0,0 +1,177 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include "ibgda.hpp"
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
#include <cuda.h>
#include <cuda_runtime.h>
#include <infiniband/mlx5dv.h>
#include <unistd.h>
#include <cstdint>
#include <cstring>
#include <mscclpp/errors.hpp>
#include <mscclpp/gpu_utils.hpp>
#include "logger.hpp"
#include "mlx5dv_wrapper.hpp"
namespace mscclpp {
struct IbgdaResources::Impl {
void* sq_buf_host = nullptr;
size_t sq_bytes = 0;
void* dbrec_host = nullptr; // not necessarily aligned to anything > 8B
void* dbrec_register_addr = nullptr;
size_t dbrec_register_bytes = 0;
void* uar_page_host = nullptr; // page-aligned base of the UAR
void* state_dev = nullptr; // device-resident state
bool sq_registered = false;
bool dbrec_registered = false;
bool uar_registered = false;
};
namespace {
inline uintptr_t pageMask() {
static const uintptr_t mask = static_cast<uintptr_t>(::sysconf(_SC_PAGESIZE)) - 1;
return mask;
}
} // namespace
IbgdaResources::IbgdaResources(ibv_qp* qp) : pimpl_(std::make_unique<Impl>()) {
if (qp == nullptr) {
THROW(NET, Error, ErrorCode::InvalidUsage, "IbgdaResources: qp is null");
}
if (!MLX5DV::isAvailable()) {
THROW(NET, Error, ErrorCode::InvalidUsage, "IbgdaResources: libmlx5 not available");
}
// 1. Extract sq.buf / dbrec / bf via mlx5dv_init_obj.
struct mlx5dv_qp dvQp{};
if (MLX5DV::mlx5dv_init_obj_qp(qp, &dvQp) != 0) {
THROW(NET, IbError, errno, "mlx5dv_init_obj(QP) failed (errno ", errno, ")");
}
pimpl_->sq_buf_host = dvQp.sq.buf;
pimpl_->sq_bytes = static_cast<size_t>(dvQp.sq.wqe_cnt) * dvQp.sq.stride;
pimpl_->dbrec_host = dvQp.dbrec;
if (pimpl_->sq_buf_host == nullptr || pimpl_->dbrec_host == nullptr || dvQp.bf.reg == nullptr) {
THROW(NET, Error, ErrorCode::InvalidUsage, "mlx5dv_qp returned null pointers");
}
if (dvQp.sq.wqe_cnt == 0 || (dvQp.sq.wqe_cnt & (dvQp.sq.wqe_cnt - 1)) != 0) {
THROW(NET, Error, ErrorCode::InvalidUsage, "sq.wqe_cnt must be a power of 2, got ", dvQp.sq.wqe_cnt);
}
// 2. Register the SQ buffer with CUDA (host-mapped). We register the
// enclosing whole pages because cudaHostRegister requires that.
{
uintptr_t base = reinterpret_cast<uintptr_t>(pimpl_->sq_buf_host);
uintptr_t pageBase = base & ~pageMask();
size_t pad = static_cast<size_t>(base - pageBase);
size_t regBytes = (pad + pimpl_->sq_bytes + pageMask()) & ~pageMask();
void* regAddr = reinterpret_cast<void*>(pageBase);
cudaError_t e = cudaHostRegister(regAddr, regBytes, cudaHostRegisterDefault);
if (e != cudaSuccess) {
THROW(NET, SysError, static_cast<int>(e), "cudaHostRegister(sq.buf) failed: ",
cudaGetErrorString(e));
}
pimpl_->sq_registered = true;
void* dev = nullptr;
MSCCLPP_CUDATHROW(cudaHostGetDevicePointer(&dev, pimpl_->sq_buf_host, 0));
handle_.sq_buf = dev;
}
// 3. Register the DBR record. mlx5 hands us a pointer into a small block
// of DBR records (8 B each); register a whole page starting from its
// page base.
{
uintptr_t base = reinterpret_cast<uintptr_t>(pimpl_->dbrec_host);
uintptr_t pageBase = base & ~pageMask();
size_t regBytes = pageMask() + 1; // 1 page
pimpl_->dbrec_register_addr = reinterpret_cast<void*>(pageBase);
pimpl_->dbrec_register_bytes = regBytes;
cudaError_t e = cudaHostRegister(pimpl_->dbrec_register_addr, regBytes, cudaHostRegisterDefault);
if (e == cudaErrorHostMemoryAlreadyRegistered) {
// The DBR page may overlap with another QP we already registered.
// Treat as success but skip unregister on shutdown. Clear sticky error.
pimpl_->dbrec_registered = false;
(void)cudaGetLastError();
} else if (e != cudaSuccess) {
THROW(NET, SysError, static_cast<int>(e), "cudaHostRegister(dbrec) failed: ",
cudaGetErrorString(e));
} else {
pimpl_->dbrec_registered = true;
}
void* dev = nullptr;
MSCCLPP_CUDATHROW(cudaHostGetDevicePointer(&dev, pimpl_->dbrec_host, 0));
handle_.dbrec = static_cast<uint32_t*>(dev);
}
// 4. Map the UAR page (NIC MMIO) into GPU VA via cuMemHostRegister IOMEMORY.
{
uintptr_t bfAddr = reinterpret_cast<uintptr_t>(dvQp.bf.reg);
uintptr_t pageAddr = bfAddr & ~uintptr_t(4095); // UAR pages are 4 KB
uintptr_t bfOffset = bfAddr - pageAddr;
pimpl_->uar_page_host = reinterpret_cast<void*>(pageAddr);
CUresult cuRes =
cuMemHostRegister(pimpl_->uar_page_host, 4096, CU_MEMHOSTREGISTER_IOMEMORY);
if (cuRes == CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED) {
// The UAR page is shared across QPs on the same context; if a sibling
// QP already registered it, treat as success and skip unregister.
pimpl_->uar_registered = false;
(void)cudaGetLastError();
} else if (cuRes != CUDA_SUCCESS) {
const char* s = nullptr;
cuGetErrorString(cuRes, &s);
THROW(NET, SysError, static_cast<int>(cuRes),
"cuMemHostRegister(UAR, IOMEMORY) failed: ", s ? s : "?",
". Ensure 'options nvidia NVreg_RegistryDwords=\"PeerMappingOverride=1;\"' "
"is set and nvidia_peermem is loaded.");
} else {
pimpl_->uar_registered = true;
}
CUdeviceptr dPage = 0;
MSCCLPP_CUTHROW(cuMemHostGetDevicePointer(&dPage, pimpl_->uar_page_host, 0));
handle_.bf_reg = reinterpret_cast<uint64_t*>(dPage + bfOffset);
handle_.bf_offset = static_cast<uint32_t>(bfOffset);
}
// 5. Allocate GPU-resident state (zero-initialized).
{
constexpr size_t kStateBytes = 4 * sizeof(uint64_t); // resv/ready/prod/lock
MSCCLPP_CUDATHROW(cudaMalloc(&pimpl_->state_dev, kStateBytes));
MSCCLPP_CUDATHROW(cudaMemset(pimpl_->state_dev, 0, kStateBytes));
handle_.state = static_cast<uint64_t*>(pimpl_->state_dev);
}
// 6. Fill the rest of the handle.
handle_.qpn = qp->qp_num;
handle_.wqe_cnt = dvQp.sq.wqe_cnt;
handle_.stride = dvQp.sq.stride;
}
IbgdaResources::~IbgdaResources() {
if (!pimpl_) return;
if (pimpl_->state_dev) {
cudaFree(pimpl_->state_dev);
}
if (pimpl_->uar_registered && pimpl_->uar_page_host) {
cuMemHostUnregister(pimpl_->uar_page_host);
}
if (pimpl_->dbrec_registered && pimpl_->dbrec_register_addr) {
cudaHostUnregister(pimpl_->dbrec_register_addr);
}
if (pimpl_->sq_registered && pimpl_->sq_buf_host) {
uintptr_t base = reinterpret_cast<uintptr_t>(pimpl_->sq_buf_host);
void* regAddr = reinterpret_cast<void*>(base & ~pageMask());
cudaHostUnregister(regAddr);
}
}
} // namespace mscclpp
#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM

View File

@@ -37,6 +37,12 @@ class BaseConnection {
virtual void atomicAdd(RegisteredMemory dst, uint64_t dstOffset, int64_t value) = 0;
/// Post any locally-staged work requests to the NIC. Called by the proxy
/// service to amortize ibv_post_send across many FIFO triggers. The default
/// no-op is correct for connections whose write/updateAndSync/atomicAdd
/// already post immediately (e.g., CudaIpcConnection).
virtual void postPending() {}
virtual void flush(int64_t timeoutUsec = -1) = 0;
/// Start signal forwarding to the given memory address.
@@ -152,6 +158,8 @@ class IBConnection : public BaseConnection {
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override;
void atomicAdd(RegisteredMemory dst, uint64_t dstOffset, int64_t value) override;
void postPending() override;
void flush(int64_t timeoutUsec) override;
};

View File

@@ -82,6 +82,9 @@ class IbQp {
int pollRecvCq();
IbQpInfo& getInfo() { return info_; }
/// Raw ibv_qp pointer. Owned by this IbQp. Provided so other components
/// (e.g. IbgdaResources) can call mlx5dv_init_obj on the same QP.
ibv_qp* getRawQp() const { return qp_; }
int getSendWcStatus(int idx) const;
std::string getSendWcStatusString(int idx) const;
int getNumSendCqItems() const;

View File

@@ -0,0 +1,71 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
// IBGDA-style GPU-direct work-request submission for an mlx5 QP.
//
// Given an existing ibv_qp (RC, plain ibv_create_qp), IbgdaResources extracts
// the SQ ring buffer, doorbell record, and BlueFlame UAR via mlx5dv_init_obj
// and maps them into GPU virtual address space:
// - SQ buf and DBR are host RAM, mapped via cudaHostRegister.
// - The BF UAR is device MMIO, mapped via cuMemHostRegister(IOMEMORY).
// It also allocates a small GPU-resident state struct (resv_head / ready_head
// / prod_idx / lock) used by the device-side WQE writer.
//
// The host retains ownership of the QP and never touches sq.buf/dbrec/bf
// itself once GPU posting starts — completion polling on the send CQ remains
// a host-side ibv_poll_cq() call.
#ifndef MSCCLPP_IBGDA_HPP_
#define MSCCLPP_IBGDA_HPP_
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
#include <cstdint>
#include <memory>
struct ibv_qp;
namespace mscclpp {
// POD copied to the GPU; consumed by the device-side WQE writer (added in
// Stage 2). Field order/sizes must match include/mscclpp/ibgda_device.hpp.
struct IbgdaQpHandle {
// Device-mapped pointers (MUST be dereferenceable from the GPU only).
void* sq_buf; // SQ WQE ring (host RAM, cudaHostRegister)
uint32_t* dbrec; // 32-bit doorbell record (host RAM, cudaHostRegister)
uint64_t* bf_reg; // BlueFlame doorbell (NIC MMIO, cuMemHostRegister IOMEMORY)
// GPU-resident bookkeeping state (allocated by IbgdaResources).
// Layout: { uint64_t resv_head; uint64_t ready_head; uint64_t prod_idx; int post_send_lock; }.
uint64_t* state;
// Constants (host-order).
uint32_t qpn;
uint32_t wqe_cnt; // SQ length in WQEBBs (power of 2)
uint32_t stride; // bytes per WQEBB (typically 64)
uint32_t bf_offset; // byte offset of BF doorbell within its UAR page
};
// Wraps an existing ibv_qp and prepares the GPU mappings + GPU state. Owns
// the cudaHostRegister/cuMemHostRegister registrations + GPU state buffer.
// Lifetime must enclose any kernel launch that uses getHandle().
class IbgdaResources {
public:
// The qp must already be modified to RTS before kernel posting; that is
// the caller's responsibility.
explicit IbgdaResources(ibv_qp* qp);
~IbgdaResources();
IbgdaResources(const IbgdaResources&) = delete;
IbgdaResources& operator=(const IbgdaResources&) = delete;
const IbgdaQpHandle& getHandle() const { return handle_; }
private:
struct Impl;
std::unique_ptr<Impl> pimpl_;
IbgdaQpHandle handle_{};
};
} // namespace mscclpp
#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM
#endif // MSCCLPP_IBGDA_HPP_

View File

@@ -0,0 +1,426 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
// Stage 2: device-side IBGDA primitives.
//
// Operates on an `IbgdaQpHandle` (see ibgda.hpp). All pointers in the handle
// must already be GPU-mapped by `IbgdaResources` on the host. This header is
// device-only and should be included from .cu translation units.
//
// State buffer layout (allocated by IbgdaResources, zero-initialised):
// offset 0: uint64_t resv_head // next free WQEBB slot (atomic)
// offset 8: uint64_t ready_head // next slot whose WQE write is done
// offset 16: uint64_t prod_idx // last value rung on the doorbell
// offset 24: int post_send_lock // cta-level lock for ringing
//
// Conventions copied from NVSHMEM/DeepEP:
// - wqe_idx is a 16-bit-truncated cyclic index into the SQ ring.
// - one RDMA WRITE WQE = 1 WQEBB (64B): ctrl(16) + raddr(16) + data(16) + pad(16).
// ctrl_seg.qpn_ds encodes ds=3 (three 16-byte segments).
// - DBR record holds BE32(prod_idx & 0xFFFF) in the SQ slot (dbrec[1] for
// mlx5; IbgdaResources sets handle.dbrec to that slot directly).
// - BF doorbell: 64-bit write of {opmod_idx_opcode, qpn_ds} to bf_reg.
#ifndef MSCCLPP_IBGDA_DEVICE_CUH_
#define MSCCLPP_IBGDA_DEVICE_CUH_
#include <cstdint>
#include <cuda_runtime.h>
#include "ibgda.hpp"
namespace mscclpp {
namespace ibgda {
#ifndef MSCCLPP_IBGDA_OPCODE_RDMA_WRITE
#define MSCCLPP_IBGDA_OPCODE_RDMA_WRITE 0x08
#endif
#ifndef MSCCLPP_IBGDA_CTRL_CQ_UPDATE
#define MSCCLPP_IBGDA_CTRL_CQ_UPDATE uint8_t(0x8)
#endif
// ---- BE swap helpers (PTX prmt) -------------------------------------------
__device__ static __forceinline__ uint32_t HtoBE32(uint32_t x) {
uint32_t r;
asm("{ .reg .b32 ig; prmt.b32 %0, %1, ig, 0x0123; }" : "=r"(r) : "r"(x));
return r;
}
__device__ static __forceinline__ uint64_t HtoBE64(uint64_t x) {
uint64_t r;
asm("{\n\t"
".reg .b32 ig;\n\t"
".reg .b32 lo, hi, nl, nh;\n\t"
"mov.b64 {lo,hi}, %1;\n\t"
"prmt.b32 nh, lo, ig, 0x0123;\n\t"
"prmt.b32 nl, hi, ig, 0x0123;\n\t"
"mov.b64 %0, {nl,nh};\n\t" "}"
: "=l"(r) : "l"(x));
return r;
}
// ---- Memory ordering helpers ----------------------------------------------
// st.relaxed.sys + st.release.sys equivalents. We use `volatile` writes plus
// __threadfence_system() at the points the caller needs ordering, mirroring
// the NVSHMEM `st_na_*` style without depending on its internal headers.
__device__ static __forceinline__ void store_relaxed_u32(uint32_t* p, uint32_t v) {
*reinterpret_cast<volatile uint32_t*>(p) = v;
}
__device__ static __forceinline__ void store_relaxed_u64(uint64_t* p, uint64_t v) {
*reinterpret_cast<volatile uint64_t*>(p) = v;
}
__device__ static __forceinline__ void store_relaxed_int4(int4* p, int4 v) {
asm volatile("st.relaxed.sys.v4.b32 [%0], {%1, %2, %3, %4};" ::
"l"(p), "r"(v.x), "r"(v.y), "r"(v.z), "r"(v.w) : "memory");
}
__device__ static __forceinline__ void store_release_u32(uint32_t* p, uint32_t v) {
asm volatile("st.release.sys.b32 [%0], %1;" :: "l"(p), "r"(v) : "memory");
}
__device__ static __forceinline__ void store_release_u64(uint64_t* p, uint64_t v) {
asm volatile("st.release.sys.b64 [%0], %1;" :: "l"(p), "l"(v) : "memory");
}
// ---- WQE segment layout (mirrors mlx5_wqe_*; redefined to avoid pulling
// libmlx5 headers into device code) ----------------------------------------
struct __align__(8) CtrlSeg {
uint32_t opmod_idx_opcode; // BE32: (wqe_idx << 8) | opcode
uint32_t qpn_ds; // BE32: (qpn << 8) | ds
uint8_t signature;
uint8_t rsvd0;
uint8_t rsvd1;
uint8_t fm_ce_se; // CQ_UPDATE in bit 3 (0x8)
uint32_t imm; // BE32 immediate (or 0)
};
struct __align__(8) RaddrSeg {
uint64_t raddr; // BE64
uint32_t rkey; // BE32 (caller pre-swaps)
uint32_t reserved;
};
struct __align__(8) DataSeg {
uint32_t byte_count; // BE32
uint32_t lkey; // BE32 (caller pre-swaps)
uint64_t addr; // BE64
};
static_assert(sizeof(CtrlSeg) == 16, "CtrlSeg must be 16B");
static_assert(sizeof(RaddrSeg) == 16, "RaddrSeg must be 16B");
static_assert(sizeof(DataSeg) == 16, "DataSeg must be 16B");
// ---- State accessors ------------------------------------------------------
struct StateView {
unsigned long long* resv_head;
unsigned long long* ready_head;
unsigned long long* prod_idx;
int* post_send_lock;
};
__device__ static __forceinline__ StateView stateView(const IbgdaQpHandle& h) {
auto* base = reinterpret_cast<uint64_t*>(h.state);
StateView v;
v.resv_head = reinterpret_cast<unsigned long long*>(&base[0]);
v.ready_head = reinterpret_cast<unsigned long long*>(&base[1]);
v.prod_idx = reinterpret_cast<unsigned long long*>(&base[2]);
v.post_send_lock = reinterpret_cast<int*>(&base[3]);
return v;
}
// ---- WQE ring helpers -----------------------------------------------------
__device__ static __forceinline__
uint64_t reserve_wqe_slots(const IbgdaQpHandle& h, uint32_t num_wqes) {
auto v = stateView(h);
return atomicAdd(v.resv_head, static_cast<unsigned long long>(num_wqes));
}
__device__ static __forceinline__
void* get_wqe_ptr(const IbgdaQpHandle& h, uint16_t wqe_idx) {
uint16_t mask = static_cast<uint16_t>(h.wqe_cnt - 1);
uint16_t idx = wqe_idx & mask;
return reinterpret_cast<uint8_t*>(h.sq_buf) +
(static_cast<size_t>(idx) * h.stride);
}
// ---- Doorbell record + BF ring -------------------------------------------
__device__ static __forceinline__
void update_dbr(const IbgdaQpHandle& h, uint32_t dbrec_head) {
// BE32(dbrec_head & 0xFFFF) — see DeepEP/NVSHMEM.
uint32_t v;
asm("{\n\t"
".reg .b32 lo16; .reg .b32 ig;\n\t"
"and.b32 lo16, %1, 0xffff;\n\t"
"prmt.b32 %0, lo16, ig, 0x123;\n\t"
"}" : "=r"(v) : "r"(dbrec_head));
// dbrec is the SQ slot directly (set by IbgdaResources).
store_release_u32(reinterpret_cast<uint32_t*>(h.dbrec), v);
}
__device__ static __forceinline__
void ring_db(const IbgdaQpHandle& h, uint16_t prod_idx) {
// 64-bit BF write of {opmod_idx_opcode, qpn_ds}.
// opmod_idx_opcode = BE32(prod_idx << 8); qpn_ds = BE32(qpn << 8).
uint32_t lo = HtoBE32(static_cast<uint32_t>(prod_idx) << 8);
uint32_t hi = HtoBE32(h.qpn << 8);
uint64_t v = static_cast<uint64_t>(lo) | (static_cast<uint64_t>(hi) << 32);
store_release_u64(h.bf_reg, v);
}
__device__ static __forceinline__
void post_send(const IbgdaQpHandle& h, uint64_t new_prod_idx) {
auto v = stateView(h);
// CTA-level lock so concurrent posters serialise their DBR/BF writes.
while (atomicCAS(v.post_send_lock, 0, 1) == 1) {}
__threadfence();
unsigned long long old_prod_idx =
atomicMax(v.prod_idx, static_cast<unsigned long long>(new_prod_idx));
if (new_prod_idx > old_prod_idx) {
update_dbr(h, static_cast<uint32_t>(new_prod_idx));
ring_db(h, static_cast<uint16_t>(new_prod_idx));
}
__threadfence();
store_release_u32(reinterpret_cast<uint32_t*>(v.post_send_lock), 0);
}
// ---- Submit (publish ready_head, optionally ring DB) ----------------------
template <bool kAlwaysDoPostSend>
__device__ static __forceinline__
void submit_requests(const IbgdaQpHandle& h, uint64_t base_wqe_idx,
uint32_t num_wqes, int message_idx = 0) {
auto v = stateView(h);
uint64_t new_wqe_idx = base_wqe_idx + num_wqes;
// The WQE writes themselves must be globally visible before we publish.
__threadfence_system();
// Wait for any earlier reservations to publish first, preserving order.
auto base_ull = static_cast<unsigned long long>(base_wqe_idx);
auto new_ull = static_cast<unsigned long long>(new_wqe_idx);
while (atomicCAS(v.ready_head, base_ull, new_ull) != base_ull) {}
constexpr int kBatch = 4;
if (kAlwaysDoPostSend || ((message_idx + 1) % kBatch == 0)) {
post_send(h, new_wqe_idx);
}
}
// Publish-only variant: advance ready_head in WQE-issue order (preserving the
// in-order chain that submit_requests builds) but DO NOT ring the doorbell.
// Use this for the data WRs in a coalesced "unsignaled body + signaled tail"
// pattern; the trailing WR's submit_requests<true> will atomicMax prod_idx
// past all earlier WRs and ring the doorbell once for the whole batch.
__device__ static __forceinline__
void submit_no_db(const IbgdaQpHandle& h, uint64_t base_wqe_idx,
uint32_t num_wqes) {
auto v = stateView(h);
uint64_t new_wqe_idx = base_wqe_idx + num_wqes;
__threadfence_system();
auto base_ull = static_cast<unsigned long long>(base_wqe_idx);
auto new_ull = static_cast<unsigned long long>(new_wqe_idx);
while (atomicCAS(v.ready_head, base_ull, new_ull) != base_ull) {}
// No post_send: the next signaled WR on this QP (or any concurrent ringer)
// will atomicMax prod_idx past us and the NIC will pick up our slot.
}
// ---- WQE writers ----------------------------------------------------------
// All addresses BE-encoded inside; lkey/rkey expected pre-BE-swapped.
// `signal_cqe`: if true, set CQ_UPDATE so the NIC posts a CQE for this WR;
// if false, the WR is UNSIGNALED — useful for batched data WRs whose
// completion is implicitly bounded by a later signaled WR on the same QP.
__device__ static __forceinline__
void write_rdma_write_wqe(const IbgdaQpHandle& h, void* wqe_slot,
uint64_t laddr, uint32_t lkey_be,
uint64_t raddr, uint32_t rkey_be,
uint32_t bytes, uint16_t wqe_idx,
bool signal_cqe = true) {
auto* ctrl = reinterpret_cast<CtrlSeg*>(wqe_slot);
auto* raddrp = reinterpret_cast<RaddrSeg*>(reinterpret_cast<uint8_t*>(wqe_slot) + 16);
auto* datap = reinterpret_cast<DataSeg*>(reinterpret_cast<uint8_t*>(wqe_slot) + 32);
CtrlSeg c{};
c.opmod_idx_opcode = HtoBE32((static_cast<uint32_t>(wqe_idx) << 8) |
MSCCLPP_IBGDA_OPCODE_RDMA_WRITE);
c.qpn_ds = HtoBE32((h.qpn << 8) | 3u);
c.fm_ce_se = signal_cqe ? MSCCLPP_IBGDA_CTRL_CQ_UPDATE : 0u;
c.imm = 0;
RaddrSeg r{};
r.raddr = HtoBE64(raddr);
r.rkey = rkey_be;
r.reserved = 0;
DataSeg d{};
d.byte_count = HtoBE32(bytes);
d.lkey = lkey_be;
d.addr = HtoBE64(laddr);
store_relaxed_int4(reinterpret_cast<int4*>(ctrl), *reinterpret_cast<int4*>(&c));
store_relaxed_int4(reinterpret_cast<int4*>(raddrp), *reinterpret_cast<int4*>(&r));
store_relaxed_int4(reinterpret_cast<int4*>(datap), *reinterpret_cast<int4*>(&d));
}
// One-shot helper: reserve a slot, write WQE, submit + ring doorbell.
// Returns the wqe_idx of the issued WR (caller can match against CQE wr_id
// if it embeds the index there; here we don't use wr_id).
__device__ static __forceinline__
uint64_t rdma_write(const IbgdaQpHandle& h,
uint64_t laddr, uint32_t lkey_be,
uint64_t raddr, uint32_t rkey_be,
uint32_t bytes,
bool signal_cqe = true, bool ring_db = true) {
uint64_t base = reserve_wqe_slots(h, 1);
void* slot = get_wqe_ptr(h, static_cast<uint16_t>(base));
write_rdma_write_wqe(h, slot, laddr, lkey_be, raddr, rkey_be, bytes,
static_cast<uint16_t>(base), signal_cqe);
if (ring_db) submit_requests<true>(h, base, 1);
else submit_no_db(h, base, 1);
return base;
}
// ---- Inline 4-byte RDMA WRITE --------------------------------------------
// Single WQEBB carrying (ctrl=16) + (raddr=16) + (inline_hdr=4 + data=4)
// padded to 48B. ds=3 in ctrl_seg.
//
// MLX5_INLINE_SEG = 0x80000000 in inline_seg.byte_count signals "inline".
//
// The 4-byte payload `value` is written verbatim to remote `raddr`. The
// caller does NOT need a local MR for this op.
#ifndef MSCCLPP_IBGDA_INLINE_SEG
#define MSCCLPP_IBGDA_INLINE_SEG 0x80000000u
#endif
__device__ static __forceinline__
void write_rdma_write_inl4_wqe(const IbgdaQpHandle& h, void* wqe_slot,
uint32_t value,
uint64_t raddr, uint32_t rkey_be,
uint16_t wqe_idx,
bool signal_cqe = true) {
auto* base8 = reinterpret_cast<uint8_t*>(wqe_slot);
auto* ctrl = reinterpret_cast<CtrlSeg*>(base8 + 0);
auto* raddrp = reinterpret_cast<RaddrSeg*>(base8 + 16);
// inline header is a 4-byte field (byte_count) at offset 32, followed by
// 4 bytes of payload at offset 36.
uint32_t* inl_hdr = reinterpret_cast<uint32_t*>(base8 + 32);
uint32_t* inl_dat = reinterpret_cast<uint32_t*>(base8 + 36);
CtrlSeg c{};
c.opmod_idx_opcode = HtoBE32((static_cast<uint32_t>(wqe_idx) << 8) |
MSCCLPP_IBGDA_OPCODE_RDMA_WRITE);
// ds=3 (3 * 16B = 48B; ctrl + raddr + 8B of inline hdr+data, padded).
c.qpn_ds = HtoBE32((h.qpn << 8) | 3u);
c.fm_ce_se = signal_cqe ? MSCCLPP_IBGDA_CTRL_CQ_UPDATE : 0u;
c.imm = 0;
RaddrSeg r{};
r.raddr = HtoBE64(raddr);
r.rkey = rkey_be;
r.reserved = 0;
store_relaxed_int4(reinterpret_cast<int4*>(ctrl), *reinterpret_cast<int4*>(&c));
store_relaxed_int4(reinterpret_cast<int4*>(raddrp), *reinterpret_cast<int4*>(&r));
// inline byte_count: 4 | INLINE_SEG, BE32.
store_relaxed_u32(inl_hdr, HtoBE32(4u | MSCCLPP_IBGDA_INLINE_SEG));
// The 4-byte payload is written as-is in network order — the NIC treats
// the inline data block as opaque bytes copied to remote memory.
store_relaxed_u32(inl_dat, value);
}
__device__ static __forceinline__
uint64_t rdma_write_inl4(const IbgdaQpHandle& h,
uint32_t value,
uint64_t raddr, uint32_t rkey_be,
bool signal_cqe = true, bool ring_db = true) {
uint64_t base = reserve_wqe_slots(h, 1);
void* slot = get_wqe_ptr(h, static_cast<uint16_t>(base));
write_rdma_write_inl4_wqe(h, slot, value, raddr, rkey_be,
static_cast<uint16_t>(base), signal_cqe);
if (ring_db) submit_requests<true>(h, base, 1);
else submit_no_db(h, base, 1);
return base;
}
// ---- Inline 8-byte RDMA WRITE --------------------------------------------
// Same WQE shape as inl4 (ds=3 / 48B): ctrl(16) + raddr(16) + inline_hdr(4)
// + payload(8) + pad(4). Used to publish 8B counters (e.g. dispatch
// per-(local_expert, src_rank) recv_count slot polled with int64 acquire).
__device__ static __forceinline__
void write_rdma_write_inl8_wqe(const IbgdaQpHandle& h, void* wqe_slot,
uint64_t value,
uint64_t raddr, uint32_t rkey_be,
uint16_t wqe_idx,
bool signal_cqe = true) {
auto* base8 = reinterpret_cast<uint8_t*>(wqe_slot);
auto* ctrl = reinterpret_cast<CtrlSeg*>(base8 + 0);
auto* raddrp = reinterpret_cast<RaddrSeg*>(base8 + 16);
uint32_t* inl_hdr = reinterpret_cast<uint32_t*>(base8 + 32);
// Two 32-bit halves of the 8B payload at offsets 36 / 40.
uint32_t* inl_dat_lo = reinterpret_cast<uint32_t*>(base8 + 36);
uint32_t* inl_dat_hi = reinterpret_cast<uint32_t*>(base8 + 40);
CtrlSeg c{};
c.opmod_idx_opcode = HtoBE32((static_cast<uint32_t>(wqe_idx) << 8) |
MSCCLPP_IBGDA_OPCODE_RDMA_WRITE);
c.qpn_ds = HtoBE32((h.qpn << 8) | 3u);
c.fm_ce_se = signal_cqe ? MSCCLPP_IBGDA_CTRL_CQ_UPDATE : 0u;
c.imm = 0;
RaddrSeg r{};
r.raddr = HtoBE64(raddr);
r.rkey = rkey_be;
r.reserved = 0;
store_relaxed_int4(reinterpret_cast<int4*>(ctrl), *reinterpret_cast<int4*>(&c));
store_relaxed_int4(reinterpret_cast<int4*>(raddrp), *reinterpret_cast<int4*>(&r));
// inline byte_count: 8 | INLINE_SEG, BE32.
store_relaxed_u32(inl_hdr, HtoBE32(8u | MSCCLPP_IBGDA_INLINE_SEG));
// 8B payload as two 32-bit words; treated as opaque bytes by the NIC.
store_relaxed_u32(inl_dat_lo, static_cast<uint32_t>(value & 0xffffffffull));
store_relaxed_u32(inl_dat_hi, static_cast<uint32_t>(value >> 32));
}
__device__ static __forceinline__
uint64_t rdma_write_inl8(const IbgdaQpHandle& h,
uint64_t value,
uint64_t raddr, uint32_t rkey_be,
bool signal_cqe = true, bool ring_db = true) {
uint64_t base = reserve_wqe_slots(h, 1);
void* slot = get_wqe_ptr(h, static_cast<uint16_t>(base));
write_rdma_write_inl8_wqe(h, slot, value, raddr, rkey_be,
static_cast<uint16_t>(base), signal_cqe);
if (ring_db) submit_requests<true>(h, base, 1);
else submit_no_db(h, base, 1);
return base;
}
// ---- Warp-coalesced burst (lane 0 issues N WRs as a single batch) --------
// This is the "warp granularity" pattern used by NVSHMEM IBGDA: a single
// thread reserves N contiguous slots (one atomic), writes N WQEs in a tight
// loop, then publishes ready_head and rings the doorbell exactly once. This
// amortises the ready_head CAS-spin and the BF doorbell MMIO across N WRs.
//
// Caller is responsible for ensuring that only one thread per warp invokes
// this with a given (laddr/raddr) layout — typically `if (lane_id == 0)`.
// All N WRs share the same lkey_be/rkey_be and have stride `bytes` (i.e. a
// contiguous chunk laddr_base..laddr_base+N*bytes -> raddr_base..).
__device__ static __forceinline__
uint64_t rdma_write_strided_burst(const IbgdaQpHandle& h,
uint64_t laddr_base, uint32_t lkey_be,
uint64_t raddr_base, uint32_t rkey_be,
uint32_t bytes, uint32_t num_wrs) {
if (num_wrs == 0) return 0;
uint64_t base = reserve_wqe_slots(h, num_wrs);
for (uint32_t i = 0; i < num_wrs; ++i) {
uint16_t idx = static_cast<uint16_t>(base + i);
void* slot = get_wqe_ptr(h, idx);
write_rdma_write_wqe(h, slot,
laddr_base + size_t(i) * bytes, lkey_be,
raddr_base + size_t(i) * bytes, rkey_be,
bytes, idx);
}
submit_requests<true>(h, base, num_wrs);
return base;
}
} // namespace ibgda
} // namespace mscclpp
#endif // MSCCLPP_IBGDA_DEVICE_CUH_

View File

@@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
// Stage 3: device-side methods that act on `IbgdaPortChannelDeviceHandle`.
// Split from the POD header so that the POD can be included by host code
// (which doesn't want PTX inline asm or CUDA-only types) while these inline
// device functions only appear in .cu translation units.
#ifndef MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_CUH_
#define MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_CUH_
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
#include <cuda_runtime.h>
#include <mscclpp/ibgda_port_channel_device.hpp>
#include "ibgda_device.cuh"
namespace mscclpp {
namespace ibgda {
// ---------------- put -----------------------------------------------------
// `signal_cqe` / `ring_db` default to true (preserves prior call sites);
// pass false/false for batched data WRs whose completion is implicitly
// bounded by a later signaled WR on the same QP (e.g., the trailing count
// or flag write at the end of LL dispatch / combine).
__device__ static __forceinline__
void port_put(const IbgdaPortChannelDeviceHandle& ch,
IbgdaMemoryId dst, uint64_t dstOffset,
IbgdaMemoryId src, uint64_t srcOffset,
uint64_t size,
bool signal_cqe = true, bool ring_db = true) {
IbgdaLocalMr l = ch.local_mrs[src];
IbgdaRemoteMr r = ch.remote_mrs[dst];
rdma_write(ch.qp,
l.addr + srcOffset, l.lkey_be,
r.addr + dstOffset, r.rkey_be,
static_cast<uint32_t>(size),
signal_cqe, ring_db);
}
__device__ static __forceinline__
void port_put(const IbgdaPortChannelDeviceHandle& ch,
uint64_t dstOffset, uint64_t srcOffset, uint64_t size,
bool signal_cqe = true, bool ring_db = true) {
port_put(ch, ch.dst, dstOffset, ch.src, srcOffset, size, signal_cqe, ring_db);
}
__device__ static __forceinline__
void port_put(const IbgdaPortChannelDeviceHandle& ch,
uint64_t offset, uint64_t size,
bool signal_cqe = true, bool ring_db = true) {
port_put(ch, offset, offset, size, signal_cqe, ring_db);
}
// ---------------- signal --------------------------------------------------
// Increments the local sequence counter (system-wide atomic so multiple
// CTAs serialise) and writes the new value as an inline 4B RDMA WRITE to
// the peer's signal slot.
__device__ static __forceinline__
uint32_t port_signal(const IbgdaPortChannelDeviceHandle& ch) {
// Use atomicAdd_system so the counter is consistent across CTAs sharing
// this handle (rare but possible).
uint32_t prev = atomicAdd(ch.sig_seq, 1u);
uint32_t v = prev + 1u;
rdma_write_inl4(ch.qp, v, ch.sig_remote_addr, ch.sig_rkey_be);
return v;
}
// ---------------- wait ----------------------------------------------------
// Returns when the local signal slot's value reaches `expected` (sticky;
// callers track the expected counter externally — same pattern as
// PortChannelDeviceHandle::wait()).
__device__ static __forceinline__
void port_wait(const IbgdaPortChannelDeviceHandle& ch, uint32_t expected,
int64_t maxSpinCount = 10000000) {
volatile uint32_t* p = ch.sig_local_addr;
if (maxSpinCount < 0) {
while (*p < expected) { /* spin */ }
return;
}
for (int64_t i = 0; i < maxSpinCount; ++i) {
if (*p >= expected) return;
}
// Out of patience — caller can detect with a follow-up poll(); we do NOT
// assert here to keep the device path lean.
}
__device__ static __forceinline__
bool port_poll(const IbgdaPortChannelDeviceHandle& ch, uint32_t expected) {
return *(volatile uint32_t*)ch.sig_local_addr >= expected;
}
} // namespace ibgda
} // namespace mscclpp
#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM
#endif // MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_CUH_

View File

@@ -28,6 +28,11 @@ struct MLX5DV {
/// Returns 0 on success (device supports Data Direct), non-zero otherwise.
static int mlx5dv_get_data_direct_sysfs_path(struct ibv_context* context, char* buf, size_t buf_len);
/// Wraps mlx5dv_init_obj(MLX5DV_OBJ_QP). Returns 0 on success.
/// `out` must be a pointer to an mlx5dv_qp; we keep this typeless to avoid
/// pulling <infiniband/mlx5dv.h> into the public-ish wrapper header.
static int mlx5dv_init_obj_qp(struct ibv_qp* qp, void* out);
private:
static void* dlsym(const std::string& symbol, bool allowReturnNull = false);
};

View File

@@ -121,6 +121,20 @@ int MLX5DV::mlx5dv_get_data_direct_sysfs_path(struct ibv_context* context, char*
return impl(context, buf, buf_len);
}
int MLX5DV::mlx5dv_init_obj_qp(struct ibv_qp* qp, void* out_qp) {
using FuncType = int (*)(struct mlx5dv_obj*, uint64_t);
static FuncType impl = nullptr;
if (!impl) {
void* ptr = MLX5DV::dlsym("mlx5dv_init_obj", /*allowReturnNull=*/true);
if (!ptr) return -1;
impl = reinterpret_cast<FuncType>(ptr);
}
struct mlx5dv_obj obj{};
obj.qp.in = qp;
obj.qp.out = static_cast<struct mlx5dv_qp*>(out_qp);
return impl(&obj, MLX5DV_OBJ_QP);
}
} // namespace mscclpp
#endif // defined(MSCCLPP_USE_MLX5DV)

View File

@@ -3,12 +3,75 @@
#include <mscclpp/numa.hpp>
#include <mscclpp/port_channel.hpp>
#include <unistd.h>
#include <atomic>
#include <mutex>
#include <unordered_map>
#include "api.h"
#include "connection.hpp"
#include "debug.h"
namespace mscclpp {
// Lightweight diagnostic counters kept in a side-table keyed by ProxyService*
// so the ProxyService class layout stays ABI-compatible with prebuilt
// extensions (e.g. mscclpp/_mscclpp.cpython-*.so) that were compiled against
// the older header. Populated only when MSCCLPP_PROXY_STATS=1.
namespace {
struct ProxyStats {
bool enabled = false;
bool printed = false;
uint64_t triggers = 0;
uint64_t trigData = 0;
uint64_t trigFlag = 0;
uint64_t trigAtomic = 0;
uint64_t trigSync = 0;
uint64_t postCalls = 0;
uint64_t idleDrains = 0;
};
// Process-wide flag: 0 unless any ProxyService has stats enabled. Avoids
// taking the side-table mutex on the proxy hot path when nobody asked for
// stats.
std::atomic<bool>& statsAnyEnabled() {
static std::atomic<bool> v{false};
return v;
}
std::mutex& statsMu() { static std::mutex m; return m; }
std::unordered_map<const ProxyService*, ProxyStats>& statsTable() {
static std::unordered_map<const ProxyService*, ProxyStats> t;
return t;
}
ProxyStats* getStats(const ProxyService* self) {
if (!statsAnyEnabled().load(std::memory_order_relaxed)) return nullptr;
std::lock_guard<std::mutex> lk(statsMu());
auto it = statsTable().find(self);
if (it == statsTable().end()) return nullptr;
return &it->second;
}
void printAndEraseStats(const ProxyService* self) {
std::lock_guard<std::mutex> lk(statsMu());
auto it = statsTable().find(self);
if (it == statsTable().end()) return;
ProxyStats& s = it->second;
if (s.enabled && !s.printed) {
s.printed = true;
uint64_t posts = s.postCalls + s.idleDrains;
double wrPerPost = posts ? static_cast<double>(s.triggers) / static_cast<double>(posts) : 0.0;
fprintf(stderr,
"[mscclpp proxy stats] triggers=%llu (data=%llu flag=%llu atomic=%llu sync=%llu) "
"postCalls=%llu idleDrains=%llu triggersPerPost=%.2f\n",
(unsigned long long)s.triggers, (unsigned long long)s.trigData,
(unsigned long long)s.trigFlag, (unsigned long long)s.trigAtomic,
(unsigned long long)s.trigSync, (unsigned long long)s.postCalls,
(unsigned long long)s.idleDrains, wrPerPost);
fflush(stderr);
}
statsTable().erase(it);
}
} // namespace
MSCCLPP_API_CPP BasePortChannel::BasePortChannel(SemaphoreId semaphoreId,
std::shared_ptr<Host2DeviceSemaphore> semaphore,
std::shared_ptr<Proxy> proxy)
@@ -30,6 +93,15 @@ MSCCLPP_API_CPP ProxyService::ProxyService(int fifoSize) {
int cudaDevice;
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
int deviceNumaNode = getDeviceNumaNode(cudaDevice);
if (const char* env = std::getenv("MSCCLPP_PROXY_STATS")) {
if (std::atoi(env) > 0) {
{
std::lock_guard<std::mutex> lk(statsMu());
statsTable()[this].enabled = true;
}
statsAnyEnabled().store(true, std::memory_order_relaxed);
}
}
auto initFunc = [cudaDevice, deviceNumaNode]() {
MSCCLPP_CUDATHROW(cudaSetDevice(cudaDevice));
if (deviceNumaNode >= 0) {
@@ -39,6 +111,21 @@ MSCCLPP_API_CPP ProxyService::ProxyService(int fifoSize) {
};
auto handlerFunc = [&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); };
proxy_ = std::make_shared<Proxy>(handlerFunc, initFunc, fifoSize);
// Drain any deferred ibv_post_send work the moment the FIFO empties out, so
// we batch as many triggers as possible into a single syscall while still
// keeping latency within one FIFO drain.
proxy_->setOnIdle([this]() { this->postPendingAll(); });
}
void ProxyService::postPendingAll() {
if (!stagedConns_.empty()) {
if (auto* s = getStats(this); s && s->enabled) s->idleDrains++;
}
for (auto& kv : stagedConns_) {
kv.first->postPending();
}
stagedConns_.clear();
dirtyConns_.clear();
}
MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator,
@@ -84,14 +171,45 @@ MSCCLPP_API_CPP PortChannel ProxyService::portChannel(SemaphoreId id, MemoryId d
MSCCLPP_API_CPP void ProxyService::startProxy(bool blocking) { proxy_->start(blocking); }
MSCCLPP_API_CPP void ProxyService::stopProxy() { proxy_->stop(); }
MSCCLPP_API_CPP void ProxyService::stopProxy() {
proxy_->stop();
printAndEraseStats(this);
}
ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger trigger) {
std::shared_ptr<Host2DeviceSemaphore> semaphore = semaphores_[trigger.fields.semaphoreId];
ProxyStats* stats = getStats(this);
if (stats && stats->enabled) {
stats->triggers++;
if (trigger.fields.type == 0) stats->trigAtomic++;
if (trigger.fields.type & TriggerData) stats->trigData++;
if (trigger.fields.type & TriggerFlag) stats->trigFlag++;
if (trigger.fields.type & TriggerSync) stats->trigSync++;
}
auto& conn = semaphore->connection();
int maxWriteQueueSize = conn.getMaxWriteQueueSize();
auto& numRequests = inflightRequests_[conn.impl_];
auto connImpl = BaseConnection::getImpl(conn);
auto& numRequests = inflightRequests_[connImpl];
// Batch threshold: how many staged WRs we let accumulate on one connection
// before forcing an ibv_post_send. Lower values reduce tail latency for
// signalling triggers; higher values reduce syscalls.
// Override via MSCCLPP_PROXY_BATCH_THRESHOLD (1 = post every trigger,
// matching the historical behavior; the empirical sweep on H100/IB shows
// LL traffic is wire-latency bound rather than syscall bound, so deeper
// batching does not help and may regress dispatch tail latency).
static const int kPostBatchThreshold = []() {
if (const char* env = std::getenv("MSCCLPP_PROXY_BATCH_THRESHOLD")) {
int v = std::atoi(env);
if (v >= 1) return v;
}
return 1;
}();
bool stagedWork = false;
// Atomic/signal triggers convey completion semantics; never defer them.
bool isSignal = (trigger.fields.type == 0) || (trigger.fields.type & TriggerFlag);
if (trigger.fields.type == 0) {
// type == 0 indicates an atomic add operation.
@@ -100,6 +218,7 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger trigger) {
int64_t value = static_cast<int64_t>(trigger.fst);
conn.atomicAdd(dst, trigger.fields.dstOffset, value);
numRequests++;
stagedWork = true;
}
if (trigger.fields.type & TriggerData) {
@@ -107,17 +226,38 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger trigger) {
RegisteredMemory& src = memories_[trigger.fields.srcMemoryId];
conn.write(dst, trigger.fields.dstOffset, src, trigger.fields.srcOffset, trigger.fields.size);
numRequests++;
stagedWork = true;
}
if (trigger.fields.type & TriggerFlag) {
semaphore->signal();
numRequests++;
stagedWork = true;
}
if (((trigger.fields.type & TriggerSync) && numRequests > 0) ||
(maxWriteQueueSize != -1 && numRequests >= maxWriteQueueSize)) {
if (stagedWork) {
stagedConns_[connImpl]++;
dirtyConns_.insert(connImpl);
}
bool needFlush = (trigger.fields.type & TriggerSync) && numRequests > 0;
bool needFlush2 = maxWriteQueueSize != -1 && numRequests >= maxWriteQueueSize;
if (needFlush || needFlush2) {
// flush() drains staged WRs and waits for completion.
conn.flush();
numRequests = 0;
stagedConns_.erase(connImpl);
dirtyConns_.erase(connImpl);
} else if (isSignal || stagedConns_[connImpl] >= kPostBatchThreshold) {
// Post the staged batch now: either we just queued a completion-signal
// WR (don't make the receiver wait for the next idle drain), or we've
// accumulated enough WRs that the QP send queue might overflow.
if (stats && stats->enabled) {
stats->postCalls++;
}
connImpl->postPending();
stagedConns_.erase(connImpl);
dirtyConns_.erase(connImpl);
}
return ProxyHandlerResult::Continue;

View File

@@ -20,6 +20,7 @@ constexpr int ProxyStartWarnPeriod = 1000;
struct Proxy::Impl {
ProxyHandler handler;
std::function<void()> threadInit;
std::function<void()> onIdle;
std::shared_ptr<Fifo> fifo;
std::atomic_bool threadStarted;
std::thread service;
@@ -28,6 +29,7 @@ struct Proxy::Impl {
Impl(ProxyHandler handler, std::function<void()> threadInit, int fifoSize)
: handler(handler),
threadInit(threadInit),
onIdle(nullptr),
fifo(std::make_shared<Fifo>(fifoSize)),
threadStarted(false),
running(false) {}
@@ -71,9 +73,11 @@ MSCCLPP_API_CPP void Proxy::start(bool blocking) {
pimpl_->threadStarted.store(true, std::memory_order_release);
ProxyHandler handler = this->pimpl_->handler;
auto onIdle = this->pimpl_->onIdle;
auto fifo = this->pimpl_->fifo;
ProxyTrigger trigger;
bool wasBusy = false;
int runCnt = ProxyStopCheckPeriod;
for (;;) {
if (runCnt-- == 0) {
@@ -85,8 +89,13 @@ MSCCLPP_API_CPP void Proxy::start(bool blocking) {
// Poll to see if we are ready to send anything
trigger = fifo->poll();
if (trigger.fst == 0 || trigger.snd == 0) { // TODO: this check is a potential pitfall for custom triggers
if (wasBusy && onIdle) {
onIdle();
}
wasBusy = false;
continue; // there is one in progress
}
wasBusy = true;
trigger.snd ^= (uint64_t{1} << uint64_t{63}); // this is where the last bit of snd is reverted.
ProxyHandlerResult result = handler(trigger);
@@ -95,6 +104,7 @@ MSCCLPP_API_CPP void Proxy::start(bool blocking) {
fifo->pop();
if (result == ProxyHandlerResult::Stop) {
if (onIdle) onIdle();
break;
}
}
@@ -123,4 +133,6 @@ MSCCLPP_API_CPP void Proxy::stop() {
MSCCLPP_API_CPP std::shared_ptr<Fifo> Proxy::fifo() { return pimpl_->fifo; }
MSCCLPP_API_CPP void Proxy::setOnIdle(std::function<void()> onIdle) { pimpl_->onIdle = std::move(onIdle); }
} // namespace mscclpp

View File

@@ -42,6 +42,7 @@ endif()
file(GLOB_RECURSE EP_SOURCES CONFIGURE_DEPENDS
buffer.cc
bindings.cpp
ibgda_setup.cc
kernels/*.cu
)
@@ -58,6 +59,17 @@ endif()
Python_add_library(mscclpp_ep_cpp MODULE ${EP_SOURCES})
target_compile_definitions(mscclpp_ep_cpp PRIVATE TORCH_EXTENSION_NAME=mscclpp_ep_cpp)
# Inherit ibverbs / mlx5dv defs from the core library so the IBGDA host-side
# plumbing (see ibgda_setup.cc, gated by MSCCLPP_EP_HAVE_IBGDA) compiles in.
if(IBVERBS_FOUND)
target_compile_definitions(mscclpp_ep_cpp PRIVATE USE_IBVERBS)
target_link_libraries(mscclpp_ep_cpp PRIVATE ${IBVERBS_LIBRARIES})
if(MLX5_FOUND)
target_compile_definitions(mscclpp_ep_cpp PRIVATE MSCCLPP_USE_MLX5DV)
target_include_directories(mscclpp_ep_cpp SYSTEM PRIVATE ${MLX5_INCLUDE_DIRS})
target_link_libraries(mscclpp_ep_cpp PRIVATE ${MLX5_LIBRARIES})
endif()
endif()
target_include_directories(mscclpp_ep_cpp PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${PROJECT_SOURCE_DIR}/include

View File

@@ -15,6 +15,10 @@
#include "kernels/api.cuh"
#include "kernels/configs.cuh"
#ifdef MSCCLPP_EP_HAVE_IBGDA
#include "ibgda_setup.hpp"
#endif
namespace mscclpp {
namespace ep {
@@ -76,6 +80,15 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
printf("[mscclpp_ep] num_proxy_services=%d (set MSCCLPP_EP_NUM_PROXIES to override)\n", num_proxy_services);
fflush(stdout);
}
// Resolve native-IBGDA opt-in flag. The state is built in sync() and
// currently unused by the kernels (4b.1 plumbing only); see buffer.hpp.
if (const char* e = std::getenv("MSCCLPP_EP_USE_IBGDA")) {
use_ibgda_path_ = (std::atoi(e) != 0);
}
if (rank == 0) {
printf("[mscclpp_ep] MSCCLPP_EP_USE_IBGDA=%d\n", use_ibgda_path_ ? 1 : 0);
fflush(stdout);
}
// Task fifo memory
int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS;
int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS;
@@ -518,6 +531,44 @@ void Buffer::sync(const std::vector<int>& device_ids,
}
}
#ifdef MSCCLPP_EP_HAVE_IBGDA
// ------------------------------------------------------------------
// Stage 4b.1: native IBGDA host-side plumbing.
//
// Built only when:
// - MSCCLPP_EP_USE_IBGDA=1
// - The run is cross-node (num_rdma_ranks > 1) — intra-node sticks with
// the NVLink IPC fast path already established above.
// - We have an RDMA buffer (num_rdma_bytes > 0).
//
// The setup is built but NOT yet consumed by any kernel (4b.2 will add
// the kernel branch). With or without the flag set, existing test
// outcomes must be identical.
// ------------------------------------------------------------------
if (use_ibgda_path_ && num_rdma_bytes > 0 && num_rdma_ranks > 1) {
constexpr int kNumIbgdaChannels = 16; // mirrors num_port_channels_per_rank above
try {
ibgda_setup_ = mscclpp::ep::build_ibgda_setup(rank, num_ranks, /*ib_transport_index=*/device_id,
kNumIbgdaChannels, rdma_buffer_ptr,
static_cast<std::size_t>(num_rdma_bytes), bootstrap);
if (rank == 0) {
printf("[mscclpp_ep] IBGDA setup built: channels=%d num_ranks=%d (per-rank QPs=%d)\n",
kNumIbgdaChannels, num_ranks, kNumIbgdaChannels * (num_ranks - 1));
fflush(stdout);
}
// Clear any benign CUDA sticky error left by overlapping host/UAR
// registrations across sibling QPs (e.g., cudaErrorHostMemoryAlreadyRegistered).
(void)cudaGetLastError();
} catch (const std::exception& e) {
// Don't take down the run — just log and leave use_ibgda_path_ on as
// a hint (kernel-side fallback in 4b.2 will check ibgda_setup_ != null).
fprintf(stderr, "[mscclpp_ep][rank=%d] IBGDA setup failed, falling back to host FIFO: %s\n", rank, e.what());
ibgda_setup_.reset();
(void)cudaGetLastError();
}
}
#endif // MSCCLPP_EP_HAVE_IBGDA
// Ready to use
available = true;
}
@@ -1452,6 +1503,13 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto peer_bases = peer_rdma_bases_gpu;
const bool use_ipc = ll_ipc_ready;
auto rdma_base = rdma_buffer_ptr;
#ifdef MSCCLPP_EP_HAVE_IBGDA
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_dev = ibgda_setup_ ? ibgda_setup_->device_handles.get() : nullptr;
const bool use_ibgda = use_ibgda_path_ && ibgda_dev != nullptr && !use_ipc;
#else
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_dev = nullptr;
const bool use_ibgda = false;
#endif
auto launcher = [=](int phases) {
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, packed_recv_src_info.data_ptr<int>(),
packed_recv_layout_range.data_ptr<int64_t>(), packed_recv_count.data_ptr<int>(),
@@ -1459,7 +1517,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
buffer.dispatch_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second, num_tokens, hidden,
num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, use_fp8, workspace,
launch_stream, phases, rdma_base, port_handles, peer_bases, mem_handles, use_ipc);
launch_stream, phases, rdma_base, port_handles, peer_bases, mem_handles, use_ipc,
ibgda_dev, use_ibgda);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
@@ -1528,13 +1587,21 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
auto peer_bases = peer_rdma_bases_gpu;
const bool use_ipc = ll_ipc_ready;
auto rdma_base = rdma_buffer_ptr;
#ifdef MSCCLPP_EP_HAVE_IBGDA
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_dev = ibgda_setup_ ? ibgda_setup_->device_handles.get() : nullptr;
const bool use_ibgda = use_ibgda_path_ && ibgda_dev != nullptr && !use_ipc;
#else
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_dev = nullptr;
const bool use_ibgda = false;
#endif
auto launcher = [=](int phases) {
internode_ll::combine(
combined_x.data_ptr(), buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer,
buffer.combine_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(), next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks,
workspace, launch_stream, phases, zero_copy, rdma_base, port_handles, peer_bases, mem_handles, use_ipc);
workspace, launch_stream, phases, zero_copy, rdma_base, port_handles, peer_bases, mem_handles, use_ipc,
ibgda_dev, use_ibgda);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));

View File

@@ -17,6 +17,11 @@
#include "kernels/configs.cuh"
#include "kernels/exception.cuh"
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
#define MSCCLPP_EP_HAVE_IBGDA 1
namespace mscclpp { namespace ep { struct IbgdaSetup; } }
#endif
#ifndef TORCH_EXTENSION_NAME
#define TORCH_EXTENSION_NAME mscclpp_ep_cpp
#endif
@@ -101,6 +106,17 @@ struct Buffer {
std::shared_ptr<mscclpp::MemoryChannelDeviceHandle> ll_memory_channel_handles_device_ptr;
bool ll_ipc_ready = false;
// ------------------------------------------------------------------
// Native IBGDA path (Stage 4b). Built lazily in `sync()` when env
// `MSCCLPP_EP_USE_IBGDA=1` is set AND the run is cross-node.
// The kernels do NOT consume `ibgda_setup_` until 4b.2 lands; for now
// it is constructed-but-unused, so existing tests are unaffected.
// ------------------------------------------------------------------
bool use_ibgda_path_ = false;
#ifdef MSCCLPP_EP_HAVE_IBGDA
std::unique_ptr<mscclpp::ep::IbgdaSetup> ibgda_setup_;
#endif
private:
void move_fifo_slots(int num_slots = 1);

341
src/ext/ep/ibgda_setup.cc Normal file
View File

@@ -0,0 +1,341 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
//
// Stage 4b.1 implementation. See ibgda_setup.hpp for the contract.
#include "ibgda_setup.hpp"
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
#include <arpa/inet.h>
#include <cuda_runtime.h>
#include <infiniband/verbs.h>
#include <cstdio>
#include <cstring>
#include <stdexcept>
#include <string>
#include <vector>
#include <thread>
#include <chrono>
#include <mscclpp/core.hpp>
#include <mscclpp/errors.hpp>
#include <mscclpp/gpu_utils.hpp>
#include <mscclpp/utils.hpp>
#include "kernels/exception.cuh"
namespace mscclpp {
namespace ep {
IbgdaSetup::~IbgdaSetup() {
// Stop the CQ poller first so it doesn't race with QP teardown.
if (cq_poller_thread.joinable()) {
cq_poller_stop.store(true, std::memory_order_release);
cq_poller_thread.join();
}
// Tear down in reverse order of construction:
// - device_handles, d_local_mrs, d_remote_mrs: shared_ptr from
// gpuCallocShared, auto-freed via custom deleter.
// - resources / qps / rdma_mr: smart ptrs, auto-freed.
// - sig_mr (raw ibv_mr) + sig_slots / sig_seq (raw cudaMalloc): explicit.
if (sig_mr != nullptr) {
ibv_dereg_mr(sig_mr);
sig_mr = nullptr;
}
if (sig_slots != nullptr) {
cudaFree(sig_slots);
sig_slots = nullptr;
}
if (sig_seq != nullptr) {
cudaFree(sig_seq);
sig_seq = nullptr;
}
}
namespace {
// Tunables that mirror the existing PortChannel path:
// - 16 channels (== num_port_channels_per_rank in Buffer::sync)
// - port=1, gid_index=3 (matches the value used in Stage 0..4a probes and
// the typical mscclpp/NDv5 default)
constexpr int kIbgdaPort = 1;
constexpr int kIbgdaGid = 3;
// SQ depth per QP. Each LL dispatch+combine iteration posts up to a few
// dozen WRs per QP at default LL configs; the bench loop runs 50 iters
// without explicit per-iter drain, so we size the SQ at 8192 to avoid
// wrap-around overwriting in-flight WQEs (we have no device-side
// back-pressure check in reserve_wqe_slots — the CQ poller drains
// asynchronously). 8k entries × 64B stride = 512 KiB SQ buf per QP, well
// under typical mlx5 HCA per-QP caps.
constexpr int kIbgdaMaxSendWr = 8192;
} // namespace
std::unique_ptr<IbgdaSetup> build_ibgda_setup(int rank, int num_ranks, int ib_transport_index, int num_channels,
void* rdma_buffer_ptr, std::size_t num_rdma_bytes,
std::shared_ptr<TcpBootstrap> bootstrap) {
EP_HOST_ASSERT(rdma_buffer_ptr != nullptr);
EP_HOST_ASSERT(num_rdma_bytes > 0);
EP_HOST_ASSERT(num_ranks > 1);
EP_HOST_ASSERT(num_channels > 0);
EP_HOST_ASSERT(ib_transport_index >= 0 && ib_transport_index < 8);
auto setup = std::make_unique<IbgdaSetup>();
setup->rank = rank;
setup->num_ranks = num_ranks;
setup->num_channels = num_channels;
// 1. Resolve IB device name and build the IbCtx.
auto ib_transport = static_cast<Transport>(static_cast<int>(Transport::IB0) + ib_transport_index);
std::string dev_name = getIBDeviceName(ib_transport);
setup->ib_ctx = std::make_unique<IbCtx>(dev_name);
EP_HOST_ASSERT(setup->ib_ctx->isMlx5() && "IBGDA requires an mlx5 NIC");
// 2. Create QPs. Layout: qps[channel * num_ranks + peer].
// Self entries are nullptr.
const int total_slots = num_channels * num_ranks;
setup->qps.resize(total_slots);
setup->resources.resize(total_slots);
for (int c = 0; c < num_channels; ++c) {
for (int r = 0; r < num_ranks; ++r) {
if (r == rank) continue;
auto qp = setup->ib_ctx->createQp(/*port=*/kIbgdaPort, /*gidIndex=*/kIbgdaGid,
/*maxSendCqSize=*/kIbgdaMaxSendWr * 2,
/*maxSendCqPollNum=*/64,
/*maxSendWr=*/kIbgdaMaxSendWr,
/*maxRecvWr=*/1,
/*maxWrPerSend=*/1,
/*noAtomic=*/true);
setup->qps[c * num_ranks + r] = qp;
}
}
// 3. AllGather IbQpInfos so every rank can RTR every QP.
// Layout per-rank: total_slots IbQpInfo records, in [c * num_ranks + peer]
// order. Self entries (peer == rank) are zeroed; remote entries describe
// the QP that THIS rank uses to TALK TO peer == r.
std::vector<IbQpInfo> my_infos(total_slots);
std::memset(my_infos.data(), 0, total_slots * sizeof(IbQpInfo));
for (int c = 0; c < num_channels; ++c) {
for (int r = 0; r < num_ranks; ++r) {
if (r == rank) continue;
my_infos[c * num_ranks + r] = setup->qps[c * num_ranks + r]->getInfo();
}
}
// bootstrap->allGather expects each rank to fill its own slot in a
// contiguous buffer of size num_ranks * record_bytes.
const std::size_t record_bytes = total_slots * sizeof(IbQpInfo);
std::vector<IbQpInfo> all_infos(num_ranks * total_slots);
std::memcpy(&all_infos[rank * total_slots], my_infos.data(), record_bytes);
bootstrap->allGather(all_infos.data(), record_bytes);
// 4. RTR/RTS each local QP using the matching peer-side QP info. The peer
// info we want is "rank r's QP for talking to us" == r's record indexed at
// [c * num_ranks + rank].
for (int c = 0; c < num_channels; ++c) {
for (int r = 0; r < num_ranks; ++r) {
if (r == rank) continue;
const IbQpInfo& peer_info = all_infos[r * total_slots + c * num_ranks + rank];
auto& qp = setup->qps[c * num_ranks + r];
qp->rtr(peer_info);
qp->rts();
}
}
// 5. Wrap each QP with IbgdaResources (Stage 1) — produces GPU-mapped
// sq_buf / dbrec / bf_reg / state pointers usable from the kernel.
for (int c = 0; c < num_channels; ++c) {
for (int r = 0; r < num_ranks; ++r) {
if (r == rank) continue;
ibv_qp* raw = setup->qps[c * num_ranks + r]->getRawQp();
setup->resources[c * num_ranks + r] = std::make_unique<IbgdaResources>(raw);
}
}
// 6. Register the existing rdma_buffer_ptr as an MR on this IbCtx, then
// allgather (addr, rkey) so we can build per-peer remote_mrs entries.
setup->rdma_mr = setup->ib_ctx->registerMr(rdma_buffer_ptr, num_rdma_bytes);
IbgdaSetup::PeerMr my_rdma_mr{};
{
auto info = setup->rdma_mr->getInfo();
my_rdma_mr.addr = info.addr;
my_rdma_mr.rkey = info.rkey;
}
setup->peer_rdma.assign(num_ranks, IbgdaSetup::PeerMr{});
setup->peer_rdma[rank] = my_rdma_mr;
bootstrap->allGather(setup->peer_rdma.data(), sizeof(IbgdaSetup::PeerMr));
// 7. Allocate signal slots: total_slots * 4 bytes on GPU. Layout:
// sig_slots[c * num_ranks + sender_peer] — when peer P sends signal() to
// us through channel c, it RDMA-WRITEs to *our* sig_slots[c * num_ranks + P].
// Each peer therefore needs to know our sig MR (addr+rkey) AND the offset
// within it that corresponds to (channel, sender=P) == "we are receiving
// from P". We allgather the base addr+rkey only; the offset is derivable.
const std::size_t sig_bytes = std::size_t(total_slots) * sizeof(uint32_t);
CUDA_CHECK(cudaMalloc(&setup->sig_slots, sig_bytes));
CUDA_CHECK(cudaMemset(setup->sig_slots, 0, sig_bytes));
CUDA_CHECK(cudaMalloc(&setup->sig_seq, sig_bytes));
CUDA_CHECK(cudaMemset(setup->sig_seq, 0, sig_bytes));
// Use raw verbs for the signal MR (we need the rkey/lkey directly).
ibv_pd* pd = setup->qps[(rank == 0 ? 1 : 0)]->getRawQp()->pd; // any non-null QP shares the same pd
for (int c = 0; c < num_channels && pd == nullptr; ++c)
for (int r = 0; r < num_ranks && pd == nullptr; ++r)
if (auto& q = setup->qps[c * num_ranks + r]; q) pd = q->getRawQp()->pd;
EP_HOST_ASSERT(pd != nullptr);
setup->sig_mr = ibv_reg_mr(pd, setup->sig_slots, sig_bytes,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
if (!setup->sig_mr) {
throw std::runtime_error("ibv_reg_mr(sig_slots) failed errno=" + std::to_string(errno));
}
IbgdaSetup::PeerMr my_sig_mr{};
my_sig_mr.addr = reinterpret_cast<uint64_t>(setup->sig_slots);
my_sig_mr.rkey = setup->sig_mr->rkey;
setup->peer_sig.assign(num_ranks, IbgdaSetup::PeerMr{});
setup->peer_sig[rank] = my_sig_mr;
bootstrap->allGather(setup->peer_sig.data(), sizeof(IbgdaSetup::PeerMr));
// 8. Build the GPU-resident MR tables. We have a single local MR (the
// rdma_buffer_ptr); remote_mrs has one entry per peer rank.
std::vector<IbgdaLocalMr> h_local(1);
h_local[0].addr = reinterpret_cast<uint64_t>(rdma_buffer_ptr);
h_local[0].lkey_be = htonl(setup->rdma_mr->getLkey());
h_local[0].pad = 0;
std::vector<IbgdaRemoteMr> h_remote(num_ranks);
for (int r = 0; r < num_ranks; ++r) {
h_remote[r].addr = setup->peer_rdma[r].addr;
h_remote[r].rkey_be = htonl(setup->peer_rdma[r].rkey);
h_remote[r].pad = 0;
}
setup->d_local_mrs = mscclpp::detail::gpuCallocShared<IbgdaLocalMr>(1);
setup->d_remote_mrs = mscclpp::detail::gpuCallocShared<IbgdaRemoteMr>(num_ranks);
mscclpp::gpuMemcpy<IbgdaLocalMr>(setup->d_local_mrs.get(), h_local.data(), 1, cudaMemcpyHostToDevice);
mscclpp::gpuMemcpy<IbgdaRemoteMr>(setup->d_remote_mrs.get(), h_remote.data(), num_ranks, cudaMemcpyHostToDevice);
// 9. Build the device handle array (channel × num_ranks).
std::vector<IbgdaPortChannelDeviceHandle> h_handles(total_slots);
std::memset(h_handles.data(), 0, h_handles.size() * sizeof(IbgdaPortChannelDeviceHandle));
for (int c = 0; c < num_channels; ++c) {
for (int r = 0; r < num_ranks; ++r) {
auto& h = h_handles[c * num_ranks + r];
if (r == rank) continue; // self-slot left zeroed
h.qp = setup->resources[c * num_ranks + r]->getHandle();
h.local_mrs = setup->d_local_mrs.get();
h.remote_mrs = setup->d_remote_mrs.get();
// Local sig slot WE poll for messages from peer r through channel c.
h.sig_local_addr = &setup->sig_slots[c * num_ranks + r];
h.sig_local_lkey = htonl(setup->sig_mr->lkey);
// Remote sig slot peer r polls for messages FROM US through channel c.
// Peer r's slot for "rank == us" is at offset (c * num_ranks + rank) * 4
// inside their signal buffer.
h.sig_remote_addr = setup->peer_sig[r].addr +
static_cast<uint64_t>(c * num_ranks + rank) * sizeof(uint32_t);
h.sig_rkey_be = htonl(setup->peer_sig[r].rkey);
// Outbound seq counter is per-handle; carve out one u32 from sig_seq
// (sig_seq is num_channels × num_ranks u32s so we can reuse the same
// index function — it is unrelated to the inbound sig_slots buffer
// beyond reusing the size).
h.sig_seq = &setup->sig_seq[c * num_ranks + r];
h.dst = static_cast<uint32_t>(r); // index into remote_mrs[] (peer-rank-based)
h.src = 0; // single-entry local_mrs table
h.peer_rank = static_cast<uint32_t>(r);
h._pad = 0;
}
}
setup->device_handles = mscclpp::detail::gpuCallocShared<IbgdaPortChannelDeviceHandle>(total_slots);
mscclpp::gpuMemcpy<IbgdaPortChannelDeviceHandle>(setup->device_handles.get(), h_handles.data(), total_slots,
cudaMemcpyHostToDevice);
// 10. Spawn CQ drain thread. Kernel-issued rdma_write paths set CQ_UPDATE
// on every WR; without this drain, the send CQ (sized 2 × kIbgdaMaxSendWr)
// would fill within a few iterations of LL dispatch+combine and the QP
// would error out. We collect raw send_cq pointers from each QP up front.
{
std::vector<ibv_cq*> send_cqs;
send_cqs.reserve(total_slots);
for (int idx = 0; idx < total_slots; ++idx) {
auto& qp = setup->qps[idx];
if (!qp) continue;
ibv_cq* cq = qp->getRawQp()->send_cq;
if (cq) send_cqs.push_back(cq);
}
IbgdaSetup* raw = setup.get();
int dbg_rank = rank;
setup->cq_poller_thread = std::thread([raw, send_cqs, dbg_rank]() {
// Tight loop polling all CQs round-robin. Each ibv_poll_cq is cheap
// (one PCIe-mapped read of CQE buffer + valid bit). We don't inspect
// wc fields beyond their existence — a status != IBV_WC_SUCCESS would
// indicate a fatal QP error which we surface lazily through later
// failures (the test path will hang and the user sees the error).
constexpr int kBatch = 16;
ibv_wc wc[kBatch];
uint64_t total_polled = 0;
uint64_t total_errors = 0;
auto t_last_log = std::chrono::steady_clock::now();
while (!raw->cq_poller_stop.load(std::memory_order_acquire)) {
bool any = false;
for (ibv_cq* cq : send_cqs) {
int n = ibv_poll_cq(cq, kBatch, wc);
if (n > 0) {
any = true;
total_polled += n;
for (int i = 0; i < n; ++i) {
if (wc[i].status != IBV_WC_SUCCESS) {
total_errors++;
fprintf(stderr,
"[mscclpp_ep][ibgda][rank=%d] CQE error: status=%d (%s) opcode=%d vendor_err=0x%x wr_id=%llu\n",
dbg_rank, wc[i].status, ibv_wc_status_str(wc[i].status), wc[i].opcode, wc[i].vendor_err,
static_cast<unsigned long long>(wc[i].wr_id));
fflush(stderr);
}
}
}
for (int rep = 0; rep < 3 && n == kBatch; ++rep) {
n = ibv_poll_cq(cq, kBatch, wc);
if (n > 0) {
any = true;
total_polled += n;
for (int i = 0; i < n; ++i) {
if (wc[i].status != IBV_WC_SUCCESS) {
total_errors++;
fprintf(stderr,
"[mscclpp_ep][ibgda][rank=%d] CQE error: status=%d (%s) opcode=%d vendor_err=0x%x\n",
dbg_rank, wc[i].status, ibv_wc_status_str(wc[i].status), wc[i].opcode, wc[i].vendor_err);
fflush(stderr);
}
}
}
}
}
auto now = std::chrono::steady_clock::now();
if (std::chrono::duration_cast<std::chrono::seconds>(now - t_last_log).count() >= 5) {
fprintf(stderr, "[mscclpp_ep][ibgda][rank=%d] poller: polled=%llu errors=%llu\n", dbg_rank,
static_cast<unsigned long long>(total_polled), static_cast<unsigned long long>(total_errors));
fflush(stderr);
t_last_log = now;
}
if (!any) {
std::this_thread::sleep_for(std::chrono::microseconds(20));
}
}
});
}
return setup;
}
} // namespace ep
} // namespace mscclpp
#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM

116
src/ext/ep/ibgda_setup.hpp Normal file
View File

@@ -0,0 +1,116 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
//
// Stage 4b.1: native IBGDA host-side plumbing for src/ext/ep.
//
// Owned by `Buffer` and built lazily in `Buffer::sync()` when the env var
// `MSCCLPP_EP_USE_IBGDA=1` is set AND the run is cross-node
// (num_rdma_ranks > 1). Builds a parallel "shadow" of the existing
// PortChannel layout: a flat (channel × num_ranks) array of
// `IbgdaPortChannelDeviceHandle` POD structs sitting on the GPU, plus the
// host-side resources (IbCtx, QPs, IbgdaResources, MRs, signal slots) that
// keep them valid for the lifetime of the Buffer.
//
// 4b.1 only constructs the state — the kernels do NOT consume it yet. That
// happens in 4b.2 behind a `kIbgdaPath` template branch.
#pragma once
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
#include <cstdint>
#include <memory>
#include <vector>
#include <atomic>
#include <thread>
#include <mscclpp/core.hpp>
#include <mscclpp/ibgda_port_channel_device.hpp>
#include "../../core/include/ib.hpp"
#include "../../core/include/ibgda.hpp"
struct ibv_mr;
namespace mscclpp {
namespace ep {
// One owning bundle of host-side resources backing the device handle array.
struct IbgdaSetup {
IbgdaSetup() = default;
~IbgdaSetup();
IbgdaSetup(const IbgdaSetup&) = delete;
IbgdaSetup& operator=(const IbgdaSetup&) = delete;
// Layout constants (must match the existing port_channel_handles layout in
// Buffer::sync, so kernels can reuse the same channel × num_ranks index
// arithmetic in the IBGDA branch).
int num_channels = 0; // == num_port_channels_per_rank in Buffer::sync
int num_ranks = 0;
int rank = 0;
// IB context + QPs + Stage-1 GPU mappings.
std::unique_ptr<IbCtx> ib_ctx;
// Indexed [channel * num_ranks + peer]; entries with peer == rank are null.
std::vector<std::shared_ptr<IbQp>> qps;
std::vector<std::unique_ptr<IbgdaResources>> resources;
// RDMA buffer MR. We register the *same* `rdma_buffer_ptr` that the
// existing PortChannel path uses, so dst/src offsets in the kernel work
// unchanged.
std::unique_ptr<const IbMr> rdma_mr;
// Signal slots (GPU-resident).
// Layout: 4 bytes per (channel, peer) on the receiving side. `sig_slots`
// is the local mirror polled by `port_wait()`; remote peers RDMA-WRITE
// into the corresponding slot of *their* signal MR for *us*. Each rank's
// signal slot for (channel, peer=A) is at offset
// offset_in_buf(channel, peer) = (channel * num_ranks + peer) * 4
// within rank A's signal buffer (i.e. peer=A's view of "messages from
// rank=A's POV indexed by channel × peer"). See the per-handle wiring
// in `IbgdaSetup::populate_handles`.
uint32_t* sig_slots = nullptr; // GPU mem; size = num_channels * num_ranks * 4
ibv_mr* sig_mr = nullptr; // raw verbs (we keep the IbMr only for rdma_mr)
uint32_t* sig_seq = nullptr; // GPU mem; per-handle outbound counters
// Per-peer (addr, rkey) for the RDMA buffer and the signal buffer. Filled
// by an allgather over the bootstrap.
struct PeerMr {
uint64_t addr = 0;
uint32_t rkey = 0;
uint32_t pad = 0;
};
std::vector<PeerMr> peer_rdma; // size = num_ranks
std::vector<PeerMr> peer_sig; // size = num_ranks
// Flat device-side handle array: num_channels * num_ranks entries.
// Self entries (peer == rank) are zeroed and unused.
std::shared_ptr<IbgdaPortChannelDeviceHandle> device_handles;
// Underlying GPU-side MR-table arrays referenced by every device handle
// (see IbgdaPortChannelDeviceHandle::local_mrs / remote_mrs).
std::shared_ptr<IbgdaLocalMr> d_local_mrs; // length 1 (we have a single MR)
std::shared_ptr<IbgdaRemoteMr> d_remote_mrs; // length num_ranks
// CQ drain thread. The kernel-side rdma_write paths set CQ_UPDATE on
// every WR, so without periodic ibv_poll_cq() the send CQ would fill up
// (CQ size = kIbgdaMaxSendWr * 2). One thread polls all per-QP send CQs
// round-robin until `cq_poller_stop` is set in the destructor.
std::atomic<bool> cq_poller_stop{false};
std::thread cq_poller_thread;
};
// Build the full IBGDA setup. Bootstrap is used for cross-rank exchange of
// QP info and MR keys; rdma_buffer_ptr/num_rdma_bytes is the same buffer the
// existing PortChannel path uses. `ib_transport_index` is the IB device
// index this rank will use (== `device_id` on NDv5).
//
// Throws on any irrecoverable error. On success returns a fully-initialised
// IbgdaSetup with all QPs in RTS and the device-side handle array populated.
std::unique_ptr<IbgdaSetup> build_ibgda_setup(int rank, int num_ranks, int ib_transport_index, int num_channels,
void* rdma_buffer_ptr, std::size_t num_rdma_bytes,
std::shared_ptr<TcpBootstrap> bootstrap);
} // namespace ep
} // namespace mscclpp
#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM

View File

@@ -17,6 +17,13 @@
#include <mscclpp/port_channel_device.hpp>
#include <vector>
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
#include <mscclpp/ibgda_port_channel_device.hpp>
#define MSCCLPP_EP_KERNEL_HAS_IBGDA 1
#else
namespace mscclpp { struct IbgdaPortChannelDeviceHandle; }
#endif
namespace mscclpp {
namespace ep {
@@ -131,7 +138,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank,
int num_ranks, bool use_fp8, void* workspace, cudaStream_t stream, int phases, void* rdma_buffer_ptr,
mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path);
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path,
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles, bool use_ibgda_path);
void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, const void* x,
const int64_t* topk_idx, const float* topk_weights, const int* src_info, const int64_t* layout_range,
@@ -139,7 +147,8 @@ void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void*
int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases, bool zero_copy, void* rdma_buffer_ptr,
mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path);
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path,
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles, bool use_ibgda_path);
} // namespace internode_ll

View File

@@ -31,6 +31,11 @@
#include <mscclpp/memory_channel_device.hpp>
#include <mscclpp/port_channel_device.hpp>
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
#include "ibgda_port_channel_device.cuh"
#define MSCCLPP_EP_LL_HAS_IBGDA 1
#endif
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
@@ -142,14 +147,15 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, int64_t* cl
// dispatch
// ---------------------------------------------------------------------------
template <bool kUseFP8, bool kIpcPath, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
template <bool kUseFP8, bool kIpcPath, bool kIbgdaPath, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
__global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dispatch(
void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count, void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x, const void* x,
const int64_t* topk_idx, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, int64_t* next_clean,
int num_next_clean_int, int num_tokens, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts,
int rank, int num_ranks, int phases, void* rdma_buffer_ptr, mscclpp::PortChannelDeviceHandle* port_channel_handles,
void* const* peer_rdma_bases, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
void* const* peer_rdma_bases, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
@@ -249,14 +255,31 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
const auto* dst_int4_ptr = reinterpret_cast<int4*>(peer_dst);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
} else {
// MSCCL++ port-channel PUT (lane 0 issues one request).
if (lane_id == 0) {
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr);
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank].put(dst_off, src_off,
num_bytes_per_msg);
#ifdef MSCCLPP_EP_LL_HAS_IBGDA
if constexpr (kIbgdaPath) {
// Native IBGDA: lane 0 issues one RDMA WRITE WQE directly.
// UNSIGNALED + no DB ring — the trailing count write below
// (signaled, rings DB) flushes the per-QP queue.
if (lane_id == 0) {
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr);
mscclpp::ibgda::port_put(ibgda_handles[dst_expert_local_idx * num_ranks + dst_rank], dst_off, src_off,
num_bytes_per_msg,
/*signal_cqe=*/false, /*ring_db=*/false);
}
__syncwarp();
} else
#endif
{
// MSCCL++ port-channel PUT (lane 0 issues one request).
if (lane_id == 0) {
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr);
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank].put(dst_off, src_off,
num_bytes_per_msg);
}
__syncwarp();
}
__syncwarp();
}
} else {
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
@@ -322,10 +345,26 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
peer_rdma_bases, rdma_buffer_ptr, dst_rank));
st_na_release(peer_counter, static_cast<int64_t>(-num_tokens_sent - 1));
} else {
auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(counter_ptr), rdma_buffer_ptr);
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank].atomicAdd(
off, static_cast<int64_t>(-num_tokens_sent - 1));
#ifdef MSCCLPP_EP_LL_HAS_IBGDA
if constexpr (kIbgdaPath) {
// Single writer per (dst_expert_local_idx, rank) slot, so an
// 8-byte inline RDMA WRITE delivering the encoded count is
// semantically equivalent to atomicAdd from a zero-initialised
// remote slot. The receiver polls with ld_acquire_sys_global.
auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(counter_ptr), rdma_buffer_ptr);
const auto& ch = ibgda_handles[dst_expert_local_idx * num_ranks + dst_rank];
mscclpp::IbgdaRemoteMr r = ch.remote_mrs[ch.dst];
mscclpp::ibgda::rdma_write_inl8(ch.qp, static_cast<uint64_t>(static_cast<int64_t>(-num_tokens_sent - 1)),
r.addr + off, r.rkey_be);
} else
#endif
{
auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(counter_ptr), rdma_buffer_ptr);
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank].atomicAdd(
off, static_cast<int64_t>(-num_tokens_sent - 1));
}
}
} else {
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
@@ -405,7 +444,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank,
int num_ranks, bool use_fp8, void* workspace, cudaStream_t stream, int phases, void* rdma_buffer_ptr,
mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path) {
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path,
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles, bool use_ibgda_path) {
constexpr int kNumMaxTopK = 9;
// (kNumWarpGroups, kNumWarpsPerGroup) is path-dependent. Intra-node IPC
// benefits from 1 expert per SM with 32 warps cooperating on the recv-side
@@ -446,26 +486,37 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
#define DISPATCH_LAUNCH_CASE(hidden_case) \
{ \
if (use_ipc_path) { \
auto dispatch_func = use_fp8 ? dispatch<true, true, kNumWarpGroupsIpc, kNumWarpsPerGroupIpc, hidden_case> \
: dispatch<false, true, kNumWarpGroupsIpc, kNumWarpsPerGroupIpc, hidden_case>; \
LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \
packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \
rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles); \
} else { \
auto dispatch_func = use_fp8 ? dispatch<true, false, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case> \
: dispatch<false, false, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case>; \
LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \
packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \
rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles); \
} \
} \
#define DISPATCH_LAUNCH_CASE(hidden_case) \
{ \
if (use_ipc_path) { \
auto dispatch_func = \
use_fp8 ? dispatch<true, true, false, kNumWarpGroupsIpc, kNumWarpsPerGroupIpc, hidden_case> \
: dispatch<false, true, false, kNumWarpGroupsIpc, kNumWarpsPerGroupIpc, hidden_case>; \
LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \
packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \
rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles, ibgda_handles); \
} else if (use_ibgda_path) { \
auto dispatch_func = \
use_fp8 ? dispatch<true, false, true, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case> \
: dispatch<false, false, true, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case>; \
LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \
packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \
rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles, ibgda_handles); \
} else { \
auto dispatch_func = \
use_fp8 ? dispatch<true, false, false, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case> \
: dispatch<false, false, false, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case>; \
LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \
packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \
rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles, ibgda_handles); \
} \
} \
break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
@@ -477,14 +528,15 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv
// combine
// ---------------------------------------------------------------------------
template <bool kIpcPath, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
template <bool kIpcPath, bool kIbgdaPath, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
__global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void combine(
void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, const void* x,
const int64_t* topk_idx, const float* topk_weights, const int* src_info, const int64_t* layout_range,
int64_t* next_clean, int num_next_clean_int, int* atomic_clean_flag, int num_combined_tokens, int hidden,
int num_topk, int num_max_dispatch_tokens_per_rank, int num_experts, int rank, int num_ranks, int phases,
bool zero_copy, void* rdma_buffer_ptr, mscclpp::PortChannelDeviceHandle* port_channel_handles,
void* const* peer_rdma_bases, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
void* const* peer_rdma_bases, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);
@@ -546,17 +598,35 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com
const auto peer_dst_int4 = reinterpret_cast<int4*>(peer_dst);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, peer_dst_int4, x_int4, ld_nc_global, st_na_global);
} else {
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
// MSCCL++ port-channel PUT.
if (lane_id == 0) {
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
const auto src_off = rdma_offset_of(static_cast<uint64_t>(buf_ptr), rdma_buffer_ptr);
port_channel_handles[local_expert_idx * num_ranks + dst_rank].put(dst_off, src_off,
hidden * sizeof(nv_bfloat16));
#ifdef MSCCLPP_EP_LL_HAS_IBGDA
if constexpr (kIbgdaPath) {
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
if (lane_id == 0) {
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
const auto src_off = rdma_offset_of(static_cast<uint64_t>(buf_ptr), rdma_buffer_ptr);
// UNSIGNALED + no DB ring — trailing flag write drives the doorbell.
mscclpp::ibgda::port_put(ibgda_handles[local_expert_idx * num_ranks + dst_rank], dst_off, src_off,
hidden * sizeof(nv_bfloat16),
/*signal_cqe=*/false, /*ring_db=*/false);
}
__syncwarp();
} else
#endif
{
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
// MSCCL++ port-channel PUT.
if (lane_id == 0) {
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
const auto src_off = rdma_offset_of(static_cast<uint64_t>(buf_ptr), rdma_buffer_ptr);
port_channel_handles[local_expert_idx * num_ranks + dst_rank].put(dst_off, src_off,
hidden * sizeof(nv_bfloat16));
}
__syncwarp();
}
__syncwarp();
}
}
}
@@ -573,9 +643,20 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com
peer_rdma_bases, rdma_buffer_ptr, dst_rank));
st_na_release(peer_flag, static_cast<int64_t>(1));
} else {
auto* flag_ptr = rdma_recv_flag + global_expert_idx;
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(flag_ptr), rdma_buffer_ptr);
port_channel_handles[local_expert_idx * num_ranks + dst_rank].atomicAdd(off, static_cast<int64_t>(1));
#ifdef MSCCLPP_EP_LL_HAS_IBGDA
if constexpr (kIbgdaPath) {
auto* flag_ptr = rdma_recv_flag + global_expert_idx;
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(flag_ptr), rdma_buffer_ptr);
const auto& ch = ibgda_handles[local_expert_idx * num_ranks + dst_rank];
mscclpp::IbgdaRemoteMr r = ch.remote_mrs[ch.dst];
mscclpp::ibgda::rdma_write_inl8(ch.qp, static_cast<uint64_t>(1), r.addr + off, r.rkey_be);
} else
#endif
{
auto* flag_ptr = rdma_recv_flag + global_expert_idx;
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(flag_ptr), rdma_buffer_ptr);
port_channel_handles[local_expert_idx * num_ranks + dst_rank].atomicAdd(off, static_cast<int64_t>(1));
}
}
} else {
st_na_release(rdma_recv_flag + global_expert_idx, static_cast<int64_t>(1));
@@ -639,7 +720,8 @@ void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void*
int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases, bool zero_copy, void* rdma_buffer_ptr,
mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path) {
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path,
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles, bool use_ibgda_path) {
// See the comment in `dispatch()`: (kNumWarpGroups, kNumWarpsPerGroup)
// is path-dependent. IPC uses (1, 32) to mirror NCCL-EP; PortChannel keeps
// (3, 10) to avoid host-proxy FIFO contention on the IB path.
@@ -671,24 +753,34 @@ void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void*
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
#define COMBINE_LAUNCH_CASE(hidden_case) \
{ \
if (use_ipc_path) { \
auto combine_func = combine<true, kNumWarpGroupsIpc, kNumWarpsPerGroupIpc, hidden_case, kNumMaxTopk>; \
LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \
topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \
num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \
memory_channel_handles); \
} else { \
auto combine_func = combine<false, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case, kNumMaxTopk>; \
LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \
topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \
num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \
memory_channel_handles); \
} \
} \
#define COMBINE_LAUNCH_CASE(hidden_case) \
{ \
if (use_ipc_path) { \
auto combine_func = \
combine<true, false, kNumWarpGroupsIpc, kNumWarpsPerGroupIpc, hidden_case, kNumMaxTopk>; \
LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \
topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \
num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \
memory_channel_handles, ibgda_handles); \
} else if (use_ibgda_path) { \
auto combine_func = \
combine<false, true, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case, kNumMaxTopk>; \
LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \
topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \
num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \
memory_channel_handles, ibgda_handles); \
} else { \
auto combine_func = \
combine<false, false, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case, kNumMaxTopk>; \
LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \
topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \
num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \
memory_channel_handles, ibgda_handles); \
} \
} \
break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);

View File

@@ -43,14 +43,45 @@ import sys
# noisy 'recvValue failed / Connection was likely closed' stack traces.
os.environ.setdefault("TORCH_NCCL_ENABLE_MONITORING", "0")
import ctypes
import psutil
import torch
import torch.distributed as dist
# Load libnuma for NUMA-aware memory binding (mirrors DeepEP/tests/utils.py).
try:
_libnuma = ctypes.CDLL("libnuma.so")
_libnuma.numa_available.restype = ctypes.c_int
_libnuma.numa_run_on_node.argtypes = [ctypes.c_int]
_libnuma.numa_set_preferred.argtypes = [ctypes.c_int]
except OSError:
_libnuma = None
def set_numa_affinity(local_rank: int):
cores_per_rank = 12
numa_node = local_rank // 4
core_start = local_rank * cores_per_rank
core_end = core_start + cores_per_rank
p = psutil.Process(os.getpid())
p.cpu_affinity(list(range(core_start, core_end)))
print(f"Rank {local_rank} numa node {numa_node} bound to cores {core_start}-{core_end - 1}")
# Bind memory to NUMA node
if _libnuma is not None and _libnuma.numa_available() != -1:
_libnuma.numa_set_preferred(numa_node)
print(f"Rank {local_rank}: CPU affinity → cores {core_start}-{core_end - 1}, memory NUMA → node {numa_node}")
else:
print(f"Rank {local_rank}: libnuma not available")
def init_dist():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", rank))
set_numa_affinity(local_rank)
torch.cuda.set_device(local_rank)
dist.init_process_group(
backend="nccl",