mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
IB gather WIP
This commit is contained in:
@@ -542,7 +542,8 @@ class Communicator {
|
||||
/// @param ibMaxWrPerSend The maximum number of work requests per send for IB. Unused if transport is not IB.
|
||||
/// @return std::shared_ptr<Connection> A shared pointer to the connection.
|
||||
std::shared_ptr<Connection> connectOnSetup(int remoteRank, int tag, Transport transport, int ibMaxCqSize = 1024,
|
||||
int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64);
|
||||
int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64,
|
||||
int ibMaxNumSgesPerWr = 16);
|
||||
|
||||
/// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called.
|
||||
///
|
||||
|
||||
@@ -98,7 +98,8 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int rem
|
||||
int ibMaxCqSize /*=1024*/,
|
||||
int ibMaxCqPollNum /*=1*/,
|
||||
int ibMaxSendWr /*=8192*/,
|
||||
int ibMaxWrPerSend /*=64*/) {
|
||||
int ibMaxWrPerSend /*=64*/,
|
||||
int ibMaxNumSgesPerWr /*=16*/) {
|
||||
std::shared_ptr<ConnectionBase> conn;
|
||||
if (transport == Transport::CudaIpc) {
|
||||
// sanity check: make sure the IPC connection is being made within a node
|
||||
@@ -116,7 +117,7 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int rem
|
||||
pimpl->rankToHash_[remoteRank]);
|
||||
} else if (AllIBTransports.has(transport)) {
|
||||
auto ibConn = std::make_shared<IBConnection>(remoteRank, tag, transport, ibMaxCqSize, ibMaxCqPollNum, ibMaxSendWr,
|
||||
ibMaxWrPerSend, *pimpl);
|
||||
ibMaxWrPerSend, ibMaxNumSgesPerWr, *pimpl);
|
||||
conn = ibConn;
|
||||
INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created",
|
||||
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],
|
||||
|
||||
@@ -83,13 +83,13 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) {
|
||||
// IBConnection
|
||||
|
||||
IBConnection::IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr,
|
||||
int maxWrPerSend, Communicator::Impl& commImpl)
|
||||
int maxWrPerSend, int maxNumSgesPerWr, Communicator::Impl& commImpl)
|
||||
: ConnectionBase(remoteRank, tag),
|
||||
transport_(transport),
|
||||
remoteTransport_(Transport::Unknown),
|
||||
numSignaledSends(0),
|
||||
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
|
||||
qp = commImpl.getIbContext(transport)->createQp(maxCqSize, maxCqPollNum, maxSendWr, 0, maxWrPerSend);
|
||||
qp = commImpl.getIbContext(transport)->createQp(maxCqSize, maxCqPollNum, maxSendWr, 0, maxWrPerSend, maxNumSgesPerWr);
|
||||
dummyAtomicSourceMem_ = RegisteredMemory(std::make_shared<RegisteredMemory::Impl>(
|
||||
dummyAtomicSource_.get(), sizeof(uint64_t), commImpl.bootstrap_->getRank(), transport, commImpl));
|
||||
validateTransport(dummyAtomicSourceMem_, transport);
|
||||
|
||||
127
src/ib.cc
127
src/ib.cc
@@ -16,6 +16,16 @@
|
||||
#include "api.h"
|
||||
#include "debug.h"
|
||||
|
||||
static ibv_device_attr getDeviceAttr(ibv_context *ctx) {
|
||||
ibv_device_attr devAttr;
|
||||
if (ibv_query_device(ctx, &devAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_device failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
return devAttr;
|
||||
}
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
|
||||
@@ -53,8 +63,8 @@ const void* IbMr::getBuff() const { return this->buff; }
|
||||
uint32_t IbMr::getLkey() const { return this->mr->lkey; }
|
||||
|
||||
IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr,
|
||||
int maxWrPerSend)
|
||||
: maxCqPollNum(maxCqPollNum), maxWrPerSend(maxWrPerSend) {
|
||||
int maxWrPerSend, int maxNumSgesPerWr)
|
||||
: maxCqPollNum_(maxCqPollNum), maxWrPerSend_(maxWrPerSend), maxNumSgesPerWr_(maxNumSgesPerWr) {
|
||||
this->cq = ibv_create_cq(ctx, maxCqSize, nullptr, nullptr, 0);
|
||||
if (this->cq == nullptr) {
|
||||
std::stringstream err;
|
||||
@@ -117,10 +127,11 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
this->qp = _qp;
|
||||
this->wrn = 0;
|
||||
this->wrs = std::make_unique<ibv_send_wr[]>(maxWrPerSend);
|
||||
this->sges = std::make_unique<ibv_sge[]>(maxWrPerSend);
|
||||
this->wcs = std::make_unique<ibv_wc[]>(maxCqPollNum);
|
||||
this->wrs = std::make_unique<ibv_send_wr[]>(maxWrPerSend_);
|
||||
this->sges = std::make_unique<ibv_sge[]>(maxWrPerSend_ * maxNumSgesPerWr_);
|
||||
this->wcs = std::make_unique<ibv_wc[]>(maxCqPollNum_);
|
||||
numStagedWrs_ = 0;
|
||||
numStagedSges_ = 0;
|
||||
}
|
||||
|
||||
IbQp::~IbQp() {
|
||||
@@ -181,29 +192,34 @@ void IbQp::rts() {
|
||||
}
|
||||
}
|
||||
|
||||
IbQp::WrInfo IbQp::getNewWrInfo() {
|
||||
if (this->wrn >= this->maxWrPerSend) {
|
||||
IbQp::WrInfo IbQp::getNewWrInfo(int numSges) {
|
||||
if (numStagedWrs_ >= maxWrPerSend_) {
|
||||
std::stringstream err;
|
||||
err << "too many outstanding work requests. limit is " << this->maxWrPerSend;
|
||||
err << "too many outstanding work requests. limit is " << maxWrPerSend_;
|
||||
throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage);
|
||||
}
|
||||
int wrn = this->wrn;
|
||||
|
||||
ibv_send_wr* wr_ = &this->wrs[wrn];
|
||||
ibv_sge* sge_ = &this->sges[wrn];
|
||||
wr_->sg_list = sge_;
|
||||
wr_->num_sge = 1;
|
||||
wr_->next = nullptr;
|
||||
if (wrn > 0) {
|
||||
this->wrs[wrn - 1].next = wr_;
|
||||
if (numSges > maxNumSgesPerWr_) {
|
||||
std::stringstream err;
|
||||
err << "too many sges per work request. limit is " << maxNumSgesPerWr_;
|
||||
throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage);
|
||||
}
|
||||
this->wrn++;
|
||||
|
||||
ibv_send_wr* wr_ = &this->wrs[numStagedWrs_];
|
||||
ibv_sge* sge_ = &this->sges[numStagedSges_];
|
||||
wr_->sg_list = sge_;
|
||||
wr_->num_sge = numSges;
|
||||
wr_->next = nullptr;
|
||||
if (numStagedWrs_ > 0) {
|
||||
this->wrs[numStagedWrs_ - 1].next = wr_;
|
||||
}
|
||||
numStagedWrs_++;
|
||||
numStagedSges_ += numSges;
|
||||
return IbQp::WrInfo{wr_, sge_};
|
||||
}
|
||||
|
||||
void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled) {
|
||||
auto wrInfo = this->getNewWrInfo();
|
||||
auto wrInfo = this->getNewWrInfo(1);
|
||||
wrInfo.wr->wr_id = wrId;
|
||||
wrInfo.wr->opcode = IBV_WR_RDMA_WRITE;
|
||||
wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
|
||||
@@ -215,7 +231,7 @@ void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64
|
||||
}
|
||||
|
||||
void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal) {
|
||||
auto wrInfo = this->getNewWrInfo();
|
||||
auto wrInfo = this->getNewWrInfo(1);
|
||||
wrInfo.wr->wr_id = wrId;
|
||||
wrInfo.wr->opcode = IBV_WR_ATOMIC_FETCH_AND_ADD;
|
||||
wrInfo.wr->send_flags = 0; // atomic op cannot be signaled
|
||||
@@ -229,7 +245,7 @@ void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, u
|
||||
|
||||
void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled, unsigned int immData) {
|
||||
auto wrInfo = this->getNewWrInfo();
|
||||
auto wrInfo = this->getNewWrInfo(1);
|
||||
wrInfo.wr->wr_id = wrId;
|
||||
wrInfo.wr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
|
||||
@@ -241,8 +257,28 @@ void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size,
|
||||
wrInfo.sge->lkey = mr->getLkey();
|
||||
}
|
||||
|
||||
void IbQp::stageSendGather(const std::vector<IbMr*>& srcMrs, const IbMrInfo& dstInfo, const std::vector<uint32_t>& srcSizes,
|
||||
uint64_t wrId, const std::vector<uint64_t>& srcOffsets, uint64_t dstOffset, bool signaled) {
|
||||
size_t numSrcs = srcMrs.size();
|
||||
if (numSrcs != srcSizes.size() || numSrcs != srcOffsets.size()) {
|
||||
std::stringstream err;
|
||||
err << "invalid srcs: srcMrs.size()=" << numSrcs << ", srcSizes.size()=" << srcSizes.size()
|
||||
<< ", srcOffsets.size()=" << srcOffsets.size();
|
||||
throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage);
|
||||
}
|
||||
auto wrInfo = this->getNewWrInfo(numSrcs);
|
||||
wrInfo.wr->wr_id = wrId;
|
||||
wrInfo.wr->opcode = IBV_WR_RDMA_READ;
|
||||
wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
|
||||
wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(dstInfo.addr) + dstOffset;
|
||||
wrInfo.wr->wr.rdma.rkey = dstInfo.rkey;
|
||||
// wrInfo.sge->addr = (uint64_t)(mr->getBuff()) + srcOffset;
|
||||
// wrInfo.sge->length = size;
|
||||
// wrInfo.sge->lkey = mr->getLkey();
|
||||
}
|
||||
|
||||
void IbQp::postSend() {
|
||||
if (this->wrn == 0) {
|
||||
if (numStagedWrs_ == 0) {
|
||||
return;
|
||||
}
|
||||
struct ibv_send_wr* bad_wr;
|
||||
@@ -252,7 +288,8 @@ void IbQp::postSend() {
|
||||
err << "ibv_post_send failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
this->wrn = 0;
|
||||
numStagedWrs_ = 0;
|
||||
numStagedSges_ = 0;
|
||||
}
|
||||
|
||||
void IbQp::postRecv(uint64_t wrId) {
|
||||
@@ -269,7 +306,7 @@ void IbQp::postRecv(uint64_t wrId) {
|
||||
}
|
||||
}
|
||||
|
||||
int IbQp::pollCq() { return ibv_poll_cq(this->cq, this->maxCqPollNum, this->wcs.get()); }
|
||||
int IbQp::pollCq() { return ibv_poll_cq(this->cq, maxCqPollNum_, this->wcs.get()); }
|
||||
|
||||
IbQpInfo& IbQp::getInfo() { return this->info; }
|
||||
|
||||
@@ -321,12 +358,7 @@ bool IbCtx::isPortUsable(int port) const {
|
||||
}
|
||||
|
||||
int IbCtx::getAnyActivePort() const {
|
||||
struct ibv_device_attr devAttr;
|
||||
if (ibv_query_device(this->ctx, &devAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_device failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
ibv_device_attr devAttr = getDeviceAttr(this->ctx);
|
||||
for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) {
|
||||
if (this->isPortUsable(port)) {
|
||||
return port;
|
||||
@@ -335,17 +367,42 @@ int IbCtx::getAnyActivePort() const {
|
||||
return -1;
|
||||
}
|
||||
|
||||
IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend,
|
||||
void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr,
|
||||
int port) const {
|
||||
if (!this->isPortUsable(port)) {
|
||||
throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage);
|
||||
}
|
||||
ibv_device_attr devAttr = getDeviceAttr(this->ctx);
|
||||
if (maxCqSize > devAttr.max_cqe || maxCqSize < 1) {
|
||||
throw mscclpp::Error("invalid maxCqSize: " + std::to_string(maxCqSize), ErrorCode::InvalidUsage);
|
||||
}
|
||||
if (maxCqPollNum > maxCqSize || maxCqPollNum < 1) {
|
||||
throw mscclpp::Error("invalid maxCqPollNum: " + std::to_string(maxCqPollNum), ErrorCode::InvalidUsage);
|
||||
}
|
||||
if (maxSendWr > devAttr.max_qp_wr || maxSendWr < 1) {
|
||||
throw mscclpp::Error("invalid maxSendWr: " + std::to_string(maxSendWr), ErrorCode::InvalidUsage);
|
||||
}
|
||||
if (maxRecvWr > devAttr.max_qp_wr || maxRecvWr < 1) {
|
||||
throw mscclpp::Error("invalid maxRecvWr: " + std::to_string(maxRecvWr), ErrorCode::InvalidUsage);
|
||||
}
|
||||
if (maxWrPerSend > maxSendWr || maxWrPerSend < 1) {
|
||||
throw mscclpp::Error("invalid maxWrPerSend: " + std::to_string(maxWrPerSend), ErrorCode::InvalidUsage);
|
||||
}
|
||||
if (maxNumSgesPerWr > devAttr.max_sge || maxNumSgesPerWr < 1) {
|
||||
throw mscclpp::Error("invalid maxNumSgesPerWr: " + std::to_string(maxNumSgesPerWr), ErrorCode::InvalidUsage);
|
||||
}
|
||||
}
|
||||
|
||||
IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr,
|
||||
int port /*=-1*/) {
|
||||
if (port == -1) {
|
||||
port = this->getAnyActivePort();
|
||||
if (port == -1) {
|
||||
throw mscclpp::Error("No active port found", ErrorCode::InternalError);
|
||||
}
|
||||
} else if (!this->isPortUsable(port)) {
|
||||
throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InternalError);
|
||||
}
|
||||
qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend));
|
||||
validateConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port);
|
||||
qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr));
|
||||
return qps.back().get();
|
||||
}
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class IBConnection : public ConnectionBase {
|
||||
|
||||
public:
|
||||
IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr,
|
||||
int maxWrPerSend, Communicator::Impl& commImpl);
|
||||
int maxWrPerSend, int maxNumSgesPerWr, Communicator::Impl& commImpl);
|
||||
|
||||
Transport transport() override;
|
||||
|
||||
|
||||
@@ -66,6 +66,8 @@ class IbQp {
|
||||
void stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal);
|
||||
void stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled, unsigned int immData);
|
||||
void stageSendGather(const std::vector<IbMr*>& srcMrs, const IbMrInfo& dstInfo, const std::vector<uint32_t>& srcSizes,
|
||||
uint64_t wrId, const std::vector<uint64_t>& srcOffsets, uint64_t dstOffset, bool signaled);
|
||||
void postSend();
|
||||
void postRecv(uint64_t wrId);
|
||||
int pollCq();
|
||||
@@ -80,8 +82,8 @@ class IbQp {
|
||||
};
|
||||
|
||||
IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr,
|
||||
int maxWrPerSend);
|
||||
WrInfo getNewWrInfo();
|
||||
int maxWrPerSend, int maxNumSgesPerWr);
|
||||
WrInfo getNewWrInfo(int numSges);
|
||||
|
||||
IbQpInfo info;
|
||||
|
||||
@@ -90,10 +92,12 @@ class IbQp {
|
||||
std::unique_ptr<ibv_wc[]> wcs;
|
||||
std::unique_ptr<ibv_send_wr[]> wrs;
|
||||
std::unique_ptr<ibv_sge[]> sges;
|
||||
int wrn;
|
||||
int numStagedWrs_;
|
||||
int numStagedSges_;
|
||||
|
||||
const int maxCqPollNum;
|
||||
const int maxWrPerSend;
|
||||
const int maxCqPollNum_;
|
||||
const int maxWrPerSend_;
|
||||
const int maxNumSgesPerWr_;
|
||||
|
||||
friend class IbCtx;
|
||||
};
|
||||
@@ -103,7 +107,7 @@ class IbCtx {
|
||||
IbCtx(const std::string& devName);
|
||||
~IbCtx();
|
||||
|
||||
IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int port = -1);
|
||||
IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port = -1);
|
||||
const IbMr* registerMr(void* buff, std::size_t size);
|
||||
|
||||
const std::string& getDevName() const;
|
||||
@@ -111,6 +115,7 @@ class IbCtx {
|
||||
private:
|
||||
bool isPortUsable(int port) const;
|
||||
int getAnyActivePort() const;
|
||||
void validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port) const;
|
||||
|
||||
const std::string devName;
|
||||
ibv_context* ctx;
|
||||
|
||||
@@ -36,7 +36,7 @@ void IbPeerToPeerTest::SetUp() {
|
||||
bootstrap->initialize(id);
|
||||
|
||||
ibCtx = std::make_shared<mscclpp::IbCtx>(ibDevName);
|
||||
qp = ibCtx->createQp(1024, 1, 8192, 0, 64);
|
||||
qp = ibCtx->createQp(1024, 1, 8192, 0, 64, 1);
|
||||
|
||||
qpInfo[gEnv->rank] = qp->getInfo();
|
||||
bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo));
|
||||
|
||||
Reference in New Issue
Block a user