From 8232ec731ffa1ea814a428017ea729e7643b6093 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 6 Sep 2023 12:15:12 +0000 Subject: [PATCH] Working --- include/mscclpp/core.hpp | 2 +- python/mscclpp/core_py.cpp | 2 +- src/communicator.cc | 2 +- src/ib.cc | 104 +++++++++++---------- src/include/ib.hpp | 27 +++--- test/mp_unit/ib_tests.cu | 163 ++++++++++++++++++++++++++++----- test/mp_unit/mp_unit_tests.hpp | 12 ++- 7 files changed, 224 insertions(+), 88 deletions(-) diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 67ed523a..76332c60 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -543,7 +543,7 @@ class Communicator { /// @return std::shared_ptr A shared pointer to the connection. std::shared_ptr connectOnSetup(int remoteRank, int tag, Transport transport, int ibMaxCqSize = 1024, int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64, - int ibMaxNumSgesPerWr = 16); + int ibMaxNumSgesPerWr = 1); /// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called. /// diff --git a/python/mscclpp/core_py.cpp b/python/mscclpp/core_py.cpp index a65a443a..ce0fc606 100644 --- a/python/mscclpp/core_py.cpp +++ b/python/mscclpp/core_py.cpp @@ -141,7 +141,7 @@ void register_core(nb::module_& m) { .def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag")) .def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("transport"), nb::arg("ibMaxCqSize") = 1024, nb::arg("ibMaxCqPollNum") = 1, - nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64) + nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64, nb::arg("ibMaxNumSgesPerWr") = 1) .def("setup", &Communicator::setup); } diff --git a/src/communicator.cc b/src/communicator.cc index b87388b7..891e9a91 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -96,7 +96,7 @@ MSCCLPP_API_CPP NonblockingFuture Communicator::recvMemoryOnSe MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup( int remoteRank, int tag, Transport transport, int ibMaxCqSize /*=1024*/, int ibMaxCqPollNum /*=1*/, - int ibMaxSendWr /*=8192*/, int ibMaxWrPerSend /*=64*/, int ibMaxNumSgesPerWr /*=16*/) { + int ibMaxSendWr /*=8192*/, int ibMaxWrPerSend /*=64*/, int ibMaxNumSgesPerWr /*=1*/) { std::shared_ptr conn; if (transport == Transport::CudaIpc) { // sanity check: make sure the IPC connection is being made within a node diff --git a/src/ib.cc b/src/ib.cc index 63eb1040..50fddc96 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -16,16 +16,6 @@ #include "api.h" #include "debug.h" -static ibv_device_attr getDeviceAttr(ibv_context* ctx) { - ibv_device_attr devAttr; - if (ibv_query_device(ctx, &devAttr) != 0) { - std::stringstream err; - err << "ibv_query_device failed (errno " << errno << ")"; - throw mscclpp::IbError(err.str(), errno); - } - return devAttr; -} - static ibv_qp_attr createQpAttr() { ibv_qp_attr qpAttr; std::memset(&qpAttr, 0, sizeof(qpAttr)); @@ -34,17 +24,13 @@ static ibv_qp_attr createQpAttr() { namespace mscclpp { -IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) { - if (size == 0) { - throw std::invalid_argument("invalid size: " + std::to_string(size)); - } +IbMr::IbMr(ibv_pd* pd, void* buff, size_t alignedSize) : buff(buff) { static __thread uintptr_t pageSize = 0; if (pageSize == 0) { pageSize = sysconf(_SC_PAGESIZE); } uintptr_t addr = reinterpret_cast(buff) & -pageSize; - std::size_t pages = (size + (reinterpret_cast(buff) - addr) + pageSize - 1) / pageSize; - this->mr = ibv_reg_mr(pd, reinterpret_cast(addr), pages * pageSize, + this->mr = ibv_reg_mr(pd, reinterpret_cast(addr), alignedSize, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING | IBV_ACCESS_REMOTE_ATOMIC); if (this->mr == nullptr) { @@ -52,7 +38,7 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) { err << "ibv_reg_mr failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); } - this->size = pages * pageSize; + this->size = alignedSize; } IbMr::~IbMr() { ibv_dereg_mr(this->mr); } @@ -220,8 +206,8 @@ IbQp::WrInfo IbQp::getNewWrInfo(int numSges) { return IbQp::WrInfo{wr_, sge_}; } -void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, - uint64_t dstOffset, bool signaled) { +void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset, + uint32_t dstOffset, bool signaled) { auto wrInfo = this->getNewWrInfo(1); wrInfo.wr->wr_id = wrId; wrInfo.wr->opcode = IBV_WR_RDMA_WRITE; @@ -233,7 +219,7 @@ void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64 wrInfo.sge->lkey = mr->getLkey(); } -void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal) { +void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint32_t dstOffset, uint64_t addVal) { auto wrInfo = this->getNewWrInfo(1); wrInfo.wr->wr_id = wrId; wrInfo.wr->opcode = IBV_WR_ATOMIC_FETCH_AND_ADD; @@ -246,8 +232,8 @@ void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, u wrInfo.sge->lkey = mr->getLkey(); } -void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, - uint64_t dstOffset, bool signaled, unsigned int immData) { +void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset, + uint32_t dstOffset, bool signaled, unsigned int immData) { auto wrInfo = this->getNewWrInfo(1); wrInfo.wr->wr_id = wrId; wrInfo.wr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM; @@ -260,26 +246,26 @@ void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, wrInfo.sge->lkey = mr->getLkey(); } -void IbQp::stageSendGather(const std::vector& srcMrs, const IbMrInfo& dstInfo, - const std::vector& srcSizes, uint64_t wrId, - const std::vector& srcOffsets, uint64_t dstOffset, bool signaled) { - size_t numSrcs = srcMrs.size(); - if (numSrcs != srcSizes.size() || numSrcs != srcOffsets.size()) { +void IbQp::stageSendGather(const std::vector& srcMrList, const IbMrInfo& dstMrInfo, + const std::vector& srcSizeList, uint64_t wrId, + const std::vector& srcOffsetList, uint32_t dstOffset, bool signaled) { + size_t numSrcs = srcMrList.size(); + if (numSrcs != srcSizeList.size() || numSrcs != srcOffsetList.size()) { std::stringstream err; - err << "invalid srcs: srcMrs.size()=" << numSrcs << ", srcSizes.size()=" << srcSizes.size() - << ", srcOffsets.size()=" << srcOffsets.size(); + err << "invalid srcs: srcMrList.size()=" << numSrcs << ", srcSizeList.size()=" << srcSizeList.size() + << ", srcOffsetList.size()=" << srcOffsetList.size(); throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage); } auto wrInfo = this->getNewWrInfo(numSrcs); wrInfo.wr->wr_id = wrId; - wrInfo.wr->opcode = IBV_WR_RDMA_READ; + wrInfo.wr->opcode = IBV_WR_RDMA_WRITE; wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0; - wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(dstInfo.addr) + dstOffset; - wrInfo.wr->wr.rdma.rkey = dstInfo.rkey; + wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(dstMrInfo.addr) + dstOffset; + wrInfo.wr->wr.rdma.rkey = dstMrInfo.rkey; for (size_t i = 0; i < numSrcs; ++i) { - wrInfo.sge[i].addr = (uint64_t)(srcMrs[i]->getBuff()) + srcOffsets[i]; - wrInfo.sge[i].length = srcSizes[i]; - wrInfo.sge[i].lkey = srcMrs[i]->getLkey(); + wrInfo.sge[i].addr = (uint64_t)(srcMrList[i]->getBuff()) + srcOffsetList[i]; + wrInfo.sge[i].length = srcSizeList[i]; + wrInfo.sge[i].lkey = srcMrList[i]->getLkey(); } } @@ -339,6 +325,13 @@ IbCtx::IbCtx(const std::string& devName) : devName(devName) { err << "ibv_alloc_pd failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); } + // TODO: do not use new + this->devAttr = new ibv_device_attr; + if (ibv_query_device(this->ctx, this->devAttr) != 0) { + std::stringstream err; + err << "ibv_query_device failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); + } } IbCtx::~IbCtx() { @@ -350,6 +343,8 @@ IbCtx::~IbCtx() { if (this->ctx != nullptr) { ibv_close_device(this->ctx); } + // TODO: do not use delete + delete this->devAttr; } bool IbCtx::isPortUsable(int port) const { @@ -364,8 +359,7 @@ bool IbCtx::isPortUsable(int port) const { } int IbCtx::getAnyActivePort() const { - ibv_device_attr devAttr = getDeviceAttr(this->ctx); - for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) { + for (uint8_t port = 1; port <= this->devAttr->phys_port_cnt; ++port) { if (this->isPortUsable(port)) { return port; } @@ -373,28 +367,27 @@ int IbCtx::getAnyActivePort() const { return -1; } -void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, - int maxNumSgesPerWr, int port) const { +void IbCtx::validateQpConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, + int maxNumSgesPerWr, int port) const { if (!this->isPortUsable(port)) { throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage); } - ibv_device_attr devAttr = getDeviceAttr(this->ctx); - if (maxCqSize > devAttr.max_cqe || maxCqSize < 1) { + if (maxCqSize > this->devAttr->max_cqe || maxCqSize < 1) { throw mscclpp::Error("invalid maxCqSize: " + std::to_string(maxCqSize), ErrorCode::InvalidUsage); } if (maxCqPollNum > maxCqSize || maxCqPollNum < 1) { throw mscclpp::Error("invalid maxCqPollNum: " + std::to_string(maxCqPollNum), ErrorCode::InvalidUsage); } - if (maxSendWr > devAttr.max_qp_wr || maxSendWr < 1) { + if (maxSendWr > this->devAttr->max_qp_wr) { throw mscclpp::Error("invalid maxSendWr: " + std::to_string(maxSendWr), ErrorCode::InvalidUsage); } - if (maxRecvWr > devAttr.max_qp_wr || maxRecvWr < 1) { + if (maxRecvWr > this->devAttr->max_qp_wr) { throw mscclpp::Error("invalid maxRecvWr: " + std::to_string(maxRecvWr), ErrorCode::InvalidUsage); } if (maxWrPerSend > maxSendWr || maxWrPerSend < 1) { throw mscclpp::Error("invalid maxWrPerSend: " + std::to_string(maxWrPerSend), ErrorCode::InvalidUsage); } - if (maxNumSgesPerWr > devAttr.max_sge || maxNumSgesPerWr < 1) { + if (maxNumSgesPerWr > this->devAttr->max_sge || maxNumSgesPerWr < 1) { throw mscclpp::Error("invalid maxNumSgesPerWr: " + std::to_string(maxNumSgesPerWr), ErrorCode::InvalidUsage); } } @@ -407,14 +400,31 @@ IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRec throw mscclpp::Error("No active port found", ErrorCode::InternalError); } } - this->validateConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port); + this->validateQpConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port); qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr)); return qps.back().get(); } -const IbMr* IbCtx::registerMr(void* buff, std::size_t size) { - mrs.emplace_back(new IbMr(this->pd, buff, size)); +const IbMr* IbCtx::registerMr(void* buff, uint32_t size) { + if (size == 0) { + throw mscclpp::Error("invalid size: " + std::to_string(size), ErrorCode::InvalidUsage); + } + static __thread uintptr_t pageSize = 0; + if (pageSize == 0) { + pageSize = sysconf(_SC_PAGESIZE); + } + uintptr_t addr = reinterpret_cast(buff) & -pageSize; + size_t pages = (size + (reinterpret_cast(buff) - addr) + pageSize - 1) / pageSize; + + size_t alignedSize = pages * pageSize; + if (alignedSize > this->devAttr->max_mr_size) { + std::stringstream err; + err << "invalid MR size: " << alignedSize << " max " << this->devAttr->max_mr_size; + throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage); + } + + mrs.emplace_back(new IbMr(this->pd, buff, alignedSize)); return mrs.back().get(); } diff --git a/src/include/ib.hpp b/src/include/ib.hpp index cb909111..440c208c 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -11,6 +11,7 @@ // Forward declarations of IB structures struct ibv_context; +struct ibv_device_attr; struct ibv_pd; struct ibv_mr; struct ibv_qp; @@ -35,11 +36,11 @@ class IbMr { uint32_t getLkey() const; private: - IbMr(ibv_pd* pd, void* buff, std::size_t size); + IbMr(ibv_pd* pd, void* buff, size_t alignedSize); ibv_mr* mr; void* buff; - std::size_t size; + size_t size; friend class IbCtx; }; @@ -62,13 +63,14 @@ class IbQp { void rtr(const IbQpInfo& info); void rts(); - void stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, - uint64_t dstOffset, bool signaled); - void stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal); - void stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, - uint64_t dstOffset, bool signaled, unsigned int immData); - void stageSendGather(const std::vector& srcMrs, const IbMrInfo& dstInfo, const std::vector& srcSizes, - uint64_t wrId, const std::vector& srcOffsets, uint64_t dstOffset, bool signaled); + void stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset, + uint32_t dstOffset, bool signaled); + void stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint32_t dstOffset, uint64_t addVal); + void stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset, + uint32_t dstOffset, bool signaled, unsigned int immData); + void stageSendGather(const std::vector& srcMrList, const IbMrInfo& dstMrInfo, + const std::vector& srcSizeList, uint64_t wrId, + const std::vector& srcOffsetList, uint32_t dstOffset, bool signaled); void postSend(); void postRecv(uint64_t wrId); int pollCq(); @@ -110,19 +112,20 @@ class IbCtx { IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port = -1); - const IbMr* registerMr(void* buff, std::size_t size); + const IbMr* registerMr(void* buff, uint32_t size); const std::string& getDevName() const; private: bool isPortUsable(int port) const; int getAnyActivePort() const; - void validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, - int maxNumSgesPerWr, int port) const; + void validateQpConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, + int maxNumSgesPerWr, int port) const; const std::string devName; ibv_context* ctx; ibv_pd* pd; + ibv_device_attr* devAttr; std::list> qps; std::list> mrs; }; diff --git a/test/mp_unit/ib_tests.cu b/test/mp_unit/ib_tests.cu index a38f992e..2a7d3841 100644 --- a/test/mp_unit/ib_tests.cu +++ b/test/mp_unit/ib_tests.cu @@ -38,39 +38,57 @@ void IbPeerToPeerTest::SetUp() { ibCtx = std::make_shared(ibDevName); qp = ibCtx->createQp(1024, 1, 8192, 0, 64, 1); - qpInfo[gEnv->rank] = qp->getInfo(); - bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo)); + int remoteRank = (gEnv->rank == 0) ? 1 : 0; + + mscclpp::IbQpInfo localQpInfo = qp->getInfo(); + bootstrap->send(&localQpInfo, sizeof(mscclpp::IbQpInfo), remoteRank, /*tag=*/0); + bootstrap->recv(&remoteQpInfo, sizeof(mscclpp::IbQpInfo), remoteRank, /*tag=*/0); } -void IbPeerToPeerTest::registerBufferAndConnect(void* buf, size_t size) { - bufSize = size; - mr = ibCtx->registerMr(buf, size); - mrInfo[gEnv->rank] = mr->getInfo(); - bootstrap->allGather(mrInfo.data(), sizeof(mscclpp::IbMrInfo)); - - for (int i = 0; i < bootstrap->getNranks(); ++i) { - if (i == gEnv->rank) continue; - qp->rtr(qpInfo[i]); - qp->rts(); - break; +void IbPeerToPeerTest::registerBuffersAndConnect(const std::vector& bufList, + const std::vector& sizeList) { + size_t numMrs = bufList.size(); + if (numMrs != sizeList.size()) { + throw std::runtime_error("bufList.size() != sizeList.size()"); } + + // Assume the remote side registers the same number of MRs + std::vector localMrInfo; + for (size_t i = 0; i < numMrs; ++i) { + const mscclpp::IbMr* mr = ibCtx->registerMr(bufList[i], sizeList[i]); + localMrList.push_back(mr); + localMrInfo.emplace_back(mr->getInfo()); + } + + int remoteRank = (gEnv->rank == 0) ? 1 : 0; + + // Send the number of MRs and the MR info to the remote side + bootstrap->send(&numMrs, sizeof(numMrs), remoteRank, /*tag=*/0); + bootstrap->send(localMrInfo.data(), sizeof(mscclpp::IbMrInfo) * numMrs, remoteRank, /*tag=*/1); + + // Receive the number of MRs and the MR info from the remote side + size_t numRemoteMrs; + bootstrap->recv(&numRemoteMrs, sizeof(numRemoteMrs), remoteRank, /*tag=*/0); + remoteMrInfoList.resize(numRemoteMrs); + bootstrap->recv(remoteMrInfoList.data(), sizeof(mscclpp::IbMrInfo) * numRemoteMrs, remoteRank, /*tag=*/1); + + qp->rtr(remoteQpInfo); + qp->rts(); + bootstrap->barrier(); } void IbPeerToPeerTest::stageSend(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled) { - const mscclpp::IbMrInfo& remoteMrInfo = mrInfo[(gEnv->rank == 1) ? 0 : 1]; - qp->stageSend(mr, remoteMrInfo, size, wrId, srcOffset, dstOffset, signaled); + qp->stageSend(localMrList[0], remoteMrInfoList[0], size, wrId, srcOffset, dstOffset, signaled); } void IbPeerToPeerTest::stageAtomicAdd(uint64_t wrId, uint64_t dstOffset, uint64_t addVal) { - const mscclpp::IbMrInfo& remoteMrInfo = mrInfo[(gEnv->rank == 1) ? 0 : 1]; - qp->stageAtomicAdd(mr, remoteMrInfo, wrId, dstOffset, addVal); + qp->stageAtomicAdd(localMrList[0], remoteMrInfoList[0], wrId, dstOffset, addVal); } void IbPeerToPeerTest::stageSendWithImm(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData) { - const mscclpp::IbMrInfo& remoteMrInfo = mrInfo[(gEnv->rank == 1) ? 0 : 1]; - qp->stageSendWithImm(mr, remoteMrInfo, size, wrId, srcOffset, dstOffset, signaled, immData); + qp->stageSendWithImm(localMrList[0], remoteMrInfoList[0], size, wrId, srcOffset, dstOffset, signaled, immData); } TEST_F(IbPeerToPeerTest, SimpleSendRecv) { @@ -85,7 +103,7 @@ TEST_F(IbPeerToPeerTest, SimpleSendRecv) { const int nelem = 1; auto data = mscclpp::allocUniqueCuda(nelem); - registerBufferAndConnect(data.get(), sizeof(int) * nelem); + registerBuffersAndConnect({data.get()}, {sizeof(int) * nelem}); if (gEnv->rank == 1) { mscclpp::Timer timer; @@ -194,7 +212,7 @@ TEST_F(IbPeerToPeerTest, MemoryConsistency) { const uint64_t nelem = 65536 + 1; auto data = mscclpp::allocUniqueCuda(nelem); - registerBufferAndConnect(data.get(), sizeof(uint64_t) * nelem); + registerBuffersAndConnect({data.get()}, {sizeof(uint64_t) * nelem}); uint64_t res = 0; uint64_t iter = 0; @@ -288,3 +306,106 @@ TEST_F(IbPeerToPeerTest, MemoryConsistency) { EXPECT_EQ(res, 0); } + +TEST_F(IbPeerToPeerTest, SendGather) { + if (gEnv->rank >= 2) { + // This test needs only two ranks + return; + } + + mscclpp::Timer timeout(3); + + const int numDataSrcs = 1; + const int nelemPerMr = 1024; + + // Gather send from rank 0 to 1 + if (gEnv->rank == 0) { + std::vector> dataList; + for (int i = 0; i < numDataSrcs; ++i) { + auto data = mscclpp::allocUniqueCuda(nelemPerMr); + // Fill in data for correctness check + std::vector hostData(nelemPerMr, i + 1); + mscclpp::memcpyCuda(data.get(), hostData.data(), nelemPerMr); + dataList.emplace_back(std::move(data)); + } + + std::vector dataRefList; + for (int i = 0; i < numDataSrcs; ++i) { + dataRefList.emplace_back(dataList[i].get()); + } + + // For sending a completion signal to the remote side + uint64_t outboundSema = 1; + + dataRefList.push_back(&outboundSema); + + std::vector sizeList(numDataSrcs, sizeof(int) * nelemPerMr); + sizeList.push_back(sizeof(outboundSema)); + + registerBuffersAndConnect(dataRefList, sizeList); + + auto& remoteDataMrInfo = remoteMrInfoList[0]; + auto& remoteSemaMrInfo = remoteMrInfoList[1]; + auto& localSemaMr = localMrList[numDataSrcs]; + + std::vector gatherLocalMrList; + for (int i = 0; i < numDataSrcs; ++i) { + gatherLocalMrList.emplace_back(localMrList[i]); + } + std::vector gatherSizeList(numDataSrcs, sizeof(int) * nelemPerMr); + std::vector gatherOffsetList(numDataSrcs, 0); + + qp->stageSendGather(gatherLocalMrList, remoteDataMrInfo, gatherSizeList, /*wrId=*/0, gatherOffsetList, + /*dstOffset=*/0, /*signaled=*/true); + qp->postSend(); + + qp->stageAtomicAdd(localSemaMr, remoteSemaMrInfo, /*wrId=*/0, /*dstOffset=*/0, /*addVal=*/1); + qp->postSend(); + + // Wait for send completion + bool waiting = true; + int spin = 0; + while (waiting) { + int wcNum = qp->pollCq(); + ASSERT_GE(wcNum, 0); + for (int i = 0; i < wcNum; ++i) { + const ibv_wc* wc = qp->getWc(i); + EXPECT_EQ(wc->status, IBV_WC_SUCCESS); + waiting = false; + break; + } + if (spin++ > 1000000) { + FAIL() << "Polling is stuck."; + } + } + } else { + // Data array to receive + auto data = mscclpp::allocUniqueCuda(nelemPerMr * numDataSrcs); + + // For receiving a completion signal from the remote side + uint64_t inboundSema = 0; + + registerBuffersAndConnect({data.get(), &inboundSema}, + {sizeof(int) * nelemPerMr * numDataSrcs, sizeof(inboundSema)}); + + // Wait for a signal from the remote side + volatile uint64_t* ptrInboundSema = &inboundSema; + int spin = 0; + while (*ptrInboundSema == 0) { + if (spin++ > 1000000) { + FAIL() << "Polling is stuck."; + } + } + + // Correctness check + std::vector hostData(nelemPerMr * numDataSrcs); + mscclpp::memcpyCuda(hostData.data(), data.get(), nelemPerMr * numDataSrcs); + for (int i = 0; i < numDataSrcs; ++i) { + for (int j = 0; j < nelemPerMr; ++j) { + EXPECT_EQ(hostData[i * nelemPerMr + j], i + 1); + } + } + } + + bootstrap->barrier(); +} diff --git a/test/mp_unit/mp_unit_tests.hpp b/test/mp_unit/mp_unit_tests.hpp index 39325563..0ec62e79 100644 --- a/test/mp_unit/mp_unit_tests.hpp +++ b/test/mp_unit/mp_unit_tests.hpp @@ -67,7 +67,7 @@ class IbPeerToPeerTest : public IbTestBase { protected: void SetUp() override; - void registerBufferAndConnect(void* buf, size_t size); + void registerBuffersAndConnect(const std::vector& bufList, const std::vector& sizeList); void stageSend(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled); @@ -76,14 +76,16 @@ class IbPeerToPeerTest : public IbTestBase { void stageSendWithImm(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData); + void stageSendGather(const std::vector& sizeList, uint64_t wrId, const std::vector& srcOffsetList, + uint32_t dstOffset, bool signaled); + std::shared_ptr bootstrap; std::shared_ptr ibCtx; mscclpp::IbQp* qp; - const mscclpp::IbMr* mr; - size_t bufSize; + std::vector localMrList; - std::array qpInfo; - std::array mrInfo; + mscclpp::IbQpInfo remoteQpInfo; + std::vector remoteMrInfoList; }; class CommunicatorTestBase : public MultiProcessTest {