IB gather WIP

This commit is contained in:
Changho Hwang
2023-09-05 14:41:08 +00:00
parent 858e381829
commit ad13693fe8
7 changed files with 112 additions and 48 deletions

View File

@@ -542,7 +542,8 @@ class Communicator {
/// @param ibMaxWrPerSend The maximum number of work requests per send for IB. Unused if transport is not IB.
/// @return std::shared_ptr<Connection> A shared pointer to the connection.
std::shared_ptr<Connection> connectOnSetup(int remoteRank, int tag, Transport transport, int ibMaxCqSize = 1024,
int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64);
int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64,
int ibMaxNumSgesPerWr = 16);
/// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called.
///

View File

@@ -98,7 +98,8 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int rem
int ibMaxCqSize /*=1024*/,
int ibMaxCqPollNum /*=1*/,
int ibMaxSendWr /*=8192*/,
int ibMaxWrPerSend /*=64*/) {
int ibMaxWrPerSend /*=64*/,
int ibMaxNumSgesPerWr /*=16*/) {
std::shared_ptr<ConnectionBase> conn;
if (transport == Transport::CudaIpc) {
// sanity check: make sure the IPC connection is being made within a node
@@ -116,7 +117,7 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int rem
pimpl->rankToHash_[remoteRank]);
} else if (AllIBTransports.has(transport)) {
auto ibConn = std::make_shared<IBConnection>(remoteRank, tag, transport, ibMaxCqSize, ibMaxCqPollNum, ibMaxSendWr,
ibMaxWrPerSend, *pimpl);
ibMaxWrPerSend, ibMaxNumSgesPerWr, *pimpl);
conn = ibConn;
INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created",
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],

View File

@@ -83,13 +83,13 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) {
// IBConnection
IBConnection::IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr,
int maxWrPerSend, Communicator::Impl& commImpl)
int maxWrPerSend, int maxNumSgesPerWr, Communicator::Impl& commImpl)
: ConnectionBase(remoteRank, tag),
transport_(transport),
remoteTransport_(Transport::Unknown),
numSignaledSends(0),
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
qp = commImpl.getIbContext(transport)->createQp(maxCqSize, maxCqPollNum, maxSendWr, 0, maxWrPerSend);
qp = commImpl.getIbContext(transport)->createQp(maxCqSize, maxCqPollNum, maxSendWr, 0, maxWrPerSend, maxNumSgesPerWr);
dummyAtomicSourceMem_ = RegisteredMemory(std::make_shared<RegisteredMemory::Impl>(
dummyAtomicSource_.get(), sizeof(uint64_t), commImpl.bootstrap_->getRank(), transport, commImpl));
validateTransport(dummyAtomicSourceMem_, transport);

127
src/ib.cc
View File

@@ -16,6 +16,16 @@
#include "api.h"
#include "debug.h"
static ibv_device_attr getDeviceAttr(ibv_context *ctx) {
ibv_device_attr devAttr;
if (ibv_query_device(ctx, &devAttr) != 0) {
std::stringstream err;
err << "ibv_query_device failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
}
return devAttr;
}
namespace mscclpp {
IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
@@ -53,8 +63,8 @@ const void* IbMr::getBuff() const { return this->buff; }
uint32_t IbMr::getLkey() const { return this->mr->lkey; }
IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr,
int maxWrPerSend)
: maxCqPollNum(maxCqPollNum), maxWrPerSend(maxWrPerSend) {
int maxWrPerSend, int maxNumSgesPerWr)
: maxCqPollNum_(maxCqPollNum), maxWrPerSend_(maxWrPerSend), maxNumSgesPerWr_(maxNumSgesPerWr) {
this->cq = ibv_create_cq(ctx, maxCqSize, nullptr, nullptr, 0);
if (this->cq == nullptr) {
std::stringstream err;
@@ -117,10 +127,11 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN
throw mscclpp::IbError(err.str(), errno);
}
this->qp = _qp;
this->wrn = 0;
this->wrs = std::make_unique<ibv_send_wr[]>(maxWrPerSend);
this->sges = std::make_unique<ibv_sge[]>(maxWrPerSend);
this->wcs = std::make_unique<ibv_wc[]>(maxCqPollNum);
this->wrs = std::make_unique<ibv_send_wr[]>(maxWrPerSend_);
this->sges = std::make_unique<ibv_sge[]>(maxWrPerSend_ * maxNumSgesPerWr_);
this->wcs = std::make_unique<ibv_wc[]>(maxCqPollNum_);
numStagedWrs_ = 0;
numStagedSges_ = 0;
}
IbQp::~IbQp() {
@@ -181,29 +192,34 @@ void IbQp::rts() {
}
}
IbQp::WrInfo IbQp::getNewWrInfo() {
if (this->wrn >= this->maxWrPerSend) {
IbQp::WrInfo IbQp::getNewWrInfo(int numSges) {
if (numStagedWrs_ >= maxWrPerSend_) {
std::stringstream err;
err << "too many outstanding work requests. limit is " << this->maxWrPerSend;
err << "too many outstanding work requests. limit is " << maxWrPerSend_;
throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage);
}
int wrn = this->wrn;
ibv_send_wr* wr_ = &this->wrs[wrn];
ibv_sge* sge_ = &this->sges[wrn];
wr_->sg_list = sge_;
wr_->num_sge = 1;
wr_->next = nullptr;
if (wrn > 0) {
this->wrs[wrn - 1].next = wr_;
if (numSges > maxNumSgesPerWr_) {
std::stringstream err;
err << "too many sges per work request. limit is " << maxNumSgesPerWr_;
throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage);
}
this->wrn++;
ibv_send_wr* wr_ = &this->wrs[numStagedWrs_];
ibv_sge* sge_ = &this->sges[numStagedSges_];
wr_->sg_list = sge_;
wr_->num_sge = numSges;
wr_->next = nullptr;
if (numStagedWrs_ > 0) {
this->wrs[numStagedWrs_ - 1].next = wr_;
}
numStagedWrs_++;
numStagedSges_ += numSges;
return IbQp::WrInfo{wr_, sge_};
}
void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
uint64_t dstOffset, bool signaled) {
auto wrInfo = this->getNewWrInfo();
auto wrInfo = this->getNewWrInfo(1);
wrInfo.wr->wr_id = wrId;
wrInfo.wr->opcode = IBV_WR_RDMA_WRITE;
wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
@@ -215,7 +231,7 @@ void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64
}
void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal) {
auto wrInfo = this->getNewWrInfo();
auto wrInfo = this->getNewWrInfo(1);
wrInfo.wr->wr_id = wrId;
wrInfo.wr->opcode = IBV_WR_ATOMIC_FETCH_AND_ADD;
wrInfo.wr->send_flags = 0; // atomic op cannot be signaled
@@ -229,7 +245,7 @@ void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, u
void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
uint64_t dstOffset, bool signaled, unsigned int immData) {
auto wrInfo = this->getNewWrInfo();
auto wrInfo = this->getNewWrInfo(1);
wrInfo.wr->wr_id = wrId;
wrInfo.wr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
@@ -241,8 +257,28 @@ void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size,
wrInfo.sge->lkey = mr->getLkey();
}
void IbQp::stageSendGather(const std::vector<IbMr*>& srcMrs, const IbMrInfo& dstInfo, const std::vector<uint32_t>& srcSizes,
uint64_t wrId, const std::vector<uint64_t>& srcOffsets, uint64_t dstOffset, bool signaled) {
size_t numSrcs = srcMrs.size();
if (numSrcs != srcSizes.size() || numSrcs != srcOffsets.size()) {
std::stringstream err;
err << "invalid srcs: srcMrs.size()=" << numSrcs << ", srcSizes.size()=" << srcSizes.size()
<< ", srcOffsets.size()=" << srcOffsets.size();
throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage);
}
auto wrInfo = this->getNewWrInfo(numSrcs);
wrInfo.wr->wr_id = wrId;
wrInfo.wr->opcode = IBV_WR_RDMA_READ;
wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(dstInfo.addr) + dstOffset;
wrInfo.wr->wr.rdma.rkey = dstInfo.rkey;
// wrInfo.sge->addr = (uint64_t)(mr->getBuff()) + srcOffset;
// wrInfo.sge->length = size;
// wrInfo.sge->lkey = mr->getLkey();
}
void IbQp::postSend() {
if (this->wrn == 0) {
if (numStagedWrs_ == 0) {
return;
}
struct ibv_send_wr* bad_wr;
@@ -252,7 +288,8 @@ void IbQp::postSend() {
err << "ibv_post_send failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
}
this->wrn = 0;
numStagedWrs_ = 0;
numStagedSges_ = 0;
}
void IbQp::postRecv(uint64_t wrId) {
@@ -269,7 +306,7 @@ void IbQp::postRecv(uint64_t wrId) {
}
}
int IbQp::pollCq() { return ibv_poll_cq(this->cq, this->maxCqPollNum, this->wcs.get()); }
int IbQp::pollCq() { return ibv_poll_cq(this->cq, maxCqPollNum_, this->wcs.get()); }
IbQpInfo& IbQp::getInfo() { return this->info; }
@@ -321,12 +358,7 @@ bool IbCtx::isPortUsable(int port) const {
}
int IbCtx::getAnyActivePort() const {
struct ibv_device_attr devAttr;
if (ibv_query_device(this->ctx, &devAttr) != 0) {
std::stringstream err;
err << "ibv_query_device failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
}
ibv_device_attr devAttr = getDeviceAttr(this->ctx);
for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) {
if (this->isPortUsable(port)) {
return port;
@@ -335,17 +367,42 @@ int IbCtx::getAnyActivePort() const {
return -1;
}
IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend,
void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr,
int port) const {
if (!this->isPortUsable(port)) {
throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage);
}
ibv_device_attr devAttr = getDeviceAttr(this->ctx);
if (maxCqSize > devAttr.max_cqe || maxCqSize < 1) {
throw mscclpp::Error("invalid maxCqSize: " + std::to_string(maxCqSize), ErrorCode::InvalidUsage);
}
if (maxCqPollNum > maxCqSize || maxCqPollNum < 1) {
throw mscclpp::Error("invalid maxCqPollNum: " + std::to_string(maxCqPollNum), ErrorCode::InvalidUsage);
}
if (maxSendWr > devAttr.max_qp_wr || maxSendWr < 1) {
throw mscclpp::Error("invalid maxSendWr: " + std::to_string(maxSendWr), ErrorCode::InvalidUsage);
}
if (maxRecvWr > devAttr.max_qp_wr || maxRecvWr < 1) {
throw mscclpp::Error("invalid maxRecvWr: " + std::to_string(maxRecvWr), ErrorCode::InvalidUsage);
}
if (maxWrPerSend > maxSendWr || maxWrPerSend < 1) {
throw mscclpp::Error("invalid maxWrPerSend: " + std::to_string(maxWrPerSend), ErrorCode::InvalidUsage);
}
if (maxNumSgesPerWr > devAttr.max_sge || maxNumSgesPerWr < 1) {
throw mscclpp::Error("invalid maxNumSgesPerWr: " + std::to_string(maxNumSgesPerWr), ErrorCode::InvalidUsage);
}
}
IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr,
int port /*=-1*/) {
if (port == -1) {
port = this->getAnyActivePort();
if (port == -1) {
throw mscclpp::Error("No active port found", ErrorCode::InternalError);
}
} else if (!this->isPortUsable(port)) {
throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InternalError);
}
qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend));
validateConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port);
qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr));
return qps.back().get();
}

View File

@@ -57,7 +57,7 @@ class IBConnection : public ConnectionBase {
public:
IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr,
int maxWrPerSend, Communicator::Impl& commImpl);
int maxWrPerSend, int maxNumSgesPerWr, Communicator::Impl& commImpl);
Transport transport() override;

View File

@@ -66,6 +66,8 @@ class IbQp {
void stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal);
void stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
uint64_t dstOffset, bool signaled, unsigned int immData);
void stageSendGather(const std::vector<IbMr*>& srcMrs, const IbMrInfo& dstInfo, const std::vector<uint32_t>& srcSizes,
uint64_t wrId, const std::vector<uint64_t>& srcOffsets, uint64_t dstOffset, bool signaled);
void postSend();
void postRecv(uint64_t wrId);
int pollCq();
@@ -80,8 +82,8 @@ class IbQp {
};
IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr,
int maxWrPerSend);
WrInfo getNewWrInfo();
int maxWrPerSend, int maxNumSgesPerWr);
WrInfo getNewWrInfo(int numSges);
IbQpInfo info;
@@ -90,10 +92,12 @@ class IbQp {
std::unique_ptr<ibv_wc[]> wcs;
std::unique_ptr<ibv_send_wr[]> wrs;
std::unique_ptr<ibv_sge[]> sges;
int wrn;
int numStagedWrs_;
int numStagedSges_;
const int maxCqPollNum;
const int maxWrPerSend;
const int maxCqPollNum_;
const int maxWrPerSend_;
const int maxNumSgesPerWr_;
friend class IbCtx;
};
@@ -103,7 +107,7 @@ class IbCtx {
IbCtx(const std::string& devName);
~IbCtx();
IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int port = -1);
IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port = -1);
const IbMr* registerMr(void* buff, std::size_t size);
const std::string& getDevName() const;
@@ -111,6 +115,7 @@ class IbCtx {
private:
bool isPortUsable(int port) const;
int getAnyActivePort() const;
void validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port) const;
const std::string devName;
ibv_context* ctx;

View File

@@ -36,7 +36,7 @@ void IbPeerToPeerTest::SetUp() {
bootstrap->initialize(id);
ibCtx = std::make_shared<mscclpp::IbCtx>(ibDevName);
qp = ibCtx->createQp(1024, 1, 8192, 0, 64);
qp = ibCtx->createQp(1024, 1, 8192, 0, 64, 1);
qpInfo[gEnv->rank] = qp->getInfo();
bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo));