diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 1e9e6abd..67ed523a 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -542,7 +542,8 @@ class Communicator { /// @param ibMaxWrPerSend The maximum number of work requests per send for IB. Unused if transport is not IB. /// @return std::shared_ptr A shared pointer to the connection. std::shared_ptr connectOnSetup(int remoteRank, int tag, Transport transport, int ibMaxCqSize = 1024, - int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64); + int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64, + int ibMaxNumSgesPerWr = 16); /// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called. /// diff --git a/src/communicator.cc b/src/communicator.cc index cc032355..d5b49fa1 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -98,7 +98,8 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int rem int ibMaxCqSize /*=1024*/, int ibMaxCqPollNum /*=1*/, int ibMaxSendWr /*=8192*/, - int ibMaxWrPerSend /*=64*/) { + int ibMaxWrPerSend /*=64*/, + int ibMaxNumSgesPerWr /*=16*/) { std::shared_ptr conn; if (transport == Transport::CudaIpc) { // sanity check: make sure the IPC connection is being made within a node @@ -116,7 +117,7 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int rem pimpl->rankToHash_[remoteRank]); } else if (AllIBTransports.has(transport)) { auto ibConn = std::make_shared(remoteRank, tag, transport, ibMaxCqSize, ibMaxCqPollNum, ibMaxSendWr, - ibMaxWrPerSend, *pimpl); + ibMaxWrPerSend, ibMaxNumSgesPerWr, *pimpl); conn = ibConn; INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], diff --git a/src/connection.cc b/src/connection.cc index 112e1178..d8a055b5 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -83,13 +83,13 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) { // IBConnection IBConnection::IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr, - int maxWrPerSend, Communicator::Impl& commImpl) + int maxWrPerSend, int maxNumSgesPerWr, Communicator::Impl& commImpl) : ConnectionBase(remoteRank, tag), transport_(transport), remoteTransport_(Transport::Unknown), numSignaledSends(0), dummyAtomicSource_(std::make_unique(0)) { - qp = commImpl.getIbContext(transport)->createQp(maxCqSize, maxCqPollNum, maxSendWr, 0, maxWrPerSend); + qp = commImpl.getIbContext(transport)->createQp(maxCqSize, maxCqPollNum, maxSendWr, 0, maxWrPerSend, maxNumSgesPerWr); dummyAtomicSourceMem_ = RegisteredMemory(std::make_shared( dummyAtomicSource_.get(), sizeof(uint64_t), commImpl.bootstrap_->getRank(), transport, commImpl)); validateTransport(dummyAtomicSourceMem_, transport); diff --git a/src/ib.cc b/src/ib.cc index 7a93a650..0dcb33fd 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -16,6 +16,16 @@ #include "api.h" #include "debug.h" +static ibv_device_attr getDeviceAttr(ibv_context *ctx) { + ibv_device_attr devAttr; + if (ibv_query_device(ctx, &devAttr) != 0) { + std::stringstream err; + err << "ibv_query_device failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); + } + return devAttr; +} + namespace mscclpp { IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) { @@ -53,8 +63,8 @@ const void* IbMr::getBuff() const { return this->buff; } uint32_t IbMr::getLkey() const { return this->mr->lkey; } IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, - int maxWrPerSend) - : maxCqPollNum(maxCqPollNum), maxWrPerSend(maxWrPerSend) { + int maxWrPerSend, int maxNumSgesPerWr) + : maxCqPollNum_(maxCqPollNum), maxWrPerSend_(maxWrPerSend), maxNumSgesPerWr_(maxNumSgesPerWr) { this->cq = ibv_create_cq(ctx, maxCqSize, nullptr, nullptr, 0); if (this->cq == nullptr) { std::stringstream err; @@ -117,10 +127,11 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN throw mscclpp::IbError(err.str(), errno); } this->qp = _qp; - this->wrn = 0; - this->wrs = std::make_unique(maxWrPerSend); - this->sges = std::make_unique(maxWrPerSend); - this->wcs = std::make_unique(maxCqPollNum); + this->wrs = std::make_unique(maxWrPerSend_); + this->sges = std::make_unique(maxWrPerSend_ * maxNumSgesPerWr_); + this->wcs = std::make_unique(maxCqPollNum_); + numStagedWrs_ = 0; + numStagedSges_ = 0; } IbQp::~IbQp() { @@ -181,29 +192,34 @@ void IbQp::rts() { } } -IbQp::WrInfo IbQp::getNewWrInfo() { - if (this->wrn >= this->maxWrPerSend) { +IbQp::WrInfo IbQp::getNewWrInfo(int numSges) { + if (numStagedWrs_ >= maxWrPerSend_) { std::stringstream err; - err << "too many outstanding work requests. limit is " << this->maxWrPerSend; + err << "too many outstanding work requests. limit is " << maxWrPerSend_; throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage); } - int wrn = this->wrn; - - ibv_send_wr* wr_ = &this->wrs[wrn]; - ibv_sge* sge_ = &this->sges[wrn]; - wr_->sg_list = sge_; - wr_->num_sge = 1; - wr_->next = nullptr; - if (wrn > 0) { - this->wrs[wrn - 1].next = wr_; + if (numSges > maxNumSgesPerWr_) { + std::stringstream err; + err << "too many sges per work request. limit is " << maxNumSgesPerWr_; + throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage); } - this->wrn++; + + ibv_send_wr* wr_ = &this->wrs[numStagedWrs_]; + ibv_sge* sge_ = &this->sges[numStagedSges_]; + wr_->sg_list = sge_; + wr_->num_sge = numSges; + wr_->next = nullptr; + if (numStagedWrs_ > 0) { + this->wrs[numStagedWrs_ - 1].next = wr_; + } + numStagedWrs_++; + numStagedSges_ += numSges; return IbQp::WrInfo{wr_, sge_}; } void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled) { - auto wrInfo = this->getNewWrInfo(); + auto wrInfo = this->getNewWrInfo(1); wrInfo.wr->wr_id = wrId; wrInfo.wr->opcode = IBV_WR_RDMA_WRITE; wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0; @@ -215,7 +231,7 @@ void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64 } void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal) { - auto wrInfo = this->getNewWrInfo(); + auto wrInfo = this->getNewWrInfo(1); wrInfo.wr->wr_id = wrId; wrInfo.wr->opcode = IBV_WR_ATOMIC_FETCH_AND_ADD; wrInfo.wr->send_flags = 0; // atomic op cannot be signaled @@ -229,7 +245,7 @@ void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, u void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData) { - auto wrInfo = this->getNewWrInfo(); + auto wrInfo = this->getNewWrInfo(1); wrInfo.wr->wr_id = wrId; wrInfo.wr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM; wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0; @@ -241,8 +257,28 @@ void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, wrInfo.sge->lkey = mr->getLkey(); } +void IbQp::stageSendGather(const std::vector& srcMrs, const IbMrInfo& dstInfo, const std::vector& srcSizes, + uint64_t wrId, const std::vector& srcOffsets, uint64_t dstOffset, bool signaled) { + size_t numSrcs = srcMrs.size(); + if (numSrcs != srcSizes.size() || numSrcs != srcOffsets.size()) { + std::stringstream err; + err << "invalid srcs: srcMrs.size()=" << numSrcs << ", srcSizes.size()=" << srcSizes.size() + << ", srcOffsets.size()=" << srcOffsets.size(); + throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage); + } + auto wrInfo = this->getNewWrInfo(numSrcs); + wrInfo.wr->wr_id = wrId; + wrInfo.wr->opcode = IBV_WR_RDMA_READ; + wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0; + wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(dstInfo.addr) + dstOffset; + wrInfo.wr->wr.rdma.rkey = dstInfo.rkey; + // wrInfo.sge->addr = (uint64_t)(mr->getBuff()) + srcOffset; + // wrInfo.sge->length = size; + // wrInfo.sge->lkey = mr->getLkey(); +} + void IbQp::postSend() { - if (this->wrn == 0) { + if (numStagedWrs_ == 0) { return; } struct ibv_send_wr* bad_wr; @@ -252,7 +288,8 @@ void IbQp::postSend() { err << "ibv_post_send failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); } - this->wrn = 0; + numStagedWrs_ = 0; + numStagedSges_ = 0; } void IbQp::postRecv(uint64_t wrId) { @@ -269,7 +306,7 @@ void IbQp::postRecv(uint64_t wrId) { } } -int IbQp::pollCq() { return ibv_poll_cq(this->cq, this->maxCqPollNum, this->wcs.get()); } +int IbQp::pollCq() { return ibv_poll_cq(this->cq, maxCqPollNum_, this->wcs.get()); } IbQpInfo& IbQp::getInfo() { return this->info; } @@ -321,12 +358,7 @@ bool IbCtx::isPortUsable(int port) const { } int IbCtx::getAnyActivePort() const { - struct ibv_device_attr devAttr; - if (ibv_query_device(this->ctx, &devAttr) != 0) { - std::stringstream err; - err << "ibv_query_device failed (errno " << errno << ")"; - throw mscclpp::IbError(err.str(), errno); - } + ibv_device_attr devAttr = getDeviceAttr(this->ctx); for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) { if (this->isPortUsable(port)) { return port; @@ -335,17 +367,42 @@ int IbCtx::getAnyActivePort() const { return -1; } -IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, +void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, + int port) const { + if (!this->isPortUsable(port)) { + throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage); + } + ibv_device_attr devAttr = getDeviceAttr(this->ctx); + if (maxCqSize > devAttr.max_cqe || maxCqSize < 1) { + throw mscclpp::Error("invalid maxCqSize: " + std::to_string(maxCqSize), ErrorCode::InvalidUsage); + } + if (maxCqPollNum > maxCqSize || maxCqPollNum < 1) { + throw mscclpp::Error("invalid maxCqPollNum: " + std::to_string(maxCqPollNum), ErrorCode::InvalidUsage); + } + if (maxSendWr > devAttr.max_qp_wr || maxSendWr < 1) { + throw mscclpp::Error("invalid maxSendWr: " + std::to_string(maxSendWr), ErrorCode::InvalidUsage); + } + if (maxRecvWr > devAttr.max_qp_wr || maxRecvWr < 1) { + throw mscclpp::Error("invalid maxRecvWr: " + std::to_string(maxRecvWr), ErrorCode::InvalidUsage); + } + if (maxWrPerSend > maxSendWr || maxWrPerSend < 1) { + throw mscclpp::Error("invalid maxWrPerSend: " + std::to_string(maxWrPerSend), ErrorCode::InvalidUsage); + } + if (maxNumSgesPerWr > devAttr.max_sge || maxNumSgesPerWr < 1) { + throw mscclpp::Error("invalid maxNumSgesPerWr: " + std::to_string(maxNumSgesPerWr), ErrorCode::InvalidUsage); + } +} + +IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port /*=-1*/) { if (port == -1) { port = this->getAnyActivePort(); if (port == -1) { throw mscclpp::Error("No active port found", ErrorCode::InternalError); } - } else if (!this->isPortUsable(port)) { - throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InternalError); } - qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend)); + validateConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port); + qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr)); return qps.back().get(); } diff --git a/src/include/connection.hpp b/src/include/connection.hpp index 0475691c..106b86d9 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -57,7 +57,7 @@ class IBConnection : public ConnectionBase { public: IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr, - int maxWrPerSend, Communicator::Impl& commImpl); + int maxWrPerSend, int maxNumSgesPerWr, Communicator::Impl& commImpl); Transport transport() override; diff --git a/src/include/ib.hpp b/src/include/ib.hpp index 1bec30b8..0126ef89 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -66,6 +66,8 @@ class IbQp { void stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal); void stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData); + void stageSendGather(const std::vector& srcMrs, const IbMrInfo& dstInfo, const std::vector& srcSizes, + uint64_t wrId, const std::vector& srcOffsets, uint64_t dstOffset, bool signaled); void postSend(); void postRecv(uint64_t wrId); int pollCq(); @@ -80,8 +82,8 @@ class IbQp { }; IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, - int maxWrPerSend); - WrInfo getNewWrInfo(); + int maxWrPerSend, int maxNumSgesPerWr); + WrInfo getNewWrInfo(int numSges); IbQpInfo info; @@ -90,10 +92,12 @@ class IbQp { std::unique_ptr wcs; std::unique_ptr wrs; std::unique_ptr sges; - int wrn; + int numStagedWrs_; + int numStagedSges_; - const int maxCqPollNum; - const int maxWrPerSend; + const int maxCqPollNum_; + const int maxWrPerSend_; + const int maxNumSgesPerWr_; friend class IbCtx; }; @@ -103,7 +107,7 @@ class IbCtx { IbCtx(const std::string& devName); ~IbCtx(); - IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int port = -1); + IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port = -1); const IbMr* registerMr(void* buff, std::size_t size); const std::string& getDevName() const; @@ -111,6 +115,7 @@ class IbCtx { private: bool isPortUsable(int port) const; int getAnyActivePort() const; + void validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port) const; const std::string devName; ibv_context* ctx; diff --git a/test/mp_unit/ib_tests.cu b/test/mp_unit/ib_tests.cu index 7ab892b5..a38f992e 100644 --- a/test/mp_unit/ib_tests.cu +++ b/test/mp_unit/ib_tests.cu @@ -36,7 +36,7 @@ void IbPeerToPeerTest::SetUp() { bootstrap->initialize(id); ibCtx = std::make_shared(ibDevName); - qp = ibCtx->createQp(1024, 1, 8192, 0, 64); + qp = ibCtx->createQp(1024, 1, 8192, 0, 64, 1); qpInfo[gEnv->rank] = qp->getInfo(); bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo));