This commit is contained in:
Changho Hwang
2023-09-06 12:15:12 +00:00
parent 89cad56721
commit 8232ec731f
7 changed files with 224 additions and 88 deletions

View File

@@ -543,7 +543,7 @@ class Communicator {
/// @return std::shared_ptr<Connection> A shared pointer to the connection.
std::shared_ptr<Connection> connectOnSetup(int remoteRank, int tag, Transport transport, int ibMaxCqSize = 1024,
int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64,
int ibMaxNumSgesPerWr = 16);
int ibMaxNumSgesPerWr = 1);
/// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called.
///

View File

@@ -141,7 +141,7 @@ void register_core(nb::module_& m) {
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
nb::arg("transport"), nb::arg("ibMaxCqSize") = 1024, nb::arg("ibMaxCqPollNum") = 1,
nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64)
nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64, nb::arg("ibMaxNumSgesPerWr") = 1)
.def("setup", &Communicator::setup);
}

View File

@@ -96,7 +96,7 @@ MSCCLPP_API_CPP NonblockingFuture<RegisteredMemory> Communicator::recvMemoryOnSe
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(
int remoteRank, int tag, Transport transport, int ibMaxCqSize /*=1024*/, int ibMaxCqPollNum /*=1*/,
int ibMaxSendWr /*=8192*/, int ibMaxWrPerSend /*=64*/, int ibMaxNumSgesPerWr /*=16*/) {
int ibMaxSendWr /*=8192*/, int ibMaxWrPerSend /*=64*/, int ibMaxNumSgesPerWr /*=1*/) {
std::shared_ptr<ConnectionBase> conn;
if (transport == Transport::CudaIpc) {
// sanity check: make sure the IPC connection is being made within a node

104
src/ib.cc
View File

@@ -16,16 +16,6 @@
#include "api.h"
#include "debug.h"
static ibv_device_attr getDeviceAttr(ibv_context* ctx) {
ibv_device_attr devAttr;
if (ibv_query_device(ctx, &devAttr) != 0) {
std::stringstream err;
err << "ibv_query_device failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
}
return devAttr;
}
static ibv_qp_attr createQpAttr() {
ibv_qp_attr qpAttr;
std::memset(&qpAttr, 0, sizeof(qpAttr));
@@ -34,17 +24,13 @@ static ibv_qp_attr createQpAttr() {
namespace mscclpp {
IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
if (size == 0) {
throw std::invalid_argument("invalid size: " + std::to_string(size));
}
IbMr::IbMr(ibv_pd* pd, void* buff, size_t alignedSize) : buff(buff) {
static __thread uintptr_t pageSize = 0;
if (pageSize == 0) {
pageSize = sysconf(_SC_PAGESIZE);
}
uintptr_t addr = reinterpret_cast<uintptr_t>(buff) & -pageSize;
std::size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
this->mr = ibv_reg_mr(pd, reinterpret_cast<void*>(addr), pages * pageSize,
this->mr = ibv_reg_mr(pd, reinterpret_cast<void*>(addr), alignedSize,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ |
IBV_ACCESS_RELAXED_ORDERING | IBV_ACCESS_REMOTE_ATOMIC);
if (this->mr == nullptr) {
@@ -52,7 +38,7 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
err << "ibv_reg_mr failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
}
this->size = pages * pageSize;
this->size = alignedSize;
}
IbMr::~IbMr() { ibv_dereg_mr(this->mr); }
@@ -220,8 +206,8 @@ IbQp::WrInfo IbQp::getNewWrInfo(int numSges) {
return IbQp::WrInfo{wr_, sge_};
}
void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
uint64_t dstOffset, bool signaled) {
void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset,
uint32_t dstOffset, bool signaled) {
auto wrInfo = this->getNewWrInfo(1);
wrInfo.wr->wr_id = wrId;
wrInfo.wr->opcode = IBV_WR_RDMA_WRITE;
@@ -233,7 +219,7 @@ void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64
wrInfo.sge->lkey = mr->getLkey();
}
void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal) {
void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint32_t dstOffset, uint64_t addVal) {
auto wrInfo = this->getNewWrInfo(1);
wrInfo.wr->wr_id = wrId;
wrInfo.wr->opcode = IBV_WR_ATOMIC_FETCH_AND_ADD;
@@ -246,8 +232,8 @@ void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, u
wrInfo.sge->lkey = mr->getLkey();
}
void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
uint64_t dstOffset, bool signaled, unsigned int immData) {
void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset,
uint32_t dstOffset, bool signaled, unsigned int immData) {
auto wrInfo = this->getNewWrInfo(1);
wrInfo.wr->wr_id = wrId;
wrInfo.wr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
@@ -260,26 +246,26 @@ void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size,
wrInfo.sge->lkey = mr->getLkey();
}
void IbQp::stageSendGather(const std::vector<IbMr*>& srcMrs, const IbMrInfo& dstInfo,
const std::vector<uint32_t>& srcSizes, uint64_t wrId,
const std::vector<uint64_t>& srcOffsets, uint64_t dstOffset, bool signaled) {
size_t numSrcs = srcMrs.size();
if (numSrcs != srcSizes.size() || numSrcs != srcOffsets.size()) {
void IbQp::stageSendGather(const std::vector<const IbMr*>& srcMrList, const IbMrInfo& dstMrInfo,
const std::vector<uint32_t>& srcSizeList, uint64_t wrId,
const std::vector<uint32_t>& srcOffsetList, uint32_t dstOffset, bool signaled) {
size_t numSrcs = srcMrList.size();
if (numSrcs != srcSizeList.size() || numSrcs != srcOffsetList.size()) {
std::stringstream err;
err << "invalid srcs: srcMrs.size()=" << numSrcs << ", srcSizes.size()=" << srcSizes.size()
<< ", srcOffsets.size()=" << srcOffsets.size();
err << "invalid srcs: srcMrList.size()=" << numSrcs << ", srcSizeList.size()=" << srcSizeList.size()
<< ", srcOffsetList.size()=" << srcOffsetList.size();
throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage);
}
auto wrInfo = this->getNewWrInfo(numSrcs);
wrInfo.wr->wr_id = wrId;
wrInfo.wr->opcode = IBV_WR_RDMA_READ;
wrInfo.wr->opcode = IBV_WR_RDMA_WRITE;
wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(dstInfo.addr) + dstOffset;
wrInfo.wr->wr.rdma.rkey = dstInfo.rkey;
wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(dstMrInfo.addr) + dstOffset;
wrInfo.wr->wr.rdma.rkey = dstMrInfo.rkey;
for (size_t i = 0; i < numSrcs; ++i) {
wrInfo.sge[i].addr = (uint64_t)(srcMrs[i]->getBuff()) + srcOffsets[i];
wrInfo.sge[i].length = srcSizes[i];
wrInfo.sge[i].lkey = srcMrs[i]->getLkey();
wrInfo.sge[i].addr = (uint64_t)(srcMrList[i]->getBuff()) + srcOffsetList[i];
wrInfo.sge[i].length = srcSizeList[i];
wrInfo.sge[i].lkey = srcMrList[i]->getLkey();
}
}
@@ -339,6 +325,13 @@ IbCtx::IbCtx(const std::string& devName) : devName(devName) {
err << "ibv_alloc_pd failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
}
// TODO: do not use new
this->devAttr = new ibv_device_attr;
if (ibv_query_device(this->ctx, this->devAttr) != 0) {
std::stringstream err;
err << "ibv_query_device failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
}
}
IbCtx::~IbCtx() {
@@ -350,6 +343,8 @@ IbCtx::~IbCtx() {
if (this->ctx != nullptr) {
ibv_close_device(this->ctx);
}
// TODO: do not use delete
delete this->devAttr;
}
bool IbCtx::isPortUsable(int port) const {
@@ -364,8 +359,7 @@ bool IbCtx::isPortUsable(int port) const {
}
int IbCtx::getAnyActivePort() const {
ibv_device_attr devAttr = getDeviceAttr(this->ctx);
for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) {
for (uint8_t port = 1; port <= this->devAttr->phys_port_cnt; ++port) {
if (this->isPortUsable(port)) {
return port;
}
@@ -373,28 +367,27 @@ int IbCtx::getAnyActivePort() const {
return -1;
}
void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend,
int maxNumSgesPerWr, int port) const {
void IbCtx::validateQpConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend,
int maxNumSgesPerWr, int port) const {
if (!this->isPortUsable(port)) {
throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage);
}
ibv_device_attr devAttr = getDeviceAttr(this->ctx);
if (maxCqSize > devAttr.max_cqe || maxCqSize < 1) {
if (maxCqSize > this->devAttr->max_cqe || maxCqSize < 1) {
throw mscclpp::Error("invalid maxCqSize: " + std::to_string(maxCqSize), ErrorCode::InvalidUsage);
}
if (maxCqPollNum > maxCqSize || maxCqPollNum < 1) {
throw mscclpp::Error("invalid maxCqPollNum: " + std::to_string(maxCqPollNum), ErrorCode::InvalidUsage);
}
if (maxSendWr > devAttr.max_qp_wr || maxSendWr < 1) {
if (maxSendWr > this->devAttr->max_qp_wr) {
throw mscclpp::Error("invalid maxSendWr: " + std::to_string(maxSendWr), ErrorCode::InvalidUsage);
}
if (maxRecvWr > devAttr.max_qp_wr || maxRecvWr < 1) {
if (maxRecvWr > this->devAttr->max_qp_wr) {
throw mscclpp::Error("invalid maxRecvWr: " + std::to_string(maxRecvWr), ErrorCode::InvalidUsage);
}
if (maxWrPerSend > maxSendWr || maxWrPerSend < 1) {
throw mscclpp::Error("invalid maxWrPerSend: " + std::to_string(maxWrPerSend), ErrorCode::InvalidUsage);
}
if (maxNumSgesPerWr > devAttr.max_sge || maxNumSgesPerWr < 1) {
if (maxNumSgesPerWr > this->devAttr->max_sge || maxNumSgesPerWr < 1) {
throw mscclpp::Error("invalid maxNumSgesPerWr: " + std::to_string(maxNumSgesPerWr), ErrorCode::InvalidUsage);
}
}
@@ -407,14 +400,31 @@ IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRec
throw mscclpp::Error("No active port found", ErrorCode::InternalError);
}
}
this->validateConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port);
this->validateQpConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port);
qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend,
maxNumSgesPerWr));
return qps.back().get();
}
const IbMr* IbCtx::registerMr(void* buff, std::size_t size) {
mrs.emplace_back(new IbMr(this->pd, buff, size));
const IbMr* IbCtx::registerMr(void* buff, uint32_t size) {
if (size == 0) {
throw mscclpp::Error("invalid size: " + std::to_string(size), ErrorCode::InvalidUsage);
}
static __thread uintptr_t pageSize = 0;
if (pageSize == 0) {
pageSize = sysconf(_SC_PAGESIZE);
}
uintptr_t addr = reinterpret_cast<uintptr_t>(buff) & -pageSize;
size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
size_t alignedSize = pages * pageSize;
if (alignedSize > this->devAttr->max_mr_size) {
std::stringstream err;
err << "invalid MR size: " << alignedSize << " max " << this->devAttr->max_mr_size;
throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage);
}
mrs.emplace_back(new IbMr(this->pd, buff, alignedSize));
return mrs.back().get();
}

View File

@@ -11,6 +11,7 @@
// Forward declarations of IB structures
struct ibv_context;
struct ibv_device_attr;
struct ibv_pd;
struct ibv_mr;
struct ibv_qp;
@@ -35,11 +36,11 @@ class IbMr {
uint32_t getLkey() const;
private:
IbMr(ibv_pd* pd, void* buff, std::size_t size);
IbMr(ibv_pd* pd, void* buff, size_t alignedSize);
ibv_mr* mr;
void* buff;
std::size_t size;
size_t size;
friend class IbCtx;
};
@@ -62,13 +63,14 @@ class IbQp {
void rtr(const IbQpInfo& info);
void rts();
void stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
uint64_t dstOffset, bool signaled);
void stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal);
void stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
uint64_t dstOffset, bool signaled, unsigned int immData);
void stageSendGather(const std::vector<IbMr*>& srcMrs, const IbMrInfo& dstInfo, const std::vector<uint32_t>& srcSizes,
uint64_t wrId, const std::vector<uint64_t>& srcOffsets, uint64_t dstOffset, bool signaled);
void stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset,
uint32_t dstOffset, bool signaled);
void stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint32_t dstOffset, uint64_t addVal);
void stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset,
uint32_t dstOffset, bool signaled, unsigned int immData);
void stageSendGather(const std::vector<const IbMr*>& srcMrList, const IbMrInfo& dstMrInfo,
const std::vector<uint32_t>& srcSizeList, uint64_t wrId,
const std::vector<uint32_t>& srcOffsetList, uint32_t dstOffset, bool signaled);
void postSend();
void postRecv(uint64_t wrId);
int pollCq();
@@ -110,19 +112,20 @@ class IbCtx {
IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr,
int port = -1);
const IbMr* registerMr(void* buff, std::size_t size);
const IbMr* registerMr(void* buff, uint32_t size);
const std::string& getDevName() const;
private:
bool isPortUsable(int port) const;
int getAnyActivePort() const;
void validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend,
int maxNumSgesPerWr, int port) const;
void validateQpConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend,
int maxNumSgesPerWr, int port) const;
const std::string devName;
ibv_context* ctx;
ibv_pd* pd;
ibv_device_attr* devAttr;
std::list<std::unique_ptr<IbQp>> qps;
std::list<std::unique_ptr<IbMr>> mrs;
};

View File

@@ -38,39 +38,57 @@ void IbPeerToPeerTest::SetUp() {
ibCtx = std::make_shared<mscclpp::IbCtx>(ibDevName);
qp = ibCtx->createQp(1024, 1, 8192, 0, 64, 1);
qpInfo[gEnv->rank] = qp->getInfo();
bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo));
int remoteRank = (gEnv->rank == 0) ? 1 : 0;
mscclpp::IbQpInfo localQpInfo = qp->getInfo();
bootstrap->send(&localQpInfo, sizeof(mscclpp::IbQpInfo), remoteRank, /*tag=*/0);
bootstrap->recv(&remoteQpInfo, sizeof(mscclpp::IbQpInfo), remoteRank, /*tag=*/0);
}
void IbPeerToPeerTest::registerBufferAndConnect(void* buf, size_t size) {
bufSize = size;
mr = ibCtx->registerMr(buf, size);
mrInfo[gEnv->rank] = mr->getInfo();
bootstrap->allGather(mrInfo.data(), sizeof(mscclpp::IbMrInfo));
for (int i = 0; i < bootstrap->getNranks(); ++i) {
if (i == gEnv->rank) continue;
qp->rtr(qpInfo[i]);
qp->rts();
break;
void IbPeerToPeerTest::registerBuffersAndConnect(const std::vector<void*>& bufList,
const std::vector<uint32_t>& sizeList) {
size_t numMrs = bufList.size();
if (numMrs != sizeList.size()) {
throw std::runtime_error("bufList.size() != sizeList.size()");
}
// Assume the remote side registers the same number of MRs
std::vector<mscclpp::IbMrInfo> localMrInfo;
for (size_t i = 0; i < numMrs; ++i) {
const mscclpp::IbMr* mr = ibCtx->registerMr(bufList[i], sizeList[i]);
localMrList.push_back(mr);
localMrInfo.emplace_back(mr->getInfo());
}
int remoteRank = (gEnv->rank == 0) ? 1 : 0;
// Send the number of MRs and the MR info to the remote side
bootstrap->send(&numMrs, sizeof(numMrs), remoteRank, /*tag=*/0);
bootstrap->send(localMrInfo.data(), sizeof(mscclpp::IbMrInfo) * numMrs, remoteRank, /*tag=*/1);
// Receive the number of MRs and the MR info from the remote side
size_t numRemoteMrs;
bootstrap->recv(&numRemoteMrs, sizeof(numRemoteMrs), remoteRank, /*tag=*/0);
remoteMrInfoList.resize(numRemoteMrs);
bootstrap->recv(remoteMrInfoList.data(), sizeof(mscclpp::IbMrInfo) * numRemoteMrs, remoteRank, /*tag=*/1);
qp->rtr(remoteQpInfo);
qp->rts();
bootstrap->barrier();
}
void IbPeerToPeerTest::stageSend(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled) {
const mscclpp::IbMrInfo& remoteMrInfo = mrInfo[(gEnv->rank == 1) ? 0 : 1];
qp->stageSend(mr, remoteMrInfo, size, wrId, srcOffset, dstOffset, signaled);
qp->stageSend(localMrList[0], remoteMrInfoList[0], size, wrId, srcOffset, dstOffset, signaled);
}
void IbPeerToPeerTest::stageAtomicAdd(uint64_t wrId, uint64_t dstOffset, uint64_t addVal) {
const mscclpp::IbMrInfo& remoteMrInfo = mrInfo[(gEnv->rank == 1) ? 0 : 1];
qp->stageAtomicAdd(mr, remoteMrInfo, wrId, dstOffset, addVal);
qp->stageAtomicAdd(localMrList[0], remoteMrInfoList[0], wrId, dstOffset, addVal);
}
void IbPeerToPeerTest::stageSendWithImm(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset,
bool signaled, unsigned int immData) {
const mscclpp::IbMrInfo& remoteMrInfo = mrInfo[(gEnv->rank == 1) ? 0 : 1];
qp->stageSendWithImm(mr, remoteMrInfo, size, wrId, srcOffset, dstOffset, signaled, immData);
qp->stageSendWithImm(localMrList[0], remoteMrInfoList[0], size, wrId, srcOffset, dstOffset, signaled, immData);
}
TEST_F(IbPeerToPeerTest, SimpleSendRecv) {
@@ -85,7 +103,7 @@ TEST_F(IbPeerToPeerTest, SimpleSendRecv) {
const int nelem = 1;
auto data = mscclpp::allocUniqueCuda<int>(nelem);
registerBufferAndConnect(data.get(), sizeof(int) * nelem);
registerBuffersAndConnect({data.get()}, {sizeof(int) * nelem});
if (gEnv->rank == 1) {
mscclpp::Timer timer;
@@ -194,7 +212,7 @@ TEST_F(IbPeerToPeerTest, MemoryConsistency) {
const uint64_t nelem = 65536 + 1;
auto data = mscclpp::allocUniqueCuda<uint64_t>(nelem);
registerBufferAndConnect(data.get(), sizeof(uint64_t) * nelem);
registerBuffersAndConnect({data.get()}, {sizeof(uint64_t) * nelem});
uint64_t res = 0;
uint64_t iter = 0;
@@ -288,3 +306,106 @@ TEST_F(IbPeerToPeerTest, MemoryConsistency) {
EXPECT_EQ(res, 0);
}
TEST_F(IbPeerToPeerTest, SendGather) {
if (gEnv->rank >= 2) {
// This test needs only two ranks
return;
}
mscclpp::Timer timeout(3);
const int numDataSrcs = 1;
const int nelemPerMr = 1024;
// Gather send from rank 0 to 1
if (gEnv->rank == 0) {
std::vector<mscclpp::UniqueCudaPtr<int>> dataList;
for (int i = 0; i < numDataSrcs; ++i) {
auto data = mscclpp::allocUniqueCuda<int>(nelemPerMr);
// Fill in data for correctness check
std::vector<int> hostData(nelemPerMr, i + 1);
mscclpp::memcpyCuda<int>(data.get(), hostData.data(), nelemPerMr);
dataList.emplace_back(std::move(data));
}
std::vector<void*> dataRefList;
for (int i = 0; i < numDataSrcs; ++i) {
dataRefList.emplace_back(dataList[i].get());
}
// For sending a completion signal to the remote side
uint64_t outboundSema = 1;
dataRefList.push_back(&outboundSema);
std::vector<uint32_t> sizeList(numDataSrcs, sizeof(int) * nelemPerMr);
sizeList.push_back(sizeof(outboundSema));
registerBuffersAndConnect(dataRefList, sizeList);
auto& remoteDataMrInfo = remoteMrInfoList[0];
auto& remoteSemaMrInfo = remoteMrInfoList[1];
auto& localSemaMr = localMrList[numDataSrcs];
std::vector<const mscclpp::IbMr*> gatherLocalMrList;
for (int i = 0; i < numDataSrcs; ++i) {
gatherLocalMrList.emplace_back(localMrList[i]);
}
std::vector<uint32_t> gatherSizeList(numDataSrcs, sizeof(int) * nelemPerMr);
std::vector<uint32_t> gatherOffsetList(numDataSrcs, 0);
qp->stageSendGather(gatherLocalMrList, remoteDataMrInfo, gatherSizeList, /*wrId=*/0, gatherOffsetList,
/*dstOffset=*/0, /*signaled=*/true);
qp->postSend();
qp->stageAtomicAdd(localSemaMr, remoteSemaMrInfo, /*wrId=*/0, /*dstOffset=*/0, /*addVal=*/1);
qp->postSend();
// Wait for send completion
bool waiting = true;
int spin = 0;
while (waiting) {
int wcNum = qp->pollCq();
ASSERT_GE(wcNum, 0);
for (int i = 0; i < wcNum; ++i) {
const ibv_wc* wc = qp->getWc(i);
EXPECT_EQ(wc->status, IBV_WC_SUCCESS);
waiting = false;
break;
}
if (spin++ > 1000000) {
FAIL() << "Polling is stuck.";
}
}
} else {
// Data array to receive
auto data = mscclpp::allocUniqueCuda<int>(nelemPerMr * numDataSrcs);
// For receiving a completion signal from the remote side
uint64_t inboundSema = 0;
registerBuffersAndConnect({data.get(), &inboundSema},
{sizeof(int) * nelemPerMr * numDataSrcs, sizeof(inboundSema)});
// Wait for a signal from the remote side
volatile uint64_t* ptrInboundSema = &inboundSema;
int spin = 0;
while (*ptrInboundSema == 0) {
if (spin++ > 1000000) {
FAIL() << "Polling is stuck.";
}
}
// Correctness check
std::vector<int> hostData(nelemPerMr * numDataSrcs);
mscclpp::memcpyCuda<int>(hostData.data(), data.get(), nelemPerMr * numDataSrcs);
for (int i = 0; i < numDataSrcs; ++i) {
for (int j = 0; j < nelemPerMr; ++j) {
EXPECT_EQ(hostData[i * nelemPerMr + j], i + 1);
}
}
}
bootstrap->barrier();
}

View File

@@ -67,7 +67,7 @@ class IbPeerToPeerTest : public IbTestBase {
protected:
void SetUp() override;
void registerBufferAndConnect(void* buf, size_t size);
void registerBuffersAndConnect(const std::vector<void*>& bufList, const std::vector<uint32_t>& sizeList);
void stageSend(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled);
@@ -76,14 +76,16 @@ class IbPeerToPeerTest : public IbTestBase {
void stageSendWithImm(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled,
unsigned int immData);
void stageSendGather(const std::vector<uint32_t>& sizeList, uint64_t wrId, const std::vector<uint32_t>& srcOffsetList,
uint32_t dstOffset, bool signaled);
std::shared_ptr<mscclpp::TcpBootstrap> bootstrap;
std::shared_ptr<mscclpp::IbCtx> ibCtx;
mscclpp::IbQp* qp;
const mscclpp::IbMr* mr;
size_t bufSize;
std::vector<const mscclpp::IbMr*> localMrList;
std::array<mscclpp::IbQpInfo, 2> qpInfo;
std::array<mscclpp::IbMrInfo, 2> mrInfo;
mscclpp::IbQpInfo remoteQpInfo;
std::vector<mscclpp::IbMrInfo> remoteMrInfoList;
};
class CommunicatorTestBase : public MultiProcessTest {