mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
ext/ep: GPU-initiated IBGDA path for low-latency dispatch/combine
Add a GPU-initiated RDMA WRITE path for the LL dispatch/combine kernels
based on mlx5dv direct verbs, alongside the existing IPC and host-FIFO
PortChannel paths. Selected at runtime via MSCCLPP_EP_USE_IBGDA when
num_rdma_ranks > 1.
Core (src/core, include/mscclpp):
- New ibgda module (ibgda.{hpp,cc}, ibgda_device.cuh): per-peer mlx5
QP/MR/CQ setup, device-side WQE writers (write_rdma_wqe,
write_rdma_write_inl_wqe for 4B/8B), submit_requests / submit_no_db
ring helpers, and a poller thread for send CQs.
- ibgda_port_channel_device.{hpp,cuh}: thin port_put() wrapper over
rdma_write with signal_cqe / ring_db flags so callers can issue
UNSIGNALED batched WRs and ring the doorbell once at the tail.
- mlx5dv_wrapper: expose extra symbols needed for direct WQE
construction; minor connection.cc / proxy.cc / port_channel.cc
plumbing to surface QP / MR handles and rkeys to the EP layer.
EP layer (src/ext/ep):
- ibgda_setup.{hpp,cc}: build per-(local_expert, peer_rank) GpuQp
handles, exchange remote MR addr/rkey via the bootstrap, own the
CQ poller. h.dst is set to the per-peer remote_mrs index.
- buffer.{hpp,cc}: gate IBGDA path with use_ibgda_path_ &&
ibgda_setup_ != nullptr && !use_ipc; pass device_handles to the
kernel launchers.
- kernels/internode_ll.cu: 3-way DISPATCH_LAUNCH_CASE /
COMBINE_LAUNCH_CASE (IPC / IBGDA / port-FIFO), templated on
kIbgdaPath. Data PUTs are issued UNSIGNALED with ring_db=false;
the trailing per-QP count write (dispatch) and flag write
(combine) keep the defaults so each QP gets a single signaled
WR that advances prod_idx past all queued data WRs and rings
the doorbell once.
Test (test/python/ext/ep): extend test_low_latency_multirank.py with
env-driven config knobs (MSCCLPP_EP_LL_TOKENS / _HIDDEN / _TOPK /
_EXPERTS_PER_RANK) for sweeping the new path.
This commit is contained in:
101
include/mscclpp/ibgda_port_channel_device.hpp
Normal file
101
include/mscclpp/ibgda_port_channel_device.hpp
Normal file
@@ -0,0 +1,101 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
//
|
||||
// Stage 3: device-side IBGDA "port channel" handle.
|
||||
//
|
||||
// Mirrors the device-facing API of mscclpp::PortChannelDeviceHandle (see
|
||||
// include/mscclpp/port_channel_device.hpp) but issues RDMA traffic directly
|
||||
// from the GPU via mscclpp::ibgda::rdma_write — no host proxy / FIFO.
|
||||
//
|
||||
// put(dstOffset, srcOffset, size) — issue 1 RDMA WRITE WQE
|
||||
// signal() — issue 1 inline 4-byte RDMA WRITE to a
|
||||
// remote 4B counter slot
|
||||
// wait() — spin on the local-side mirror of that
|
||||
// counter
|
||||
//
|
||||
// Memory model (each rank constructs):
|
||||
// - For every registered local memory region M, a flat row
|
||||
// local_mrs[M] = { base_addr, lkey_be }
|
||||
// - For every (peer rank R, registered region M) pair, a flat row
|
||||
// remote_mrs[M] = { base_addr_on_R, rkey_be_on_R }
|
||||
// (a channel is bound to a single peer R, so the rank dimension is
|
||||
// flattened away inside the handle).
|
||||
//
|
||||
// The signal counter:
|
||||
// - Each rank cudaMallocs a 4-byte "signal slot" registered with this NIC,
|
||||
// and exchanges its address+rkey with peers. The handle's
|
||||
// {sig_remote_addr, sig_rkey_be} points at the *peer's* slot, while
|
||||
// sig_local_addr is the local mirror used by wait() to compare against.
|
||||
//
|
||||
// - signal() atomically bumps sig_seq (CTA-shared-by-channel u32 in GPU
|
||||
// memory), writes that value as a 4-byte inline RDMA WRITE to peer's
|
||||
// slot. Receiver polls its local slot.
|
||||
|
||||
#ifndef MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_HPP_
|
||||
#define MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_HPP_
|
||||
|
||||
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "ibgda.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
using IbgdaMemoryId = uint32_t;
|
||||
|
||||
struct IbgdaLocalMr {
|
||||
uint64_t addr; // host-order base virtual address (GPU-mapped)
|
||||
uint32_t lkey_be; // BE-encoded local key
|
||||
uint32_t pad;
|
||||
};
|
||||
|
||||
struct IbgdaRemoteMr {
|
||||
uint64_t addr; // host-order base virtual address on the peer
|
||||
uint32_t rkey_be; // BE-encoded remote key
|
||||
uint32_t pad;
|
||||
};
|
||||
|
||||
// POD copied to GPU. Fields kept dense and 8-byte aligned.
|
||||
struct IbgdaPortChannelDeviceHandle {
|
||||
IbgdaQpHandle qp;
|
||||
|
||||
// Tables: indexed by IbgdaMemoryId. local_mrs[id] is for THIS rank;
|
||||
// remote_mrs[id] is for the peer rank this channel is bound to.
|
||||
// Both pointers must dereference from device code.
|
||||
const IbgdaLocalMr* local_mrs;
|
||||
const IbgdaRemoteMr* remote_mrs;
|
||||
|
||||
// Signal slot:
|
||||
// sig_local_addr — pointer to a 4-byte slot in this rank's GPU memory
|
||||
// used by wait() to poll. NIC writes here from the peer.
|
||||
// sig_local_lkey — BE-encoded lkey of the local 4B inline staging
|
||||
// buffer (same MR as sig_local_addr; we use it as both
|
||||
// the receive slot for wait() and as a backing MR for
|
||||
// completeness — but inline WRs don't actually need
|
||||
// a local MR. Kept for symmetry / future non-inline
|
||||
// fallback).
|
||||
// sig_remote_addr — peer's 4-byte slot
|
||||
// sig_rkey_be — peer's rkey for sig_remote_addr
|
||||
// sig_seq — pointer to a 4-byte GPU-resident counter incremented
|
||||
// by signal(); the value sent is the post-increment.
|
||||
// Distinct from sig_local_addr.
|
||||
uint32_t* sig_local_addr;
|
||||
uint32_t sig_local_lkey; // unused for inline; kept for layout stability
|
||||
uint64_t sig_remote_addr;
|
||||
uint32_t sig_rkey_be;
|
||||
uint32_t* sig_seq;
|
||||
|
||||
// Bound peer (informational); MemoryIds default to (0,0) in the simple
|
||||
// case but we keep them for parity with PortChannelDeviceHandle.
|
||||
IbgdaMemoryId dst;
|
||||
IbgdaMemoryId src;
|
||||
uint32_t peer_rank;
|
||||
uint32_t _pad;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM
|
||||
|
||||
#endif // MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_HPP_
|
||||
@@ -4,6 +4,9 @@
|
||||
#ifndef MSCCLPP_PORT_CHANNEL_HPP_
|
||||
#define MSCCLPP_PORT_CHANNEL_HPP_
|
||||
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "core.hpp"
|
||||
#include "port_channel_device.hpp"
|
||||
#include "proxy.hpp"
|
||||
@@ -84,6 +87,14 @@ class ProxyService : public BaseProxyService {
|
||||
std::vector<RegisteredMemory> memories_;
|
||||
std::shared_ptr<Proxy> proxy_;
|
||||
std::unordered_map<std::shared_ptr<BaseConnection>, int> inflightRequests_;
|
||||
// Connections with WRs staged but not yet posted. Mapped to the count of
|
||||
// WRs staged since the last postPending() call (used to trip the batch
|
||||
// threshold). Distinct from `inflightRequests_` which counts QP-inflight
|
||||
// WRs since the last flush() (used for QP overflow detection).
|
||||
std::unordered_map<std::shared_ptr<BaseConnection>, int> stagedConns_;
|
||||
std::unordered_set<std::shared_ptr<BaseConnection>> dirtyConns_;
|
||||
|
||||
void postPendingAll();
|
||||
|
||||
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw);
|
||||
};
|
||||
|
||||
@@ -53,6 +53,13 @@ class Proxy {
|
||||
/// @return Shared pointer to FIFO.
|
||||
std::shared_ptr<Fifo> fifo();
|
||||
|
||||
/// Set a callback invoked by the proxy thread when the FIFO transitions
|
||||
/// from busy to idle (i.e., a poll returns no trigger after at least one
|
||||
/// trigger was processed). Useful for batching: handlers can defer
|
||||
/// expensive system calls (e.g., ibv_post_send) until the FIFO drains.
|
||||
/// Must be called before start().
|
||||
void setOnIdle(std::function<void()> onIdle);
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> pimpl_;
|
||||
|
||||
@@ -396,7 +396,10 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem
|
||||
qp_.lock()->stageSendWrite(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset,
|
||||
/*dstOffset=*/dstOffset, /*signaled=*/true);
|
||||
|
||||
qp_.lock()->postSend();
|
||||
// NOTE: postSend() is intentionally deferred. The proxy service batches
|
||||
// many triggers and calls postPending() once at idle. External callers
|
||||
// that pair write() with flush() still observe correct semantics because
|
||||
// flush() now drains staged WRs before polling the CQ.
|
||||
INFO(CONN, "IBConnection write: from ", (uint8_t*)srcMr->getBuff() + srcOffset, " to ",
|
||||
(uint8_t*)dstMrInfo.addr + dstOffset, ", size ", size);
|
||||
|
||||
@@ -438,12 +441,12 @@ void IBConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint6
|
||||
/*size=*/0, /*wrId=*/0,
|
||||
/*srcOffset=*/0, /*dstOffset=*/0,
|
||||
/*signaled=*/true, /*immData=*/immData);
|
||||
qp_.lock()->postSend();
|
||||
// postSend deferred; see IBConnection::write.
|
||||
INFO(CONN, "IBConnection signal forwarding: value ", oldValue, " -> ", newValue);
|
||||
} else {
|
||||
qp_.lock()->stageSendAtomicAdd(atomicSrcTransportInfo_.ibMr, dstMrInfo, /*wrId=*/0, dstOffset, newValue - oldValue,
|
||||
/*signaled=*/true);
|
||||
qp_.lock()->postSend();
|
||||
// postSend deferred; see IBConnection::write.
|
||||
INFO(CONN, "IBConnection atomic write: from ", src, " to ", (uint8_t*)dstMrInfo.addr + dstOffset, ", ", oldValue,
|
||||
" -> ", newValue);
|
||||
}
|
||||
@@ -458,6 +461,9 @@ void IBConnection::flush(int64_t timeoutUsec) {
|
||||
NpKit::CollectCpuEvent(NPKIT_EVENT_CONN_IB_FLUSH_ENTRY, 0, 0, *NpKit::GetCpuTimestamp(), 0);
|
||||
#endif
|
||||
|
||||
// Drain any staged WRs that were deferred by the staging-only write path.
|
||||
qp_.lock()->postSend();
|
||||
|
||||
// Check if the recv thread has already reported an error (e.g., QP entered error state).
|
||||
if (recvThreadError_.load(std::memory_order_acquire)) {
|
||||
THROW(CONN, Error, ErrorCode::SystemError, "IBConnection recv thread failed: ", recvThreadErrorMsg_);
|
||||
@@ -503,10 +509,16 @@ void IBConnection::atomicAdd(RegisteredMemory dst, uint64_t dstOffset, int64_t v
|
||||
|
||||
qp_.lock()->stageSendAtomicAdd(atomicSrcTransportInfo_.ibMr, dstMrInfo, /*wrId=*/0, dstOffset,
|
||||
static_cast<uint64_t>(value), /*signaled=*/true);
|
||||
qp_.lock()->postSend();
|
||||
// postSend deferred; see IBConnection::write.
|
||||
INFO(CONN, "IBConnection atomicAdd: dst ", (uint8_t*)dstMrInfo.addr + dstOffset, ", value ", value);
|
||||
}
|
||||
|
||||
void IBConnection::postPending() {
|
||||
// Drain all staged WRs to the NIC in a single ibv_post_send call.
|
||||
// Cheap no-op when nothing is staged (IbQp::postSend early-returns).
|
||||
qp_.lock()->postSend();
|
||||
}
|
||||
|
||||
// EthernetConnection
|
||||
|
||||
EthernetConnection::EthernetConnection(std::shared_ptr<Context> context, const Endpoint& localEndpoint,
|
||||
|
||||
177
src/core/ibgda.cc
Normal file
177
src/core/ibgda.cc
Normal file
@@ -0,0 +1,177 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "ibgda.hpp"
|
||||
|
||||
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <infiniband/mlx5dv.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <mscclpp/errors.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
|
||||
#include "logger.hpp"
|
||||
#include "mlx5dv_wrapper.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct IbgdaResources::Impl {
|
||||
void* sq_buf_host = nullptr;
|
||||
size_t sq_bytes = 0;
|
||||
void* dbrec_host = nullptr; // not necessarily aligned to anything > 8B
|
||||
void* dbrec_register_addr = nullptr;
|
||||
size_t dbrec_register_bytes = 0;
|
||||
void* uar_page_host = nullptr; // page-aligned base of the UAR
|
||||
void* state_dev = nullptr; // device-resident state
|
||||
bool sq_registered = false;
|
||||
bool dbrec_registered = false;
|
||||
bool uar_registered = false;
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
inline uintptr_t pageMask() {
|
||||
static const uintptr_t mask = static_cast<uintptr_t>(::sysconf(_SC_PAGESIZE)) - 1;
|
||||
return mask;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
IbgdaResources::IbgdaResources(ibv_qp* qp) : pimpl_(std::make_unique<Impl>()) {
|
||||
if (qp == nullptr) {
|
||||
THROW(NET, Error, ErrorCode::InvalidUsage, "IbgdaResources: qp is null");
|
||||
}
|
||||
if (!MLX5DV::isAvailable()) {
|
||||
THROW(NET, Error, ErrorCode::InvalidUsage, "IbgdaResources: libmlx5 not available");
|
||||
}
|
||||
|
||||
// 1. Extract sq.buf / dbrec / bf via mlx5dv_init_obj.
|
||||
struct mlx5dv_qp dvQp{};
|
||||
if (MLX5DV::mlx5dv_init_obj_qp(qp, &dvQp) != 0) {
|
||||
THROW(NET, IbError, errno, "mlx5dv_init_obj(QP) failed (errno ", errno, ")");
|
||||
}
|
||||
|
||||
pimpl_->sq_buf_host = dvQp.sq.buf;
|
||||
pimpl_->sq_bytes = static_cast<size_t>(dvQp.sq.wqe_cnt) * dvQp.sq.stride;
|
||||
pimpl_->dbrec_host = dvQp.dbrec;
|
||||
if (pimpl_->sq_buf_host == nullptr || pimpl_->dbrec_host == nullptr || dvQp.bf.reg == nullptr) {
|
||||
THROW(NET, Error, ErrorCode::InvalidUsage, "mlx5dv_qp returned null pointers");
|
||||
}
|
||||
if (dvQp.sq.wqe_cnt == 0 || (dvQp.sq.wqe_cnt & (dvQp.sq.wqe_cnt - 1)) != 0) {
|
||||
THROW(NET, Error, ErrorCode::InvalidUsage, "sq.wqe_cnt must be a power of 2, got ", dvQp.sq.wqe_cnt);
|
||||
}
|
||||
|
||||
// 2. Register the SQ buffer with CUDA (host-mapped). We register the
|
||||
// enclosing whole pages because cudaHostRegister requires that.
|
||||
{
|
||||
uintptr_t base = reinterpret_cast<uintptr_t>(pimpl_->sq_buf_host);
|
||||
uintptr_t pageBase = base & ~pageMask();
|
||||
size_t pad = static_cast<size_t>(base - pageBase);
|
||||
size_t regBytes = (pad + pimpl_->sq_bytes + pageMask()) & ~pageMask();
|
||||
void* regAddr = reinterpret_cast<void*>(pageBase);
|
||||
cudaError_t e = cudaHostRegister(regAddr, regBytes, cudaHostRegisterDefault);
|
||||
if (e != cudaSuccess) {
|
||||
THROW(NET, SysError, static_cast<int>(e), "cudaHostRegister(sq.buf) failed: ",
|
||||
cudaGetErrorString(e));
|
||||
}
|
||||
pimpl_->sq_registered = true;
|
||||
void* dev = nullptr;
|
||||
MSCCLPP_CUDATHROW(cudaHostGetDevicePointer(&dev, pimpl_->sq_buf_host, 0));
|
||||
handle_.sq_buf = dev;
|
||||
}
|
||||
|
||||
// 3. Register the DBR record. mlx5 hands us a pointer into a small block
|
||||
// of DBR records (8 B each); register a whole page starting from its
|
||||
// page base.
|
||||
{
|
||||
uintptr_t base = reinterpret_cast<uintptr_t>(pimpl_->dbrec_host);
|
||||
uintptr_t pageBase = base & ~pageMask();
|
||||
size_t regBytes = pageMask() + 1; // 1 page
|
||||
pimpl_->dbrec_register_addr = reinterpret_cast<void*>(pageBase);
|
||||
pimpl_->dbrec_register_bytes = regBytes;
|
||||
cudaError_t e = cudaHostRegister(pimpl_->dbrec_register_addr, regBytes, cudaHostRegisterDefault);
|
||||
if (e == cudaErrorHostMemoryAlreadyRegistered) {
|
||||
// The DBR page may overlap with another QP we already registered.
|
||||
// Treat as success but skip unregister on shutdown. Clear sticky error.
|
||||
pimpl_->dbrec_registered = false;
|
||||
(void)cudaGetLastError();
|
||||
} else if (e != cudaSuccess) {
|
||||
THROW(NET, SysError, static_cast<int>(e), "cudaHostRegister(dbrec) failed: ",
|
||||
cudaGetErrorString(e));
|
||||
} else {
|
||||
pimpl_->dbrec_registered = true;
|
||||
}
|
||||
void* dev = nullptr;
|
||||
MSCCLPP_CUDATHROW(cudaHostGetDevicePointer(&dev, pimpl_->dbrec_host, 0));
|
||||
handle_.dbrec = static_cast<uint32_t*>(dev);
|
||||
}
|
||||
|
||||
// 4. Map the UAR page (NIC MMIO) into GPU VA via cuMemHostRegister IOMEMORY.
|
||||
{
|
||||
uintptr_t bfAddr = reinterpret_cast<uintptr_t>(dvQp.bf.reg);
|
||||
uintptr_t pageAddr = bfAddr & ~uintptr_t(4095); // UAR pages are 4 KB
|
||||
uintptr_t bfOffset = bfAddr - pageAddr;
|
||||
pimpl_->uar_page_host = reinterpret_cast<void*>(pageAddr);
|
||||
CUresult cuRes =
|
||||
cuMemHostRegister(pimpl_->uar_page_host, 4096, CU_MEMHOSTREGISTER_IOMEMORY);
|
||||
if (cuRes == CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED) {
|
||||
// The UAR page is shared across QPs on the same context; if a sibling
|
||||
// QP already registered it, treat as success and skip unregister.
|
||||
pimpl_->uar_registered = false;
|
||||
(void)cudaGetLastError();
|
||||
} else if (cuRes != CUDA_SUCCESS) {
|
||||
const char* s = nullptr;
|
||||
cuGetErrorString(cuRes, &s);
|
||||
THROW(NET, SysError, static_cast<int>(cuRes),
|
||||
"cuMemHostRegister(UAR, IOMEMORY) failed: ", s ? s : "?",
|
||||
". Ensure 'options nvidia NVreg_RegistryDwords=\"PeerMappingOverride=1;\"' "
|
||||
"is set and nvidia_peermem is loaded.");
|
||||
} else {
|
||||
pimpl_->uar_registered = true;
|
||||
}
|
||||
CUdeviceptr dPage = 0;
|
||||
MSCCLPP_CUTHROW(cuMemHostGetDevicePointer(&dPage, pimpl_->uar_page_host, 0));
|
||||
handle_.bf_reg = reinterpret_cast<uint64_t*>(dPage + bfOffset);
|
||||
handle_.bf_offset = static_cast<uint32_t>(bfOffset);
|
||||
}
|
||||
|
||||
// 5. Allocate GPU-resident state (zero-initialized).
|
||||
{
|
||||
constexpr size_t kStateBytes = 4 * sizeof(uint64_t); // resv/ready/prod/lock
|
||||
MSCCLPP_CUDATHROW(cudaMalloc(&pimpl_->state_dev, kStateBytes));
|
||||
MSCCLPP_CUDATHROW(cudaMemset(pimpl_->state_dev, 0, kStateBytes));
|
||||
handle_.state = static_cast<uint64_t*>(pimpl_->state_dev);
|
||||
}
|
||||
|
||||
// 6. Fill the rest of the handle.
|
||||
handle_.qpn = qp->qp_num;
|
||||
handle_.wqe_cnt = dvQp.sq.wqe_cnt;
|
||||
handle_.stride = dvQp.sq.stride;
|
||||
}
|
||||
|
||||
IbgdaResources::~IbgdaResources() {
|
||||
if (!pimpl_) return;
|
||||
if (pimpl_->state_dev) {
|
||||
cudaFree(pimpl_->state_dev);
|
||||
}
|
||||
if (pimpl_->uar_registered && pimpl_->uar_page_host) {
|
||||
cuMemHostUnregister(pimpl_->uar_page_host);
|
||||
}
|
||||
if (pimpl_->dbrec_registered && pimpl_->dbrec_register_addr) {
|
||||
cudaHostUnregister(pimpl_->dbrec_register_addr);
|
||||
}
|
||||
if (pimpl_->sq_registered && pimpl_->sq_buf_host) {
|
||||
uintptr_t base = reinterpret_cast<uintptr_t>(pimpl_->sq_buf_host);
|
||||
void* regAddr = reinterpret_cast<void*>(base & ~pageMask());
|
||||
cudaHostUnregister(regAddr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM
|
||||
@@ -37,6 +37,12 @@ class BaseConnection {
|
||||
|
||||
virtual void atomicAdd(RegisteredMemory dst, uint64_t dstOffset, int64_t value) = 0;
|
||||
|
||||
/// Post any locally-staged work requests to the NIC. Called by the proxy
|
||||
/// service to amortize ibv_post_send across many FIFO triggers. The default
|
||||
/// no-op is correct for connections whose write/updateAndSync/atomicAdd
|
||||
/// already post immediately (e.g., CudaIpcConnection).
|
||||
virtual void postPending() {}
|
||||
|
||||
virtual void flush(int64_t timeoutUsec = -1) = 0;
|
||||
|
||||
/// Start signal forwarding to the given memory address.
|
||||
@@ -152,6 +158,8 @@ class IBConnection : public BaseConnection {
|
||||
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override;
|
||||
void atomicAdd(RegisteredMemory dst, uint64_t dstOffset, int64_t value) override;
|
||||
|
||||
void postPending() override;
|
||||
|
||||
void flush(int64_t timeoutUsec) override;
|
||||
};
|
||||
|
||||
|
||||
@@ -82,6 +82,9 @@ class IbQp {
|
||||
int pollRecvCq();
|
||||
|
||||
IbQpInfo& getInfo() { return info_; }
|
||||
/// Raw ibv_qp pointer. Owned by this IbQp. Provided so other components
|
||||
/// (e.g. IbgdaResources) can call mlx5dv_init_obj on the same QP.
|
||||
ibv_qp* getRawQp() const { return qp_; }
|
||||
int getSendWcStatus(int idx) const;
|
||||
std::string getSendWcStatusString(int idx) const;
|
||||
int getNumSendCqItems() const;
|
||||
|
||||
71
src/core/include/ibgda.hpp
Normal file
71
src/core/include/ibgda.hpp
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
//
|
||||
// IBGDA-style GPU-direct work-request submission for an mlx5 QP.
|
||||
//
|
||||
// Given an existing ibv_qp (RC, plain ibv_create_qp), IbgdaResources extracts
|
||||
// the SQ ring buffer, doorbell record, and BlueFlame UAR via mlx5dv_init_obj
|
||||
// and maps them into GPU virtual address space:
|
||||
// - SQ buf and DBR are host RAM, mapped via cudaHostRegister.
|
||||
// - The BF UAR is device MMIO, mapped via cuMemHostRegister(IOMEMORY).
|
||||
// It also allocates a small GPU-resident state struct (resv_head / ready_head
|
||||
// / prod_idx / lock) used by the device-side WQE writer.
|
||||
//
|
||||
// The host retains ownership of the QP and never touches sq.buf/dbrec/bf
|
||||
// itself once GPU posting starts — completion polling on the send CQ remains
|
||||
// a host-side ibv_poll_cq() call.
|
||||
|
||||
#ifndef MSCCLPP_IBGDA_HPP_
|
||||
#define MSCCLPP_IBGDA_HPP_
|
||||
|
||||
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
struct ibv_qp;
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
// POD copied to the GPU; consumed by the device-side WQE writer (added in
|
||||
// Stage 2). Field order/sizes must match include/mscclpp/ibgda_device.hpp.
|
||||
struct IbgdaQpHandle {
|
||||
// Device-mapped pointers (MUST be dereferenceable from the GPU only).
|
||||
void* sq_buf; // SQ WQE ring (host RAM, cudaHostRegister)
|
||||
uint32_t* dbrec; // 32-bit doorbell record (host RAM, cudaHostRegister)
|
||||
uint64_t* bf_reg; // BlueFlame doorbell (NIC MMIO, cuMemHostRegister IOMEMORY)
|
||||
// GPU-resident bookkeeping state (allocated by IbgdaResources).
|
||||
// Layout: { uint64_t resv_head; uint64_t ready_head; uint64_t prod_idx; int post_send_lock; }.
|
||||
uint64_t* state;
|
||||
// Constants (host-order).
|
||||
uint32_t qpn;
|
||||
uint32_t wqe_cnt; // SQ length in WQEBBs (power of 2)
|
||||
uint32_t stride; // bytes per WQEBB (typically 64)
|
||||
uint32_t bf_offset; // byte offset of BF doorbell within its UAR page
|
||||
};
|
||||
|
||||
// Wraps an existing ibv_qp and prepares the GPU mappings + GPU state. Owns
|
||||
// the cudaHostRegister/cuMemHostRegister registrations + GPU state buffer.
|
||||
// Lifetime must enclose any kernel launch that uses getHandle().
|
||||
class IbgdaResources {
|
||||
public:
|
||||
// The qp must already be modified to RTS before kernel posting; that is
|
||||
// the caller's responsibility.
|
||||
explicit IbgdaResources(ibv_qp* qp);
|
||||
~IbgdaResources();
|
||||
|
||||
IbgdaResources(const IbgdaResources&) = delete;
|
||||
IbgdaResources& operator=(const IbgdaResources&) = delete;
|
||||
|
||||
const IbgdaQpHandle& getHandle() const { return handle_; }
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> pimpl_;
|
||||
IbgdaQpHandle handle_{};
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM
|
||||
#endif // MSCCLPP_IBGDA_HPP_
|
||||
426
src/core/include/ibgda_device.cuh
Normal file
426
src/core/include/ibgda_device.cuh
Normal file
@@ -0,0 +1,426 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
//
|
||||
// Stage 2: device-side IBGDA primitives.
|
||||
//
|
||||
// Operates on an `IbgdaQpHandle` (see ibgda.hpp). All pointers in the handle
|
||||
// must already be GPU-mapped by `IbgdaResources` on the host. This header is
|
||||
// device-only and should be included from .cu translation units.
|
||||
//
|
||||
// State buffer layout (allocated by IbgdaResources, zero-initialised):
|
||||
// offset 0: uint64_t resv_head // next free WQEBB slot (atomic)
|
||||
// offset 8: uint64_t ready_head // next slot whose WQE write is done
|
||||
// offset 16: uint64_t prod_idx // last value rung on the doorbell
|
||||
// offset 24: int post_send_lock // cta-level lock for ringing
|
||||
//
|
||||
// Conventions copied from NVSHMEM/DeepEP:
|
||||
// - wqe_idx is a 16-bit-truncated cyclic index into the SQ ring.
|
||||
// - one RDMA WRITE WQE = 1 WQEBB (64B): ctrl(16) + raddr(16) + data(16) + pad(16).
|
||||
// ctrl_seg.qpn_ds encodes ds=3 (three 16-byte segments).
|
||||
// - DBR record holds BE32(prod_idx & 0xFFFF) in the SQ slot (dbrec[1] for
|
||||
// mlx5; IbgdaResources sets handle.dbrec to that slot directly).
|
||||
// - BF doorbell: 64-bit write of {opmod_idx_opcode, qpn_ds} to bf_reg.
|
||||
|
||||
#ifndef MSCCLPP_IBGDA_DEVICE_CUH_
|
||||
#define MSCCLPP_IBGDA_DEVICE_CUH_
|
||||
|
||||
#include <cstdint>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "ibgda.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace ibgda {
|
||||
|
||||
#ifndef MSCCLPP_IBGDA_OPCODE_RDMA_WRITE
|
||||
#define MSCCLPP_IBGDA_OPCODE_RDMA_WRITE 0x08
|
||||
#endif
|
||||
#ifndef MSCCLPP_IBGDA_CTRL_CQ_UPDATE
|
||||
#define MSCCLPP_IBGDA_CTRL_CQ_UPDATE uint8_t(0x8)
|
||||
#endif
|
||||
|
||||
// ---- BE swap helpers (PTX prmt) -------------------------------------------
|
||||
__device__ static __forceinline__ uint32_t HtoBE32(uint32_t x) {
|
||||
uint32_t r;
|
||||
asm("{ .reg .b32 ig; prmt.b32 %0, %1, ig, 0x0123; }" : "=r"(r) : "r"(x));
|
||||
return r;
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ uint64_t HtoBE64(uint64_t x) {
|
||||
uint64_t r;
|
||||
asm("{\n\t"
|
||||
".reg .b32 ig;\n\t"
|
||||
".reg .b32 lo, hi, nl, nh;\n\t"
|
||||
"mov.b64 {lo,hi}, %1;\n\t"
|
||||
"prmt.b32 nh, lo, ig, 0x0123;\n\t"
|
||||
"prmt.b32 nl, hi, ig, 0x0123;\n\t"
|
||||
"mov.b64 %0, {nl,nh};\n\t" "}"
|
||||
: "=l"(r) : "l"(x));
|
||||
return r;
|
||||
}
|
||||
|
||||
// ---- Memory ordering helpers ----------------------------------------------
|
||||
// st.relaxed.sys + st.release.sys equivalents. We use `volatile` writes plus
|
||||
// __threadfence_system() at the points the caller needs ordering, mirroring
|
||||
// the NVSHMEM `st_na_*` style without depending on its internal headers.
|
||||
__device__ static __forceinline__ void store_relaxed_u32(uint32_t* p, uint32_t v) {
|
||||
*reinterpret_cast<volatile uint32_t*>(p) = v;
|
||||
}
|
||||
__device__ static __forceinline__ void store_relaxed_u64(uint64_t* p, uint64_t v) {
|
||||
*reinterpret_cast<volatile uint64_t*>(p) = v;
|
||||
}
|
||||
__device__ static __forceinline__ void store_relaxed_int4(int4* p, int4 v) {
|
||||
asm volatile("st.relaxed.sys.v4.b32 [%0], {%1, %2, %3, %4};" ::
|
||||
"l"(p), "r"(v.x), "r"(v.y), "r"(v.z), "r"(v.w) : "memory");
|
||||
}
|
||||
__device__ static __forceinline__ void store_release_u32(uint32_t* p, uint32_t v) {
|
||||
asm volatile("st.release.sys.b32 [%0], %1;" :: "l"(p), "r"(v) : "memory");
|
||||
}
|
||||
__device__ static __forceinline__ void store_release_u64(uint64_t* p, uint64_t v) {
|
||||
asm volatile("st.release.sys.b64 [%0], %1;" :: "l"(p), "l"(v) : "memory");
|
||||
}
|
||||
|
||||
// ---- WQE segment layout (mirrors mlx5_wqe_*; redefined to avoid pulling
|
||||
// libmlx5 headers into device code) ----------------------------------------
|
||||
struct __align__(8) CtrlSeg {
|
||||
uint32_t opmod_idx_opcode; // BE32: (wqe_idx << 8) | opcode
|
||||
uint32_t qpn_ds; // BE32: (qpn << 8) | ds
|
||||
uint8_t signature;
|
||||
uint8_t rsvd0;
|
||||
uint8_t rsvd1;
|
||||
uint8_t fm_ce_se; // CQ_UPDATE in bit 3 (0x8)
|
||||
uint32_t imm; // BE32 immediate (or 0)
|
||||
};
|
||||
struct __align__(8) RaddrSeg {
|
||||
uint64_t raddr; // BE64
|
||||
uint32_t rkey; // BE32 (caller pre-swaps)
|
||||
uint32_t reserved;
|
||||
};
|
||||
struct __align__(8) DataSeg {
|
||||
uint32_t byte_count; // BE32
|
||||
uint32_t lkey; // BE32 (caller pre-swaps)
|
||||
uint64_t addr; // BE64
|
||||
};
|
||||
|
||||
static_assert(sizeof(CtrlSeg) == 16, "CtrlSeg must be 16B");
|
||||
static_assert(sizeof(RaddrSeg) == 16, "RaddrSeg must be 16B");
|
||||
static_assert(sizeof(DataSeg) == 16, "DataSeg must be 16B");
|
||||
|
||||
// ---- State accessors ------------------------------------------------------
|
||||
struct StateView {
|
||||
unsigned long long* resv_head;
|
||||
unsigned long long* ready_head;
|
||||
unsigned long long* prod_idx;
|
||||
int* post_send_lock;
|
||||
};
|
||||
|
||||
__device__ static __forceinline__ StateView stateView(const IbgdaQpHandle& h) {
|
||||
auto* base = reinterpret_cast<uint64_t*>(h.state);
|
||||
StateView v;
|
||||
v.resv_head = reinterpret_cast<unsigned long long*>(&base[0]);
|
||||
v.ready_head = reinterpret_cast<unsigned long long*>(&base[1]);
|
||||
v.prod_idx = reinterpret_cast<unsigned long long*>(&base[2]);
|
||||
v.post_send_lock = reinterpret_cast<int*>(&base[3]);
|
||||
return v;
|
||||
}
|
||||
|
||||
// ---- WQE ring helpers -----------------------------------------------------
|
||||
__device__ static __forceinline__
|
||||
uint64_t reserve_wqe_slots(const IbgdaQpHandle& h, uint32_t num_wqes) {
|
||||
auto v = stateView(h);
|
||||
return atomicAdd(v.resv_head, static_cast<unsigned long long>(num_wqes));
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
void* get_wqe_ptr(const IbgdaQpHandle& h, uint16_t wqe_idx) {
|
||||
uint16_t mask = static_cast<uint16_t>(h.wqe_cnt - 1);
|
||||
uint16_t idx = wqe_idx & mask;
|
||||
return reinterpret_cast<uint8_t*>(h.sq_buf) +
|
||||
(static_cast<size_t>(idx) * h.stride);
|
||||
}
|
||||
|
||||
// ---- Doorbell record + BF ring -------------------------------------------
|
||||
__device__ static __forceinline__
|
||||
void update_dbr(const IbgdaQpHandle& h, uint32_t dbrec_head) {
|
||||
// BE32(dbrec_head & 0xFFFF) — see DeepEP/NVSHMEM.
|
||||
uint32_t v;
|
||||
asm("{\n\t"
|
||||
".reg .b32 lo16; .reg .b32 ig;\n\t"
|
||||
"and.b32 lo16, %1, 0xffff;\n\t"
|
||||
"prmt.b32 %0, lo16, ig, 0x123;\n\t"
|
||||
"}" : "=r"(v) : "r"(dbrec_head));
|
||||
// dbrec is the SQ slot directly (set by IbgdaResources).
|
||||
store_release_u32(reinterpret_cast<uint32_t*>(h.dbrec), v);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
void ring_db(const IbgdaQpHandle& h, uint16_t prod_idx) {
|
||||
// 64-bit BF write of {opmod_idx_opcode, qpn_ds}.
|
||||
// opmod_idx_opcode = BE32(prod_idx << 8); qpn_ds = BE32(qpn << 8).
|
||||
uint32_t lo = HtoBE32(static_cast<uint32_t>(prod_idx) << 8);
|
||||
uint32_t hi = HtoBE32(h.qpn << 8);
|
||||
uint64_t v = static_cast<uint64_t>(lo) | (static_cast<uint64_t>(hi) << 32);
|
||||
store_release_u64(h.bf_reg, v);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
void post_send(const IbgdaQpHandle& h, uint64_t new_prod_idx) {
|
||||
auto v = stateView(h);
|
||||
// CTA-level lock so concurrent posters serialise their DBR/BF writes.
|
||||
while (atomicCAS(v.post_send_lock, 0, 1) == 1) {}
|
||||
__threadfence();
|
||||
|
||||
unsigned long long old_prod_idx =
|
||||
atomicMax(v.prod_idx, static_cast<unsigned long long>(new_prod_idx));
|
||||
if (new_prod_idx > old_prod_idx) {
|
||||
update_dbr(h, static_cast<uint32_t>(new_prod_idx));
|
||||
ring_db(h, static_cast<uint16_t>(new_prod_idx));
|
||||
}
|
||||
|
||||
__threadfence();
|
||||
store_release_u32(reinterpret_cast<uint32_t*>(v.post_send_lock), 0);
|
||||
}
|
||||
|
||||
// ---- Submit (publish ready_head, optionally ring DB) ----------------------
|
||||
template <bool kAlwaysDoPostSend>
|
||||
__device__ static __forceinline__
|
||||
void submit_requests(const IbgdaQpHandle& h, uint64_t base_wqe_idx,
|
||||
uint32_t num_wqes, int message_idx = 0) {
|
||||
auto v = stateView(h);
|
||||
uint64_t new_wqe_idx = base_wqe_idx + num_wqes;
|
||||
|
||||
// The WQE writes themselves must be globally visible before we publish.
|
||||
__threadfence_system();
|
||||
|
||||
// Wait for any earlier reservations to publish first, preserving order.
|
||||
auto base_ull = static_cast<unsigned long long>(base_wqe_idx);
|
||||
auto new_ull = static_cast<unsigned long long>(new_wqe_idx);
|
||||
while (atomicCAS(v.ready_head, base_ull, new_ull) != base_ull) {}
|
||||
|
||||
constexpr int kBatch = 4;
|
||||
if (kAlwaysDoPostSend || ((message_idx + 1) % kBatch == 0)) {
|
||||
post_send(h, new_wqe_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Publish-only variant: advance ready_head in WQE-issue order (preserving the
|
||||
// in-order chain that submit_requests builds) but DO NOT ring the doorbell.
|
||||
// Use this for the data WRs in a coalesced "unsignaled body + signaled tail"
|
||||
// pattern; the trailing WR's submit_requests<true> will atomicMax prod_idx
|
||||
// past all earlier WRs and ring the doorbell once for the whole batch.
|
||||
__device__ static __forceinline__
|
||||
void submit_no_db(const IbgdaQpHandle& h, uint64_t base_wqe_idx,
|
||||
uint32_t num_wqes) {
|
||||
auto v = stateView(h);
|
||||
uint64_t new_wqe_idx = base_wqe_idx + num_wqes;
|
||||
__threadfence_system();
|
||||
auto base_ull = static_cast<unsigned long long>(base_wqe_idx);
|
||||
auto new_ull = static_cast<unsigned long long>(new_wqe_idx);
|
||||
while (atomicCAS(v.ready_head, base_ull, new_ull) != base_ull) {}
|
||||
// No post_send: the next signaled WR on this QP (or any concurrent ringer)
|
||||
// will atomicMax prod_idx past us and the NIC will pick up our slot.
|
||||
}
|
||||
|
||||
// ---- WQE writers ----------------------------------------------------------
|
||||
// All addresses BE-encoded inside; lkey/rkey expected pre-BE-swapped.
|
||||
// `signal_cqe`: if true, set CQ_UPDATE so the NIC posts a CQE for this WR;
|
||||
// if false, the WR is UNSIGNALED — useful for batched data WRs whose
|
||||
// completion is implicitly bounded by a later signaled WR on the same QP.
|
||||
__device__ static __forceinline__
|
||||
void write_rdma_write_wqe(const IbgdaQpHandle& h, void* wqe_slot,
|
||||
uint64_t laddr, uint32_t lkey_be,
|
||||
uint64_t raddr, uint32_t rkey_be,
|
||||
uint32_t bytes, uint16_t wqe_idx,
|
||||
bool signal_cqe = true) {
|
||||
auto* ctrl = reinterpret_cast<CtrlSeg*>(wqe_slot);
|
||||
auto* raddrp = reinterpret_cast<RaddrSeg*>(reinterpret_cast<uint8_t*>(wqe_slot) + 16);
|
||||
auto* datap = reinterpret_cast<DataSeg*>(reinterpret_cast<uint8_t*>(wqe_slot) + 32);
|
||||
|
||||
CtrlSeg c{};
|
||||
c.opmod_idx_opcode = HtoBE32((static_cast<uint32_t>(wqe_idx) << 8) |
|
||||
MSCCLPP_IBGDA_OPCODE_RDMA_WRITE);
|
||||
c.qpn_ds = HtoBE32((h.qpn << 8) | 3u);
|
||||
c.fm_ce_se = signal_cqe ? MSCCLPP_IBGDA_CTRL_CQ_UPDATE : 0u;
|
||||
c.imm = 0;
|
||||
|
||||
RaddrSeg r{};
|
||||
r.raddr = HtoBE64(raddr);
|
||||
r.rkey = rkey_be;
|
||||
r.reserved = 0;
|
||||
|
||||
DataSeg d{};
|
||||
d.byte_count = HtoBE32(bytes);
|
||||
d.lkey = lkey_be;
|
||||
d.addr = HtoBE64(laddr);
|
||||
|
||||
store_relaxed_int4(reinterpret_cast<int4*>(ctrl), *reinterpret_cast<int4*>(&c));
|
||||
store_relaxed_int4(reinterpret_cast<int4*>(raddrp), *reinterpret_cast<int4*>(&r));
|
||||
store_relaxed_int4(reinterpret_cast<int4*>(datap), *reinterpret_cast<int4*>(&d));
|
||||
}
|
||||
|
||||
// One-shot helper: reserve a slot, write WQE, submit + ring doorbell.
|
||||
// Returns the wqe_idx of the issued WR (caller can match against CQE wr_id
|
||||
// if it embeds the index there; here we don't use wr_id).
|
||||
__device__ static __forceinline__
|
||||
uint64_t rdma_write(const IbgdaQpHandle& h,
|
||||
uint64_t laddr, uint32_t lkey_be,
|
||||
uint64_t raddr, uint32_t rkey_be,
|
||||
uint32_t bytes,
|
||||
bool signal_cqe = true, bool ring_db = true) {
|
||||
uint64_t base = reserve_wqe_slots(h, 1);
|
||||
void* slot = get_wqe_ptr(h, static_cast<uint16_t>(base));
|
||||
write_rdma_write_wqe(h, slot, laddr, lkey_be, raddr, rkey_be, bytes,
|
||||
static_cast<uint16_t>(base), signal_cqe);
|
||||
if (ring_db) submit_requests<true>(h, base, 1);
|
||||
else submit_no_db(h, base, 1);
|
||||
return base;
|
||||
}
|
||||
|
||||
// ---- Inline 4-byte RDMA WRITE --------------------------------------------
|
||||
// Single WQEBB carrying (ctrl=16) + (raddr=16) + (inline_hdr=4 + data=4)
|
||||
// padded to 48B. ds=3 in ctrl_seg.
|
||||
//
|
||||
// MLX5_INLINE_SEG = 0x80000000 in inline_seg.byte_count signals "inline".
|
||||
//
|
||||
// The 4-byte payload `value` is written verbatim to remote `raddr`. The
|
||||
// caller does NOT need a local MR for this op.
|
||||
#ifndef MSCCLPP_IBGDA_INLINE_SEG
|
||||
#define MSCCLPP_IBGDA_INLINE_SEG 0x80000000u
|
||||
#endif
|
||||
|
||||
__device__ static __forceinline__
|
||||
void write_rdma_write_inl4_wqe(const IbgdaQpHandle& h, void* wqe_slot,
|
||||
uint32_t value,
|
||||
uint64_t raddr, uint32_t rkey_be,
|
||||
uint16_t wqe_idx,
|
||||
bool signal_cqe = true) {
|
||||
auto* base8 = reinterpret_cast<uint8_t*>(wqe_slot);
|
||||
auto* ctrl = reinterpret_cast<CtrlSeg*>(base8 + 0);
|
||||
auto* raddrp = reinterpret_cast<RaddrSeg*>(base8 + 16);
|
||||
// inline header is a 4-byte field (byte_count) at offset 32, followed by
|
||||
// 4 bytes of payload at offset 36.
|
||||
uint32_t* inl_hdr = reinterpret_cast<uint32_t*>(base8 + 32);
|
||||
uint32_t* inl_dat = reinterpret_cast<uint32_t*>(base8 + 36);
|
||||
|
||||
CtrlSeg c{};
|
||||
c.opmod_idx_opcode = HtoBE32((static_cast<uint32_t>(wqe_idx) << 8) |
|
||||
MSCCLPP_IBGDA_OPCODE_RDMA_WRITE);
|
||||
// ds=3 (3 * 16B = 48B; ctrl + raddr + 8B of inline hdr+data, padded).
|
||||
c.qpn_ds = HtoBE32((h.qpn << 8) | 3u);
|
||||
c.fm_ce_se = signal_cqe ? MSCCLPP_IBGDA_CTRL_CQ_UPDATE : 0u;
|
||||
c.imm = 0;
|
||||
|
||||
RaddrSeg r{};
|
||||
r.raddr = HtoBE64(raddr);
|
||||
r.rkey = rkey_be;
|
||||
r.reserved = 0;
|
||||
|
||||
store_relaxed_int4(reinterpret_cast<int4*>(ctrl), *reinterpret_cast<int4*>(&c));
|
||||
store_relaxed_int4(reinterpret_cast<int4*>(raddrp), *reinterpret_cast<int4*>(&r));
|
||||
// inline byte_count: 4 | INLINE_SEG, BE32.
|
||||
store_relaxed_u32(inl_hdr, HtoBE32(4u | MSCCLPP_IBGDA_INLINE_SEG));
|
||||
// The 4-byte payload is written as-is in network order — the NIC treats
|
||||
// the inline data block as opaque bytes copied to remote memory.
|
||||
store_relaxed_u32(inl_dat, value);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
uint64_t rdma_write_inl4(const IbgdaQpHandle& h,
|
||||
uint32_t value,
|
||||
uint64_t raddr, uint32_t rkey_be,
|
||||
bool signal_cqe = true, bool ring_db = true) {
|
||||
uint64_t base = reserve_wqe_slots(h, 1);
|
||||
void* slot = get_wqe_ptr(h, static_cast<uint16_t>(base));
|
||||
write_rdma_write_inl4_wqe(h, slot, value, raddr, rkey_be,
|
||||
static_cast<uint16_t>(base), signal_cqe);
|
||||
if (ring_db) submit_requests<true>(h, base, 1);
|
||||
else submit_no_db(h, base, 1);
|
||||
return base;
|
||||
}
|
||||
|
||||
// ---- Inline 8-byte RDMA WRITE --------------------------------------------
|
||||
// Same WQE shape as inl4 (ds=3 / 48B): ctrl(16) + raddr(16) + inline_hdr(4)
|
||||
// + payload(8) + pad(4). Used to publish 8B counters (e.g. dispatch
|
||||
// per-(local_expert, src_rank) recv_count slot polled with int64 acquire).
|
||||
__device__ static __forceinline__
|
||||
void write_rdma_write_inl8_wqe(const IbgdaQpHandle& h, void* wqe_slot,
|
||||
uint64_t value,
|
||||
uint64_t raddr, uint32_t rkey_be,
|
||||
uint16_t wqe_idx,
|
||||
bool signal_cqe = true) {
|
||||
auto* base8 = reinterpret_cast<uint8_t*>(wqe_slot);
|
||||
auto* ctrl = reinterpret_cast<CtrlSeg*>(base8 + 0);
|
||||
auto* raddrp = reinterpret_cast<RaddrSeg*>(base8 + 16);
|
||||
uint32_t* inl_hdr = reinterpret_cast<uint32_t*>(base8 + 32);
|
||||
// Two 32-bit halves of the 8B payload at offsets 36 / 40.
|
||||
uint32_t* inl_dat_lo = reinterpret_cast<uint32_t*>(base8 + 36);
|
||||
uint32_t* inl_dat_hi = reinterpret_cast<uint32_t*>(base8 + 40);
|
||||
|
||||
CtrlSeg c{};
|
||||
c.opmod_idx_opcode = HtoBE32((static_cast<uint32_t>(wqe_idx) << 8) |
|
||||
MSCCLPP_IBGDA_OPCODE_RDMA_WRITE);
|
||||
c.qpn_ds = HtoBE32((h.qpn << 8) | 3u);
|
||||
c.fm_ce_se = signal_cqe ? MSCCLPP_IBGDA_CTRL_CQ_UPDATE : 0u;
|
||||
c.imm = 0;
|
||||
|
||||
RaddrSeg r{};
|
||||
r.raddr = HtoBE64(raddr);
|
||||
r.rkey = rkey_be;
|
||||
r.reserved = 0;
|
||||
|
||||
store_relaxed_int4(reinterpret_cast<int4*>(ctrl), *reinterpret_cast<int4*>(&c));
|
||||
store_relaxed_int4(reinterpret_cast<int4*>(raddrp), *reinterpret_cast<int4*>(&r));
|
||||
// inline byte_count: 8 | INLINE_SEG, BE32.
|
||||
store_relaxed_u32(inl_hdr, HtoBE32(8u | MSCCLPP_IBGDA_INLINE_SEG));
|
||||
// 8B payload as two 32-bit words; treated as opaque bytes by the NIC.
|
||||
store_relaxed_u32(inl_dat_lo, static_cast<uint32_t>(value & 0xffffffffull));
|
||||
store_relaxed_u32(inl_dat_hi, static_cast<uint32_t>(value >> 32));
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
uint64_t rdma_write_inl8(const IbgdaQpHandle& h,
|
||||
uint64_t value,
|
||||
uint64_t raddr, uint32_t rkey_be,
|
||||
bool signal_cqe = true, bool ring_db = true) {
|
||||
uint64_t base = reserve_wqe_slots(h, 1);
|
||||
void* slot = get_wqe_ptr(h, static_cast<uint16_t>(base));
|
||||
write_rdma_write_inl8_wqe(h, slot, value, raddr, rkey_be,
|
||||
static_cast<uint16_t>(base), signal_cqe);
|
||||
if (ring_db) submit_requests<true>(h, base, 1);
|
||||
else submit_no_db(h, base, 1);
|
||||
return base;
|
||||
}
|
||||
|
||||
// ---- Warp-coalesced burst (lane 0 issues N WRs as a single batch) --------
|
||||
// This is the "warp granularity" pattern used by NVSHMEM IBGDA: a single
|
||||
// thread reserves N contiguous slots (one atomic), writes N WQEs in a tight
|
||||
// loop, then publishes ready_head and rings the doorbell exactly once. This
|
||||
// amortises the ready_head CAS-spin and the BF doorbell MMIO across N WRs.
|
||||
//
|
||||
// Caller is responsible for ensuring that only one thread per warp invokes
|
||||
// this with a given (laddr/raddr) layout — typically `if (lane_id == 0)`.
|
||||
// All N WRs share the same lkey_be/rkey_be and have stride `bytes` (i.e. a
|
||||
// contiguous chunk laddr_base..laddr_base+N*bytes -> raddr_base..).
|
||||
__device__ static __forceinline__
|
||||
uint64_t rdma_write_strided_burst(const IbgdaQpHandle& h,
|
||||
uint64_t laddr_base, uint32_t lkey_be,
|
||||
uint64_t raddr_base, uint32_t rkey_be,
|
||||
uint32_t bytes, uint32_t num_wrs) {
|
||||
if (num_wrs == 0) return 0;
|
||||
uint64_t base = reserve_wqe_slots(h, num_wrs);
|
||||
for (uint32_t i = 0; i < num_wrs; ++i) {
|
||||
uint16_t idx = static_cast<uint16_t>(base + i);
|
||||
void* slot = get_wqe_ptr(h, idx);
|
||||
write_rdma_write_wqe(h, slot,
|
||||
laddr_base + size_t(i) * bytes, lkey_be,
|
||||
raddr_base + size_t(i) * bytes, rkey_be,
|
||||
bytes, idx);
|
||||
}
|
||||
submit_requests<true>(h, base, num_wrs);
|
||||
return base;
|
||||
}
|
||||
|
||||
} // namespace ibgda
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_IBGDA_DEVICE_CUH_
|
||||
100
src/core/include/ibgda_port_channel_device.cuh
Normal file
100
src/core/include/ibgda_port_channel_device.cuh
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
//
|
||||
// Stage 3: device-side methods that act on `IbgdaPortChannelDeviceHandle`.
|
||||
// Split from the POD header so that the POD can be included by host code
|
||||
// (which doesn't want PTX inline asm or CUDA-only types) while these inline
|
||||
// device functions only appear in .cu translation units.
|
||||
|
||||
#ifndef MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_CUH_
|
||||
#define MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_CUH_
|
||||
|
||||
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <mscclpp/ibgda_port_channel_device.hpp>
|
||||
|
||||
#include "ibgda_device.cuh"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace ibgda {
|
||||
|
||||
// ---------------- put -----------------------------------------------------
|
||||
// `signal_cqe` / `ring_db` default to true (preserves prior call sites);
|
||||
// pass false/false for batched data WRs whose completion is implicitly
|
||||
// bounded by a later signaled WR on the same QP (e.g., the trailing count
|
||||
// or flag write at the end of LL dispatch / combine).
|
||||
__device__ static __forceinline__
|
||||
void port_put(const IbgdaPortChannelDeviceHandle& ch,
|
||||
IbgdaMemoryId dst, uint64_t dstOffset,
|
||||
IbgdaMemoryId src, uint64_t srcOffset,
|
||||
uint64_t size,
|
||||
bool signal_cqe = true, bool ring_db = true) {
|
||||
IbgdaLocalMr l = ch.local_mrs[src];
|
||||
IbgdaRemoteMr r = ch.remote_mrs[dst];
|
||||
rdma_write(ch.qp,
|
||||
l.addr + srcOffset, l.lkey_be,
|
||||
r.addr + dstOffset, r.rkey_be,
|
||||
static_cast<uint32_t>(size),
|
||||
signal_cqe, ring_db);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
void port_put(const IbgdaPortChannelDeviceHandle& ch,
|
||||
uint64_t dstOffset, uint64_t srcOffset, uint64_t size,
|
||||
bool signal_cqe = true, bool ring_db = true) {
|
||||
port_put(ch, ch.dst, dstOffset, ch.src, srcOffset, size, signal_cqe, ring_db);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
void port_put(const IbgdaPortChannelDeviceHandle& ch,
|
||||
uint64_t offset, uint64_t size,
|
||||
bool signal_cqe = true, bool ring_db = true) {
|
||||
port_put(ch, offset, offset, size, signal_cqe, ring_db);
|
||||
}
|
||||
|
||||
// ---------------- signal --------------------------------------------------
|
||||
// Increments the local sequence counter (system-wide atomic so multiple
|
||||
// CTAs serialise) and writes the new value as an inline 4B RDMA WRITE to
|
||||
// the peer's signal slot.
|
||||
__device__ static __forceinline__
|
||||
uint32_t port_signal(const IbgdaPortChannelDeviceHandle& ch) {
|
||||
// Use atomicAdd_system so the counter is consistent across CTAs sharing
|
||||
// this handle (rare but possible).
|
||||
uint32_t prev = atomicAdd(ch.sig_seq, 1u);
|
||||
uint32_t v = prev + 1u;
|
||||
rdma_write_inl4(ch.qp, v, ch.sig_remote_addr, ch.sig_rkey_be);
|
||||
return v;
|
||||
}
|
||||
|
||||
// ---------------- wait ----------------------------------------------------
|
||||
// Returns when the local signal slot's value reaches `expected` (sticky;
|
||||
// callers track the expected counter externally — same pattern as
|
||||
// PortChannelDeviceHandle::wait()).
|
||||
__device__ static __forceinline__
|
||||
void port_wait(const IbgdaPortChannelDeviceHandle& ch, uint32_t expected,
|
||||
int64_t maxSpinCount = 10000000) {
|
||||
volatile uint32_t* p = ch.sig_local_addr;
|
||||
if (maxSpinCount < 0) {
|
||||
while (*p < expected) { /* spin */ }
|
||||
return;
|
||||
}
|
||||
for (int64_t i = 0; i < maxSpinCount; ++i) {
|
||||
if (*p >= expected) return;
|
||||
}
|
||||
// Out of patience — caller can detect with a follow-up poll(); we do NOT
|
||||
// assert here to keep the device path lean.
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
bool port_poll(const IbgdaPortChannelDeviceHandle& ch, uint32_t expected) {
|
||||
return *(volatile uint32_t*)ch.sig_local_addr >= expected;
|
||||
}
|
||||
|
||||
} // namespace ibgda
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM
|
||||
|
||||
#endif // MSCCLPP_IBGDA_PORT_CHANNEL_DEVICE_CUH_
|
||||
@@ -28,6 +28,11 @@ struct MLX5DV {
|
||||
/// Returns 0 on success (device supports Data Direct), non-zero otherwise.
|
||||
static int mlx5dv_get_data_direct_sysfs_path(struct ibv_context* context, char* buf, size_t buf_len);
|
||||
|
||||
/// Wraps mlx5dv_init_obj(MLX5DV_OBJ_QP). Returns 0 on success.
|
||||
/// `out` must be a pointer to an mlx5dv_qp; we keep this typeless to avoid
|
||||
/// pulling <infiniband/mlx5dv.h> into the public-ish wrapper header.
|
||||
static int mlx5dv_init_obj_qp(struct ibv_qp* qp, void* out);
|
||||
|
||||
private:
|
||||
static void* dlsym(const std::string& symbol, bool allowReturnNull = false);
|
||||
};
|
||||
|
||||
@@ -121,6 +121,20 @@ int MLX5DV::mlx5dv_get_data_direct_sysfs_path(struct ibv_context* context, char*
|
||||
return impl(context, buf, buf_len);
|
||||
}
|
||||
|
||||
int MLX5DV::mlx5dv_init_obj_qp(struct ibv_qp* qp, void* out_qp) {
|
||||
using FuncType = int (*)(struct mlx5dv_obj*, uint64_t);
|
||||
static FuncType impl = nullptr;
|
||||
if (!impl) {
|
||||
void* ptr = MLX5DV::dlsym("mlx5dv_init_obj", /*allowReturnNull=*/true);
|
||||
if (!ptr) return -1;
|
||||
impl = reinterpret_cast<FuncType>(ptr);
|
||||
}
|
||||
struct mlx5dv_obj obj{};
|
||||
obj.qp.in = qp;
|
||||
obj.qp.out = static_cast<struct mlx5dv_qp*>(out_qp);
|
||||
return impl(&obj, MLX5DV_OBJ_QP);
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // defined(MSCCLPP_USE_MLX5DV)
|
||||
|
||||
@@ -3,12 +3,75 @@
|
||||
|
||||
#include <mscclpp/numa.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "api.h"
|
||||
#include "connection.hpp"
|
||||
#include "debug.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
// Lightweight diagnostic counters kept in a side-table keyed by ProxyService*
|
||||
// so the ProxyService class layout stays ABI-compatible with prebuilt
|
||||
// extensions (e.g. mscclpp/_mscclpp.cpython-*.so) that were compiled against
|
||||
// the older header. Populated only when MSCCLPP_PROXY_STATS=1.
|
||||
namespace {
|
||||
struct ProxyStats {
|
||||
bool enabled = false;
|
||||
bool printed = false;
|
||||
uint64_t triggers = 0;
|
||||
uint64_t trigData = 0;
|
||||
uint64_t trigFlag = 0;
|
||||
uint64_t trigAtomic = 0;
|
||||
uint64_t trigSync = 0;
|
||||
uint64_t postCalls = 0;
|
||||
uint64_t idleDrains = 0;
|
||||
};
|
||||
// Process-wide flag: 0 unless any ProxyService has stats enabled. Avoids
|
||||
// taking the side-table mutex on the proxy hot path when nobody asked for
|
||||
// stats.
|
||||
std::atomic<bool>& statsAnyEnabled() {
|
||||
static std::atomic<bool> v{false};
|
||||
return v;
|
||||
}
|
||||
std::mutex& statsMu() { static std::mutex m; return m; }
|
||||
std::unordered_map<const ProxyService*, ProxyStats>& statsTable() {
|
||||
static std::unordered_map<const ProxyService*, ProxyStats> t;
|
||||
return t;
|
||||
}
|
||||
ProxyStats* getStats(const ProxyService* self) {
|
||||
if (!statsAnyEnabled().load(std::memory_order_relaxed)) return nullptr;
|
||||
std::lock_guard<std::mutex> lk(statsMu());
|
||||
auto it = statsTable().find(self);
|
||||
if (it == statsTable().end()) return nullptr;
|
||||
return &it->second;
|
||||
}
|
||||
void printAndEraseStats(const ProxyService* self) {
|
||||
std::lock_guard<std::mutex> lk(statsMu());
|
||||
auto it = statsTable().find(self);
|
||||
if (it == statsTable().end()) return;
|
||||
ProxyStats& s = it->second;
|
||||
if (s.enabled && !s.printed) {
|
||||
s.printed = true;
|
||||
uint64_t posts = s.postCalls + s.idleDrains;
|
||||
double wrPerPost = posts ? static_cast<double>(s.triggers) / static_cast<double>(posts) : 0.0;
|
||||
fprintf(stderr,
|
||||
"[mscclpp proxy stats] triggers=%llu (data=%llu flag=%llu atomic=%llu sync=%llu) "
|
||||
"postCalls=%llu idleDrains=%llu triggersPerPost=%.2f\n",
|
||||
(unsigned long long)s.triggers, (unsigned long long)s.trigData,
|
||||
(unsigned long long)s.trigFlag, (unsigned long long)s.trigAtomic,
|
||||
(unsigned long long)s.trigSync, (unsigned long long)s.postCalls,
|
||||
(unsigned long long)s.idleDrains, wrPerPost);
|
||||
fflush(stderr);
|
||||
}
|
||||
statsTable().erase(it);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MSCCLPP_API_CPP BasePortChannel::BasePortChannel(SemaphoreId semaphoreId,
|
||||
std::shared_ptr<Host2DeviceSemaphore> semaphore,
|
||||
std::shared_ptr<Proxy> proxy)
|
||||
@@ -30,6 +93,15 @@ MSCCLPP_API_CPP ProxyService::ProxyService(int fifoSize) {
|
||||
int cudaDevice;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
|
||||
int deviceNumaNode = getDeviceNumaNode(cudaDevice);
|
||||
if (const char* env = std::getenv("MSCCLPP_PROXY_STATS")) {
|
||||
if (std::atoi(env) > 0) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(statsMu());
|
||||
statsTable()[this].enabled = true;
|
||||
}
|
||||
statsAnyEnabled().store(true, std::memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
auto initFunc = [cudaDevice, deviceNumaNode]() {
|
||||
MSCCLPP_CUDATHROW(cudaSetDevice(cudaDevice));
|
||||
if (deviceNumaNode >= 0) {
|
||||
@@ -39,6 +111,21 @@ MSCCLPP_API_CPP ProxyService::ProxyService(int fifoSize) {
|
||||
};
|
||||
auto handlerFunc = [&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); };
|
||||
proxy_ = std::make_shared<Proxy>(handlerFunc, initFunc, fifoSize);
|
||||
// Drain any deferred ibv_post_send work the moment the FIFO empties out, so
|
||||
// we batch as many triggers as possible into a single syscall while still
|
||||
// keeping latency within one FIFO drain.
|
||||
proxy_->setOnIdle([this]() { this->postPendingAll(); });
|
||||
}
|
||||
|
||||
void ProxyService::postPendingAll() {
|
||||
if (!stagedConns_.empty()) {
|
||||
if (auto* s = getStats(this); s && s->enabled) s->idleDrains++;
|
||||
}
|
||||
for (auto& kv : stagedConns_) {
|
||||
kv.first->postPending();
|
||||
}
|
||||
stagedConns_.clear();
|
||||
dirtyConns_.clear();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator,
|
||||
@@ -84,14 +171,45 @@ MSCCLPP_API_CPP PortChannel ProxyService::portChannel(SemaphoreId id, MemoryId d
|
||||
|
||||
MSCCLPP_API_CPP void ProxyService::startProxy(bool blocking) { proxy_->start(blocking); }
|
||||
|
||||
MSCCLPP_API_CPP void ProxyService::stopProxy() { proxy_->stop(); }
|
||||
MSCCLPP_API_CPP void ProxyService::stopProxy() {
|
||||
proxy_->stop();
|
||||
printAndEraseStats(this);
|
||||
}
|
||||
|
||||
ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger trigger) {
|
||||
std::shared_ptr<Host2DeviceSemaphore> semaphore = semaphores_[trigger.fields.semaphoreId];
|
||||
|
||||
ProxyStats* stats = getStats(this);
|
||||
if (stats && stats->enabled) {
|
||||
stats->triggers++;
|
||||
if (trigger.fields.type == 0) stats->trigAtomic++;
|
||||
if (trigger.fields.type & TriggerData) stats->trigData++;
|
||||
if (trigger.fields.type & TriggerFlag) stats->trigFlag++;
|
||||
if (trigger.fields.type & TriggerSync) stats->trigSync++;
|
||||
}
|
||||
|
||||
auto& conn = semaphore->connection();
|
||||
int maxWriteQueueSize = conn.getMaxWriteQueueSize();
|
||||
auto& numRequests = inflightRequests_[conn.impl_];
|
||||
auto connImpl = BaseConnection::getImpl(conn);
|
||||
auto& numRequests = inflightRequests_[connImpl];
|
||||
|
||||
// Batch threshold: how many staged WRs we let accumulate on one connection
|
||||
// before forcing an ibv_post_send. Lower values reduce tail latency for
|
||||
// signalling triggers; higher values reduce syscalls.
|
||||
// Override via MSCCLPP_PROXY_BATCH_THRESHOLD (1 = post every trigger,
|
||||
// matching the historical behavior; the empirical sweep on H100/IB shows
|
||||
// LL traffic is wire-latency bound rather than syscall bound, so deeper
|
||||
// batching does not help and may regress dispatch tail latency).
|
||||
static const int kPostBatchThreshold = []() {
|
||||
if (const char* env = std::getenv("MSCCLPP_PROXY_BATCH_THRESHOLD")) {
|
||||
int v = std::atoi(env);
|
||||
if (v >= 1) return v;
|
||||
}
|
||||
return 1;
|
||||
}();
|
||||
bool stagedWork = false;
|
||||
// Atomic/signal triggers convey completion semantics; never defer them.
|
||||
bool isSignal = (trigger.fields.type == 0) || (trigger.fields.type & TriggerFlag);
|
||||
|
||||
if (trigger.fields.type == 0) {
|
||||
// type == 0 indicates an atomic add operation.
|
||||
@@ -100,6 +218,7 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger trigger) {
|
||||
int64_t value = static_cast<int64_t>(trigger.fst);
|
||||
conn.atomicAdd(dst, trigger.fields.dstOffset, value);
|
||||
numRequests++;
|
||||
stagedWork = true;
|
||||
}
|
||||
|
||||
if (trigger.fields.type & TriggerData) {
|
||||
@@ -107,17 +226,38 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger trigger) {
|
||||
RegisteredMemory& src = memories_[trigger.fields.srcMemoryId];
|
||||
conn.write(dst, trigger.fields.dstOffset, src, trigger.fields.srcOffset, trigger.fields.size);
|
||||
numRequests++;
|
||||
stagedWork = true;
|
||||
}
|
||||
|
||||
if (trigger.fields.type & TriggerFlag) {
|
||||
semaphore->signal();
|
||||
numRequests++;
|
||||
stagedWork = true;
|
||||
}
|
||||
|
||||
if (((trigger.fields.type & TriggerSync) && numRequests > 0) ||
|
||||
(maxWriteQueueSize != -1 && numRequests >= maxWriteQueueSize)) {
|
||||
if (stagedWork) {
|
||||
stagedConns_[connImpl]++;
|
||||
dirtyConns_.insert(connImpl);
|
||||
}
|
||||
|
||||
bool needFlush = (trigger.fields.type & TriggerSync) && numRequests > 0;
|
||||
bool needFlush2 = maxWriteQueueSize != -1 && numRequests >= maxWriteQueueSize;
|
||||
if (needFlush || needFlush2) {
|
||||
// flush() drains staged WRs and waits for completion.
|
||||
conn.flush();
|
||||
numRequests = 0;
|
||||
stagedConns_.erase(connImpl);
|
||||
dirtyConns_.erase(connImpl);
|
||||
} else if (isSignal || stagedConns_[connImpl] >= kPostBatchThreshold) {
|
||||
// Post the staged batch now: either we just queued a completion-signal
|
||||
// WR (don't make the receiver wait for the next idle drain), or we've
|
||||
// accumulated enough WRs that the QP send queue might overflow.
|
||||
if (stats && stats->enabled) {
|
||||
stats->postCalls++;
|
||||
}
|
||||
connImpl->postPending();
|
||||
stagedConns_.erase(connImpl);
|
||||
dirtyConns_.erase(connImpl);
|
||||
}
|
||||
|
||||
return ProxyHandlerResult::Continue;
|
||||
|
||||
@@ -20,6 +20,7 @@ constexpr int ProxyStartWarnPeriod = 1000;
|
||||
struct Proxy::Impl {
|
||||
ProxyHandler handler;
|
||||
std::function<void()> threadInit;
|
||||
std::function<void()> onIdle;
|
||||
std::shared_ptr<Fifo> fifo;
|
||||
std::atomic_bool threadStarted;
|
||||
std::thread service;
|
||||
@@ -28,6 +29,7 @@ struct Proxy::Impl {
|
||||
Impl(ProxyHandler handler, std::function<void()> threadInit, int fifoSize)
|
||||
: handler(handler),
|
||||
threadInit(threadInit),
|
||||
onIdle(nullptr),
|
||||
fifo(std::make_shared<Fifo>(fifoSize)),
|
||||
threadStarted(false),
|
||||
running(false) {}
|
||||
@@ -71,9 +73,11 @@ MSCCLPP_API_CPP void Proxy::start(bool blocking) {
|
||||
pimpl_->threadStarted.store(true, std::memory_order_release);
|
||||
|
||||
ProxyHandler handler = this->pimpl_->handler;
|
||||
auto onIdle = this->pimpl_->onIdle;
|
||||
auto fifo = this->pimpl_->fifo;
|
||||
ProxyTrigger trigger;
|
||||
|
||||
bool wasBusy = false;
|
||||
int runCnt = ProxyStopCheckPeriod;
|
||||
for (;;) {
|
||||
if (runCnt-- == 0) {
|
||||
@@ -85,8 +89,13 @@ MSCCLPP_API_CPP void Proxy::start(bool blocking) {
|
||||
// Poll to see if we are ready to send anything
|
||||
trigger = fifo->poll();
|
||||
if (trigger.fst == 0 || trigger.snd == 0) { // TODO: this check is a potential pitfall for custom triggers
|
||||
if (wasBusy && onIdle) {
|
||||
onIdle();
|
||||
}
|
||||
wasBusy = false;
|
||||
continue; // there is one in progress
|
||||
}
|
||||
wasBusy = true;
|
||||
trigger.snd ^= (uint64_t{1} << uint64_t{63}); // this is where the last bit of snd is reverted.
|
||||
|
||||
ProxyHandlerResult result = handler(trigger);
|
||||
@@ -95,6 +104,7 @@ MSCCLPP_API_CPP void Proxy::start(bool blocking) {
|
||||
fifo->pop();
|
||||
|
||||
if (result == ProxyHandlerResult::Stop) {
|
||||
if (onIdle) onIdle();
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -123,4 +133,6 @@ MSCCLPP_API_CPP void Proxy::stop() {
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Fifo> Proxy::fifo() { return pimpl_->fifo; }
|
||||
|
||||
MSCCLPP_API_CPP void Proxy::setOnIdle(std::function<void()> onIdle) { pimpl_->onIdle = std::move(onIdle); }
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -42,6 +42,7 @@ endif()
|
||||
file(GLOB_RECURSE EP_SOURCES CONFIGURE_DEPENDS
|
||||
buffer.cc
|
||||
bindings.cpp
|
||||
ibgda_setup.cc
|
||||
kernels/*.cu
|
||||
)
|
||||
|
||||
@@ -58,6 +59,17 @@ endif()
|
||||
Python_add_library(mscclpp_ep_cpp MODULE ${EP_SOURCES})
|
||||
|
||||
target_compile_definitions(mscclpp_ep_cpp PRIVATE TORCH_EXTENSION_NAME=mscclpp_ep_cpp)
|
||||
# Inherit ibverbs / mlx5dv defs from the core library so the IBGDA host-side
|
||||
# plumbing (see ibgda_setup.cc, gated by MSCCLPP_EP_HAVE_IBGDA) compiles in.
|
||||
if(IBVERBS_FOUND)
|
||||
target_compile_definitions(mscclpp_ep_cpp PRIVATE USE_IBVERBS)
|
||||
target_link_libraries(mscclpp_ep_cpp PRIVATE ${IBVERBS_LIBRARIES})
|
||||
if(MLX5_FOUND)
|
||||
target_compile_definitions(mscclpp_ep_cpp PRIVATE MSCCLPP_USE_MLX5DV)
|
||||
target_include_directories(mscclpp_ep_cpp SYSTEM PRIVATE ${MLX5_INCLUDE_DIRS})
|
||||
target_link_libraries(mscclpp_ep_cpp PRIVATE ${MLX5_LIBRARIES})
|
||||
endif()
|
||||
endif()
|
||||
target_include_directories(mscclpp_ep_cpp PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
|
||||
@@ -15,6 +15,10 @@
|
||||
#include "kernels/api.cuh"
|
||||
#include "kernels/configs.cuh"
|
||||
|
||||
#ifdef MSCCLPP_EP_HAVE_IBGDA
|
||||
#include "ibgda_setup.hpp"
|
||||
#endif
|
||||
|
||||
namespace mscclpp {
|
||||
namespace ep {
|
||||
|
||||
@@ -76,6 +80,15 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
|
||||
printf("[mscclpp_ep] num_proxy_services=%d (set MSCCLPP_EP_NUM_PROXIES to override)\n", num_proxy_services);
|
||||
fflush(stdout);
|
||||
}
|
||||
// Resolve native-IBGDA opt-in flag. The state is built in sync() and
|
||||
// currently unused by the kernels (4b.1 plumbing only); see buffer.hpp.
|
||||
if (const char* e = std::getenv("MSCCLPP_EP_USE_IBGDA")) {
|
||||
use_ibgda_path_ = (std::atoi(e) != 0);
|
||||
}
|
||||
if (rank == 0) {
|
||||
printf("[mscclpp_ep] MSCCLPP_EP_USE_IBGDA=%d\n", use_ibgda_path_ ? 1 : 0);
|
||||
fflush(stdout);
|
||||
}
|
||||
// Task fifo memory
|
||||
int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS;
|
||||
int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS;
|
||||
@@ -518,6 +531,44 @@ void Buffer::sync(const std::vector<int>& device_ids,
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef MSCCLPP_EP_HAVE_IBGDA
|
||||
// ------------------------------------------------------------------
|
||||
// Stage 4b.1: native IBGDA host-side plumbing.
|
||||
//
|
||||
// Built only when:
|
||||
// - MSCCLPP_EP_USE_IBGDA=1
|
||||
// - The run is cross-node (num_rdma_ranks > 1) — intra-node sticks with
|
||||
// the NVLink IPC fast path already established above.
|
||||
// - We have an RDMA buffer (num_rdma_bytes > 0).
|
||||
//
|
||||
// The setup is built but NOT yet consumed by any kernel (4b.2 will add
|
||||
// the kernel branch). With or without the flag set, existing test
|
||||
// outcomes must be identical.
|
||||
// ------------------------------------------------------------------
|
||||
if (use_ibgda_path_ && num_rdma_bytes > 0 && num_rdma_ranks > 1) {
|
||||
constexpr int kNumIbgdaChannels = 16; // mirrors num_port_channels_per_rank above
|
||||
try {
|
||||
ibgda_setup_ = mscclpp::ep::build_ibgda_setup(rank, num_ranks, /*ib_transport_index=*/device_id,
|
||||
kNumIbgdaChannels, rdma_buffer_ptr,
|
||||
static_cast<std::size_t>(num_rdma_bytes), bootstrap);
|
||||
if (rank == 0) {
|
||||
printf("[mscclpp_ep] IBGDA setup built: channels=%d num_ranks=%d (per-rank QPs=%d)\n",
|
||||
kNumIbgdaChannels, num_ranks, kNumIbgdaChannels * (num_ranks - 1));
|
||||
fflush(stdout);
|
||||
}
|
||||
// Clear any benign CUDA sticky error left by overlapping host/UAR
|
||||
// registrations across sibling QPs (e.g., cudaErrorHostMemoryAlreadyRegistered).
|
||||
(void)cudaGetLastError();
|
||||
} catch (const std::exception& e) {
|
||||
// Don't take down the run — just log and leave use_ibgda_path_ on as
|
||||
// a hint (kernel-side fallback in 4b.2 will check ibgda_setup_ != null).
|
||||
fprintf(stderr, "[mscclpp_ep][rank=%d] IBGDA setup failed, falling back to host FIFO: %s\n", rank, e.what());
|
||||
ibgda_setup_.reset();
|
||||
(void)cudaGetLastError();
|
||||
}
|
||||
}
|
||||
#endif // MSCCLPP_EP_HAVE_IBGDA
|
||||
|
||||
// Ready to use
|
||||
available = true;
|
||||
}
|
||||
@@ -1452,6 +1503,13 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
auto peer_bases = peer_rdma_bases_gpu;
|
||||
const bool use_ipc = ll_ipc_ready;
|
||||
auto rdma_base = rdma_buffer_ptr;
|
||||
#ifdef MSCCLPP_EP_HAVE_IBGDA
|
||||
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_dev = ibgda_setup_ ? ibgda_setup_->device_handles.get() : nullptr;
|
||||
const bool use_ibgda = use_ibgda_path_ && ibgda_dev != nullptr && !use_ipc;
|
||||
#else
|
||||
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_dev = nullptr;
|
||||
const bool use_ibgda = false;
|
||||
#endif
|
||||
auto launcher = [=](int phases) {
|
||||
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, packed_recv_src_info.data_ptr<int>(),
|
||||
packed_recv_layout_range.data_ptr<int64_t>(), packed_recv_count.data_ptr<int>(),
|
||||
@@ -1459,7 +1517,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
buffer.dispatch_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr<int64_t>(),
|
||||
next_clean_meta.first, next_clean_meta.second, num_tokens, hidden,
|
||||
num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, use_fp8, workspace,
|
||||
launch_stream, phases, rdma_base, port_handles, peer_bases, mem_handles, use_ipc);
|
||||
launch_stream, phases, rdma_base, port_handles, peer_bases, mem_handles, use_ipc,
|
||||
ibgda_dev, use_ibgda);
|
||||
};
|
||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||
|
||||
@@ -1528,13 +1587,21 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
|
||||
auto peer_bases = peer_rdma_bases_gpu;
|
||||
const bool use_ipc = ll_ipc_ready;
|
||||
auto rdma_base = rdma_buffer_ptr;
|
||||
#ifdef MSCCLPP_EP_HAVE_IBGDA
|
||||
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_dev = ibgda_setup_ ? ibgda_setup_->device_handles.get() : nullptr;
|
||||
const bool use_ibgda = use_ibgda_path_ && ibgda_dev != nullptr && !use_ipc;
|
||||
#else
|
||||
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_dev = nullptr;
|
||||
const bool use_ibgda = false;
|
||||
#endif
|
||||
auto launcher = [=](int phases) {
|
||||
internode_ll::combine(
|
||||
combined_x.data_ptr(), buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer,
|
||||
buffer.combine_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
|
||||
src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(), next_clean_meta.first, next_clean_meta.second,
|
||||
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks,
|
||||
workspace, launch_stream, phases, zero_copy, rdma_base, port_handles, peer_bases, mem_handles, use_ipc);
|
||||
workspace, launch_stream, phases, zero_copy, rdma_base, port_handles, peer_bases, mem_handles, use_ipc,
|
||||
ibgda_dev, use_ibgda);
|
||||
};
|
||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||
|
||||
|
||||
@@ -17,6 +17,11 @@
|
||||
#include "kernels/configs.cuh"
|
||||
#include "kernels/exception.cuh"
|
||||
|
||||
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
|
||||
#define MSCCLPP_EP_HAVE_IBGDA 1
|
||||
namespace mscclpp { namespace ep { struct IbgdaSetup; } }
|
||||
#endif
|
||||
|
||||
#ifndef TORCH_EXTENSION_NAME
|
||||
#define TORCH_EXTENSION_NAME mscclpp_ep_cpp
|
||||
#endif
|
||||
@@ -101,6 +106,17 @@ struct Buffer {
|
||||
std::shared_ptr<mscclpp::MemoryChannelDeviceHandle> ll_memory_channel_handles_device_ptr;
|
||||
bool ll_ipc_ready = false;
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Native IBGDA path (Stage 4b). Built lazily in `sync()` when env
|
||||
// `MSCCLPP_EP_USE_IBGDA=1` is set AND the run is cross-node.
|
||||
// The kernels do NOT consume `ibgda_setup_` until 4b.2 lands; for now
|
||||
// it is constructed-but-unused, so existing tests are unaffected.
|
||||
// ------------------------------------------------------------------
|
||||
bool use_ibgda_path_ = false;
|
||||
#ifdef MSCCLPP_EP_HAVE_IBGDA
|
||||
std::unique_ptr<mscclpp::ep::IbgdaSetup> ibgda_setup_;
|
||||
#endif
|
||||
|
||||
private:
|
||||
void move_fifo_slots(int num_slots = 1);
|
||||
|
||||
|
||||
341
src/ext/ep/ibgda_setup.cc
Normal file
341
src/ext/ep/ibgda_setup.cc
Normal file
@@ -0,0 +1,341 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
//
|
||||
// Stage 4b.1 implementation. See ibgda_setup.hpp for the contract.
|
||||
|
||||
#include "ibgda_setup.hpp"
|
||||
|
||||
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <infiniband/verbs.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/errors.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
|
||||
#include "kernels/exception.cuh"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace ep {
|
||||
|
||||
IbgdaSetup::~IbgdaSetup() {
|
||||
// Stop the CQ poller first so it doesn't race with QP teardown.
|
||||
if (cq_poller_thread.joinable()) {
|
||||
cq_poller_stop.store(true, std::memory_order_release);
|
||||
cq_poller_thread.join();
|
||||
}
|
||||
// Tear down in reverse order of construction:
|
||||
// - device_handles, d_local_mrs, d_remote_mrs: shared_ptr from
|
||||
// gpuCallocShared, auto-freed via custom deleter.
|
||||
// - resources / qps / rdma_mr: smart ptrs, auto-freed.
|
||||
// - sig_mr (raw ibv_mr) + sig_slots / sig_seq (raw cudaMalloc): explicit.
|
||||
if (sig_mr != nullptr) {
|
||||
ibv_dereg_mr(sig_mr);
|
||||
sig_mr = nullptr;
|
||||
}
|
||||
if (sig_slots != nullptr) {
|
||||
cudaFree(sig_slots);
|
||||
sig_slots = nullptr;
|
||||
}
|
||||
if (sig_seq != nullptr) {
|
||||
cudaFree(sig_seq);
|
||||
sig_seq = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Tunables that mirror the existing PortChannel path:
|
||||
// - 16 channels (== num_port_channels_per_rank in Buffer::sync)
|
||||
// - port=1, gid_index=3 (matches the value used in Stage 0..4a probes and
|
||||
// the typical mscclpp/NDv5 default)
|
||||
constexpr int kIbgdaPort = 1;
|
||||
constexpr int kIbgdaGid = 3;
|
||||
// SQ depth per QP. Each LL dispatch+combine iteration posts up to a few
|
||||
// dozen WRs per QP at default LL configs; the bench loop runs 50 iters
|
||||
// without explicit per-iter drain, so we size the SQ at 8192 to avoid
|
||||
// wrap-around overwriting in-flight WQEs (we have no device-side
|
||||
// back-pressure check in reserve_wqe_slots — the CQ poller drains
|
||||
// asynchronously). 8k entries × 64B stride = 512 KiB SQ buf per QP, well
|
||||
// under typical mlx5 HCA per-QP caps.
|
||||
constexpr int kIbgdaMaxSendWr = 8192;
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<IbgdaSetup> build_ibgda_setup(int rank, int num_ranks, int ib_transport_index, int num_channels,
|
||||
void* rdma_buffer_ptr, std::size_t num_rdma_bytes,
|
||||
std::shared_ptr<TcpBootstrap> bootstrap) {
|
||||
EP_HOST_ASSERT(rdma_buffer_ptr != nullptr);
|
||||
EP_HOST_ASSERT(num_rdma_bytes > 0);
|
||||
EP_HOST_ASSERT(num_ranks > 1);
|
||||
EP_HOST_ASSERT(num_channels > 0);
|
||||
EP_HOST_ASSERT(ib_transport_index >= 0 && ib_transport_index < 8);
|
||||
|
||||
auto setup = std::make_unique<IbgdaSetup>();
|
||||
setup->rank = rank;
|
||||
setup->num_ranks = num_ranks;
|
||||
setup->num_channels = num_channels;
|
||||
|
||||
// 1. Resolve IB device name and build the IbCtx.
|
||||
auto ib_transport = static_cast<Transport>(static_cast<int>(Transport::IB0) + ib_transport_index);
|
||||
std::string dev_name = getIBDeviceName(ib_transport);
|
||||
setup->ib_ctx = std::make_unique<IbCtx>(dev_name);
|
||||
EP_HOST_ASSERT(setup->ib_ctx->isMlx5() && "IBGDA requires an mlx5 NIC");
|
||||
|
||||
// 2. Create QPs. Layout: qps[channel * num_ranks + peer].
|
||||
// Self entries are nullptr.
|
||||
const int total_slots = num_channels * num_ranks;
|
||||
setup->qps.resize(total_slots);
|
||||
setup->resources.resize(total_slots);
|
||||
|
||||
for (int c = 0; c < num_channels; ++c) {
|
||||
for (int r = 0; r < num_ranks; ++r) {
|
||||
if (r == rank) continue;
|
||||
auto qp = setup->ib_ctx->createQp(/*port=*/kIbgdaPort, /*gidIndex=*/kIbgdaGid,
|
||||
/*maxSendCqSize=*/kIbgdaMaxSendWr * 2,
|
||||
/*maxSendCqPollNum=*/64,
|
||||
/*maxSendWr=*/kIbgdaMaxSendWr,
|
||||
/*maxRecvWr=*/1,
|
||||
/*maxWrPerSend=*/1,
|
||||
/*noAtomic=*/true);
|
||||
setup->qps[c * num_ranks + r] = qp;
|
||||
}
|
||||
}
|
||||
|
||||
// 3. AllGather IbQpInfos so every rank can RTR every QP.
|
||||
// Layout per-rank: total_slots IbQpInfo records, in [c * num_ranks + peer]
|
||||
// order. Self entries (peer == rank) are zeroed; remote entries describe
|
||||
// the QP that THIS rank uses to TALK TO peer == r.
|
||||
std::vector<IbQpInfo> my_infos(total_slots);
|
||||
std::memset(my_infos.data(), 0, total_slots * sizeof(IbQpInfo));
|
||||
for (int c = 0; c < num_channels; ++c) {
|
||||
for (int r = 0; r < num_ranks; ++r) {
|
||||
if (r == rank) continue;
|
||||
my_infos[c * num_ranks + r] = setup->qps[c * num_ranks + r]->getInfo();
|
||||
}
|
||||
}
|
||||
// bootstrap->allGather expects each rank to fill its own slot in a
|
||||
// contiguous buffer of size num_ranks * record_bytes.
|
||||
const std::size_t record_bytes = total_slots * sizeof(IbQpInfo);
|
||||
std::vector<IbQpInfo> all_infos(num_ranks * total_slots);
|
||||
std::memcpy(&all_infos[rank * total_slots], my_infos.data(), record_bytes);
|
||||
bootstrap->allGather(all_infos.data(), record_bytes);
|
||||
|
||||
// 4. RTR/RTS each local QP using the matching peer-side QP info. The peer
|
||||
// info we want is "rank r's QP for talking to us" == r's record indexed at
|
||||
// [c * num_ranks + rank].
|
||||
for (int c = 0; c < num_channels; ++c) {
|
||||
for (int r = 0; r < num_ranks; ++r) {
|
||||
if (r == rank) continue;
|
||||
const IbQpInfo& peer_info = all_infos[r * total_slots + c * num_ranks + rank];
|
||||
auto& qp = setup->qps[c * num_ranks + r];
|
||||
qp->rtr(peer_info);
|
||||
qp->rts();
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Wrap each QP with IbgdaResources (Stage 1) — produces GPU-mapped
|
||||
// sq_buf / dbrec / bf_reg / state pointers usable from the kernel.
|
||||
for (int c = 0; c < num_channels; ++c) {
|
||||
for (int r = 0; r < num_ranks; ++r) {
|
||||
if (r == rank) continue;
|
||||
ibv_qp* raw = setup->qps[c * num_ranks + r]->getRawQp();
|
||||
setup->resources[c * num_ranks + r] = std::make_unique<IbgdaResources>(raw);
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Register the existing rdma_buffer_ptr as an MR on this IbCtx, then
|
||||
// allgather (addr, rkey) so we can build per-peer remote_mrs entries.
|
||||
setup->rdma_mr = setup->ib_ctx->registerMr(rdma_buffer_ptr, num_rdma_bytes);
|
||||
IbgdaSetup::PeerMr my_rdma_mr{};
|
||||
{
|
||||
auto info = setup->rdma_mr->getInfo();
|
||||
my_rdma_mr.addr = info.addr;
|
||||
my_rdma_mr.rkey = info.rkey;
|
||||
}
|
||||
setup->peer_rdma.assign(num_ranks, IbgdaSetup::PeerMr{});
|
||||
setup->peer_rdma[rank] = my_rdma_mr;
|
||||
bootstrap->allGather(setup->peer_rdma.data(), sizeof(IbgdaSetup::PeerMr));
|
||||
|
||||
// 7. Allocate signal slots: total_slots * 4 bytes on GPU. Layout:
|
||||
// sig_slots[c * num_ranks + sender_peer] — when peer P sends signal() to
|
||||
// us through channel c, it RDMA-WRITEs to *our* sig_slots[c * num_ranks + P].
|
||||
// Each peer therefore needs to know our sig MR (addr+rkey) AND the offset
|
||||
// within it that corresponds to (channel, sender=P) == "we are receiving
|
||||
// from P". We allgather the base addr+rkey only; the offset is derivable.
|
||||
const std::size_t sig_bytes = std::size_t(total_slots) * sizeof(uint32_t);
|
||||
CUDA_CHECK(cudaMalloc(&setup->sig_slots, sig_bytes));
|
||||
CUDA_CHECK(cudaMemset(setup->sig_slots, 0, sig_bytes));
|
||||
CUDA_CHECK(cudaMalloc(&setup->sig_seq, sig_bytes));
|
||||
CUDA_CHECK(cudaMemset(setup->sig_seq, 0, sig_bytes));
|
||||
|
||||
// Use raw verbs for the signal MR (we need the rkey/lkey directly).
|
||||
ibv_pd* pd = setup->qps[(rank == 0 ? 1 : 0)]->getRawQp()->pd; // any non-null QP shares the same pd
|
||||
for (int c = 0; c < num_channels && pd == nullptr; ++c)
|
||||
for (int r = 0; r < num_ranks && pd == nullptr; ++r)
|
||||
if (auto& q = setup->qps[c * num_ranks + r]; q) pd = q->getRawQp()->pd;
|
||||
EP_HOST_ASSERT(pd != nullptr);
|
||||
setup->sig_mr = ibv_reg_mr(pd, setup->sig_slots, sig_bytes,
|
||||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
|
||||
if (!setup->sig_mr) {
|
||||
throw std::runtime_error("ibv_reg_mr(sig_slots) failed errno=" + std::to_string(errno));
|
||||
}
|
||||
|
||||
IbgdaSetup::PeerMr my_sig_mr{};
|
||||
my_sig_mr.addr = reinterpret_cast<uint64_t>(setup->sig_slots);
|
||||
my_sig_mr.rkey = setup->sig_mr->rkey;
|
||||
setup->peer_sig.assign(num_ranks, IbgdaSetup::PeerMr{});
|
||||
setup->peer_sig[rank] = my_sig_mr;
|
||||
bootstrap->allGather(setup->peer_sig.data(), sizeof(IbgdaSetup::PeerMr));
|
||||
|
||||
// 8. Build the GPU-resident MR tables. We have a single local MR (the
|
||||
// rdma_buffer_ptr); remote_mrs has one entry per peer rank.
|
||||
std::vector<IbgdaLocalMr> h_local(1);
|
||||
h_local[0].addr = reinterpret_cast<uint64_t>(rdma_buffer_ptr);
|
||||
h_local[0].lkey_be = htonl(setup->rdma_mr->getLkey());
|
||||
h_local[0].pad = 0;
|
||||
|
||||
std::vector<IbgdaRemoteMr> h_remote(num_ranks);
|
||||
for (int r = 0; r < num_ranks; ++r) {
|
||||
h_remote[r].addr = setup->peer_rdma[r].addr;
|
||||
h_remote[r].rkey_be = htonl(setup->peer_rdma[r].rkey);
|
||||
h_remote[r].pad = 0;
|
||||
}
|
||||
|
||||
setup->d_local_mrs = mscclpp::detail::gpuCallocShared<IbgdaLocalMr>(1);
|
||||
setup->d_remote_mrs = mscclpp::detail::gpuCallocShared<IbgdaRemoteMr>(num_ranks);
|
||||
mscclpp::gpuMemcpy<IbgdaLocalMr>(setup->d_local_mrs.get(), h_local.data(), 1, cudaMemcpyHostToDevice);
|
||||
mscclpp::gpuMemcpy<IbgdaRemoteMr>(setup->d_remote_mrs.get(), h_remote.data(), num_ranks, cudaMemcpyHostToDevice);
|
||||
|
||||
// 9. Build the device handle array (channel × num_ranks).
|
||||
std::vector<IbgdaPortChannelDeviceHandle> h_handles(total_slots);
|
||||
std::memset(h_handles.data(), 0, h_handles.size() * sizeof(IbgdaPortChannelDeviceHandle));
|
||||
for (int c = 0; c < num_channels; ++c) {
|
||||
for (int r = 0; r < num_ranks; ++r) {
|
||||
auto& h = h_handles[c * num_ranks + r];
|
||||
if (r == rank) continue; // self-slot left zeroed
|
||||
h.qp = setup->resources[c * num_ranks + r]->getHandle();
|
||||
h.local_mrs = setup->d_local_mrs.get();
|
||||
h.remote_mrs = setup->d_remote_mrs.get();
|
||||
|
||||
// Local sig slot WE poll for messages from peer r through channel c.
|
||||
h.sig_local_addr = &setup->sig_slots[c * num_ranks + r];
|
||||
h.sig_local_lkey = htonl(setup->sig_mr->lkey);
|
||||
|
||||
// Remote sig slot peer r polls for messages FROM US through channel c.
|
||||
// Peer r's slot for "rank == us" is at offset (c * num_ranks + rank) * 4
|
||||
// inside their signal buffer.
|
||||
h.sig_remote_addr = setup->peer_sig[r].addr +
|
||||
static_cast<uint64_t>(c * num_ranks + rank) * sizeof(uint32_t);
|
||||
h.sig_rkey_be = htonl(setup->peer_sig[r].rkey);
|
||||
|
||||
// Outbound seq counter is per-handle; carve out one u32 from sig_seq
|
||||
// (sig_seq is num_channels × num_ranks u32s so we can reuse the same
|
||||
// index function — it is unrelated to the inbound sig_slots buffer
|
||||
// beyond reusing the size).
|
||||
h.sig_seq = &setup->sig_seq[c * num_ranks + r];
|
||||
|
||||
h.dst = static_cast<uint32_t>(r); // index into remote_mrs[] (peer-rank-based)
|
||||
h.src = 0; // single-entry local_mrs table
|
||||
h.peer_rank = static_cast<uint32_t>(r);
|
||||
h._pad = 0;
|
||||
}
|
||||
}
|
||||
|
||||
setup->device_handles = mscclpp::detail::gpuCallocShared<IbgdaPortChannelDeviceHandle>(total_slots);
|
||||
mscclpp::gpuMemcpy<IbgdaPortChannelDeviceHandle>(setup->device_handles.get(), h_handles.data(), total_slots,
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
// 10. Spawn CQ drain thread. Kernel-issued rdma_write paths set CQ_UPDATE
|
||||
// on every WR; without this drain, the send CQ (sized 2 × kIbgdaMaxSendWr)
|
||||
// would fill within a few iterations of LL dispatch+combine and the QP
|
||||
// would error out. We collect raw send_cq pointers from each QP up front.
|
||||
{
|
||||
std::vector<ibv_cq*> send_cqs;
|
||||
send_cqs.reserve(total_slots);
|
||||
for (int idx = 0; idx < total_slots; ++idx) {
|
||||
auto& qp = setup->qps[idx];
|
||||
if (!qp) continue;
|
||||
ibv_cq* cq = qp->getRawQp()->send_cq;
|
||||
if (cq) send_cqs.push_back(cq);
|
||||
}
|
||||
IbgdaSetup* raw = setup.get();
|
||||
int dbg_rank = rank;
|
||||
setup->cq_poller_thread = std::thread([raw, send_cqs, dbg_rank]() {
|
||||
// Tight loop polling all CQs round-robin. Each ibv_poll_cq is cheap
|
||||
// (one PCIe-mapped read of CQE buffer + valid bit). We don't inspect
|
||||
// wc fields beyond their existence — a status != IBV_WC_SUCCESS would
|
||||
// indicate a fatal QP error which we surface lazily through later
|
||||
// failures (the test path will hang and the user sees the error).
|
||||
constexpr int kBatch = 16;
|
||||
ibv_wc wc[kBatch];
|
||||
uint64_t total_polled = 0;
|
||||
uint64_t total_errors = 0;
|
||||
auto t_last_log = std::chrono::steady_clock::now();
|
||||
while (!raw->cq_poller_stop.load(std::memory_order_acquire)) {
|
||||
bool any = false;
|
||||
for (ibv_cq* cq : send_cqs) {
|
||||
int n = ibv_poll_cq(cq, kBatch, wc);
|
||||
if (n > 0) {
|
||||
any = true;
|
||||
total_polled += n;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (wc[i].status != IBV_WC_SUCCESS) {
|
||||
total_errors++;
|
||||
fprintf(stderr,
|
||||
"[mscclpp_ep][ibgda][rank=%d] CQE error: status=%d (%s) opcode=%d vendor_err=0x%x wr_id=%llu\n",
|
||||
dbg_rank, wc[i].status, ibv_wc_status_str(wc[i].status), wc[i].opcode, wc[i].vendor_err,
|
||||
static_cast<unsigned long long>(wc[i].wr_id));
|
||||
fflush(stderr);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int rep = 0; rep < 3 && n == kBatch; ++rep) {
|
||||
n = ibv_poll_cq(cq, kBatch, wc);
|
||||
if (n > 0) {
|
||||
any = true;
|
||||
total_polled += n;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (wc[i].status != IBV_WC_SUCCESS) {
|
||||
total_errors++;
|
||||
fprintf(stderr,
|
||||
"[mscclpp_ep][ibgda][rank=%d] CQE error: status=%d (%s) opcode=%d vendor_err=0x%x\n",
|
||||
dbg_rank, wc[i].status, ibv_wc_status_str(wc[i].status), wc[i].opcode, wc[i].vendor_err);
|
||||
fflush(stderr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
if (std::chrono::duration_cast<std::chrono::seconds>(now - t_last_log).count() >= 5) {
|
||||
fprintf(stderr, "[mscclpp_ep][ibgda][rank=%d] poller: polled=%llu errors=%llu\n", dbg_rank,
|
||||
static_cast<unsigned long long>(total_polled), static_cast<unsigned long long>(total_errors));
|
||||
fflush(stderr);
|
||||
t_last_log = now;
|
||||
}
|
||||
if (!any) {
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(20));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return setup;
|
||||
}
|
||||
|
||||
} // namespace ep
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM
|
||||
116
src/ext/ep/ibgda_setup.hpp
Normal file
116
src/ext/ep/ibgda_setup.hpp
Normal file
@@ -0,0 +1,116 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
//
|
||||
// Stage 4b.1: native IBGDA host-side plumbing for src/ext/ep.
|
||||
//
|
||||
// Owned by `Buffer` and built lazily in `Buffer::sync()` when the env var
|
||||
// `MSCCLPP_EP_USE_IBGDA=1` is set AND the run is cross-node
|
||||
// (num_rdma_ranks > 1). Builds a parallel "shadow" of the existing
|
||||
// PortChannel layout: a flat (channel × num_ranks) array of
|
||||
// `IbgdaPortChannelDeviceHandle` POD structs sitting on the GPU, plus the
|
||||
// host-side resources (IbCtx, QPs, IbgdaResources, MRs, signal slots) that
|
||||
// keep them valid for the lifetime of the Buffer.
|
||||
//
|
||||
// 4b.1 only constructs the state — the kernels do NOT consume it yet. That
|
||||
// happens in 4b.2 behind a `kIbgdaPath` template branch.
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/ibgda_port_channel_device.hpp>
|
||||
|
||||
#include "../../core/include/ib.hpp"
|
||||
#include "../../core/include/ibgda.hpp"
|
||||
|
||||
struct ibv_mr;
|
||||
|
||||
namespace mscclpp {
|
||||
namespace ep {
|
||||
|
||||
// One owning bundle of host-side resources backing the device handle array.
|
||||
struct IbgdaSetup {
|
||||
IbgdaSetup() = default;
|
||||
~IbgdaSetup();
|
||||
IbgdaSetup(const IbgdaSetup&) = delete;
|
||||
IbgdaSetup& operator=(const IbgdaSetup&) = delete;
|
||||
// Layout constants (must match the existing port_channel_handles layout in
|
||||
// Buffer::sync, so kernels can reuse the same channel × num_ranks index
|
||||
// arithmetic in the IBGDA branch).
|
||||
int num_channels = 0; // == num_port_channels_per_rank in Buffer::sync
|
||||
int num_ranks = 0;
|
||||
int rank = 0;
|
||||
|
||||
// IB context + QPs + Stage-1 GPU mappings.
|
||||
std::unique_ptr<IbCtx> ib_ctx;
|
||||
// Indexed [channel * num_ranks + peer]; entries with peer == rank are null.
|
||||
std::vector<std::shared_ptr<IbQp>> qps;
|
||||
std::vector<std::unique_ptr<IbgdaResources>> resources;
|
||||
|
||||
// RDMA buffer MR. We register the *same* `rdma_buffer_ptr` that the
|
||||
// existing PortChannel path uses, so dst/src offsets in the kernel work
|
||||
// unchanged.
|
||||
std::unique_ptr<const IbMr> rdma_mr;
|
||||
|
||||
// Signal slots (GPU-resident).
|
||||
// Layout: 4 bytes per (channel, peer) on the receiving side. `sig_slots`
|
||||
// is the local mirror polled by `port_wait()`; remote peers RDMA-WRITE
|
||||
// into the corresponding slot of *their* signal MR for *us*. Each rank's
|
||||
// signal slot for (channel, peer=A) is at offset
|
||||
// offset_in_buf(channel, peer) = (channel * num_ranks + peer) * 4
|
||||
// within rank A's signal buffer (i.e. peer=A's view of "messages from
|
||||
// rank=A's POV indexed by channel × peer"). See the per-handle wiring
|
||||
// in `IbgdaSetup::populate_handles`.
|
||||
uint32_t* sig_slots = nullptr; // GPU mem; size = num_channels * num_ranks * 4
|
||||
ibv_mr* sig_mr = nullptr; // raw verbs (we keep the IbMr only for rdma_mr)
|
||||
uint32_t* sig_seq = nullptr; // GPU mem; per-handle outbound counters
|
||||
|
||||
// Per-peer (addr, rkey) for the RDMA buffer and the signal buffer. Filled
|
||||
// by an allgather over the bootstrap.
|
||||
struct PeerMr {
|
||||
uint64_t addr = 0;
|
||||
uint32_t rkey = 0;
|
||||
uint32_t pad = 0;
|
||||
};
|
||||
std::vector<PeerMr> peer_rdma; // size = num_ranks
|
||||
std::vector<PeerMr> peer_sig; // size = num_ranks
|
||||
|
||||
// Flat device-side handle array: num_channels * num_ranks entries.
|
||||
// Self entries (peer == rank) are zeroed and unused.
|
||||
std::shared_ptr<IbgdaPortChannelDeviceHandle> device_handles;
|
||||
|
||||
// Underlying GPU-side MR-table arrays referenced by every device handle
|
||||
// (see IbgdaPortChannelDeviceHandle::local_mrs / remote_mrs).
|
||||
std::shared_ptr<IbgdaLocalMr> d_local_mrs; // length 1 (we have a single MR)
|
||||
std::shared_ptr<IbgdaRemoteMr> d_remote_mrs; // length num_ranks
|
||||
|
||||
// CQ drain thread. The kernel-side rdma_write paths set CQ_UPDATE on
|
||||
// every WR, so without periodic ibv_poll_cq() the send CQ would fill up
|
||||
// (CQ size = kIbgdaMaxSendWr * 2). One thread polls all per-QP send CQs
|
||||
// round-robin until `cq_poller_stop` is set in the destructor.
|
||||
std::atomic<bool> cq_poller_stop{false};
|
||||
std::thread cq_poller_thread;
|
||||
};
|
||||
|
||||
// Build the full IBGDA setup. Bootstrap is used for cross-rank exchange of
|
||||
// QP info and MR keys; rdma_buffer_ptr/num_rdma_bytes is the same buffer the
|
||||
// existing PortChannel path uses. `ib_transport_index` is the IB device
|
||||
// index this rank will use (== `device_id` on NDv5).
|
||||
//
|
||||
// Throws on any irrecoverable error. On success returns a fully-initialised
|
||||
// IbgdaSetup with all QPs in RTS and the device-side handle array populated.
|
||||
std::unique_ptr<IbgdaSetup> build_ibgda_setup(int rank, int num_ranks, int ib_transport_index, int num_channels,
|
||||
void* rdma_buffer_ptr, std::size_t num_rdma_bytes,
|
||||
std::shared_ptr<TcpBootstrap> bootstrap);
|
||||
|
||||
} // namespace ep
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // USE_IBVERBS && MSCCLPP_USE_MLX5DV && !MSCCLPP_USE_ROCM
|
||||
@@ -17,6 +17,13 @@
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
#include <vector>
|
||||
|
||||
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
|
||||
#include <mscclpp/ibgda_port_channel_device.hpp>
|
||||
#define MSCCLPP_EP_KERNEL_HAS_IBGDA 1
|
||||
#else
|
||||
namespace mscclpp { struct IbgdaPortChannelDeviceHandle; }
|
||||
#endif
|
||||
|
||||
namespace mscclpp {
|
||||
namespace ep {
|
||||
|
||||
@@ -131,7 +138,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv
|
||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank,
|
||||
int num_ranks, bool use_fp8, void* workspace, cudaStream_t stream, int phases, void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path);
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path,
|
||||
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles, bool use_ibgda_path);
|
||||
|
||||
void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, const void* x,
|
||||
const int64_t* topk_idx, const float* topk_weights, const int* src_info, const int64_t* layout_range,
|
||||
@@ -139,7 +147,8 @@ void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void*
|
||||
int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases, bool zero_copy, void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path);
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path,
|
||||
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles, bool use_ibgda_path);
|
||||
|
||||
} // namespace internode_ll
|
||||
|
||||
|
||||
@@ -31,6 +31,11 @@
|
||||
#include <mscclpp/memory_channel_device.hpp>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
|
||||
#if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM)
|
||||
#include "ibgda_port_channel_device.cuh"
|
||||
#define MSCCLPP_EP_LL_HAS_IBGDA 1
|
||||
#endif
|
||||
|
||||
#include "configs.cuh"
|
||||
#include "exception.cuh"
|
||||
#include "launch.cuh"
|
||||
@@ -142,14 +147,15 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, int64_t* cl
|
||||
// dispatch
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
template <bool kUseFP8, bool kIpcPath, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
template <bool kUseFP8, bool kIpcPath, bool kIbgdaPath, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
__global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dispatch(
|
||||
void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count, void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x, const void* x,
|
||||
const int64_t* topk_idx, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, int64_t* next_clean,
|
||||
int num_next_clean_int, int num_tokens, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts,
|
||||
int rank, int num_ranks, int phases, void* rdma_buffer_ptr, mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
void* const* peer_rdma_bases, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
|
||||
void* const* peer_rdma_bases, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
|
||||
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||
@@ -249,14 +255,31 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
|
||||
const auto* dst_int4_ptr = reinterpret_cast<int4*>(peer_dst);
|
||||
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
|
||||
} else {
|
||||
// MSCCL++ port-channel PUT (lane 0 issues one request).
|
||||
if (lane_id == 0) {
|
||||
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
|
||||
const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr);
|
||||
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank].put(dst_off, src_off,
|
||||
num_bytes_per_msg);
|
||||
#ifdef MSCCLPP_EP_LL_HAS_IBGDA
|
||||
if constexpr (kIbgdaPath) {
|
||||
// Native IBGDA: lane 0 issues one RDMA WRITE WQE directly.
|
||||
// UNSIGNALED + no DB ring — the trailing count write below
|
||||
// (signaled, rings DB) flushes the per-QP queue.
|
||||
if (lane_id == 0) {
|
||||
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
|
||||
const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr);
|
||||
mscclpp::ibgda::port_put(ibgda_handles[dst_expert_local_idx * num_ranks + dst_rank], dst_off, src_off,
|
||||
num_bytes_per_msg,
|
||||
/*signal_cqe=*/false, /*ring_db=*/false);
|
||||
}
|
||||
__syncwarp();
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
// MSCCL++ port-channel PUT (lane 0 issues one request).
|
||||
if (lane_id == 0) {
|
||||
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
|
||||
const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr);
|
||||
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank].put(dst_off, src_off,
|
||||
num_bytes_per_msg);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
} else {
|
||||
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
|
||||
@@ -322,10 +345,26 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
|
||||
peer_rdma_bases, rdma_buffer_ptr, dst_rank));
|
||||
st_na_release(peer_counter, static_cast<int64_t>(-num_tokens_sent - 1));
|
||||
} else {
|
||||
auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
|
||||
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(counter_ptr), rdma_buffer_ptr);
|
||||
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank].atomicAdd(
|
||||
off, static_cast<int64_t>(-num_tokens_sent - 1));
|
||||
#ifdef MSCCLPP_EP_LL_HAS_IBGDA
|
||||
if constexpr (kIbgdaPath) {
|
||||
// Single writer per (dst_expert_local_idx, rank) slot, so an
|
||||
// 8-byte inline RDMA WRITE delivering the encoded count is
|
||||
// semantically equivalent to atomicAdd from a zero-initialised
|
||||
// remote slot. The receiver polls with ld_acquire_sys_global.
|
||||
auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
|
||||
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(counter_ptr), rdma_buffer_ptr);
|
||||
const auto& ch = ibgda_handles[dst_expert_local_idx * num_ranks + dst_rank];
|
||||
mscclpp::IbgdaRemoteMr r = ch.remote_mrs[ch.dst];
|
||||
mscclpp::ibgda::rdma_write_inl8(ch.qp, static_cast<uint64_t>(static_cast<int64_t>(-num_tokens_sent - 1)),
|
||||
r.addr + off, r.rkey_be);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
|
||||
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(counter_ptr), rdma_buffer_ptr);
|
||||
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank].atomicAdd(
|
||||
off, static_cast<int64_t>(-num_tokens_sent - 1));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
|
||||
@@ -405,7 +444,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv
|
||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank,
|
||||
int num_ranks, bool use_fp8, void* workspace, cudaStream_t stream, int phases, void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path) {
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path,
|
||||
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles, bool use_ibgda_path) {
|
||||
constexpr int kNumMaxTopK = 9;
|
||||
// (kNumWarpGroups, kNumWarpsPerGroup) is path-dependent. Intra-node IPC
|
||||
// benefits from 1 expert per SM with 32 warps cooperating on the recv-side
|
||||
@@ -446,26 +486,37 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv
|
||||
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
|
||||
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(hidden_case) \
|
||||
{ \
|
||||
if (use_ipc_path) { \
|
||||
auto dispatch_func = use_fp8 ? dispatch<true, true, kNumWarpGroupsIpc, kNumWarpsPerGroupIpc, hidden_case> \
|
||||
: dispatch<false, true, kNumWarpGroupsIpc, kNumWarpsPerGroupIpc, hidden_case>; \
|
||||
LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \
|
||||
packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
|
||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \
|
||||
rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles); \
|
||||
} else { \
|
||||
auto dispatch_func = use_fp8 ? dispatch<true, false, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case> \
|
||||
: dispatch<false, false, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case>; \
|
||||
LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \
|
||||
packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
|
||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \
|
||||
rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles); \
|
||||
} \
|
||||
} \
|
||||
#define DISPATCH_LAUNCH_CASE(hidden_case) \
|
||||
{ \
|
||||
if (use_ipc_path) { \
|
||||
auto dispatch_func = \
|
||||
use_fp8 ? dispatch<true, true, false, kNumWarpGroupsIpc, kNumWarpsPerGroupIpc, hidden_case> \
|
||||
: dispatch<false, true, false, kNumWarpGroupsIpc, kNumWarpsPerGroupIpc, hidden_case>; \
|
||||
LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \
|
||||
packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
|
||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \
|
||||
rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles, ibgda_handles); \
|
||||
} else if (use_ibgda_path) { \
|
||||
auto dispatch_func = \
|
||||
use_fp8 ? dispatch<true, false, true, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case> \
|
||||
: dispatch<false, false, true, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case>; \
|
||||
LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \
|
||||
packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
|
||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \
|
||||
rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles, ibgda_handles); \
|
||||
} else { \
|
||||
auto dispatch_func = \
|
||||
use_fp8 ? dispatch<true, false, false, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case> \
|
||||
: dispatch<false, false, false, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case>; \
|
||||
LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \
|
||||
packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
|
||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \
|
||||
rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles, ibgda_handles); \
|
||||
} \
|
||||
} \
|
||||
break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
@@ -477,14 +528,15 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv
|
||||
// combine
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
template <bool kIpcPath, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
|
||||
template <bool kIpcPath, bool kIbgdaPath, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
|
||||
__global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void combine(
|
||||
void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, const void* x,
|
||||
const int64_t* topk_idx, const float* topk_weights, const int* src_info, const int64_t* layout_range,
|
||||
int64_t* next_clean, int num_next_clean_int, int* atomic_clean_flag, int num_combined_tokens, int hidden,
|
||||
int num_topk, int num_max_dispatch_tokens_per_rank, int num_experts, int rank, int num_ranks, int phases,
|
||||
bool zero_copy, void* rdma_buffer_ptr, mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
void* const* peer_rdma_bases, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
|
||||
void* const* peer_rdma_bases, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
|
||||
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
@@ -546,17 +598,35 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com
|
||||
const auto peer_dst_int4 = reinterpret_cast<int4*>(peer_dst);
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, peer_dst_int4, x_int4, ld_nc_global, st_na_global);
|
||||
} else {
|
||||
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
|
||||
if (not zero_copy)
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
// MSCCL++ port-channel PUT.
|
||||
if (lane_id == 0) {
|
||||
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
|
||||
const auto src_off = rdma_offset_of(static_cast<uint64_t>(buf_ptr), rdma_buffer_ptr);
|
||||
port_channel_handles[local_expert_idx * num_ranks + dst_rank].put(dst_off, src_off,
|
||||
hidden * sizeof(nv_bfloat16));
|
||||
#ifdef MSCCLPP_EP_LL_HAS_IBGDA
|
||||
if constexpr (kIbgdaPath) {
|
||||
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
|
||||
if (not zero_copy)
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
if (lane_id == 0) {
|
||||
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
|
||||
const auto src_off = rdma_offset_of(static_cast<uint64_t>(buf_ptr), rdma_buffer_ptr);
|
||||
// UNSIGNALED + no DB ring — trailing flag write drives the doorbell.
|
||||
mscclpp::ibgda::port_put(ibgda_handles[local_expert_idx * num_ranks + dst_rank], dst_off, src_off,
|
||||
hidden * sizeof(nv_bfloat16),
|
||||
/*signal_cqe=*/false, /*ring_db=*/false);
|
||||
}
|
||||
__syncwarp();
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
|
||||
if (not zero_copy)
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
// MSCCL++ port-channel PUT.
|
||||
if (lane_id == 0) {
|
||||
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
|
||||
const auto src_off = rdma_offset_of(static_cast<uint64_t>(buf_ptr), rdma_buffer_ptr);
|
||||
port_channel_handles[local_expert_idx * num_ranks + dst_rank].put(dst_off, src_off,
|
||||
hidden * sizeof(nv_bfloat16));
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -573,9 +643,20 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com
|
||||
peer_rdma_bases, rdma_buffer_ptr, dst_rank));
|
||||
st_na_release(peer_flag, static_cast<int64_t>(1));
|
||||
} else {
|
||||
auto* flag_ptr = rdma_recv_flag + global_expert_idx;
|
||||
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(flag_ptr), rdma_buffer_ptr);
|
||||
port_channel_handles[local_expert_idx * num_ranks + dst_rank].atomicAdd(off, static_cast<int64_t>(1));
|
||||
#ifdef MSCCLPP_EP_LL_HAS_IBGDA
|
||||
if constexpr (kIbgdaPath) {
|
||||
auto* flag_ptr = rdma_recv_flag + global_expert_idx;
|
||||
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(flag_ptr), rdma_buffer_ptr);
|
||||
const auto& ch = ibgda_handles[local_expert_idx * num_ranks + dst_rank];
|
||||
mscclpp::IbgdaRemoteMr r = ch.remote_mrs[ch.dst];
|
||||
mscclpp::ibgda::rdma_write_inl8(ch.qp, static_cast<uint64_t>(1), r.addr + off, r.rkey_be);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
auto* flag_ptr = rdma_recv_flag + global_expert_idx;
|
||||
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(flag_ptr), rdma_buffer_ptr);
|
||||
port_channel_handles[local_expert_idx * num_ranks + dst_rank].atomicAdd(off, static_cast<int64_t>(1));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
st_na_release(rdma_recv_flag + global_expert_idx, static_cast<int64_t>(1));
|
||||
@@ -639,7 +720,8 @@ void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void*
|
||||
int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases, bool zero_copy, void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path) {
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path,
|
||||
mscclpp::IbgdaPortChannelDeviceHandle* ibgda_handles, bool use_ibgda_path) {
|
||||
// See the comment in `dispatch()`: (kNumWarpGroups, kNumWarpsPerGroup)
|
||||
// is path-dependent. IPC uses (1, 32) to mirror NCCL-EP; PortChannel keeps
|
||||
// (3, 10) to avoid host-proxy FIFO contention on the IB path.
|
||||
@@ -671,24 +753,34 @@ void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void*
|
||||
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
|
||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
|
||||
|
||||
#define COMBINE_LAUNCH_CASE(hidden_case) \
|
||||
{ \
|
||||
if (use_ipc_path) { \
|
||||
auto combine_func = combine<true, kNumWarpGroupsIpc, kNumWarpsPerGroupIpc, hidden_case, kNumMaxTopk>; \
|
||||
LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \
|
||||
topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \
|
||||
num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \
|
||||
num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \
|
||||
memory_channel_handles); \
|
||||
} else { \
|
||||
auto combine_func = combine<false, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case, kNumMaxTopk>; \
|
||||
LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \
|
||||
topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \
|
||||
num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \
|
||||
num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \
|
||||
memory_channel_handles); \
|
||||
} \
|
||||
} \
|
||||
#define COMBINE_LAUNCH_CASE(hidden_case) \
|
||||
{ \
|
||||
if (use_ipc_path) { \
|
||||
auto combine_func = \
|
||||
combine<true, false, kNumWarpGroupsIpc, kNumWarpsPerGroupIpc, hidden_case, kNumMaxTopk>; \
|
||||
LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \
|
||||
topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \
|
||||
num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \
|
||||
num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \
|
||||
memory_channel_handles, ibgda_handles); \
|
||||
} else if (use_ibgda_path) { \
|
||||
auto combine_func = \
|
||||
combine<false, true, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case, kNumMaxTopk>; \
|
||||
LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \
|
||||
topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \
|
||||
num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \
|
||||
num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \
|
||||
memory_channel_handles, ibgda_handles); \
|
||||
} else { \
|
||||
auto combine_func = \
|
||||
combine<false, false, kNumWarpGroupsRdma, kNumWarpsPerGroupRdma, hidden_case, kNumMaxTopk>; \
|
||||
LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \
|
||||
topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \
|
||||
num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \
|
||||
num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \
|
||||
memory_channel_handles, ibgda_handles); \
|
||||
} \
|
||||
} \
|
||||
break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
|
||||
@@ -43,14 +43,45 @@ import sys
|
||||
# noisy 'recvValue failed / Connection was likely closed' stack traces.
|
||||
os.environ.setdefault("TORCH_NCCL_ENABLE_MONITORING", "0")
|
||||
|
||||
import ctypes
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# Load libnuma for NUMA-aware memory binding (mirrors DeepEP/tests/utils.py).
|
||||
try:
|
||||
_libnuma = ctypes.CDLL("libnuma.so")
|
||||
_libnuma.numa_available.restype = ctypes.c_int
|
||||
_libnuma.numa_run_on_node.argtypes = [ctypes.c_int]
|
||||
_libnuma.numa_set_preferred.argtypes = [ctypes.c_int]
|
||||
except OSError:
|
||||
_libnuma = None
|
||||
|
||||
|
||||
def set_numa_affinity(local_rank: int):
|
||||
cores_per_rank = 12
|
||||
numa_node = local_rank // 4
|
||||
core_start = local_rank * cores_per_rank
|
||||
core_end = core_start + cores_per_rank
|
||||
p = psutil.Process(os.getpid())
|
||||
p.cpu_affinity(list(range(core_start, core_end)))
|
||||
print(f"Rank {local_rank} numa node {numa_node} bound to cores {core_start}-{core_end - 1}")
|
||||
|
||||
# Bind memory to NUMA node
|
||||
if _libnuma is not None and _libnuma.numa_available() != -1:
|
||||
_libnuma.numa_set_preferred(numa_node)
|
||||
print(f"Rank {local_rank}: CPU affinity → cores {core_start}-{core_end - 1}, memory NUMA → node {numa_node}")
|
||||
else:
|
||||
print(f"Rank {local_rank}: libnuma not available")
|
||||
|
||||
|
||||
def init_dist():
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", rank))
|
||||
set_numa_affinity(local_rank)
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
|
||||
Reference in New Issue
Block a user