diff --git a/include/mscclpp/ibgda_port_channel_device.hpp b/include/mscclpp/ibgda_port_channel_device.hpp new file mode 100644 index 00000000..7d37a770 --- /dev/null +++ b/include/mscclpp/ibgda_port_channel_device.hpp @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Stage 3: device-side IBGDA "port channel" handle. +// +// Mirrors the device-facing API of mscclpp::PortChannelDeviceHandle (see +// include/mscclpp/port_channel_device.hpp) but issues RDMA traffic directly +// from the GPU via mscclpp::ibgda::rdma_write — no host proxy / FIFO. +// +// put(dstOffset, srcOffset, size) — issue 1 RDMA WRITE WQE +// signal() — issue 1 inline 4-byte RDMA WRITE to a +// remote 4B counter slot +// wait() — spin on the local-side mirror of that +// counter +// +// Memory model (each rank constructs): +// - For every registered local memory region M, a flat row +// local_mrs[M] = { base_addr, lkey_be } +// - For every (peer rank R, registered region M) pair, a flat row +// remote_mrs[M] = { base_addr_on_R, rkey_be_on_R } +// (a channel is bound to a single peer R, so the rank dimension is +// flattened away inside the handle). +// +// The signal counter: +// - Each rank cudaMallocs a 4-byte "signal slot" registered with this NIC, +// and exchanges its address+rkey with peers. The handle's +// {sig_remote_addr, sig_rkey_be} points at the *peer's* slot, while +// sig_local_addr is the local mirror used by wait() to compare against. +// +// - signal() atomically bumps sig_seq (CTA-shared-by-channel u32 in GPU +// memory), writes that value as a 4-byte inline RDMA WRITE to peer's +// slot. Receiver polls its local slot. + +#ifndef MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_HPP_ +#define MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_HPP_ + +#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM) + +#include + +#include "ibgda.hpp" + +namespace mscclpp { + +using IbgdaMemoryId = uint32_t; + +struct IbgdaLocalMr { + uint64_t addr; // host-order base virtual address (GPU-mapped) + uint32_t lkey_be; // BE-encoded local key + uint32_t pad; +}; + +struct IbgdaRemoteMr { + uint64_t addr; // host-order base virtual address on the peer + uint32_t rkey_be; // BE-encoded remote key + uint32_t pad; +}; + +// POD copied to GPU. Fields kept dense and 8-byte aligned. +struct IbgdaPortChannelDeviceHandle { + IbgdaQpHandle qp; + + // Tables: indexed by IbgdaMemoryId. local_mrs[id] is for THIS rank; + // remote_mrs[id] is for the peer rank this channel is bound to. + // Both pointers must dereference from device code. + const IbgdaLocalMr* local_mrs; + const IbgdaRemoteMr* remote_mrs; + + // Signal slot: + // sig_local_addr — pointer to a 4-byte slot in this rank's GPU memory + // used by wait() to poll. NIC writes here from the peer. + // sig_local_lkey — BE-encoded lkey of the local 4B inline staging + // buffer (same MR as sig_local_addr; we use it as both + // the receive slot for wait() and as a backing MR for + // completeness — but inline WRs don't actually need + // a local MR. Kept for symmetry / future non-inline + // fallback). + // sig_remote_addr — peer's 4-byte slot + // sig_rkey_be — peer's rkey for sig_remote_addr + // sig_seq — pointer to a 4-byte GPU-resident counter incremented + // by signal(); the value sent is the post-increment. + // Distinct from sig_local_addr. + uint32_t* sig_local_addr; + uint32_t sig_local_lkey; // unused for inline; kept for layout stability + uint64_t sig_remote_addr; + uint32_t sig_rkey_be; + uint32_t* sig_seq; + + // Bound peer (informational); MemoryIds default to (0,0) in the simple + // case but we keep them for parity with PortChannelDeviceHandle. + IbgdaMemoryId dst; + IbgdaMemoryId src; + uint32_t peer_rank; + uint32_t _pad; +}; + +} // namespace mscclpp + +#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM + +#endif // MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_HPP_ diff --git a/include/mscclpp/port_channel.hpp b/include/mscclpp/port_channel.hpp index ed660407..77e3c730 100644 --- a/include/mscclpp/port_channel.hpp +++ b/include/mscclpp/port_channel.hpp @@ -4,6 +4,9 @@ #ifndef MSCCLPP_PORT_CHANNEL_HPP_ #define MSCCLPP_PORT_CHANNEL_HPP_ +#include +#include + #include "core.hpp" #include "port_channel_device.hpp" #include "proxy.hpp" @@ -84,6 +87,14 @@ class ProxyService : public BaseProxyService { std::vector memories_; std::shared_ptr proxy_; std::unordered_map, int> inflightRequests_; + // Connections with WRs staged but not yet posted. Mapped to the count of + // WRs staged since the last postPending() call (used to trip the batch + // threshold). Distinct from `inflightRequests_` which counts QP-inflight + // WRs since the last flush() (used for QP overflow detection). + std::unordered_map, int> stagedConns_; + std::unordered_set> dirtyConns_; + + void postPendingAll(); ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw); }; diff --git a/include/mscclpp/proxy.hpp b/include/mscclpp/proxy.hpp index 990deabb..9ae6c547 100644 --- a/include/mscclpp/proxy.hpp +++ b/include/mscclpp/proxy.hpp @@ -53,6 +53,13 @@ class Proxy { /// @return Shared pointer to FIFO. std::shared_ptr fifo(); + /// Set a callback invoked by the proxy thread when the FIFO transitions + /// from busy to idle (i.e., a poll returns no trigger after at least one + /// trigger was processed). Useful for batching: handlers can defer + /// expensive system calls (e.g., ibv_post_send) until the FIFO drains. + /// Must be called before start(). + void setOnIdle(std::function onIdle); + private: struct Impl; std::unique_ptr pimpl_; diff --git a/src/core/connection.cc b/src/core/connection.cc index 276b3d75..2c419532 100644 --- a/src/core/connection.cc +++ b/src/core/connection.cc @@ -396,7 +396,10 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem qp_.lock()->stageSendWrite(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/true); - qp_.lock()->postSend(); + // NOTE: postSend() is intentionally deferred. The proxy service batches + // many triggers and calls postPending() once at idle. External callers + // that pair write() with flush() still observe correct semantics because + // flush() now drains staged WRs before polling the CQ. INFO(CONN, "IBConnection write: from ", (uint8_t*)srcMr->getBuff() + srcOffset, " to ", (uint8_t*)dstMrInfo.addr + dstOffset, ", size ", size); @@ -438,12 +441,12 @@ void IBConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint6 /*size=*/0, /*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/0, /*signaled=*/true, /*immData=*/immData); - qp_.lock()->postSend(); + // postSend deferred; see IBConnection::write. INFO(CONN, "IBConnection signal forwarding: value ", oldValue, " -> ", newValue); } else { qp_.lock()->stageSendAtomicAdd(atomicSrcTransportInfo_.ibMr, dstMrInfo, /*wrId=*/0, dstOffset, newValue - oldValue, /*signaled=*/true); - qp_.lock()->postSend(); + // postSend deferred; see IBConnection::write. INFO(CONN, "IBConnection atomic write: from ", src, " to ", (uint8_t*)dstMrInfo.addr + dstOffset, ", ", oldValue, " -> ", newValue); } @@ -458,6 +461,9 @@ void IBConnection::flush(int64_t timeoutUsec) { NpKit::CollectCpuEvent(NPKIT_EVENT_CONN_IB_FLUSH_ENTRY, 0, 0, *NpKit::GetCpuTimestamp(), 0); #endif + // Drain any staged WRs that were deferred by the staging-only write path. + qp_.lock()->postSend(); + // Check if the recv thread has already reported an error (e.g., QP entered error state). if (recvThreadError_.load(std::memory_order_acquire)) { THROW(CONN, Error, ErrorCode::SystemError, "IBConnection recv thread failed: ", recvThreadErrorMsg_); @@ -503,10 +509,16 @@ void IBConnection::atomicAdd(RegisteredMemory dst, uint64_t dstOffset, int64_t v qp_.lock()->stageSendAtomicAdd(atomicSrcTransportInfo_.ibMr, dstMrInfo, /*wrId=*/0, dstOffset, static_cast(value), /*signaled=*/true); - qp_.lock()->postSend(); + // postSend deferred; see IBConnection::write. INFO(CONN, "IBConnection atomicAdd: dst ", (uint8_t*)dstMrInfo.addr + dstOffset, ", value ", value); } +void IBConnection::postPending() { + // Drain all staged WRs to the NIC in a single ibv_post_send call. + // Cheap no-op when nothing is staged (IbQp::postSend early-returns). + qp_.lock()->postSend(); +} + // EthernetConnection EthernetConnection::EthernetConnection(std::shared_ptr context, const Endpoint& localEndpoint, diff --git a/src/core/ibgda.cc b/src/core/ibgda.cc new file mode 100644 index 00000000..98419a64 --- /dev/null +++ b/src/core/ibgda.cc @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ibgda.hpp" + +#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM) + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "logger.hpp" +#include "mlx5dv_wrapper.hpp" + +namespace mscclpp { + +struct IbgdaResources::Impl { + void* sq_buf_host = nullptr; + size_t sq_bytes = 0; + void* dbrec_host = nullptr; // not necessarily aligned to anything > 8B + void* dbrec_register_addr = nullptr; + size_t dbrec_register_bytes = 0; + void* uar_page_host = nullptr; // page-aligned base of the UAR + void* state_dev = nullptr; // device-resident state + bool sq_registered = false; + bool dbrec_registered = false; + bool uar_registered = false; +}; + +namespace { + +inline uintptr_t pageMask() { + static const uintptr_t mask = static_cast(::sysconf(_SC_PAGESIZE)) - 1; + return mask; +} + +} // namespace + +IbgdaResources::IbgdaResources(ibv_qp* qp) : pimpl_(std::make_unique()) { + if (qp == nullptr) { + THROW(NET, Error, ErrorCode::InvalidUsage, "IbgdaResources: qp is null"); + } + if (!MLX5DV::isAvailable()) { + THROW(NET, Error, ErrorCode::InvalidUsage, "IbgdaResources: libmlx5 not available"); + } + + // 1. Extract sq.buf / dbrec / bf via mlx5dv_init_obj. + struct mlx5dv_qp dvQp{}; + if (MLX5DV::mlx5dv_init_obj_qp(qp, &dvQp) != 0) { + THROW(NET, IbError, errno, "mlx5dv_init_obj(QP) failed (errno ", errno, ")"); + } + + pimpl_->sq_buf_host = dvQp.sq.buf; + pimpl_->sq_bytes = static_cast(dvQp.sq.wqe_cnt) * dvQp.sq.stride; + pimpl_->dbrec_host = dvQp.dbrec; + if (pimpl_->sq_buf_host == nullptr || pimpl_->dbrec_host == nullptr || dvQp.bf.reg == nullptr) { + THROW(NET, Error, ErrorCode::InvalidUsage, "mlx5dv_qp returned null pointers"); + } + if (dvQp.sq.wqe_cnt == 0 || (dvQp.sq.wqe_cnt & (dvQp.sq.wqe_cnt - 1)) != 0) { + THROW(NET, Error, ErrorCode::InvalidUsage, "sq.wqe_cnt must be a power of 2, got ", dvQp.sq.wqe_cnt); + } + + // 2. Register the SQ buffer with CUDA (host-mapped). We register the + // enclosing whole pages because cudaHostRegister requires that. + { + uintptr_t base = reinterpret_cast(pimpl_->sq_buf_host); + uintptr_t pageBase = base & ~pageMask(); + size_t pad = static_cast(base - pageBase); + size_t regBytes = (pad + pimpl_->sq_bytes + pageMask()) & ~pageMask(); + void* regAddr = reinterpret_cast(pageBase); + cudaError_t e = cudaHostRegister(regAddr, regBytes, cudaHostRegisterDefault); + if (e != cudaSuccess) { + THROW(NET, SysError, static_cast(e), "cudaHostRegister(sq.buf) failed: ", + cudaGetErrorString(e)); + } + pimpl_->sq_registered = true; + void* dev = nullptr; + MSCCLPP_CUDATHROW(cudaHostGetDevicePointer(&dev, pimpl_->sq_buf_host, 0)); + handle_.sq_buf = dev; + } + + // 3. Register the DBR record. mlx5 hands us a pointer into a small block + // of DBR records (8 B each); register a whole page starting from its + // page base. + { + uintptr_t base = reinterpret_cast(pimpl_->dbrec_host); + uintptr_t pageBase = base & ~pageMask(); + size_t regBytes = pageMask() + 1; // 1 page + pimpl_->dbrec_register_addr = reinterpret_cast(pageBase); + pimpl_->dbrec_register_bytes = regBytes; + cudaError_t e = cudaHostRegister(pimpl_->dbrec_register_addr, regBytes, cudaHostRegisterDefault); + if (e == cudaErrorHostMemoryAlreadyRegistered) { + // The DBR page may overlap with another QP we already registered. + // Treat as success but skip unregister on shutdown. Clear sticky error. + pimpl_->dbrec_registered = false; + (void)cudaGetLastError(); + } else if (e != cudaSuccess) { + THROW(NET, SysError, static_cast(e), "cudaHostRegister(dbrec) failed: ", + cudaGetErrorString(e)); + } else { + pimpl_->dbrec_registered = true; + } + void* dev = nullptr; + MSCCLPP_CUDATHROW(cudaHostGetDevicePointer(&dev, pimpl_->dbrec_host, 0)); + handle_.dbrec = static_cast(dev); + } + + // 4. Map the UAR page (NIC MMIO) into GPU VA via cuMemHostRegister IOMEMORY. + { + uintptr_t bfAddr = reinterpret_cast(dvQp.bf.reg); + uintptr_t pageAddr = bfAddr & ~uintptr_t(4095); // UAR pages are 4 KB + uintptr_t bfOffset = bfAddr - pageAddr; + pimpl_->uar_page_host = reinterpret_cast(pageAddr); + CUresult cuRes = + cuMemHostRegister(pimpl_->uar_page_host, 4096, CU_MEMHOSTREGISTER_IOMEMORY); + if (cuRes == CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED) { + // The UAR page is shared across QPs on the same context; if a sibling + // QP already registered it, treat as success and skip unregister. + pimpl_->uar_registered = false; + (void)cudaGetLastError(); + } else if (cuRes != CUDA_SUCCESS) { + const char* s = nullptr; + cuGetErrorString(cuRes, &s); + THROW(NET, SysError, static_cast(cuRes), + "cuMemHostRegister(UAR, IOMEMORY) failed: ", s ? s : "?", + ". Ensure 'options nvidia NVreg_RegistryDwords=\"PeerMappingOverride=1;\"' " + "is set and nvidia_peermem is loaded."); + } else { + pimpl_->uar_registered = true; + } + CUdeviceptr dPage = 0; + MSCCLPP_CUTHROW(cuMemHostGetDevicePointer(&dPage, pimpl_->uar_page_host, 0)); + handle_.bf_reg = reinterpret_cast(dPage + bfOffset); + handle_.bf_offset = static_cast(bfOffset); + } + + // 5. Allocate GPU-resident state (zero-initialized). + { + constexpr size_t kStateBytes = 4 * sizeof(uint64_t); // resv/ready/prod/lock + MSCCLPP_CUDATHROW(cudaMalloc(&pimpl_->state_dev, kStateBytes)); + MSCCLPP_CUDATHROW(cudaMemset(pimpl_->state_dev, 0, kStateBytes)); + handle_.state = static_cast(pimpl_->state_dev); + } + + // 6. Fill the rest of the handle. + handle_.qpn = qp->qp_num; + handle_.wqe_cnt = dvQp.sq.wqe_cnt; + handle_.stride = dvQp.sq.stride; +} + +IbgdaResources::~IbgdaResources() { + if (!pimpl_) return; + if (pimpl_->state_dev) { + cudaFree(pimpl_->state_dev); + } + if (pimpl_->uar_registered && pimpl_->uar_page_host) { + cuMemHostUnregister(pimpl_->uar_page_host); + } + if (pimpl_->dbrec_registered && pimpl_->dbrec_register_addr) { + cudaHostUnregister(pimpl_->dbrec_register_addr); + } + if (pimpl_->sq_registered && pimpl_->sq_buf_host) { + uintptr_t base = reinterpret_cast(pimpl_->sq_buf_host); + void* regAddr = reinterpret_cast(base & ~pageMask()); + cudaHostUnregister(regAddr); + } +} + +} // namespace mscclpp + +#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM diff --git a/src/core/include/connection.hpp b/src/core/include/connection.hpp index 97f00ece..ace78c9a 100644 --- a/src/core/include/connection.hpp +++ b/src/core/include/connection.hpp @@ -37,6 +37,12 @@ class BaseConnection { virtual void atomicAdd(RegisteredMemory dst, uint64_t dstOffset, int64_t value) = 0; + /// Post any locally-staged work requests to the NIC. Called by the proxy + /// service to amortize ibv_post_send across many FIFO triggers. The default + /// no-op is correct for connections whose write/updateAndSync/atomicAdd + /// already post immediately (e.g., CudaIpcConnection). + virtual void postPending() {} + virtual void flush(int64_t timeoutUsec = -1) = 0; /// Start signal forwarding to the given memory address. @@ -152,6 +158,8 @@ class IBConnection : public BaseConnection { void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override; void atomicAdd(RegisteredMemory dst, uint64_t dstOffset, int64_t value) override; + void postPending() override; + void flush(int64_t timeoutUsec) override; }; diff --git a/src/core/include/ib.hpp b/src/core/include/ib.hpp index 36c5a237..15a82c19 100644 --- a/src/core/include/ib.hpp +++ b/src/core/include/ib.hpp @@ -82,6 +82,9 @@ class IbQp { int pollRecvCq(); IbQpInfo& getInfo() { return info_; } + /// Raw ibv_qp pointer. Owned by this IbQp. Provided so other components + /// (e.g. IbgdaResources) can call mlx5dv_init_obj on the same QP. + ibv_qp* getRawQp() const { return qp_; } int getSendWcStatus(int idx) const; std::string getSendWcStatusString(int idx) const; int getNumSendCqItems() const; diff --git a/src/core/include/ibgda.hpp b/src/core/include/ibgda.hpp new file mode 100644 index 00000000..eb625f86 --- /dev/null +++ b/src/core/include/ibgda.hpp @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// IBGDA-style GPU-direct work-request submission for an mlx5 QP. +// +// Given an existing ibv_qp (RC, plain ibv_create_qp), IbgdaResources extracts +// the SQ ring buffer, doorbell record, and BlueFlame UAR via mlx5dv_init_obj +// and maps them into GPU virtual address space: +// - SQ buf and DBR are host RAM, mapped via cudaHostRegister. +// - The BF UAR is device MMIO, mapped via cuMemHostRegister(IOMEMORY). +// It also allocates a small GPU-resident state struct (resv_head / ready_head +// / prod_idx / lock) used by the device-side WQE writer. +// +// The host retains ownership of the QP and never touches sq.buf/dbrec/bf +// itself once GPU posting starts — completion polling on the send CQ remains +// a host-side ibv_poll_cq() call. + +#ifndef MSCCLPP_IBGDA_HPP_ +#define MSCCLPP_IBGDA_HPP_ + +#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM) + +#include +#include + +struct ibv_qp; + +namespace mscclpp { + +// POD copied to the GPU; consumed by the device-side WQE writer (added in +// Stage 2). Field order/sizes must match include/mscclpp/ibgda_device.hpp. +struct IbgdaQpHandle { + // Device-mapped pointers (MUST be dereferenceable from the GPU only). + void* sq_buf; // SQ WQE ring (host RAM, cudaHostRegister) + uint32_t* dbrec; // 32-bit doorbell record (host RAM, cudaHostRegister) + uint64_t* bf_reg; // BlueFlame doorbell (NIC MMIO, cuMemHostRegister IOMEMORY) + // GPU-resident bookkeeping state (allocated by IbgdaResources). + // Layout: { uint64_t resv_head; uint64_t ready_head; uint64_t prod_idx; int post_send_lock; }. + uint64_t* state; + // Constants (host-order). + uint32_t qpn; + uint32_t wqe_cnt; // SQ length in WQEBBs (power of 2) + uint32_t stride; // bytes per WQEBB (typically 64) + uint32_t bf_offset; // byte offset of BF doorbell within its UAR page +}; + +// Wraps an existing ibv_qp and prepares the GPU mappings + GPU state. Owns +// the cudaHostRegister/cuMemHostRegister registrations + GPU state buffer. +// Lifetime must enclose any kernel launch that uses getHandle(). +class IbgdaResources { + public: + // The qp must already be modified to RTS before kernel posting; that is + // the caller's responsibility. + explicit IbgdaResources(ibv_qp* qp); + ~IbgdaResources(); + + IbgdaResources(const IbgdaResources&) = delete; + IbgdaResources& operator=(const IbgdaResources&) = delete; + + const IbgdaQpHandle& getHandle() const { return handle_; } + + private: + struct Impl; + std::unique_ptr pimpl_; + IbgdaQpHandle handle_{}; +}; + +} // namespace mscclpp + +#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM +#endif // MSCCLPP_IBGDA_HPP_ diff --git a/src/core/include/ibgda_device.cuh b/src/core/include/ibgda_device.cuh new file mode 100644 index 00000000..4bfaac6e --- /dev/null +++ b/src/core/include/ibgda_device.cuh @@ -0,0 +1,426 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Stage 2: device-side IBGDA primitives. +// +// Operates on an `IbgdaQpHandle` (see ibgda.hpp). All pointers in the handle +// must already be GPU-mapped by `IbgdaResources` on the host. This header is +// device-only and should be included from .cu translation units. +// +// State buffer layout (allocated by IbgdaResources, zero-initialised): +// offset 0: uint64_t resv_head // next free WQEBB slot (atomic) +// offset 8: uint64_t ready_head // next slot whose WQE write is done +// offset 16: uint64_t prod_idx // last value rung on the doorbell +// offset 24: int post_send_lock // cta-level lock for ringing +// +// Conventions copied from NVSHMEM/DeepEP: +// - wqe_idx is a 16-bit-truncated cyclic index into the SQ ring. +// - one RDMA WRITE WQE = 1 WQEBB (64B): ctrl(16) + raddr(16) + data(16) + pad(16). +// ctrl_seg.qpn_ds encodes ds=3 (three 16-byte segments). +// - DBR record holds BE32(prod_idx & 0xFFFF) in the SQ slot (dbrec[1] for +// mlx5; IbgdaResources sets handle.dbrec to that slot directly). +// - BF doorbell: 64-bit write of {opmod_idx_opcode, qpn_ds} to bf_reg. + +#ifndef MSCCLPP_IBGDA_DEVICE_CUH_ +#define MSCCLPP_IBGDA_DEVICE_CUH_ + +#include +#include + +#include "ibgda.hpp" + +namespace mscclpp { +namespace ibgda { + +#ifndef MSCCLPP_IBGDA_OPCODE_RDMA_WRITE +#define MSCCLPP_IBGDA_OPCODE_RDMA_WRITE 0x08 +#endif +#ifndef MSCCLPP_IBGDA_CTRL_CQ_UPDATE +#define MSCCLPP_IBGDA_CTRL_CQ_UPDATE uint8_t(0x8) +#endif + +// ---- BE swap helpers (PTX prmt) ------------------------------------------- +__device__ static __forceinline__ uint32_t HtoBE32(uint32_t x) { + uint32_t r; + asm("{ .reg .b32 ig; prmt.b32 %0, %1, ig, 0x0123; }" : "=r"(r) : "r"(x)); + return r; +} + +__device__ static __forceinline__ uint64_t HtoBE64(uint64_t x) { + uint64_t r; + asm("{\n\t" + ".reg .b32 ig;\n\t" + ".reg .b32 lo, hi, nl, nh;\n\t" + "mov.b64 {lo,hi}, %1;\n\t" + "prmt.b32 nh, lo, ig, 0x0123;\n\t" + "prmt.b32 nl, hi, ig, 0x0123;\n\t" + "mov.b64 %0, {nl,nh};\n\t" "}" + : "=l"(r) : "l"(x)); + return r; +} + +// ---- Memory ordering helpers ---------------------------------------------- +// st.relaxed.sys + st.release.sys equivalents. We use `volatile` writes plus +// __threadfence_system() at the points the caller needs ordering, mirroring +// the NVSHMEM `st_na_*` style without depending on its internal headers. +__device__ static __forceinline__ void store_relaxed_u32(uint32_t* p, uint32_t v) { + *reinterpret_cast(p) = v; +} +__device__ static __forceinline__ void store_relaxed_u64(uint64_t* p, uint64_t v) { + *reinterpret_cast(p) = v; +} +__device__ static __forceinline__ void store_relaxed_int4(int4* p, int4 v) { + asm volatile("st.relaxed.sys.v4.b32 [%0], {%1, %2, %3, %4};" :: + "l"(p), "r"(v.x), "r"(v.y), "r"(v.z), "r"(v.w) : "memory"); +} +__device__ static __forceinline__ void store_release_u32(uint32_t* p, uint32_t v) { + asm volatile("st.release.sys.b32 [%0], %1;" :: "l"(p), "r"(v) : "memory"); +} +__device__ static __forceinline__ void store_release_u64(uint64_t* p, uint64_t v) { + asm volatile("st.release.sys.b64 [%0], %1;" :: "l"(p), "l"(v) : "memory"); +} + +// ---- WQE segment layout (mirrors mlx5_wqe_*; redefined to avoid pulling +// libmlx5 headers into device code) ---------------------------------------- +struct __align__(8) CtrlSeg { + uint32_t opmod_idx_opcode; // BE32: (wqe_idx << 8) | opcode + uint32_t qpn_ds; // BE32: (qpn << 8) | ds + uint8_t signature; + uint8_t rsvd0; + uint8_t rsvd1; + uint8_t fm_ce_se; // CQ_UPDATE in bit 3 (0x8) + uint32_t imm; // BE32 immediate (or 0) +}; +struct __align__(8) RaddrSeg { + uint64_t raddr; // BE64 + uint32_t rkey; // BE32 (caller pre-swaps) + uint32_t reserved; +}; +struct __align__(8) DataSeg { + uint32_t byte_count; // BE32 + uint32_t lkey; // BE32 (caller pre-swaps) + uint64_t addr; // BE64 +}; + +static_assert(sizeof(CtrlSeg) == 16, "CtrlSeg must be 16B"); +static_assert(sizeof(RaddrSeg) == 16, "RaddrSeg must be 16B"); +static_assert(sizeof(DataSeg) == 16, "DataSeg must be 16B"); + +// ---- State accessors ------------------------------------------------------ +struct StateView { + unsigned long long* resv_head; + unsigned long long* ready_head; + unsigned long long* prod_idx; + int* post_send_lock; +}; + +__device__ static __forceinline__ StateView stateView(const IbgdaQpHandle& h) { + auto* base = reinterpret_cast(h.state); + StateView v; + v.resv_head = reinterpret_cast(&base[0]); + v.ready_head = reinterpret_cast(&base[1]); + v.prod_idx = reinterpret_cast(&base[2]); + v.post_send_lock = reinterpret_cast(&base[3]); + return v; +} + +// ---- WQE ring helpers ----------------------------------------------------- +__device__ static __forceinline__ +uint64_t reserve_wqe_slots(const IbgdaQpHandle& h, uint32_t num_wqes) { + auto v = stateView(h); + return atomicAdd(v.resv_head, static_cast(num_wqes)); +} + +__device__ static __forceinline__ +void* get_wqe_ptr(const IbgdaQpHandle& h, uint16_t wqe_idx) { + uint16_t mask = static_cast(h.wqe_cnt - 1); + uint16_t idx = wqe_idx & mask; + return reinterpret_cast(h.sq_buf) + + (static_cast(idx) * h.stride); +} + +// ---- Doorbell record + BF ring ------------------------------------------- +__device__ static __forceinline__ +void update_dbr(const IbgdaQpHandle& h, uint32_t dbrec_head) { + // BE32(dbrec_head & 0xFFFF) — see DeepEP/NVSHMEM. + uint32_t v; + asm("{\n\t" + ".reg .b32 lo16; .reg .b32 ig;\n\t" + "and.b32 lo16, %1, 0xffff;\n\t" + "prmt.b32 %0, lo16, ig, 0x123;\n\t" + "}" : "=r"(v) : "r"(dbrec_head)); + // dbrec is the SQ slot directly (set by IbgdaResources). + store_release_u32(reinterpret_cast(h.dbrec), v); +} + +__device__ static __forceinline__ +void ring_db(const IbgdaQpHandle& h, uint16_t prod_idx) { + // 64-bit BF write of {opmod_idx_opcode, qpn_ds}. + // opmod_idx_opcode = BE32(prod_idx << 8); qpn_ds = BE32(qpn << 8). + uint32_t lo = HtoBE32(static_cast(prod_idx) << 8); + uint32_t hi = HtoBE32(h.qpn << 8); + uint64_t v = static_cast(lo) | (static_cast(hi) << 32); + store_release_u64(h.bf_reg, v); +} + +__device__ static __forceinline__ +void post_send(const IbgdaQpHandle& h, uint64_t new_prod_idx) { + auto v = stateView(h); + // CTA-level lock so concurrent posters serialise their DBR/BF writes. + while (atomicCAS(v.post_send_lock, 0, 1) == 1) {} + __threadfence(); + + unsigned long long old_prod_idx = + atomicMax(v.prod_idx, static_cast(new_prod_idx)); + if (new_prod_idx > old_prod_idx) { + update_dbr(h, static_cast(new_prod_idx)); + ring_db(h, static_cast(new_prod_idx)); + } + + __threadfence(); + store_release_u32(reinterpret_cast(v.post_send_lock), 0); +} + +// ---- Submit (publish ready_head, optionally ring DB) ---------------------- +template +__device__ static __forceinline__ +void submit_requests(const IbgdaQpHandle& h, uint64_t base_wqe_idx, + uint32_t num_wqes, int message_idx = 0) { + auto v = stateView(h); + uint64_t new_wqe_idx = base_wqe_idx + num_wqes; + + // The WQE writes themselves must be globally visible before we publish. + __threadfence_system(); + + // Wait for any earlier reservations to publish first, preserving order. + auto base_ull = static_cast(base_wqe_idx); + auto new_ull = static_cast(new_wqe_idx); + while (atomicCAS(v.ready_head, base_ull, new_ull) != base_ull) {} + + constexpr int kBatch = 4; + if (kAlwaysDoPostSend || ((message_idx + 1) % kBatch == 0)) { + post_send(h, new_wqe_idx); + } +} + +// Publish-only variant: advance ready_head in WQE-issue order (preserving the +// in-order chain that submit_requests builds) but DO NOT ring the doorbell. +// Use this for the data WRs in a coalesced "unsignaled body + signaled tail" +// pattern; the trailing WR's submit_requests will atomicMax prod_idx +// past all earlier WRs and ring the doorbell once for the whole batch. +__device__ static __forceinline__ +void submit_no_db(const IbgdaQpHandle& h, uint64_t base_wqe_idx, + uint32_t num_wqes) { + auto v = stateView(h); + uint64_t new_wqe_idx = base_wqe_idx + num_wqes; + __threadfence_system(); + auto base_ull = static_cast(base_wqe_idx); + auto new_ull = static_cast(new_wqe_idx); + while (atomicCAS(v.ready_head, base_ull, new_ull) != base_ull) {} + // No post_send: the next signaled WR on this QP (or any concurrent ringer) + // will atomicMax prod_idx past us and the NIC will pick up our slot. +} + +// ---- WQE writers ---------------------------------------------------------- +// All addresses BE-encoded inside; lkey/rkey expected pre-BE-swapped. +// `signal_cqe`: if true, set CQ_UPDATE so the NIC posts a CQE for this WR; +// if false, the WR is UNSIGNALED — useful for batched data WRs whose +// completion is implicitly bounded by a later signaled WR on the same QP. +__device__ static __forceinline__ +void write_rdma_write_wqe(const IbgdaQpHandle& h, void* wqe_slot, + uint64_t laddr, uint32_t lkey_be, + uint64_t raddr, uint32_t rkey_be, + uint32_t bytes, uint16_t wqe_idx, + bool signal_cqe = true) { + auto* ctrl = reinterpret_cast(wqe_slot); + auto* raddrp = reinterpret_cast(reinterpret_cast(wqe_slot) + 16); + auto* datap = reinterpret_cast(reinterpret_cast(wqe_slot) + 32); + + CtrlSeg c{}; + c.opmod_idx_opcode = HtoBE32((static_cast(wqe_idx) << 8) | + MSCCLPP_IBGDA_OPCODE_RDMA_WRITE); + c.qpn_ds = HtoBE32((h.qpn << 8) | 3u); + c.fm_ce_se = signal_cqe ? MSCCLPP_IBGDA_CTRL_CQ_UPDATE : 0u; + c.imm = 0; + + RaddrSeg r{}; + r.raddr = HtoBE64(raddr); + r.rkey = rkey_be; + r.reserved = 0; + + DataSeg d{}; + d.byte_count = HtoBE32(bytes); + d.lkey = lkey_be; + d.addr = HtoBE64(laddr); + + store_relaxed_int4(reinterpret_cast(ctrl), *reinterpret_cast(&c)); + store_relaxed_int4(reinterpret_cast(raddrp), *reinterpret_cast(&r)); + store_relaxed_int4(reinterpret_cast(datap), *reinterpret_cast(&d)); +} + +// One-shot helper: reserve a slot, write WQE, submit + ring doorbell. +// Returns the wqe_idx of the issued WR (caller can match against CQE wr_id +// if it embeds the index there; here we don't use wr_id). +__device__ static __forceinline__ +uint64_t rdma_write(const IbgdaQpHandle& h, + uint64_t laddr, uint32_t lkey_be, + uint64_t raddr, uint32_t rkey_be, + uint32_t bytes, + bool signal_cqe = true, bool ring_db = true) { + uint64_t base = reserve_wqe_slots(h, 1); + void* slot = get_wqe_ptr(h, static_cast(base)); + write_rdma_write_wqe(h, slot, laddr, lkey_be, raddr, rkey_be, bytes, + static_cast(base), signal_cqe); + if (ring_db) submit_requests(h, base, 1); + else submit_no_db(h, base, 1); + return base; +} + +// ---- Inline 4-byte RDMA WRITE -------------------------------------------- +// Single WQEBB carrying (ctrl=16) + (raddr=16) + (inline_hdr=4 + data=4) +// padded to 48B. ds=3 in ctrl_seg. +// +// MLX5_INLINE_SEG = 0x80000000 in inline_seg.byte_count signals "inline". +// +// The 4-byte payload `value` is written verbatim to remote `raddr`. The +// caller does NOT need a local MR for this op. +#ifndef MSCCLPP_IBGDA_INLINE_SEG +#define MSCCLPP_IBGDA_INLINE_SEG 0x80000000u +#endif + +__device__ static __forceinline__ +void write_rdma_write_inl4_wqe(const IbgdaQpHandle& h, void* wqe_slot, + uint32_t value, + uint64_t raddr, uint32_t rkey_be, + uint16_t wqe_idx, + bool signal_cqe = true) { + auto* base8 = reinterpret_cast(wqe_slot); + auto* ctrl = reinterpret_cast(base8 + 0); + auto* raddrp = reinterpret_cast(base8 + 16); + // inline header is a 4-byte field (byte_count) at offset 32, followed by + // 4 bytes of payload at offset 36. + uint32_t* inl_hdr = reinterpret_cast(base8 + 32); + uint32_t* inl_dat = reinterpret_cast(base8 + 36); + + CtrlSeg c{}; + c.opmod_idx_opcode = HtoBE32((static_cast(wqe_idx) << 8) | + MSCCLPP_IBGDA_OPCODE_RDMA_WRITE); + // ds=3 (3 * 16B = 48B; ctrl + raddr + 8B of inline hdr+data, padded). + c.qpn_ds = HtoBE32((h.qpn << 8) | 3u); + c.fm_ce_se = signal_cqe ? MSCCLPP_IBGDA_CTRL_CQ_UPDATE : 0u; + c.imm = 0; + + RaddrSeg r{}; + r.raddr = HtoBE64(raddr); + r.rkey = rkey_be; + r.reserved = 0; + + store_relaxed_int4(reinterpret_cast(ctrl), *reinterpret_cast(&c)); + store_relaxed_int4(reinterpret_cast(raddrp), *reinterpret_cast(&r)); + // inline byte_count: 4 | INLINE_SEG, BE32. + store_relaxed_u32(inl_hdr, HtoBE32(4u | MSCCLPP_IBGDA_INLINE_SEG)); + // The 4-byte payload is written as-is in network order — the NIC treats + // the inline data block as opaque bytes copied to remote memory. + store_relaxed_u32(inl_dat, value); +} + +__device__ static __forceinline__ +uint64_t rdma_write_inl4(const IbgdaQpHandle& h, + uint32_t value, + uint64_t raddr, uint32_t rkey_be, + bool signal_cqe = true, bool ring_db = true) { + uint64_t base = reserve_wqe_slots(h, 1); + void* slot = get_wqe_ptr(h, static_cast(base)); + write_rdma_write_inl4_wqe(h, slot, value, raddr, rkey_be, + static_cast(base), signal_cqe); + if (ring_db) submit_requests(h, base, 1); + else submit_no_db(h, base, 1); + return base; +} + +// ---- Inline 8-byte RDMA WRITE -------------------------------------------- +// Same WQE shape as inl4 (ds=3 / 48B): ctrl(16) + raddr(16) + inline_hdr(4) +// + payload(8) + pad(4). Used to publish 8B counters (e.g. dispatch +// per-(local_expert, src_rank) recv_count slot polled with int64 acquire). +__device__ static __forceinline__ +void write_rdma_write_inl8_wqe(const IbgdaQpHandle& h, void* wqe_slot, + uint64_t value, + uint64_t raddr, uint32_t rkey_be, + uint16_t wqe_idx, + bool signal_cqe = true) { + auto* base8 = reinterpret_cast(wqe_slot); + auto* ctrl = reinterpret_cast(base8 + 0); + auto* raddrp = reinterpret_cast(base8 + 16); + uint32_t* inl_hdr = reinterpret_cast(base8 + 32); + // Two 32-bit halves of the 8B payload at offsets 36 / 40. + uint32_t* inl_dat_lo = reinterpret_cast(base8 + 36); + uint32_t* inl_dat_hi = reinterpret_cast(base8 + 40); + + CtrlSeg c{}; + c.opmod_idx_opcode = HtoBE32((static_cast(wqe_idx) << 8) | + MSCCLPP_IBGDA_OPCODE_RDMA_WRITE); + c.qpn_ds = HtoBE32((h.qpn << 8) | 3u); + c.fm_ce_se = signal_cqe ? MSCCLPP_IBGDA_CTRL_CQ_UPDATE : 0u; + c.imm = 0; + + RaddrSeg r{}; + r.raddr = HtoBE64(raddr); + r.rkey = rkey_be; + r.reserved = 0; + + store_relaxed_int4(reinterpret_cast(ctrl), *reinterpret_cast(&c)); + store_relaxed_int4(reinterpret_cast(raddrp), *reinterpret_cast(&r)); + // inline byte_count: 8 | INLINE_SEG, BE32. + store_relaxed_u32(inl_hdr, HtoBE32(8u | MSCCLPP_IBGDA_INLINE_SEG)); + // 8B payload as two 32-bit words; treated as opaque bytes by the NIC. + store_relaxed_u32(inl_dat_lo, static_cast(value & 0xffffffffull)); + store_relaxed_u32(inl_dat_hi, static_cast(value >> 32)); +} + +__device__ static __forceinline__ +uint64_t rdma_write_inl8(const IbgdaQpHandle& h, + uint64_t value, + uint64_t raddr, uint32_t rkey_be, + bool signal_cqe = true, bool ring_db = true) { + uint64_t base = reserve_wqe_slots(h, 1); + void* slot = get_wqe_ptr(h, static_cast(base)); + write_rdma_write_inl8_wqe(h, slot, value, raddr, rkey_be, + static_cast(base), signal_cqe); + if (ring_db) submit_requests(h, base, 1); + else submit_no_db(h, base, 1); + return base; +} + +// ---- Warp-coalesced burst (lane 0 issues N WRs as a single batch) -------- +// This is the "warp granularity" pattern used by NVSHMEM IBGDA: a single +// thread reserves N contiguous slots (one atomic), writes N WQEs in a tight +// loop, then publishes ready_head and rings the doorbell exactly once. This +// amortises the ready_head CAS-spin and the BF doorbell MMIO across N WRs. +// +// Caller is responsible for ensuring that only one thread per warp invokes +// this with a given (laddr/raddr) layout — typically `if (lane_id == 0)`. +// All N WRs share the same lkey_be/rkey_be and have stride `bytes` (i.e. a +// contiguous chunk laddr_base..laddr_base+N*bytes -> raddr_base..). +__device__ static __forceinline__ +uint64_t rdma_write_strided_burst(const IbgdaQpHandle& h, + uint64_t laddr_base, uint32_t lkey_be, + uint64_t raddr_base, uint32_t rkey_be, + uint32_t bytes, uint32_t num_wrs) { + if (num_wrs == 0) return 0; + uint64_t base = reserve_wqe_slots(h, num_wrs); + for (uint32_t i = 0; i < num_wrs; ++i) { + uint16_t idx = static_cast(base + i); + void* slot = get_wqe_ptr(h, idx); + write_rdma_write_wqe(h, slot, + laddr_base + size_t(i) * bytes, lkey_be, + raddr_base + size_t(i) * bytes, rkey_be, + bytes, idx); + } + submit_requests(h, base, num_wrs); + return base; +} + +} // namespace ibgda +} // namespace mscclpp + +#endif // MSCCLPP_IBGDA_DEVICE_CUH_ diff --git a/src/core/include/ibgda_port_channel_device.cuh b/src/core/include/ibgda_port_channel_device.cuh new file mode 100644 index 00000000..85585861 --- /dev/null +++ b/src/core/include/ibgda_port_channel_device.cuh @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Stage 3: device-side methods that act on `IbgdaPortChannelDeviceHandle`. +// Split from the POD header so that the POD can be included by host code +// (which doesn't want PTX inline asm or CUDA-only types) while these inline +// device functions only appear in .cu translation units. + +#ifndef MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_CUH_ +#define MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_CUH_ + +#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM) + +#include + +#include + +#include "ibgda_device.cuh" + +namespace mscclpp { +namespace ibgda { + +// ---------------- put ----------------------------------------------------- +// `signal_cqe` / `ring_db` default to true (preserves prior call sites); +// pass false/false for batched data WRs whose completion is implicitly +// bounded by a later signaled WR on the same QP (e.g., the trailing count +// or flag write at the end of LL dispatch / combine). +__device__ static __forceinline__ +void port_put(const IbgdaPortChannelDeviceHandle& ch, + IbgdaMemoryId dst, uint64_t dstOffset, + IbgdaMemoryId src, uint64_t srcOffset, + uint64_t size, + bool signal_cqe = true, bool ring_db = true) { + IbgdaLocalMr l = ch.local_mrs[src]; + IbgdaRemoteMr r = ch.remote_mrs[dst]; + rdma_write(ch.qp, + l.addr + srcOffset, l.lkey_be, + r.addr + dstOffset, r.rkey_be, + static_cast(size), + signal_cqe, ring_db); +} + +__device__ static __forceinline__ +void port_put(const IbgdaPortChannelDeviceHandle& ch, + uint64_t dstOffset, uint64_t srcOffset, uint64_t size, + bool signal_cqe = true, bool ring_db = true) { + port_put(ch, ch.dst, dstOffset, ch.src, srcOffset, size, signal_cqe, ring_db); +} + +__device__ static __forceinline__ +void port_put(const IbgdaPortChannelDeviceHandle& ch, + uint64_t offset, uint64_t size, + bool signal_cqe = true, bool ring_db = true) { + port_put(ch, offset, offset, size, signal_cqe, ring_db); +} + +// ---------------- signal -------------------------------------------------- +// Increments the local sequence counter (system-wide atomic so multiple +// CTAs serialise) and writes the new value as an inline 4B RDMA WRITE to +// the peer's signal slot. +__device__ static __forceinline__ +uint32_t port_signal(const IbgdaPortChannelDeviceHandle& ch) { + // Use atomicAdd_system so the counter is consistent across CTAs sharing + // this handle (rare but possible). + uint32_t prev = atomicAdd(ch.sig_seq, 1u); + uint32_t v = prev + 1u; + rdma_write_inl4(ch.qp, v, ch.sig_remote_addr, ch.sig_rkey_be); + return v; +} + +// ---------------- wait ---------------------------------------------------- +// Returns when the local signal slot's value reaches `expected` (sticky; +// callers track the expected counter externally — same pattern as +// PortChannelDeviceHandle::wait()). +__device__ static __forceinline__ +void port_wait(const IbgdaPortChannelDeviceHandle& ch, uint32_t expected, + int64_t maxSpinCount = 10000000) { + volatile uint32_t* p = ch.sig_local_addr; + if (maxSpinCount < 0) { + while (*p < expected) { /* spin */ } + return; + } + for (int64_t i = 0; i < maxSpinCount; ++i) { + if (*p >= expected) return; + } + // Out of patience — caller can detect with a follow-up poll(); we do NOT + // assert here to keep the device path lean. +} + +__device__ static __forceinline__ +bool port_poll(const IbgdaPortChannelDeviceHandle& ch, uint32_t expected) { + return *(volatile uint32_t*)ch.sig_local_addr >= expected; +} + +} // namespace ibgda +} // namespace mscclpp + +#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM + +#endif // MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_CUH_ diff --git a/src/core/include/mlx5dv_wrapper.hpp b/src/core/include/mlx5dv_wrapper.hpp index 79403a36..83912bc6 100644 --- a/src/core/include/mlx5dv_wrapper.hpp +++ b/src/core/include/mlx5dv_wrapper.hpp @@ -28,6 +28,11 @@ struct MLX5DV { /// Returns 0 on success (device supports Data Direct), non-zero otherwise. static int mlx5dv_get_data_direct_sysfs_path(struct ibv_context* context, char* buf, size_t buf_len); + /// Wraps mlx5dv_init_obj(MLX5DV_OBJ_QP). Returns 0 on success. + /// `out` must be a pointer to an mlx5dv_qp; we keep this typeless to avoid + /// pulling into the public-ish wrapper header. + static int mlx5dv_init_obj_qp(struct ibv_qp* qp, void* out); + private: static void* dlsym(const std::string& symbol, bool allowReturnNull = false); }; diff --git a/src/core/mlx5dv_wrapper.cc b/src/core/mlx5dv_wrapper.cc index a56fad96..8b55f177 100644 --- a/src/core/mlx5dv_wrapper.cc +++ b/src/core/mlx5dv_wrapper.cc @@ -121,6 +121,20 @@ int MLX5DV::mlx5dv_get_data_direct_sysfs_path(struct ibv_context* context, char* return impl(context, buf, buf_len); } +int MLX5DV::mlx5dv_init_obj_qp(struct ibv_qp* qp, void* out_qp) { + using FuncType = int (*)(struct mlx5dv_obj*, uint64_t); + static FuncType impl = nullptr; + if (!impl) { + void* ptr = MLX5DV::dlsym("mlx5dv_init_obj", /*allowReturnNull=*/true); + if (!ptr) return -1; + impl = reinterpret_cast(ptr); + } + struct mlx5dv_obj obj{}; + obj.qp.in = qp; + obj.qp.out = static_cast(out_qp); + return impl(&obj, MLX5DV_OBJ_QP); +} + } // namespace mscclpp #endif // defined(MSCCLPP_USE_MLX5DV) diff --git a/src/core/port_channel.cc b/src/core/port_channel.cc index f2fc1efe..f47890bd 100644 --- a/src/core/port_channel.cc +++ b/src/core/port_channel.cc @@ -3,12 +3,75 @@ #include #include +#include + +#include +#include +#include #include "api.h" +#include "connection.hpp" #include "debug.h" namespace mscclpp { +// Lightweight diagnostic counters kept in a side-table keyed by ProxyService* +// so the ProxyService class layout stays ABI-compatible with prebuilt +// extensions (e.g. mscclpp/_mscclpp.cpython-*.so) that were compiled against +// the older header. Populated only when MSCCLPP_PROXY_STATS=1. +namespace { +struct ProxyStats { + bool enabled = false; + bool printed = false; + uint64_t triggers = 0; + uint64_t trigData = 0; + uint64_t trigFlag = 0; + uint64_t trigAtomic = 0; + uint64_t trigSync = 0; + uint64_t postCalls = 0; + uint64_t idleDrains = 0; +}; +// Process-wide flag: 0 unless any ProxyService has stats enabled. Avoids +// taking the side-table mutex on the proxy hot path when nobody asked for +// stats. +std::atomic& statsAnyEnabled() { + static std::atomic v{false}; + return v; +} +std::mutex& statsMu() { static std::mutex m; return m; } +std::unordered_map& statsTable() { + static std::unordered_map t; + return t; +} +ProxyStats* getStats(const ProxyService* self) { + if (!statsAnyEnabled().load(std::memory_order_relaxed)) return nullptr; + std::lock_guard lk(statsMu()); + auto it = statsTable().find(self); + if (it == statsTable().end()) return nullptr; + return &it->second; +} +void printAndEraseStats(const ProxyService* self) { + std::lock_guard lk(statsMu()); + auto it = statsTable().find(self); + if (it == statsTable().end()) return; + ProxyStats& s = it->second; + if (s.enabled && !s.printed) { + s.printed = true; + uint64_t posts = s.postCalls + s.idleDrains; + double wrPerPost = posts ? static_cast(s.triggers) / static_cast(posts) : 0.0; + fprintf(stderr, + "[mscclpp proxy stats] triggers=%llu (data=%llu flag=%llu atomic=%llu sync=%llu) " + "postCalls=%llu idleDrains=%llu triggersPerPost=%.2f\n", + (unsigned long long)s.triggers, (unsigned long long)s.trigData, + (unsigned long long)s.trigFlag, (unsigned long long)s.trigAtomic, + (unsigned long long)s.trigSync, (unsigned long long)s.postCalls, + (unsigned long long)s.idleDrains, wrPerPost); + fflush(stderr); + } + statsTable().erase(it); +} +} // namespace + MSCCLPP_API_CPP BasePortChannel::BasePortChannel(SemaphoreId semaphoreId, std::shared_ptr semaphore, std::shared_ptr proxy) @@ -30,6 +93,15 @@ MSCCLPP_API_CPP ProxyService::ProxyService(int fifoSize) { int cudaDevice; MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice)); int deviceNumaNode = getDeviceNumaNode(cudaDevice); + if (const char* env = std::getenv("MSCCLPP_PROXY_STATS")) { + if (std::atoi(env) > 0) { + { + std::lock_guard lk(statsMu()); + statsTable()[this].enabled = true; + } + statsAnyEnabled().store(true, std::memory_order_relaxed); + } + } auto initFunc = [cudaDevice, deviceNumaNode]() { MSCCLPP_CUDATHROW(cudaSetDevice(cudaDevice)); if (deviceNumaNode >= 0) { @@ -39,6 +111,21 @@ MSCCLPP_API_CPP ProxyService::ProxyService(int fifoSize) { }; auto handlerFunc = [&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }; proxy_ = std::make_shared(handlerFunc, initFunc, fifoSize); + // Drain any deferred ibv_post_send work the moment the FIFO empties out, so + // we batch as many triggers as possible into a single syscall while still + // keeping latency within one FIFO drain. + proxy_->setOnIdle([this]() { this->postPendingAll(); }); +} + +void ProxyService::postPendingAll() { + if (!stagedConns_.empty()) { + if (auto* s = getStats(this); s && s->enabled) s->idleDrains++; + } + for (auto& kv : stagedConns_) { + kv.first->postPending(); + } + stagedConns_.clear(); + dirtyConns_.clear(); } MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator, @@ -84,14 +171,45 @@ MSCCLPP_API_CPP PortChannel ProxyService::portChannel(SemaphoreId id, MemoryId d MSCCLPP_API_CPP void ProxyService::startProxy(bool blocking) { proxy_->start(blocking); } -MSCCLPP_API_CPP void ProxyService::stopProxy() { proxy_->stop(); } +MSCCLPP_API_CPP void ProxyService::stopProxy() { + proxy_->stop(); + printAndEraseStats(this); +} ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger trigger) { std::shared_ptr semaphore = semaphores_[trigger.fields.semaphoreId]; + ProxyStats* stats = getStats(this); + if (stats && stats->enabled) { + stats->triggers++; + if (trigger.fields.type == 0) stats->trigAtomic++; + if (trigger.fields.type & TriggerData) stats->trigData++; + if (trigger.fields.type & TriggerFlag) stats->trigFlag++; + if (trigger.fields.type & TriggerSync) stats->trigSync++; + } + auto& conn = semaphore->connection(); int maxWriteQueueSize = conn.getMaxWriteQueueSize(); - auto& numRequests = inflightRequests_[conn.impl_]; + auto connImpl = BaseConnection::getImpl(conn); + auto& numRequests = inflightRequests_[connImpl]; + + // Batch threshold: how many staged WRs we let accumulate on one connection + // before forcing an ibv_post_send. Lower values reduce tail latency for + // signalling triggers; higher values reduce syscalls. + // Override via MSCCLPP_PROXY_BATCH_THRESHOLD (1 = post every trigger, + // matching the historical behavior; the empirical sweep on H100/IB shows + // LL traffic is wire-latency bound rather than syscall bound, so deeper + // batching does not help and may regress dispatch tail latency). + static const int kPostBatchThreshold = []() { + if (const char* env = std::getenv("MSCCLPP_PROXY_BATCH_THRESHOLD")) { + int v = std::atoi(env); + if (v >= 1) return v; + } + return 1; + }(); + bool stagedWork = false; + // Atomic/signal triggers convey completion semantics; never defer them. + bool isSignal = (trigger.fields.type == 0) || (trigger.fields.type & TriggerFlag); if (trigger.fields.type == 0) { // type == 0 indicates an atomic add operation. @@ -100,6 +218,7 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger trigger) { int64_t value = static_cast(trigger.fst); conn.atomicAdd(dst, trigger.fields.dstOffset, value); numRequests++; + stagedWork = true; } if (trigger.fields.type & TriggerData) { @@ -107,17 +226,38 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger trigger) { RegisteredMemory& src = memories_[trigger.fields.srcMemoryId]; conn.write(dst, trigger.fields.dstOffset, src, trigger.fields.srcOffset, trigger.fields.size); numRequests++; + stagedWork = true; } if (trigger.fields.type & TriggerFlag) { semaphore->signal(); numRequests++; + stagedWork = true; } - if (((trigger.fields.type & TriggerSync) && numRequests > 0) || - (maxWriteQueueSize != -1 && numRequests >= maxWriteQueueSize)) { + if (stagedWork) { + stagedConns_[connImpl]++; + dirtyConns_.insert(connImpl); + } + + bool needFlush = (trigger.fields.type & TriggerSync) && numRequests > 0; + bool needFlush2 = maxWriteQueueSize != -1 && numRequests >= maxWriteQueueSize; + if (needFlush || needFlush2) { + // flush() drains staged WRs and waits for completion. conn.flush(); numRequests = 0; + stagedConns_.erase(connImpl); + dirtyConns_.erase(connImpl); + } else if (isSignal || stagedConns_[connImpl] >= kPostBatchThreshold) { + // Post the staged batch now: either we just queued a completion-signal + // WR (don't make the receiver wait for the next idle drain), or we've + // accumulated enough WRs that the QP send queue might overflow. + if (stats && stats->enabled) { + stats->postCalls++; + } + connImpl->postPending(); + stagedConns_.erase(connImpl); + dirtyConns_.erase(connImpl); } return ProxyHandlerResult::Continue; diff --git a/src/core/proxy.cc b/src/core/proxy.cc index de5b90fc..64ceea11 100644 --- a/src/core/proxy.cc +++ b/src/core/proxy.cc @@ -20,6 +20,7 @@ constexpr int ProxyStartWarnPeriod = 1000; struct Proxy::Impl { ProxyHandler handler; std::function threadInit; + std::function onIdle; std::shared_ptr fifo; std::atomic_bool threadStarted; std::thread service; @@ -28,6 +29,7 @@ struct Proxy::Impl { Impl(ProxyHandler handler, std::function threadInit, int fifoSize) : handler(handler), threadInit(threadInit), + onIdle(nullptr), fifo(std::make_shared(fifoSize)), threadStarted(false), running(false) {} @@ -71,9 +73,11 @@ MSCCLPP_API_CPP void Proxy::start(bool blocking) { pimpl_->threadStarted.store(true, std::memory_order_release); ProxyHandler handler = this->pimpl_->handler; + auto onIdle = this->pimpl_->onIdle; auto fifo = this->pimpl_->fifo; ProxyTrigger trigger; + bool wasBusy = false; int runCnt = ProxyStopCheckPeriod; for (;;) { if (runCnt-- == 0) { @@ -85,8 +89,13 @@ MSCCLPP_API_CPP void Proxy::start(bool blocking) { // Poll to see if we are ready to send anything trigger = fifo->poll(); if (trigger.fst == 0 || trigger.snd == 0) { // TODO: this check is a potential pitfall for custom triggers + if (wasBusy && onIdle) { + onIdle(); + } + wasBusy = false; continue; // there is one in progress } + wasBusy = true; trigger.snd ^= (uint64_t{1} << uint64_t{63}); // this is where the last bit of snd is reverted. ProxyHandlerResult result = handler(trigger); @@ -95,6 +104,7 @@ MSCCLPP_API_CPP void Proxy::start(bool blocking) { fifo->pop(); if (result == ProxyHandlerResult::Stop) { + if (onIdle) onIdle(); break; } } @@ -123,4 +133,6 @@ MSCCLPP_API_CPP void Proxy::stop() { MSCCLPP_API_CPP std::shared_ptr Proxy::fifo() { return pimpl_->fifo; } +MSCCLPP_API_CPP void Proxy::setOnIdle(std::function onIdle) { pimpl_->onIdle = std::move(onIdle); } + } // namespace mscclpp diff --git a/src/ext/ep/CMakeLists.txt b/src/ext/ep/CMakeLists.txt index c32132c7..f1394a0e 100644 --- a/src/ext/ep/CMakeLists.txt +++ b/src/ext/ep/CMakeLists.txt @@ -42,6 +42,7 @@ endif() file(GLOB_RECURSE EP_SOURCES CONFIGURE_DEPENDS buffer.cc bindings.cpp + ibgda_setup.cc kernels/*.cu ) @@ -58,6 +59,17 @@ endif() Python_add_library(mscclpp_ep_cpp MODULE ${EP_SOURCES}) target_compile_definitions(mscclpp_ep_cpp PRIVATE TORCH_EXTENSION_NAME=mscclpp_ep_cpp) +# Inherit ibverbs / mlx5dv defs from the core library so the IBGDA host-side +# plumbing (see ibgda_setup.cc, gated by MSCCLPP_EP_HAVE_IBGDA) compiles in. +if(IBVERBS_FOUND) + target_compile_definitions(mscclpp_ep_cpp PRIVATE USE_IBVERBS) + target_link_libraries(mscclpp_ep_cpp PRIVATE ${IBVERBS_LIBRARIES}) + if(MLX5_FOUND) + target_compile_definitions(mscclpp_ep_cpp PRIVATE MSCCLPP_USE_MLX5DV) + target_include_directories(mscclpp_ep_cpp SYSTEM PRIVATE ${MLX5_INCLUDE_DIRS}) + target_link_libraries(mscclpp_ep_cpp PRIVATE ${MLX5_LIBRARIES}) + endif() +endif() target_include_directories(mscclpp_ep_cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ${PROJECT_SOURCE_DIR}/include diff --git a/src/ext/ep/buffer.cc b/src/ext/ep/buffer.cc index 35702f8c..b13c8762 100644 --- a/src/ext/ep/buffer.cc +++ b/src/ext/ep/buffer.cc @@ -15,6 +15,10 @@ #include "kernels/api.cuh" #include "kernels/configs.cuh" +#ifdef MSCCLPP_EP_HAVE_IBGDA +#include "ibgda_setup.hpp" +#endif + namespace mscclpp { namespace ep { @@ -76,6 +80,15 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ printf("[mscclpp_ep] num_proxy_services=%d (set MSCCLPP_EP_NUM_PROXIES to override)\n", num_proxy_services); fflush(stdout); } + // Resolve native-IBGDA opt-in flag. The state is built in sync() and + // currently unused by the kernels (4b.1 plumbing only); see buffer.hpp. + if (const char* e = std::getenv("MSCCLPP_EP_USE_IBGDA")) { + use_ibgda_path_ = (std::atoi(e) != 0); + } + if (rank == 0) { + printf("[mscclpp_ep] MSCCLPP_EP_USE_IBGDA=%d\n", use_ibgda_path_ ? 1 : 0); + fflush(stdout); + } // Task fifo memory int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS; int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS; @@ -518,6 +531,44 @@ void Buffer::sync(const std::vector& device_ids, } } +#ifdef MSCCLPP_EP_HAVE_IBGDA + // ------------------------------------------------------------------ + // Stage 4b.1: native IBGDA host-side plumbing. + // + // Built only when: + // - MSCCLPP_EP_USE_IBGDA=1 + // - The run is cross-node (num_rdma_ranks > 1) — intra-node sticks with + // the NVLink IPC fast path already established above. + // - We have an RDMA buffer (num_rdma_bytes > 0). + // + // The setup is built but NOT yet consumed by any kernel (4b.2 will add + // the kernel branch). With or without the flag set, existing test + // outcomes must be identical. + // ------------------------------------------------------------------ + if (use_ibgda_path_ && num_rdma_bytes > 0 && num_rdma_ranks > 1) { + constexpr int kNumIbgdaChannels = 16; // mirrors num_port_channels_per_rank above + try { + ibgda_setup_ = mscclpp::ep::build_ibgda_setup(rank, num_ranks, /*ib_transport_index=*/device_id, + kNumIbgdaChannels, rdma_buffer_ptr, + static_cast(num_rdma_bytes), bootstrap); + if (rank == 0) { + printf("[mscclpp_ep] IBGDA setup built: channels=%d num_ranks=%d (per-rank QPs=%d)\n", + kNumIbgdaChannels, num_ranks, kNumIbgdaChannels * (num_ranks - 1)); + fflush(stdout); + } + // Clear any benign CUDA sticky error left by overlapping host/UAR + // registrations across sibling QPs (e.g., cudaErrorHostMemoryAlreadyRegistered). + (void)cudaGetLastError(); + } catch (const std::exception& e) { + // Don't take down the run — just log and leave use_ibgda_path_ on as + // a hint (kernel-side fallback in 4b.2 will check ibgda_setup_ != null). + fprintf(stderr, "[mscclpp_ep][rank=%d] IBGDA setup failed, falling back to host FIFO: %s\n", rank, e.what()); + ibgda_setup_.reset(); + (void)cudaGetLastError(); + } + } +#endif // MSCCLPP_EP_HAVE_IBGDA + // Ready to use available = true; } @@ -1452,6 +1503,13 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i auto peer_bases = peer_rdma_bases_gpu; const bool use_ipc = ll_ipc_ready; auto rdma_base = rdma_buffer_ptr; +#ifdef MSCCLPP_EP_HAVE_IBGDA + mscclpp::IbgdaPortChannelDeviceHandle* ibgda_dev = ibgda_setup_ ? ibgda_setup_->device_handles.get() : nullptr; + const bool use_ibgda = use_ibgda_path_ && ibgda_dev != nullptr && !use_ipc; +#else + mscclpp::IbgdaPortChannelDeviceHandle* ibgda_dev = nullptr; + const bool use_ibgda = false; +#endif auto launcher = [=](int phases) { internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), packed_recv_count.data_ptr(), @@ -1459,7 +1517,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i buffer.dispatch_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr(), next_clean_meta.first, next_clean_meta.second, num_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, use_fp8, workspace, - launch_stream, phases, rdma_base, port_handles, peer_bases, mem_handles, use_ipc); + launch_stream, phases, rdma_base, port_handles, peer_bases, mem_handles, use_ipc, + ibgda_dev, use_ibgda); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); @@ -1528,13 +1587,21 @@ std::tuple, std::optionaldevice_handles.get() : nullptr; + const bool use_ibgda = use_ibgda_path_ && ibgda_dev != nullptr && !use_ipc; +#else + mscclpp::IbgdaPortChannelDeviceHandle* ibgda_dev = nullptr; + const bool use_ibgda = false; +#endif auto launcher = [=](int phases) { internode_ll::combine( combined_x.data_ptr(), buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, buffer.combine_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr(), topk_weights.data_ptr(), src_info.data_ptr(), layout_range.data_ptr(), next_clean_meta.first, next_clean_meta.second, num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, - workspace, launch_stream, phases, zero_copy, rdma_base, port_handles, peer_bases, mem_handles, use_ipc); + workspace, launch_stream, phases, zero_copy, rdma_base, port_handles, peer_bases, mem_handles, use_ipc, + ibgda_dev, use_ibgda); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); diff --git a/src/ext/ep/buffer.hpp b/src/ext/ep/buffer.hpp index 7c2a0540..704cc16d 100644 --- a/src/ext/ep/buffer.hpp +++ b/src/ext/ep/buffer.hpp @@ -17,6 +17,11 @@ #include "kernels/configs.cuh" #include "kernels/exception.cuh" +#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM) +#define MSCCLPP_EP_HAVE_IBGDA 1 +namespace mscclpp { namespace ep { struct IbgdaSetup; } } +#endif + #ifndef TORCH_EXTENSION_NAME #define TORCH_EXTENSION_NAME mscclpp_ep_cpp #endif @@ -101,6 +106,17 @@ struct Buffer { std::shared_ptr ll_memory_channel_handles_device_ptr; bool ll_ipc_ready = false; + // ------------------------------------------------------------------ + // Native IBGDA path (Stage 4b). Built lazily in `sync()` when env + // `MSCCLPP_EP_USE_IBGDA=1` is set AND the run is cross-node. + // The kernels do NOT consume `ibgda_setup_` until 4b.2 lands; for now + // it is constructed-but-unused, so existing tests are unaffected. + // ------------------------------------------------------------------ + bool use_ibgda_path_ = false; +#ifdef MSCCLPP_EP_HAVE_IBGDA + std::unique_ptr ibgda_setup_; +#endif + private: void move_fifo_slots(int num_slots = 1); diff --git a/src/ext/ep/ibgda_setup.cc b/src/ext/ep/ibgda_setup.cc new file mode 100644 index 00000000..f8ea2ca5 --- /dev/null +++ b/src/ext/ep/ibgda_setup.cc @@ -0,0 +1,341 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// +// Stage 4b.1 implementation. See ibgda_setup.hpp for the contract. + +#include "ibgda_setup.hpp" + +#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM) + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "kernels/exception.cuh" + +namespace mscclpp { +namespace ep { + +IbgdaSetup::~IbgdaSetup() { + // Stop the CQ poller first so it doesn't race with QP teardown. + if (cq_poller_thread.joinable()) { + cq_poller_stop.store(true, std::memory_order_release); + cq_poller_thread.join(); + } + // Tear down in reverse order of construction: + // - device_handles, d_local_mrs, d_remote_mrs: shared_ptr from + // gpuCallocShared, auto-freed via custom deleter. + // - resources / qps / rdma_mr: smart ptrs, auto-freed. + // - sig_mr (raw ibv_mr) + sig_slots / sig_seq (raw cudaMalloc): explicit. + if (sig_mr != nullptr) { + ibv_dereg_mr(sig_mr); + sig_mr = nullptr; + } + if (sig_slots != nullptr) { + cudaFree(sig_slots); + sig_slots = nullptr; + } + if (sig_seq != nullptr) { + cudaFree(sig_seq); + sig_seq = nullptr; + } +} + +namespace { + +// Tunables that mirror the existing PortChannel path: +// - 16 channels (== num_port_channels_per_rank in Buffer::sync) +// - port=1, gid_index=3 (matches the value used in Stage 0..4a probes and +// the typical mscclpp/NDv5 default) +constexpr int kIbgdaPort = 1; +constexpr int kIbgdaGid = 3; +// SQ depth per QP. Each LL dispatch+combine iteration posts up to a few +// dozen WRs per QP at default LL configs; the bench loop runs 50 iters +// without explicit per-iter drain, so we size the SQ at 8192 to avoid +// wrap-around overwriting in-flight WQEs (we have no device-side +// back-pressure check in reserve_wqe_slots — the CQ poller drains +// asynchronously). 8k entries × 64B stride = 512 KiB SQ buf per QP, well +// under typical mlx5 HCA per-QP caps. +constexpr int kIbgdaMaxSendWr = 8192; + +} // namespace + +std::unique_ptr build_ibgda_setup(int rank, int num_ranks, int ib_transport_index, int num_channels, + void* rdma_buffer_ptr, std::size_t num_rdma_bytes, + std::shared_ptr bootstrap) { + EP_HOST_ASSERT(rdma_buffer_ptr != nullptr); + EP_HOST_ASSERT(num_rdma_bytes > 0); + EP_HOST_ASSERT(num_ranks > 1); + EP_HOST_ASSERT(num_channels > 0); + EP_HOST_ASSERT(ib_transport_index >= 0 && ib_transport_index < 8); + + auto setup = std::make_unique(); + setup->rank = rank; + setup->num_ranks = num_ranks; + setup->num_channels = num_channels; + + // 1. Resolve IB device name and build the IbCtx. + auto ib_transport = static_cast(static_cast(Transport::IB0) + ib_transport_index); + std::string dev_name = getIBDeviceName(ib_transport); + setup->ib_ctx = std::make_unique(dev_name); + EP_HOST_ASSERT(setup->ib_ctx->isMlx5() && "IBGDA requires an mlx5 NIC"); + + // 2. Create QPs. Layout: qps[channel * num_ranks + peer]. + // Self entries are nullptr. + const int total_slots = num_channels * num_ranks; + setup->qps.resize(total_slots); + setup->resources.resize(total_slots); + + for (int c = 0; c < num_channels; ++c) { + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + auto qp = setup->ib_ctx->createQp(/*port=*/kIbgdaPort, /*gidIndex=*/kIbgdaGid, + /*maxSendCqSize=*/kIbgdaMaxSendWr * 2, + /*maxSendCqPollNum=*/64, + /*maxSendWr=*/kIbgdaMaxSendWr, + /*maxRecvWr=*/1, + /*maxWrPerSend=*/1, + /*noAtomic=*/true); + setup->qps[c * num_ranks + r] = qp; + } + } + + // 3. AllGather IbQpInfos so every rank can RTR every QP. + // Layout per-rank: total_slots IbQpInfo records, in [c * num_ranks + peer] + // order. Self entries (peer == rank) are zeroed; remote entries describe + // the QP that THIS rank uses to TALK TO peer == r. + std::vector my_infos(total_slots); + std::memset(my_infos.data(), 0, total_slots * sizeof(IbQpInfo)); + for (int c = 0; c < num_channels; ++c) { + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + my_infos[c * num_ranks + r] = setup->qps[c * num_ranks + r]->getInfo(); + } + } + // bootstrap->allGather expects each rank to fill its own slot in a + // contiguous buffer of size num_ranks * record_bytes. + const std::size_t record_bytes = total_slots * sizeof(IbQpInfo); + std::vector all_infos(num_ranks * total_slots); + std::memcpy(&all_infos[rank * total_slots], my_infos.data(), record_bytes); + bootstrap->allGather(all_infos.data(), record_bytes); + + // 4. RTR/RTS each local QP using the matching peer-side QP info. The peer + // info we want is "rank r's QP for talking to us" == r's record indexed at + // [c * num_ranks + rank]. + for (int c = 0; c < num_channels; ++c) { + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + const IbQpInfo& peer_info = all_infos[r * total_slots + c * num_ranks + rank]; + auto& qp = setup->qps[c * num_ranks + r]; + qp->rtr(peer_info); + qp->rts(); + } + } + + // 5. Wrap each QP with IbgdaResources (Stage 1) — produces GPU-mapped + // sq_buf / dbrec / bf_reg / state pointers usable from the kernel. + for (int c = 0; c < num_channels; ++c) { + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + ibv_qp* raw = setup->qps[c * num_ranks + r]->getRawQp(); + setup->resources[c * num_ranks + r] = std::make_unique(raw); + } + } + + // 6. Register the existing rdma_buffer_ptr as an MR on this IbCtx, then + // allgather (addr, rkey) so we can build per-peer remote_mrs entries. + setup->rdma_mr = setup->ib_ctx->registerMr(rdma_buffer_ptr, num_rdma_bytes); + IbgdaSetup::PeerMr my_rdma_mr{}; + { + auto info = setup->rdma_mr->getInfo(); + my_rdma_mr.addr = info.addr; + my_rdma_mr.rkey = info.rkey; + } + setup->peer_rdma.assign(num_ranks, IbgdaSetup::PeerMr{}); + setup->peer_rdma[rank] = my_rdma_mr; + bootstrap->allGather(setup->peer_rdma.data(), sizeof(IbgdaSetup::PeerMr)); + + // 7. Allocate signal slots: total_slots * 4 bytes on GPU. Layout: + // sig_slots[c * num_ranks + sender_peer] — when peer P sends signal() to + // us through channel c, it RDMA-WRITEs to *our* sig_slots[c * num_ranks + P]. + // Each peer therefore needs to know our sig MR (addr+rkey) AND the offset + // within it that corresponds to (channel, sender=P) == "we are receiving + // from P". We allgather the base addr+rkey only; the offset is derivable. + const std::size_t sig_bytes = std::size_t(total_slots) * sizeof(uint32_t); + CUDA_CHECK(cudaMalloc(&setup->sig_slots, sig_bytes)); + CUDA_CHECK(cudaMemset(setup->sig_slots, 0, sig_bytes)); + CUDA_CHECK(cudaMalloc(&setup->sig_seq, sig_bytes)); + CUDA_CHECK(cudaMemset(setup->sig_seq, 0, sig_bytes)); + + // Use raw verbs for the signal MR (we need the rkey/lkey directly). + ibv_pd* pd = setup->qps[(rank == 0 ? 1 : 0)]->getRawQp()->pd; // any non-null QP shares the same pd + for (int c = 0; c < num_channels && pd == nullptr; ++c) + for (int r = 0; r < num_ranks && pd == nullptr; ++r) + if (auto& q = setup->qps[c * num_ranks + r]; q) pd = q->getRawQp()->pd; + EP_HOST_ASSERT(pd != nullptr); + setup->sig_mr = ibv_reg_mr(pd, setup->sig_slots, sig_bytes, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + if (!setup->sig_mr) { + throw std::runtime_error("ibv_reg_mr(sig_slots) failed errno=" + std::to_string(errno)); + } + + IbgdaSetup::PeerMr my_sig_mr{}; + my_sig_mr.addr = reinterpret_cast(setup->sig_slots); + my_sig_mr.rkey = setup->sig_mr->rkey; + setup->peer_sig.assign(num_ranks, IbgdaSetup::PeerMr{}); + setup->peer_sig[rank] = my_sig_mr; + bootstrap->allGather(setup->peer_sig.data(), sizeof(IbgdaSetup::PeerMr)); + + // 8. Build the GPU-resident MR tables. We have a single local MR (the + // rdma_buffer_ptr); remote_mrs has one entry per peer rank. + std::vector h_local(1); + h_local[0].addr = reinterpret_cast(rdma_buffer_ptr); + h_local[0].lkey_be = htonl(setup->rdma_mr->getLkey()); + h_local[0].pad = 0; + + std::vector h_remote(num_ranks); + for (int r = 0; r < num_ranks; ++r) { + h_remote[r].addr = setup->peer_rdma[r].addr; + h_remote[r].rkey_be = htonl(setup->peer_rdma[r].rkey); + h_remote[r].pad = 0; + } + + setup->d_local_mrs = mscclpp::detail::gpuCallocShared(1); + setup->d_remote_mrs = mscclpp::detail::gpuCallocShared(num_ranks); + mscclpp::gpuMemcpy(setup->d_local_mrs.get(), h_local.data(), 1, cudaMemcpyHostToDevice); + mscclpp::gpuMemcpy(setup->d_remote_mrs.get(), h_remote.data(), num_ranks, cudaMemcpyHostToDevice); + + // 9. Build the device handle array (channel × num_ranks). + std::vector h_handles(total_slots); + std::memset(h_handles.data(), 0, h_handles.size() * sizeof(IbgdaPortChannelDeviceHandle)); + for (int c = 0; c < num_channels; ++c) { + for (int r = 0; r < num_ranks; ++r) { + auto& h = h_handles[c * num_ranks + r]; + if (r == rank) continue; // self-slot left zeroed + h.qp = setup->resources[c * num_ranks + r]->getHandle(); + h.local_mrs = setup->d_local_mrs.get(); + h.remote_mrs = setup->d_remote_mrs.get(); + + // Local sig slot WE poll for messages from peer r through channel c. + h.sig_local_addr = &setup->sig_slots[c * num_ranks + r]; + h.sig_local_lkey = htonl(setup->sig_mr->lkey); + + // Remote sig slot peer r polls for messages FROM US through channel c. + // Peer r's slot for "rank == us" is at offset (c * num_ranks + rank) * 4 + // inside their signal buffer. + h.sig_remote_addr = setup->peer_sig[r].addr + + static_cast(c * num_ranks + rank) * sizeof(uint32_t); + h.sig_rkey_be = htonl(setup->peer_sig[r].rkey); + + // Outbound seq counter is per-handle; carve out one u32 from sig_seq + // (sig_seq is num_channels × num_ranks u32s so we can reuse the same + // index function — it is unrelated to the inbound sig_slots buffer + // beyond reusing the size). + h.sig_seq = &setup->sig_seq[c * num_ranks + r]; + + h.dst = static_cast(r); // index into remote_mrs[] (peer-rank-based) + h.src = 0; // single-entry local_mrs table + h.peer_rank = static_cast(r); + h._pad = 0; + } + } + + setup->device_handles = mscclpp::detail::gpuCallocShared(total_slots); + mscclpp::gpuMemcpy(setup->device_handles.get(), h_handles.data(), total_slots, + cudaMemcpyHostToDevice); + + // 10. Spawn CQ drain thread. Kernel-issued rdma_write paths set CQ_UPDATE + // on every WR; without this drain, the send CQ (sized 2 × kIbgdaMaxSendWr) + // would fill within a few iterations of LL dispatch+combine and the QP + // would error out. We collect raw send_cq pointers from each QP up front. + { + std::vector send_cqs; + send_cqs.reserve(total_slots); + for (int idx = 0; idx < total_slots; ++idx) { + auto& qp = setup->qps[idx]; + if (!qp) continue; + ibv_cq* cq = qp->getRawQp()->send_cq; + if (cq) send_cqs.push_back(cq); + } + IbgdaSetup* raw = setup.get(); + int dbg_rank = rank; + setup->cq_poller_thread = std::thread([raw, send_cqs, dbg_rank]() { + // Tight loop polling all CQs round-robin. Each ibv_poll_cq is cheap + // (one PCIe-mapped read of CQE buffer + valid bit). We don't inspect + // wc fields beyond their existence — a status != IBV_WC_SUCCESS would + // indicate a fatal QP error which we surface lazily through later + // failures (the test path will hang and the user sees the error). + constexpr int kBatch = 16; + ibv_wc wc[kBatch]; + uint64_t total_polled = 0; + uint64_t total_errors = 0; + auto t_last_log = std::chrono::steady_clock::now(); + while (!raw->cq_poller_stop.load(std::memory_order_acquire)) { + bool any = false; + for (ibv_cq* cq : send_cqs) { + int n = ibv_poll_cq(cq, kBatch, wc); + if (n > 0) { + any = true; + total_polled += n; + for (int i = 0; i < n; ++i) { + if (wc[i].status != IBV_WC_SUCCESS) { + total_errors++; + fprintf(stderr, + "[mscclpp_ep][ibgda][rank=%d] CQE error: status=%d (%s) opcode=%d vendor_err=0x%x wr_id=%llu\n", + dbg_rank, wc[i].status, ibv_wc_status_str(wc[i].status), wc[i].opcode, wc[i].vendor_err, + static_cast(wc[i].wr_id)); + fflush(stderr); + } + } + } + for (int rep = 0; rep < 3 && n == kBatch; ++rep) { + n = ibv_poll_cq(cq, kBatch, wc); + if (n > 0) { + any = true; + total_polled += n; + for (int i = 0; i < n; ++i) { + if (wc[i].status != IBV_WC_SUCCESS) { + total_errors++; + fprintf(stderr, + "[mscclpp_ep][ibgda][rank=%d] CQE error: status=%d (%s) opcode=%d vendor_err=0x%x\n", + dbg_rank, wc[i].status, ibv_wc_status_str(wc[i].status), wc[i].opcode, wc[i].vendor_err); + fflush(stderr); + } + } + } + } + } + auto now = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast(now - t_last_log).count() >= 5) { + fprintf(stderr, "[mscclpp_ep][ibgda][rank=%d] poller: polled=%llu errors=%llu\n", dbg_rank, + static_cast(total_polled), static_cast(total_errors)); + fflush(stderr); + t_last_log = now; + } + if (!any) { + std::this_thread::sleep_for(std::chrono::microseconds(20)); + } + } + }); + } + + return setup; +} + +} // namespace ep +} // namespace mscclpp + +#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM diff --git a/src/ext/ep/ibgda_setup.hpp b/src/ext/ep/ibgda_setup.hpp new file mode 100644 index 00000000..faad75bc --- /dev/null +++ b/src/ext/ep/ibgda_setup.hpp @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// +// Stage 4b.1: native IBGDA host-side plumbing for src/ext/ep. +// +// Owned by `Buffer` and built lazily in `Buffer::sync()` when the env var +// `MSCCLPP_EP_USE_IBGDA=1` is set AND the run is cross-node +// (num_rdma_ranks > 1). Builds a parallel "shadow" of the existing +// PortChannel layout: a flat (channel × num_ranks) array of +// `IbgdaPortChannelDeviceHandle` POD structs sitting on the GPU, plus the +// host-side resources (IbCtx, QPs, IbgdaResources, MRs, signal slots) that +// keep them valid for the lifetime of the Buffer. +// +// 4b.1 only constructs the state — the kernels do NOT consume it yet. That +// happens in 4b.2 behind a `kIbgdaPath` template branch. + +#pragma once + +#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM) + +#include +#include +#include +#include +#include + +#include +#include + +#include "../../core/include/ib.hpp" +#include "../../core/include/ibgda.hpp" + +struct ibv_mr; + +namespace mscclpp { +namespace ep { + +// One owning bundle of host-side resources backing the device handle array. +struct IbgdaSetup { + IbgdaSetup() = default; + ~IbgdaSetup(); + IbgdaSetup(const IbgdaSetup&) = delete; + IbgdaSetup& operator=(const IbgdaSetup&) = delete; + // Layout constants (must match the existing port_channel_handles layout in + // Buffer::sync, so kernels can reuse the same channel × num_ranks index + // arithmetic in the IBGDA branch). + int num_channels = 0; // == num_port_channels_per_rank in Buffer::sync + int num_ranks = 0; + int rank = 0; + + // IB context + QPs + Stage-1 GPU mappings. + std::unique_ptr ib_ctx; + // Indexed [channel * num_ranks + peer]; entries with peer == rank are null. + std::vector> qps; + std::vector> resources; + + // RDMA buffer MR. We register the *same* `rdma_buffer_ptr` that the + // existing PortChannel path uses, so dst/src offsets in the kernel work + // unchanged. + std::unique_ptr rdma_mr; + + // Signal slots (GPU-resident). + // Layout: 4 bytes per (channel, peer) on the receiving side. `sig_slots` + // is the local mirror polled by `port_wait()`; remote peers RDMA-WRITE + // into the corresponding slot of *their* signal MR for *us*. Each rank's + // signal slot for (channel, peer=A) is at offset + // offset_in_buf(channel, peer) = (channel * num_ranks + peer) * 4 + // within rank A's signal buffer (i.e. peer=A's view of "messages from + // rank=A's POV indexed by channel × peer"). See the per-handle wiring + // in `IbgdaSetup::populate_handles`. + uint32_t* sig_slots = nullptr; // GPU mem; size = num_channels * num_ranks * 4 + ibv_mr* sig_mr = nullptr; // raw verbs (we keep the IbMr only for rdma_mr) + uint32_t* sig_seq = nullptr; // GPU mem; per-handle outbound counters + + // Per-peer (addr, rkey) for the RDMA buffer and the signal buffer. Filled + // by an allgather over the bootstrap. + struct PeerMr { + uint64_t addr = 0; + uint32_t rkey = 0; + uint32_t pad = 0; + }; + std::vector peer_rdma; // size = num_ranks + std::vector peer_sig; // size = num_ranks + + // Flat device-side handle array: num_channels * num_ranks entries. + // Self entries (peer == rank) are zeroed and unused. + std::shared_ptr device_handles; + + // Underlying GPU-side MR-table arrays referenced by every device handle + // (see IbgdaPortChannelDeviceHandle::local_mrs / remote_mrs). + std::shared_ptr d_local_mrs; // length 1 (we have a single MR) + std::shared_ptr d_remote_mrs; // length num_ranks + + // CQ drain thread. The kernel-side rdma_write paths set CQ_UPDATE on + // every WR, so without periodic ibv_poll_cq() the send CQ would fill up + // (CQ size = kIbgdaMaxSendWr * 2). One thread polls all per-QP send CQs + // round-robin until `cq_poller_stop` is set in the destructor. + std::atomic cq_poller_stop{false}; + std::thread cq_poller_thread; +}; + +// Build the full IBGDA setup. Bootstrap is used for cross-rank exchange of +// QP info and MR keys; rdma_buffer_ptr/num_rdma_bytes is the same buffer the +// existing PortChannel path uses. `ib_transport_index` is the IB device +// index this rank will use (== `device_id` on NDv5). +// +// Throws on any irrecoverable error. On success returns a fully-initialised +// IbgdaSetup with all QPs in RTS and the device-side handle array populated. +std::unique_ptr build_ibgda_setup(int rank, int num_ranks, int ib_transport_index, int num_channels, + void* rdma_buffer_ptr, std::size_t num_rdma_bytes, + std::shared_ptr bootstrap); + +} // namespace ep +} // namespace mscclpp + +#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM diff --git a/src/ext/ep/kernels/api.cuh b/src/ext/ep/kernels/api.cuh index e83a480d..2792c8c0 100644 --- a/src/ext/ep/kernels/api.cuh +++ b/src/ext/ep/kernels/api.cuh @@ -17,6 +17,13 @@ #include #include +#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM) +#include +#define MSCCLPP_EP_KERNEL_HAS_IBGDA 1 +#else +namespace mscclpp { struct IbgdaPortChannelDeviceHandle; } +#endif + namespace mscclpp { namespace ep { @@ -131,7 +138,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, void* workspace, cudaStream_t stream, int phases, void* rdma_buffer_ptr, mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path); + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path, + mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles, bool use_ibgda_path); void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, const void* x, const int64_t* topk_idx, const float* topk_weights, const int* src_info, const int64_t* layout_range, @@ -139,7 +147,8 @@ void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, void* workspace, cudaStream_t stream, int phases, bool zero_copy, void* rdma_buffer_ptr, mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path); + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path, + mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles, bool use_ibgda_path); } // namespace internode_ll diff --git a/src/ext/ep/kernels/internode_ll.cu b/src/ext/ep/kernels/internode_ll.cu index 5c451577..49980ed0 100644 --- a/src/ext/ep/kernels/internode_ll.cu +++ b/src/ext/ep/kernels/internode_ll.cu @@ -31,6 +31,11 @@ #include #include +#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM) +#include "ibgda_port_channel_device.cuh" +#define MSCCLPP_EP_LL_HAS_IBGDA 1 +#endif + #include "configs.cuh" #include "exception.cuh" #include "launch.cuh" @@ -142,14 +147,15 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, int64_t* cl // dispatch // --------------------------------------------------------------------------- -template +template __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dispatch( void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x, const void* x, const int64_t* topk_idx, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, int64_t* next_clean, int num_next_clean_int, int num_tokens, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, int phases, void* rdma_buffer_ptr, mscclpp::PortChannelDeviceHandle* port_channel_handles, - void* const* peer_rdma_bases, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { + void* const* peer_rdma_bases, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, + mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles) { const auto sm_id = static_cast(blockIdx.x); const auto thread_id = static_cast(threadIdx.x); const auto warp_id = thread_id / 32, lane_id = get_lane_id(); @@ -249,14 +255,31 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis const auto* dst_int4_ptr = reinterpret_cast(peer_dst); UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); } else { - // MSCCL++ port-channel PUT (lane 0 issues one request). - if (lane_id == 0) { - const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); - const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr); - port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank].put(dst_off, src_off, - num_bytes_per_msg); +#ifdef MSCCLPP_EP_LL_HAS_IBGDA + if constexpr (kIbgdaPath) { + // Native IBGDA: lane 0 issues one RDMA WRITE WQE directly. + // UNSIGNALED + no DB ring — the trailing count write below + // (signaled, rings DB) flushes the per-QP queue. + if (lane_id == 0) { + const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); + const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr); + mscclpp::ibgda::port_put(ibgda_handles[dst_expert_local_idx * num_ranks + dst_rank], dst_off, src_off, + num_bytes_per_msg, + /*signal_cqe=*/false, /*ring_db=*/false); + } + __syncwarp(); + } else +#endif + { + // MSCCL++ port-channel PUT (lane 0 issues one request). + if (lane_id == 0) { + const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); + const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr); + port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank].put(dst_off, src_off, + num_bytes_per_msg); + } + __syncwarp(); } - __syncwarp(); } } else { const auto* src_int4_ptr = reinterpret_cast(src_ptr); @@ -322,10 +345,26 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis peer_rdma_bases, rdma_buffer_ptr, dst_rank)); st_na_release(peer_counter, static_cast(-num_tokens_sent - 1)); } else { - auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank; - const auto off = rdma_offset_of(reinterpret_cast(counter_ptr), rdma_buffer_ptr); - port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank].atomicAdd( - off, static_cast(-num_tokens_sent - 1)); +#ifdef MSCCLPP_EP_LL_HAS_IBGDA + if constexpr (kIbgdaPath) { + // Single writer per (dst_expert_local_idx, rank) slot, so an + // 8-byte inline RDMA WRITE delivering the encoded count is + // semantically equivalent to atomicAdd from a zero-initialised + // remote slot. The receiver polls with ld_acquire_sys_global. + auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank; + const auto off = rdma_offset_of(reinterpret_cast(counter_ptr), rdma_buffer_ptr); + const auto& ch = ibgda_handles[dst_expert_local_idx * num_ranks + dst_rank]; + mscclpp::IbgdaRemoteMr r = ch.remote_mrs[ch.dst]; + mscclpp::ibgda::rdma_write_inl8(ch.qp, static_cast(static_cast(-num_tokens_sent - 1)), + r.addr + off, r.rkey_be); + } else +#endif + { + auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank; + const auto off = rdma_offset_of(reinterpret_cast(counter_ptr), rdma_buffer_ptr); + port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank].atomicAdd( + off, static_cast(-num_tokens_sent - 1)); + } } } else { st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, @@ -405,7 +444,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, void* workspace, cudaStream_t stream, int phases, void* rdma_buffer_ptr, mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path) { + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path, + mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles, bool use_ibgda_path) { constexpr int kNumMaxTopK = 9; // (kNumWarpGroups, kNumWarpsPerGroup) is path-dependent. Intra-node IPC // benefits from 1 expert per SM with 32 warps cooperating on the recv-side @@ -446,26 +486,37 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); -#define DISPATCH_LAUNCH_CASE(hidden_case) \ - { \ - if (use_ipc_path) { \ - auto dispatch_func = use_fp8 ? dispatch \ - : dispatch; \ - LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \ - packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \ - atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \ - num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \ - rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles); \ - } else { \ - auto dispatch_func = use_fp8 ? dispatch \ - : dispatch; \ - LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \ - packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \ - atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \ - num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \ - rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles); \ - } \ - } \ +#define DISPATCH_LAUNCH_CASE(hidden_case) \ + { \ + if (use_ipc_path) { \ + auto dispatch_func = \ + use_fp8 ? dispatch \ + : dispatch; \ + LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \ + packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \ + atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \ + num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \ + rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles, ibgda_handles); \ + } else if (use_ibgda_path) { \ + auto dispatch_func = \ + use_fp8 ? dispatch \ + : dispatch; \ + LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \ + packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \ + atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \ + num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \ + rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles, ibgda_handles); \ + } else { \ + auto dispatch_func = \ + use_fp8 ? dispatch \ + : dispatch; \ + LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \ + packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \ + atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \ + num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \ + rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles, ibgda_handles); \ + } \ + } \ break SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream); @@ -477,14 +528,15 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv // combine // --------------------------------------------------------------------------- -template +template __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void combine( void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, const void* x, const int64_t* topk_idx, const float* topk_weights, const int* src_info, const int64_t* layout_range, int64_t* next_clean, int num_next_clean_int, int* atomic_clean_flag, int num_combined_tokens, int hidden, int num_topk, int num_max_dispatch_tokens_per_rank, int num_experts, int rank, int num_ranks, int phases, bool zero_copy, void* rdma_buffer_ptr, mscclpp::PortChannelDeviceHandle* port_channel_handles, - void* const* peer_rdma_bases, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { + void* const* peer_rdma_bases, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, + mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles) { const auto sm_id = static_cast(blockIdx.x); const auto num_sms = static_cast(gridDim.x); const auto thread_id = static_cast(threadIdx.x); @@ -546,17 +598,35 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com const auto peer_dst_int4 = reinterpret_cast(peer_dst); UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, peer_dst_int4, x_int4, ld_nc_global, st_na_global); } else { - const auto buf_int4_ptr = reinterpret_cast(buf_ptr); - if (not zero_copy) - UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); - // MSCCL++ port-channel PUT. - if (lane_id == 0) { - const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); - const auto src_off = rdma_offset_of(static_cast(buf_ptr), rdma_buffer_ptr); - port_channel_handles[local_expert_idx * num_ranks + dst_rank].put(dst_off, src_off, - hidden * sizeof(nv_bfloat16)); +#ifdef MSCCLPP_EP_LL_HAS_IBGDA + if constexpr (kIbgdaPath) { + const auto buf_int4_ptr = reinterpret_cast(buf_ptr); + if (not zero_copy) + UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); + if (lane_id == 0) { + const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); + const auto src_off = rdma_offset_of(static_cast(buf_ptr), rdma_buffer_ptr); + // UNSIGNALED + no DB ring — trailing flag write drives the doorbell. + mscclpp::ibgda::port_put(ibgda_handles[local_expert_idx * num_ranks + dst_rank], dst_off, src_off, + hidden * sizeof(nv_bfloat16), + /*signal_cqe=*/false, /*ring_db=*/false); + } + __syncwarp(); + } else +#endif + { + const auto buf_int4_ptr = reinterpret_cast(buf_ptr); + if (not zero_copy) + UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); + // MSCCL++ port-channel PUT. + if (lane_id == 0) { + const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); + const auto src_off = rdma_offset_of(static_cast(buf_ptr), rdma_buffer_ptr); + port_channel_handles[local_expert_idx * num_ranks + dst_rank].put(dst_off, src_off, + hidden * sizeof(nv_bfloat16)); + } + __syncwarp(); } - __syncwarp(); } } } @@ -573,9 +643,20 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com peer_rdma_bases, rdma_buffer_ptr, dst_rank)); st_na_release(peer_flag, static_cast(1)); } else { - auto* flag_ptr = rdma_recv_flag + global_expert_idx; - const auto off = rdma_offset_of(reinterpret_cast(flag_ptr), rdma_buffer_ptr); - port_channel_handles[local_expert_idx * num_ranks + dst_rank].atomicAdd(off, static_cast(1)); +#ifdef MSCCLPP_EP_LL_HAS_IBGDA + if constexpr (kIbgdaPath) { + auto* flag_ptr = rdma_recv_flag + global_expert_idx; + const auto off = rdma_offset_of(reinterpret_cast(flag_ptr), rdma_buffer_ptr); + const auto& ch = ibgda_handles[local_expert_idx * num_ranks + dst_rank]; + mscclpp::IbgdaRemoteMr r = ch.remote_mrs[ch.dst]; + mscclpp::ibgda::rdma_write_inl8(ch.qp, static_cast(1), r.addr + off, r.rkey_be); + } else +#endif + { + auto* flag_ptr = rdma_recv_flag + global_expert_idx; + const auto off = rdma_offset_of(reinterpret_cast(flag_ptr), rdma_buffer_ptr); + port_channel_handles[local_expert_idx * num_ranks + dst_rank].atomicAdd(off, static_cast(1)); + } } } else { st_na_release(rdma_recv_flag + global_expert_idx, static_cast(1)); @@ -639,7 +720,8 @@ void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, void* workspace, cudaStream_t stream, int phases, bool zero_copy, void* rdma_buffer_ptr, mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path) { + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path, + mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles, bool use_ibgda_path) { // See the comment in `dispatch()`: (kNumWarpGroups, kNumWarpsPerGroup) // is path-dependent. IPC uses (1, 32) to mirror NCCL-EP; PortChannel keeps // (3, 10) to avoid host-proxy FIFO contention on the IB path. @@ -671,24 +753,34 @@ void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES); EP_HOST_ASSERT(num_topk <= kNumMaxTopk); -#define COMBINE_LAUNCH_CASE(hidden_case) \ - { \ - if (use_ipc_path) { \ - auto combine_func = combine; \ - LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \ - topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \ - num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \ - num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \ - memory_channel_handles); \ - } else { \ - auto combine_func = combine; \ - LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \ - topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \ - num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \ - num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \ - memory_channel_handles); \ - } \ - } \ +#define COMBINE_LAUNCH_CASE(hidden_case) \ + { \ + if (use_ipc_path) { \ + auto combine_func = \ + combine; \ + LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \ + topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \ + num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \ + num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \ + memory_channel_handles, ibgda_handles); \ + } else if (use_ibgda_path) { \ + auto combine_func = \ + combine; \ + LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \ + topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \ + num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \ + num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \ + memory_channel_handles, ibgda_handles); \ + } else { \ + auto combine_func = \ + combine; \ + LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \ + topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \ + num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \ + num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \ + memory_channel_handles, ibgda_handles); \ + } \ + } \ break SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream); diff --git a/test/python/ext/ep/test_low_latency_multirank.py b/test/python/ext/ep/test_low_latency_multirank.py index 18e64a4e..f392553a 100644 --- a/test/python/ext/ep/test_low_latency_multirank.py +++ b/test/python/ext/ep/test_low_latency_multirank.py @@ -43,14 +43,45 @@ import sys # noisy 'recvValue failed / Connection was likely closed' stack traces. os.environ.setdefault("TORCH_NCCL_ENABLE_MONITORING", "0") +import ctypes + +import psutil import torch import torch.distributed as dist +# Load libnuma for NUMA-aware memory binding (mirrors DeepEP/tests/utils.py). +try: + _libnuma = ctypes.CDLL("libnuma.so") + _libnuma.numa_available.restype = ctypes.c_int + _libnuma.numa_run_on_node.argtypes = [ctypes.c_int] + _libnuma.numa_set_preferred.argtypes = [ctypes.c_int] +except OSError: + _libnuma = None + + +def set_numa_affinity(local_rank: int): + cores_per_rank = 12 + numa_node = local_rank // 4 + core_start = local_rank * cores_per_rank + core_end = core_start + cores_per_rank + p = psutil.Process(os.getpid()) + p.cpu_affinity(list(range(core_start, core_end))) + print(f"Rank {local_rank} numa node {numa_node} bound to cores {core_start}-{core_end - 1}") + + # Bind memory to NUMA node + if _libnuma is not None and _libnuma.numa_available() != -1: + _libnuma.numa_set_preferred(numa_node) + print(f"Rank {local_rank}: CPU affinity → cores {core_start}-{core_end - 1}, memory NUMA → node {numa_node}") + else: + print(f"Rank {local_rank}: libnuma not available") + + def init_dist(): rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ.get("LOCAL_RANK", rank)) + set_numa_affinity(local_rank) torch.cuda.set_device(local_rank) dist.init_process_group( backend="nccl",