From 08e80f1754527fe9f72026d032c1b08301587a8d Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 27 Apr 2023 04:01:46 +0000 Subject: [PATCH] IB: completely replaced with C++ interfaces --- src/communicator.cc | 87 +---- src/connection.cc | 34 +- src/ib.cc | 612 +++++++++++++----------------- src/include/comm.h | 7 +- src/include/communicator.hpp | 6 +- src/include/connection.hpp | 4 +- src/include/ib.h | 69 ---- src/include/ib.hpp | 53 ++- src/include/mscclpp.h | 17 +- src/include/proxy.h | 2 +- src/include/registered_memory.hpp | 6 +- src/init.cc | 79 ++-- src/proxy.cc | 2 +- src/registered_memory.cc | 5 +- tests/unittests/ib_test.cc | 64 ++-- 15 files changed, 409 insertions(+), 638 deletions(-) delete mode 100644 src/include/ib.h diff --git a/src/communicator.cc b/src/communicator.cc index c34dbb31..6c501d70 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -16,7 +16,7 @@ Communicator::Impl::Impl() : comm(nullptr) {} Communicator::Impl::~Impl() { for (auto& entry : ibContexts) { - mscclppIbContextDestroy(entry.second); + delete entry.second; } ibContexts.clear(); if (comm) { @@ -24,13 +24,12 @@ Communicator::Impl::~Impl() { } } -mscclppIbContext* Communicator::Impl::getIbContext(TransportFlags ibTransport) { +IbCtx* Communicator::Impl::getIbContext(TransportFlags ibTransport) { // Find IB context or create it auto it = ibContexts.find(ibTransport); if (it == ibContexts.end()) { auto ibDev = getIBDeviceName(ibTransport); - mscclppIbContext* ibCtx; - MSCCLPPTHROW(mscclppIbContextCreate(&ibCtx, ibDev.c_str())); + IbCtx* ibCtx = new IbCtx(ibDev); ibContexts[ibTransport] = ibCtx; return ibCtx; } else { @@ -92,6 +91,7 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank throw std::runtime_error("Unsupported transport"); } pimpl->connections.push_back(conn); + return conn; } MSCCLPP_API_CPP void Communicator::connectionSetup() { @@ -115,81 +115,4 @@ MSCCLPP_API_CPP int Communicator::size() { return result; } -// TODO: move these elsewhere - -int getIBDeviceCount() { - int num; - ibv_get_device_list(&num); - return num; -} - -std::string getIBDeviceName(TransportFlags ibTransport) { - int num; - struct ibv_device** devices = ibv_get_device_list(&num); - int ibTransportIndex; - switch (ibTransport) { // TODO: get rid of this ugly switch - case TransportIB0: - ibTransportIndex = 0; - break; - case TransportIB1: - ibTransportIndex = 1; - break; - case TransportIB2: - ibTransportIndex = 2; - break; - case TransportIB3: - ibTransportIndex = 3; - break; - case TransportIB4: - ibTransportIndex = 4; - break; - case TransportIB5: - ibTransportIndex = 5; - break; - case TransportIB6: - ibTransportIndex = 6; - break; - case TransportIB7: - ibTransportIndex = 7; - break; - default: - throw std::runtime_error("Not an IB transport"); - } - if (ibTransportIndex >= num) { - throw std::runtime_error("IB transport out of range"); - } - return devices[ibTransportIndex]->name; -} - -TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName) { - int num; - struct ibv_device** devices = ibv_get_device_list(&num); - for (int i = 0; i < num; ++i) { - if (ibDeviceName == devices[i]->name) { - switch (i) { // TODO: get rid of this ugly switch - case 0: - return TransportIB0; - case 1: - return TransportIB1; - case 2: - return TransportIB2; - case 3: - return TransportIB3; - case 4: - return TransportIB4; - case 5: - return TransportIB5; - case 6: - return TransportIB6; - case 7: - return TransportIB7; - default: - throw std::runtime_error("IB device index out of range"); - } - } - } - throw std::runtime_error("IB device not found"); -} - - -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/connection.cc b/src/connection.cc index 8d1b5e11..1e21694c 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -2,6 +2,7 @@ #include "checks.hpp" #include "registered_memory.hpp" #include "npkit/npkit.h" +#include "infiniband/verbs.h" namespace mscclpp { @@ -54,7 +55,7 @@ void CudaIpcConnection::flush() { // IBConnection IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank(remoteRank), tag(tag), transport_(transport), remoteTransport_(TransportNone) { - MSCCLPPTHROW(mscclppIbContextCreateQp(commImpl.getIbContext(transport), &qp)); + qp = commImpl.getIbContext(transport)->createQp(); } IBConnection::~IBConnection() { @@ -85,13 +86,8 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem auto dstMrInfo = dstTransportInfo.ibMrInfo; auto srcMr = srcTransportInfo.ibMr; - qp->stageSend(srcMr, &dstMrInfo, (uint32_t)size, - /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false); - int ret = qp->postSend(); - if (ret != 0) { - // Return value is errno. - WARN("data postSend failed: errno %d", ret); - } + qp->stageSend(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false); + qp->postSend(); // npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size); } @@ -104,15 +100,11 @@ void IBConnection::flush() { continue; } for (int i = 0; i < wcNum; ++i) { - struct ibv_wc* wc = &qp->wcs[i]; + const struct ibv_wc* wc = reinterpret_cast(qp->getWc(i)); if (wc->status != IBV_WC_SUCCESS) { WARN("wc status %d", wc->status); continue; } - if (wc->qp_num != qp->qp->qp_num) { - WARN("got wc of unknown qp_num %d", wc->qp_num); - continue; - } if (wc->opcode == IBV_WC_RDMA_WRITE) { isWaiting = false; break; @@ -123,18 +115,16 @@ void IBConnection::flush() { } void IBConnection::startSetup(Communicator& comm) { - comm.bootstrap().send(&qp->info, sizeof(qp->info), remoteRank, tag); + // TODO(chhwang): temporarily disabled to compile + // comm.bootstrap().send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank, tag); } void IBConnection::endSetup(Communicator& comm) { - mscclppIbQpInfo qpInfo; - comm.bootstrap().recv(&qpInfo, sizeof(qpInfo), remoteRank, tag); - if (qp->rtr(&qpInfo) != 0) { - throw std::runtime_error("Failed to transition QP to RTR"); - } - if (qp->rts() != 0) { - throw std::runtime_error("Failed to transition QP to RTS"); - } + IbQpInfo qpInfo; + // TODO(chhwang): temporarily disabled to compile + // comm.bootstrap().recv(&qpInfo, sizeof(qpInfo), remoteRank, tag); + qp->rtr(qpInfo); + qp->rts(); } } // namespace mscclpp diff --git a/src/ib.cc b/src/ib.cc index 4a094761..4dc0285b 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -4,370 +4,67 @@ #include #include #include -#include +#include "mscclpp.hpp" #include "alloc.h" #include "comm.h" #include "debug.h" -#include "ib.h" #include "ib.hpp" #include "checks.hpp" +#include +#include -mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName) -{ - struct mscclppIbContext* _ctx; - MSCCLPPCHECK(mscclppCalloc(&_ctx, 1)); +namespace mscclpp { - std::vector ports; - - int num; - struct ibv_device** devices = ibv_get_device_list(&num); - for (int i = 0; i < num; ++i) { - if (strncmp(devices[i]->name, ibDevName, IBV_SYSFS_NAME_MAX) == 0) { - _ctx->ctx = ibv_open_device(devices[i]); - break; - } - } - ibv_free_device_list(devices); - if (_ctx->ctx == nullptr) { - WARN("ibv_open_device failed (errno %d, device name %s)", errno, ibDevName); - goto fail; - } - - // Check available ports - struct ibv_device_attr devAttr; - if (ibv_query_device(_ctx->ctx, &devAttr) != 0) { - WARN("ibv_query_device failed (errno %d, device name %s)", errno, ibDevName); - goto fail; - } - - for (uint8_t i = 1; i <= devAttr.phys_port_cnt; ++i) { - struct ibv_port_attr portAttr; - if (ibv_query_port(_ctx->ctx, i, &portAttr) != 0) { - WARN("ibv_query_port failed (errno %d, port %d)", errno, i); - goto fail; - } - if (portAttr.state != IBV_PORT_ACTIVE) { - continue; - } - if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND && portAttr.link_layer != IBV_LINK_LAYER_ETHERNET) { - continue; - } - ports.push_back((int)i); - } - if (ports.size() == 0) { - WARN("no active IB port found"); - goto fail; - } - MSCCLPPCHECK(mscclppCalloc(&_ctx->ports, ports.size())); - _ctx->nPorts = (int)ports.size(); - for (int i = 0; i < _ctx->nPorts; ++i) { - _ctx->ports[i] = ports[i]; - } - - _ctx->pd = ibv_alloc_pd(_ctx->ctx); - if (_ctx->pd == NULL) { - WARN("ibv_alloc_pd failed (errno %d)", errno); - goto fail; - } - - *ctx = _ctx; - return mscclppSuccess; -fail: - *ctx = NULL; - if (_ctx->ports != NULL) { - free(_ctx->ports); - } - free(_ctx); - return mscclppInternalError; -} - -mscclppResult_t mscclppIbContextDestroy(struct mscclppIbContext* ctx) -{ - for (int i = 0; i < ctx->nMrs; ++i) { - if (ctx->mrs[i].mr) { - ibv_dereg_mr(ctx->mrs[i].mr); - } - } - for (int i = 0; i < ctx->nQps; ++i) { - if (ctx->qps[i].qp) { - ibv_destroy_qp(ctx->qps[i].qp); - } - ibv_destroy_cq(ctx->qps[i].cq); - free(ctx->qps[i].wcs); - free(ctx->qps[i].sges); - free(ctx->qps[i].wrs); - } - if (ctx->pd != NULL) { - ibv_dealloc_pd(ctx->pd); - } - if (ctx->ctx != NULL) { - ibv_close_device(ctx->ctx); - } - free(ctx->mrs); - free(ctx->qps); - free(ctx->ports); - free(ctx); - return mscclppSuccess; -} - -mscclppResult_t mscclppIbContextCreateQp(struct mscclppIbContext* ctx, struct mscclppIbQp** ibQp, int port /*=-1*/) -{ - if (port < 0) { - port = ctx->ports[0]; - } else { - bool found = false; - for (int i = 0; i < ctx->nPorts; ++i) { - if (ctx->ports[i] == port) { - found = true; - break; - } - } - if (!found) { - WARN("invalid IB port: %d", port); - return mscclppInternalError; - } - } - - struct ibv_cq* cq = ibv_create_cq(ctx->ctx, MSCCLPP_IB_CQ_SIZE, NULL, NULL, 0); - if (cq == NULL) { - WARN("ibv_create_cq failed (errno %d)", errno); - return mscclppInternalError; - } - - struct ibv_qp_init_attr qp_init_attr; - std::memset(&qp_init_attr, 0, sizeof(struct ibv_qp_init_attr)); - qp_init_attr.sq_sig_all = 0; - qp_init_attr.send_cq = cq; - qp_init_attr.recv_cq = cq; - qp_init_attr.qp_type = IBV_QPT_RC; - qp_init_attr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; - qp_init_attr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; - qp_init_attr.cap.max_send_sge = 1; - qp_init_attr.cap.max_recv_sge = 1; - qp_init_attr.cap.max_inline_data = 0; - struct ibv_qp* qp = ibv_create_qp(ctx->pd, &qp_init_attr); - if (qp == nullptr) { - WARN("ibv_create_qp failed (errno %d)", errno); - return mscclppInternalError; - } - struct ibv_port_attr port_attr; - if (ibv_query_port(ctx->ctx, port, &port_attr) != 0) { - WARN("ibv_query_port failed (errno %d, port %d)", errno, port); - return mscclppInternalError; - } - - // Register QP to this ctx - qp->context = ctx->ctx; - if (qp->context == NULL) { - WARN("IB context is NULL"); - return mscclppInternalError; - } - ctx->nQps++; - if (ctx->qps == NULL) { - MSCCLPPCHECK(mscclppCalloc(&ctx->qps, MAXCONNECTIONS)); - ctx->maxQps = MAXCONNECTIONS; - } - if (ctx->maxQps < ctx->nQps) { - WARN("too many QPs"); - return mscclppInternalError; - } - struct mscclppIbQp* _ibQp = &ctx->qps[ctx->nQps - 1]; - _ibQp->qp = qp; - _ibQp->info.lid = port_attr.lid; - _ibQp->info.port = port; - _ibQp->info.linkLayer = port_attr.link_layer; - _ibQp->info.qpn = qp->qp_num; - _ibQp->info.mtu = port_attr.active_mtu; - if (port_attr.link_layer != IBV_LINK_LAYER_INFINIBAND) { - union ibv_gid gid; - if (ibv_query_gid(ctx->ctx, port, 0, &gid) != 0) { - WARN("ibv_query_gid failed (errno %d)", errno); - return mscclppInternalError; - } - _ibQp->info.spn = gid.global.subnet_prefix; - } - - struct ibv_qp_attr qp_attr; - std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); - qp_attr.qp_state = IBV_QPS_INIT; - qp_attr.pkey_index = 0; - qp_attr.port_num = port; - qp_attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; - if (ibv_modify_qp(qp, &qp_attr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { - WARN("ibv_modify_qp failed (errno %d)", errno); - return mscclppInternalError; - } - - MSCCLPPCHECK(mscclppCalloc(&_ibQp->wrs, MSCCLPP_IB_MAX_SENDS)); - MSCCLPPCHECK(mscclppCalloc(&_ibQp->sges, MSCCLPP_IB_MAX_SENDS)); - MSCCLPPCHECK(mscclppCalloc(&_ibQp->wcs, MSCCLPP_IB_CQ_POLL_NUM)); - _ibQp->cq = cq; - - *ibQp = _ibQp; - - return mscclppSuccess; -} - -mscclppResult_t mscclppIbContextRegisterMr(struct mscclppIbContext* ctx, void* buff, size_t size, - struct mscclppIbMr** ibMr) +IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff) { if (size == 0) { - WARN("invalid size: %zu", size); - return mscclppInvalidArgument; + throw std::runtime_error("invalid size: " + std::to_string(size)); } static __thread uintptr_t pageSize = 0; if (pageSize == 0) { pageSize = sysconf(_SC_PAGESIZE); } uintptr_t addr = reinterpret_cast(buff) & -pageSize; - size_t pages = (size + (reinterpret_cast(buff) - addr) + pageSize - 1) / pageSize; - struct ibv_mr* mr = - ibv_reg_mr(ctx->pd, reinterpret_cast(addr), pages * pageSize, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING); - if (mr == nullptr) { - WARN("ibv_reg_mr failed (errno %d)", errno); - return mscclppInternalError; + std::size_t pages = (size + (reinterpret_cast(buff) - addr) + pageSize - 1) / pageSize; + struct ibv_pd* _pd = reinterpret_cast(pd); + struct ibv_mr* _mr = ibv_reg_mr(_pd, reinterpret_cast(addr), pages * pageSize, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING); + if (_mr == nullptr) { + std::stringstream err; + err << "ibv_reg_mr failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); } - ctx->nMrs++; - if (ctx->mrs == NULL) { - MSCCLPPCHECK(mscclppCalloc(&ctx->mrs, MAXCONNECTIONS)); - ctx->maxMrs = MAXCONNECTIONS; - } - if (ctx->maxMrs < ctx->nMrs) { - WARN("too many MRs"); - return mscclppInternalError; - } - struct mscclppIbMr* _ibMr = &ctx->mrs[ctx->nMrs - 1]; - _ibMr->mr = mr; - _ibMr->buff = buff; - _ibMr->info.addr = (uint64_t)buff; - _ibMr->info.rkey = mr->rkey; - *ibMr = _ibMr; - return mscclppSuccess; + this->mr = _mr; + this->size = pages * pageSize; } -////////////////////////////////////////////////////////////////////////////// - -int mscclppIbQp::rtr(const mscclppIbQpInfo* info) +IbMr::~IbMr() { - struct ibv_qp_attr qp_attr; - std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); - qp_attr.qp_state = IBV_QPS_RTR; - qp_attr.path_mtu = info->mtu; - qp_attr.dest_qp_num = info->qpn; - qp_attr.rq_psn = 0; - qp_attr.max_dest_rd_atomic = 1; - qp_attr.min_rnr_timer = 0x12; - if (info->linkLayer == IBV_LINK_LAYER_ETHERNET) { - qp_attr.ah_attr.is_global = 1; - qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info->spn; - qp_attr.ah_attr.grh.dgid.global.interface_id = info->lid; - qp_attr.ah_attr.grh.flow_label = 0; - qp_attr.ah_attr.grh.sgid_index = 0; - qp_attr.ah_attr.grh.hop_limit = 255; - qp_attr.ah_attr.grh.traffic_class = 0; - } else { - qp_attr.ah_attr.is_global = 0; - qp_attr.ah_attr.dlid = info->lid; - } - qp_attr.ah_attr.sl = 0; - qp_attr.ah_attr.src_path_bits = 0; - qp_attr.ah_attr.port_num = info->port; - return ibv_modify_qp(this->qp, &qp_attr, - IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | - IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); + ibv_dereg_mr(reinterpret_cast(this->mr)); } -int mscclppIbQp::rts() +IbMrInfo IbMr::getInfo() const { - struct ibv_qp_attr qp_attr; - std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); - qp_attr.qp_state = IBV_QPS_RTS; - qp_attr.timeout = 18; - qp_attr.retry_cnt = 7; - qp_attr.rnr_retry = 7; - qp_attr.sq_psn = 0; - qp_attr.max_rd_atomic = 1; - return ibv_modify_qp(this->qp, &qp_attr, - IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | - IBV_QP_MAX_QP_RD_ATOMIC); + IbMrInfo info; + info.addr = reinterpret_cast(this->buff); + info.rkey = reinterpret_cast(this->mr)->rkey; + return info; } -int mscclppIbQp::stageSend(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId, - uint64_t srcOffset, uint64_t dstOffset, bool signaled) +const void* IbMr::getBuff() const { - if (this->wrn >= MSCCLPP_IB_MAX_SENDS) { - return -1; - } - int wrn = this->wrn; - struct ibv_send_wr* wr_ = &this->wrs[wrn]; - struct ibv_sge* sge_ = &this->sges[wrn]; - // std::memset(wr_, 0, sizeof(struct ibv_send_wr)); - // std::memset(sge_, 0, sizeof(struct ibv_sge)); - wr_->wr_id = wrId; - wr_->sg_list = sge_; - wr_->num_sge = 1; - wr_->opcode = IBV_WR_RDMA_WRITE; - wr_->send_flags = signaled ? IBV_SEND_SIGNALED : 0; - wr_->wr.rdma.remote_addr = (uint64_t)(info->addr) + dstOffset; - wr_->wr.rdma.rkey = info->rkey; - wr_->next = nullptr; - sge_->addr = (uint64_t)(ibMr->buff) + srcOffset; - sge_->length = size; - sge_->lkey = ibMr->mr->lkey; - if (wrn > 0) { - this->wrs[wrn - 1].next = wr_; - } - this->wrn++; - return this->wrn; + return this->buff; } -int mscclppIbQp::stageSendWithImm(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId, - uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData) +uint32_t IbMr::getLkey() const { - int wrn = this->stageSend(ibMr, info, size, wrId, srcOffset, dstOffset, signaled); - this->wrs[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - this->wrs[wrn - 1].imm_data = immData; - return wrn; + return reinterpret_cast(this->mr)->lkey; } -int mscclppIbQp::postSend() -{ - if (this->wrn == 0) { - return 0; - } - - struct ibv_send_wr* bad_wr; - int ret = ibv_post_send(this->qp, this->wrs, &bad_wr); - if (ret != 0) { - return ret; - } - this->wrn = 0; - return 0; -} - -int mscclppIbQp::postRecv(uint64_t wrId) -{ - struct ibv_recv_wr wr, *bad_wr; - wr.wr_id = wrId; - wr.sg_list = nullptr; - wr.num_sge = 0; - wr.next = nullptr; - return ibv_post_recv(this->qp, &wr, &bad_wr); -} - -int mscclppIbQp::pollCq() -{ - return ibv_poll_cq(this->cq, MSCCLPP_IB_CQ_POLL_NUM, this->wcs); -} - -namespace mscclpp { - IbQp::IbQp(void* ctx, void* pd, int port) { - struct ibv_context* _ctx = static_cast(ctx); - struct ibv_pd* _pd = static_cast(pd); + struct ibv_context* _ctx = reinterpret_cast(ctx); + struct ibv_pd* _pd = reinterpret_cast(pd); this->cq = ibv_create_cq(_ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0); if (this->cq == nullptr) { @@ -379,8 +76,8 @@ IbQp::IbQp(void* ctx, void* pd, int port) struct ibv_qp_init_attr qpInitAttr; std::memset(&qpInitAttr, 0, sizeof(qpInitAttr)); qpInitAttr.sq_sig_all = 0; - qpInitAttr.send_cq = static_cast(this->cq); - qpInitAttr.recv_cq = static_cast(this->cq); + qpInitAttr.send_cq = reinterpret_cast(this->cq); + qpInitAttr.recv_cq = reinterpret_cast(this->cq); qpInitAttr.qp_type = IBV_QPT_RC; qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; @@ -428,14 +125,160 @@ IbQp::IbQp(void* ctx, void* pd, int port) throw std::runtime_error(err.str()); } this->qp = _qp; + MSCCLPPTHROW(mscclppCalloc(reinterpret_cast(&this->wrs), MSCCLPP_IB_MAX_SENDS)); + MSCCLPPTHROW(mscclppCalloc(reinterpret_cast(&this->sges), MSCCLPP_IB_MAX_SENDS)); + MSCCLPPTHROW(mscclppCalloc(reinterpret_cast(&this->wcs), MSCCLPP_IB_CQ_POLL_NUM)); } -IbCtx::IbCtx(const std::string& ibDevName) +IbQp::~IbQp() +{ + ibv_destroy_qp(reinterpret_cast(this->qp)); + ibv_destroy_cq(reinterpret_cast(this->cq)); + std::free(this->wrs); + std::free(this->sges); + std::free(this->wcs); +} + +void IbQp::rtr(const IbQpInfo& info) +{ + struct ibv_qp_attr qp_attr; + std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); + qp_attr.qp_state = IBV_QPS_RTR; + qp_attr.path_mtu = static_cast(info.mtu); + qp_attr.dest_qp_num = info.qpn; + qp_attr.rq_psn = 0; + qp_attr.max_dest_rd_atomic = 1; + qp_attr.min_rnr_timer = 0x12; + if (info.linkLayer == IBV_LINK_LAYER_ETHERNET) { + qp_attr.ah_attr.is_global = 1; + qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info.spn; + qp_attr.ah_attr.grh.dgid.global.interface_id = info.lid; + qp_attr.ah_attr.grh.flow_label = 0; + qp_attr.ah_attr.grh.sgid_index = 0; + qp_attr.ah_attr.grh.hop_limit = 255; + qp_attr.ah_attr.grh.traffic_class = 0; + } else { + qp_attr.ah_attr.is_global = 0; + qp_attr.ah_attr.dlid = info.lid; + } + qp_attr.ah_attr.sl = 0; + qp_attr.ah_attr.src_path_bits = 0; + qp_attr.ah_attr.port_num = info.port; + int ret = ibv_modify_qp(reinterpret_cast(this->qp), &qp_attr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); + if (ret != 0) { + std::stringstream err; + err << "ibv_modify_qp failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } +} + +void IbQp::rts() +{ + struct ibv_qp_attr qp_attr; + std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); + qp_attr.qp_state = IBV_QPS_RTS; + qp_attr.timeout = 18; + qp_attr.retry_cnt = 7; + qp_attr.rnr_retry = 7; + qp_attr.sq_psn = 0; + qp_attr.max_rd_atomic = 1; + int ret = ibv_modify_qp(reinterpret_cast(this->qp), &qp_attr, IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC); + if (ret != 0) { + std::stringstream err; + err << "ibv_modify_qp failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } +} + +int IbQp::stageSend(const IbMr *mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled) +{ + if (this->wrn >= MSCCLPP_IB_MAX_SENDS) { + return -1; + } + int wrn = this->wrn; + struct ibv_send_wr* wrs_ = reinterpret_cast(this->wrs); + struct ibv_sge* sges_ = reinterpret_cast(this->sges); + + struct ibv_send_wr* wr_ = &wrs_[wrn]; + struct ibv_sge* sge_ = &sges_[wrn]; + wr_->wr_id = wrId; + wr_->sg_list = sge_; + wr_->num_sge = 1; + wr_->opcode = IBV_WR_RDMA_WRITE; + wr_->send_flags = signaled ? IBV_SEND_SIGNALED : 0; + wr_->wr.rdma.remote_addr = (uint64_t)(info.addr) + dstOffset; + wr_->wr.rdma.rkey = info.rkey; + wr_->next = nullptr; + sge_->addr = (uint64_t)(mr->getBuff()) + srcOffset; + sge_->length = size; + sge_->lkey = mr->getLkey(); + if (wrn > 0) { + wrs_[wrn - 1].next = wr_; + } + this->wrn++; + return this->wrn; +} + +int IbQp::stageSendWithImm(const IbMr *mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData) +{ + int wrn = this->stageSend(mr, info, size, wrId, srcOffset, dstOffset, signaled); + struct ibv_send_wr* wrs_ = reinterpret_cast(this->wrs); + wrs_[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wrs_[wrn - 1].imm_data = immData; + return wrn; +} + +void IbQp::postSend() +{ + if (this->wrn == 0) { + return; + } + struct ibv_send_wr* bad_wr; + int ret = ibv_post_send(reinterpret_cast(this->qp), reinterpret_cast(this->wrs), &bad_wr); + if (ret != 0) { + std::stringstream err; + err << "ibv_post_send failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } + this->wrn = 0; +} + +void IbQp::postRecv(uint64_t wrId) +{ + struct ibv_recv_wr wr, *bad_wr; + wr.wr_id = wrId; + wr.sg_list = nullptr; + wr.num_sge = 0; + wr.next = nullptr; + int ret = ibv_post_recv(reinterpret_cast(this->qp), &wr, &bad_wr); + if (ret != 0) { + std::stringstream err; + err << "ibv_post_recv failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } +} + +int IbQp::pollCq() +{ + return ibv_poll_cq(reinterpret_cast(this->cq), MSCCLPP_IB_CQ_POLL_NUM, reinterpret_cast(this->wcs)); +} + +const IbQpInfo& IbQp::getInfo() const +{ + return this->info; +} + +const void* IbQp::getWc(int idx) const +{ + return &reinterpret_cast(this->wcs)[idx]; +} + +IbCtx::IbCtx(const std::string& devName) : devName(devName) { int num; struct ibv_device** devices = ibv_get_device_list(&num); for (int i = 0; i < num; ++i) { - if (std::string(devices[i]->name) == ibDevName) { + if (std::string(devices[i]->name) == devName) { this->ctx = ibv_open_device(devices[i]); break; } @@ -443,10 +286,10 @@ IbCtx::IbCtx(const std::string& ibDevName) ibv_free_device_list(devices); if (this->ctx == nullptr) { std::stringstream err; - err << "ibv_open_device failed (errno " << errno << ", device name << " << ibDevName << ")"; + err << "ibv_open_device failed (errno " << errno << ", device name << " << devName << ")"; throw std::runtime_error(err.str()); } - this->pd = ibv_alloc_pd(static_cast(this->ctx)); + this->pd = ibv_alloc_pd(reinterpret_cast(this->ctx)); if (this->pd == nullptr) { std::stringstream err; err << "ibv_alloc_pd failed (errno " << errno << ")"; @@ -456,18 +299,20 @@ IbCtx::IbCtx(const std::string& ibDevName) IbCtx::~IbCtx() { + this->mrs.clear(); + this->qps.clear(); if (this->pd != nullptr) { - ibv_dealloc_pd(static_cast(this->pd)); + ibv_dealloc_pd(reinterpret_cast(this->pd)); } if (this->ctx != nullptr) { - ibv_close_device(static_cast(this->ctx)); + ibv_close_device(reinterpret_cast(this->ctx)); } } bool IbCtx::isPortUsable(int port) const { struct ibv_port_attr portAttr; - if (ibv_query_port(static_cast(this->ctx), port, &portAttr) != 0) { + if (ibv_query_port(reinterpret_cast(this->ctx), port, &portAttr) != 0) { std::stringstream err; err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")"; throw std::runtime_error(err.str()); @@ -479,7 +324,7 @@ bool IbCtx::isPortUsable(int port) const int IbCtx::getAnyActivePort() const { struct ibv_device_attr devAttr; - if (ibv_query_device(static_cast(this->ctx), &devAttr) != 0) { + if (ibv_query_device(reinterpret_cast(this->ctx), &devAttr) != 0) { std::stringstream err; err << "ibv_query_device failed (errno " << errno << ")"; throw std::runtime_error(err.str()); @@ -506,4 +351,89 @@ IbQp* IbCtx::createQp(int port /*=-1*/) return qps.back().get(); } +const IbMr* IbCtx::registerMr(void* buff, std::size_t size) +{ + mrs.emplace_back(new IbMr(this->pd, buff, size)); + return mrs.back().get(); +} + +const std::string& IbCtx::getDevName() const +{ + return this->devName; +} + +int getIBDeviceCount() { + int num; + ibv_get_device_list(&num); + return num; +} + +std::string getIBDeviceName(TransportFlags ibTransport) { + int num; + struct ibv_device** devices = ibv_get_device_list(&num); + int ibTransportIndex; + switch (ibTransport) { // TODO: get rid of this ugly switch + case TransportIB0: + ibTransportIndex = 0; + break; + case TransportIB1: + ibTransportIndex = 1; + break; + case TransportIB2: + ibTransportIndex = 2; + break; + case TransportIB3: + ibTransportIndex = 3; + break; + case TransportIB4: + ibTransportIndex = 4; + break; + case TransportIB5: + ibTransportIndex = 5; + break; + case TransportIB6: + ibTransportIndex = 6; + break; + case TransportIB7: + ibTransportIndex = 7; + break; + default: + throw std::runtime_error("Not an IB transport"); + } + if (ibTransportIndex >= num) { + throw std::runtime_error("IB transport out of range"); + } + return devices[ibTransportIndex]->name; +} + +TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName) { + int num; + struct ibv_device** devices = ibv_get_device_list(&num); + for (int i = 0; i < num; ++i) { + if (ibDeviceName == devices[i]->name) { + switch (i) { // TODO: get rid of this ugly switch + case 0: + return TransportIB0; + case 1: + return TransportIB1; + case 2: + return TransportIB2; + case 3: + return TransportIB3; + case 4: + return TransportIB4; + case 5: + return TransportIB5; + case 6: + return TransportIB6; + case 7: + return TransportIB7; + default: + throw std::runtime_error("IB device index out of range"); + } + } + } + throw std::runtime_error("IB device not found"); +} + } // namespace mscclpp diff --git a/src/include/comm.h b/src/include/comm.h index 8275e0cb..dce724fa 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -7,9 +7,10 @@ #ifndef MSCCLPP_COMM_H_ #define MSCCLPP_COMM_H_ -#include "ib.h" +#include "ib.hpp" #include "proxy.h" #include +#include #define MAXCONNECTIONS 64 @@ -31,7 +32,7 @@ struct mscclppConn std::vector bufferRegistrations; std::vector remoteBufferRegistrations; - struct mscclppIbContext* ibCtx; + mscclpp::IbCtx* ibCtx; #if defined(ENABLE_NPKIT) std::vector npkitUsedReqIds; std::vector npkitFreeReqIds; @@ -57,7 +58,7 @@ struct mscclppComm // Flag to ask MSCCLPP kernels to abort volatile uint32_t* abortFlag; - struct mscclppIbContext* ibContext[MSCCLPP_IB_MAX_DEVS]; + std::unique_ptr ibContext[MSCCLPP_IB_MAX_DEVS]; struct mscclppProxyState* proxyState[MSCCLPP_PROXY_MAX_NUM]; }; diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index 879501c0..37abb31b 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -5,7 +5,7 @@ #include "mscclpp.h" #include "channel.hpp" #include "proxy.hpp" -#include "ib.h" +#include "ib.hpp" #include namespace mscclpp { @@ -15,13 +15,13 @@ class ConnectionBase; struct Communicator::Impl { mscclppComm_t comm; std::vector> connections; - std::unordered_map ibContexts; + std::unordered_map ibContexts; Impl(); ~Impl(); - mscclppIbContext* getIbContext(TransportFlags ibTransport); + IbCtx* getIbContext(TransportFlags ibTransport); }; } // namespace mscclpp diff --git a/src/include/connection.hpp b/src/include/connection.hpp index ac1dd6a1..dcf21362 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -3,7 +3,7 @@ #include "mscclpp.hpp" #include -#include "ib.h" +#include "ib.hpp" #include "communicator.hpp" namespace mscclpp { @@ -38,7 +38,7 @@ class IBConnection : public ConnectionBase { int tag; TransportFlags transport_; TransportFlags remoteTransport_; - mscclppIbQp* qp; + IbQp* qp; public: IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl); diff --git a/src/include/ib.h b/src/include/ib.h deleted file mode 100644 index 7494ab11..00000000 --- a/src/include/ib.h +++ /dev/null @@ -1,69 +0,0 @@ -#ifndef MSCCLPP_IB_H_ -#define MSCCLPP_IB_H_ - -#include "mscclpp.h" -#include -#include -#include -#include - -#define MSCCLPP_IB_CQ_SIZE 1024 -#define MSCCLPP_IB_CQ_POLL_NUM 4 -#define MSCCLPP_IB_MAX_SENDS 64 -#define MSCCLPP_IB_MAX_DEVS 8 - -// QP info to be shared with the remote peer -struct mscclppIbQpInfo -{ - uint16_t lid; - uint8_t port; - uint8_t linkLayer; - uint32_t qpn; - uint64_t spn; - ibv_mtu mtu; -}; - -// IB queue pair -struct mscclppIbQp -{ - struct ibv_qp* qp; - struct mscclppIbQpInfo info; - struct ibv_send_wr* wrs; - struct ibv_sge* sges; - struct ibv_cq* cq; - struct ibv_wc* wcs; - int wrn; - - int rtr(const mscclppIbQpInfo* info); - int rts(); - int stageSend(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId, uint64_t srcOffset, - uint64_t dstOffset, bool signaled); - int stageSendWithImm(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId, - uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData); - int postSend(); - int postRecv(uint64_t wrId); - int pollCq(); -}; - -// Holds resources of a single IB device. -struct mscclppIbContext -{ - struct ibv_context* ctx; - struct ibv_pd* pd; - int* ports; - int nPorts; - struct mscclppIbQp* qps; - int nQps; - int maxQps; - struct mscclppIbMr* mrs; - int nMrs; - int maxMrs; -}; - -mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName); -mscclppResult_t mscclppIbContextDestroy(struct mscclppIbContext* ctx); -mscclppResult_t mscclppIbContextCreateQp(struct mscclppIbContext* ctx, struct mscclppIbQp** ibQp, int port = -1); -mscclppResult_t mscclppIbContextRegisterMr(struct mscclppIbContext* ctx, void* buff, size_t size, - struct mscclppIbMr** ibMr); - -#endif diff --git a/src/include/ib.hpp b/src/include/ib.hpp index 85c92af7..d04b75bd 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -5,8 +5,38 @@ #include #include +#define MSCCLPP_IB_CQ_SIZE 1024 +#define MSCCLPP_IB_CQ_POLL_NUM 1 +#define MSCCLPP_IB_MAX_SENDS 64 +#define MSCCLPP_IB_MAX_DEVS 8 + namespace mscclpp { +struct IbMrInfo +{ + uint64_t addr; + uint32_t rkey; +}; + +class IbMr +{ +public: + ~IbMr(); + + IbMrInfo getInfo() const; + const void* getBuff() const; + uint32_t getLkey() const; + +private: + IbMr(void* pd, void* buff, std::size_t size); + + void* mr; + void* buff; + std::size_t size; + + friend class IbCtx; +}; + // QP info to be shared with the remote peer struct IbQpInfo { @@ -15,7 +45,7 @@ struct IbQpInfo uint8_t linkLayer; uint32_t qpn; uint64_t spn; - uint32_t mtu; + int mtu; }; class IbQp @@ -23,11 +53,22 @@ class IbQp public: ~IbQp(); - IbQpInfo info; + void rtr(const IbQpInfo& info); + void rts(); + int stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled); + int stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData); + void postSend(); + void postRecv(uint64_t wrId); + int pollCq(); + + const IbQpInfo& getInfo() const; + const void* getWc(int idx) const; private: IbQp(void* ctx, void* pd, int port); + IbQpInfo info; + void* qp; void* cq; void* wcs; @@ -38,22 +79,26 @@ private: friend class IbCtx; }; - class IbCtx { public: - IbCtx(const std::string& ibDevName); + IbCtx(const std::string& devName); ~IbCtx(); IbQp* createQp(int port = -1); + const IbMr* registerMr(void* buff, std::size_t size); + + const std::string& getDevName() const; private: bool isPortUsable(int port) const; int getAnyActivePort() const; + const std::string devName; void* ctx; void* pd; std::list> qps; + std::list> mrs; }; } // namespace mscclpp diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index 6f96af10..c01246ab 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -207,25 +207,10 @@ typedef struct char internal[MSCCLPP_UNIQUE_ID_BYTES]; } mscclppUniqueId; -// MR info to be shared with the remote peer -struct mscclppIbMrInfo -{ - uint64_t addr; - uint32_t rkey; -}; - -// IB memory region -struct mscclppIbMr -{ - struct ibv_mr* mr; - void* buff; - struct mscclppIbMrInfo info; -}; - struct mscclppRegisteredMemoryP2P { void* remoteBuff; - mscclppIbMr* IbMr; + const void* IbMr; }; struct mscclppRegisteredMemory diff --git a/src/include/proxy.h b/src/include/proxy.h index 3da0196c..3746806b 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -59,7 +59,7 @@ struct mscclppProxyState mscclppProxyRunState_t run; int numaNodeToBind; - struct mscclppIbContext* ibContext; // For IB connection only + mscclpp::IbCtx* ibContext; // For IB connection only cudaStream_t p2pStream; // for P2P DMA engine only struct mscclppProxyFifo fifo; diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index 7a0ab1d0..d2270d46 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -3,7 +3,7 @@ #include "mscclpp.hpp" #include "mscclpp.h" -#include "ib.h" +#include "ib.hpp" #include "communicator.hpp" #include @@ -16,8 +16,8 @@ struct TransportInfo { bool ibLocal; union { cudaIpcMemHandle_t cudaIpcHandle; - mscclppIbMr* ibMr; - mscclppIbMrInfo ibMrInfo; + const IbMr* ibMr; + IbMrInfo ibMrInfo; }; }; diff --git a/src/init.cc b/src/init.cc index 7cf159c8..c5b6a66b 100644 --- a/src/init.cc +++ b/src/init.cc @@ -7,6 +7,7 @@ #include "gdr.h" #endif #include "mscclpp.h" +#include "infiniband/verbs.h" #include #include #include @@ -191,7 +192,7 @@ MSCCLPP_API mscclppResult_t mscclppCommDestroy(mscclppComm_t comm) for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) { if (comm->ibContext[i]) { - MSCCLPPCHECK(mscclppIbContextDestroy(comm->ibContext[i])); + comm->ibContext[i].reset(nullptr); } } @@ -366,24 +367,17 @@ struct mscclppHostIBConn : mscclppHostConn } void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize) { - this->ibQp->stageSend(this->ibMrs[src], &this->remoteIbMrInfos[dst], (uint32_t)dataSize, + this->ibQp->stageSend(this->ibMrs[src], this->remoteIbMrInfos[dst], (uint32_t)dataSize, /*wrId=*/0, /*srcOffset=*/srcDataOffset, /*dstOffset=*/dstDataOffset, /*signaled=*/false); - int ret = this->ibQp->postSend(); - if (ret != 0) { - // Return value is errno. - WARN("data postSend failed: errno %d", ret); - } + this->ibQp->postSend(); npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)dataSize); } void signal() { // My local device flag is copied to the remote's proxy flag - this->ibQp->stageSend(this->ibMrs[0], &this->remoteIbMrInfos[0], sizeof(uint64_t), + this->ibQp->stageSend(this->ibMrs[0], this->remoteIbMrInfos[0], sizeof(uint64_t), /*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/sizeof(uint64_t), /*signaled=*/true); - int ret = this->ibQp->postSend(); - if (ret != 0) { - WARN("flag postSend failed: errno %d", ret); - } + this->ibQp->postSend(); npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_FLAG_ENTRY, (uint32_t)sizeof(uint64_t)); } void wait() @@ -399,15 +393,11 @@ struct mscclppHostIBConn : mscclppHostConn continue; } for (int i = 0; i < wcNum; ++i) { - struct ibv_wc* wc = &this->ibQp->wcs[i]; + struct ibv_wc* wc = (struct ibv_wc*)this->ibQp->getWc(i); if (wc->status != IBV_WC_SUCCESS) { WARN("wc status %d", wc->status); continue; } - if (wc->qp_num != this->ibQp->qp->qp_num) { - WARN("got wc of unknown qp_num %d", wc->qp_num); - continue; - } if (wc->opcode == IBV_WC_RDMA_WRITE) { isWaiting = false; break; @@ -418,9 +408,9 @@ struct mscclppHostIBConn : mscclppHostConn } mscclppConn* conn; - struct mscclppIbQp* ibQp; - std::vector ibMrs; - std::vector remoteIbMrInfos; + mscclpp::IbQp* ibQp; + std::vector ibMrs; + std::vector remoteIbMrInfos; }; MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag, mscclppTransport_t transportType, const char* ibDev) @@ -458,7 +448,7 @@ MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int if (firstNullIdx == -1) { firstNullIdx = i; } - } else if (strncmp(comm->ibContext[i]->ctx->device->name, ibDev, IBV_SYSFS_NAME_MAX) == 0) { + } else if (strncmp(comm->ibContext[i]->getDevName().c_str(), ibDev, IBV_SYSFS_NAME_MAX) == 0) { ibDevIdx = i; break; } @@ -468,13 +458,10 @@ MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int if (ibDevIdx == -1) { // Create a new context. ibDevIdx = firstNullIdx; - if (mscclppIbContextCreate(&comm->ibContext[ibDevIdx], ibDev) != mscclppSuccess) { - WARN("Failed to create IB context"); - return mscclppInternalError; - } + comm->ibContext[ibDevIdx].reset(new mscclpp::IbCtx(std::string(ibDev))); } // Set the ib context for this conn - conn->ibCtx = comm->ibContext[ibDevIdx]; + conn->ibCtx = comm->ibContext[ibDevIdx].get(); } else if (transportType == mscclppTransportP2P) { // do the rest of the initialization later @@ -609,17 +596,17 @@ MSCCLPP_API mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t com struct mscclppBufferRegistrationInfo { cudaIpcMemHandle_t cudaHandle; - mscclppIbMrInfo ibMrInfo; + mscclpp::IbMrInfo ibMrInfo; uint64_t size; }; struct connInfo { - mscclppIbQpInfo infoQp; + mscclpp::IbQpInfo infoQp; std::vector bufferInfos; struct header { - mscclppIbQpInfo infoQp; + mscclpp::IbQpInfo infoQp; int numBufferInfos; }; @@ -702,22 +689,20 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output devConn->remoteBuff = NULL; devConn->remoteSignalEpochId = NULL; - struct mscclppIbContext* ibCtx = conn->ibCtx; + mscclpp::IbCtx* ibCtx = conn->ibCtx; if (hostConn->ibQp == NULL) { - MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &hostConn->ibQp)); + hostConn->ibQp = ibCtx->createQp(); } // Add all registered buffers for (const auto &bufReg : conn->bufferRegistrations) { - hostConn->ibMrs.emplace_back(); - MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, bufReg.data, - sizeof(struct mscclppDevConnSignalEpochId), &hostConn->ibMrs.back())); + hostConn->ibMrs.emplace_back(ibCtx->registerMr(bufReg.data, sizeof(struct mscclppDevConnSignalEpochId))); connInfo->bufferInfos.emplace_back(); - connInfo->bufferInfos.back().ibMrInfo = hostConn->ibMrs.back()->info; + connInfo->bufferInfos.back().ibMrInfo = hostConn->ibMrs.back()->getInfo(); connInfo->bufferInfos.back().size = bufReg.size; } - connInfo->infoQp = hostConn->ibQp->info; + connInfo->infoQp = hostConn->ibQp->getInfo(); return mscclppSuccess; } @@ -728,14 +713,8 @@ mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/, return mscclppInternalError; } struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn; - if (hostConn->ibQp->rtr(&connInfo->infoQp) != 0) { - WARN("Failed to transition QP to RTR"); - return mscclppInvalidUsage; - } - if (hostConn->ibQp->rts() != 0) { - WARN("Failed to transition QP to RTS"); - return mscclppInvalidUsage; - } + hostConn->ibQp->rtr(connInfo->infoQp); + hostConn->ibQp->rts(); // No remote pointers to set with IB, so we just set the Mrs @@ -788,25 +767,25 @@ MSCCLPP_API mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm) struct bufferInfo { cudaIpcMemHandle_t handleBuff; - mscclppIbMrInfo infoBuffMr; + mscclpp::IbMrInfo infoBuffMr; }; MSCCLPP_API mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, void* local_memory, size_t size, mscclppRegisteredMemory* regMem) { - std::vector ibMrs; + std::vector ibMrs; for (int i = 0; i < comm->nConns; ++i) { struct mscclppConn* conn = &comm->conns[i]; struct bufferInfo bInfo; - struct mscclppIbMr* ibBuffMr; + const mscclpp::IbMr* ibBuffMr; // TODO: (conn->transport & mscclppTransportP2P) to support both P2P and IB if (conn->transport == mscclppTransportP2P) { CUDACHECK(cudaIpcGetMemHandle(&bInfo.handleBuff, local_memory)); } else if (conn->transport == mscclppTransportIB) { - MSCCLPPCHECK(mscclppIbContextRegisterMr(conn->ibCtx, local_memory, size, &ibBuffMr)); - bInfo.infoBuffMr = ibBuffMr->info; - ibMrs.push_back(ibBuffMr); + ibBuffMr = conn->ibCtx->registerMr(local_memory, size); + bInfo.infoBuffMr = ibBuffMr->getInfo(); + ibMrs.emplace_back(ibBuffMr); } MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &bInfo, sizeof(bInfo))); diff --git a/src/proxy.cc b/src/proxy.cc index 6cfd799b..c8bf4414 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -2,7 +2,7 @@ #include "checks.h" #include "comm.h" #include "debug.h" -#include "ib.h" +#include "ib.hpp" #include "socket.h" #include diff --git a/src/registered_memory.cc b/src/registered_memory.cc index d9476e4f..f0db85ce 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -18,8 +18,7 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t auto addIb = [&](TransportFlags ibTransport) { TransportInfo transportInfo; transportInfo.transport = ibTransport; - mscclppIbMr* mr; - MSCCLPPTHROW(mscclppIbContextRegisterMr(commImpl.getIbContext(ibTransport), data, size, &mr)); + const IbMr* mr = commImpl.getIbContext(ibTransport)->registerMr(data, size); transportInfo.ibMr = mr; transportInfo.ibLocal = true; this->transportInfos.push_back(transportInfo); @@ -103,7 +102,7 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { it += sizeof(handle); transportInfo.cudaIpcHandle = handle; } else if (transportInfo.transport & TransportAllIB) { - mscclppIbMrInfo info; + IbMrInfo info; std::copy_n(it, sizeof(info), reinterpret_cast(&info)); it += sizeof(info); transportInfo.ibMrInfo = info; diff --git a/tests/unittests/ib_test.cc b/tests/unittests/ib_test.cc index 2c194eaf..6f84398f 100644 --- a/tests/unittests/ib_test.cc +++ b/tests/unittests/ib_test.cc @@ -1,8 +1,10 @@ #include "alloc.h" #include "checks.h" -#include "ib.h" -#include +#include "ib.hpp" +#include "infiniband/verbs.h" +#include "mscclpp.hpp" #include +#include // Measure current time in second. static double getTime(void) @@ -24,8 +26,8 @@ int main(int argc, const char* argv[]) printf("Usage: %s <0(recv)/1(send)> \n", argv[0]); return 1; } - const char* ip_port = argv[1]; - int is_send = atoi(argv[2]); + const char* ipPortPair = argv[1]; + int isSend = atoi(argv[2]); int cudaDevId = atoi(argv[3]); std::string ibDevName = "mlx5_ib" + std::string(argv[4]); @@ -35,51 +37,40 @@ int main(int argc, const char* argv[]) int nelem = 1; MSCCLPPCHECK(mscclppCudaCalloc(&data, nelem)); - mscclppComm_t comm; - MSCCLPPCHECK(mscclppCommInitRank(&comm, 2, ip_port, is_send)); + std::shared_ptr bootstrap(new mscclpp::Bootstrap(isSend, 2)); + bootstrap->initialize(ipPortPair); - struct mscclppIbContext* ctx; - struct mscclppIbQp* qp; - struct mscclppIbMr* mr; - MSCCLPPCHECK(mscclppIbContextCreate(&ctx, ibDevName.c_str())); - MSCCLPPCHECK(mscclppIbContextCreateQp(ctx, &qp)); - MSCCLPPCHECK(mscclppIbContextRegisterMr(ctx, data, sizeof(int) * nelem, &mr)); + mscclpp::IbCtx ctx(ibDevName); + mscclpp::IbQp* qp = ctx.createQp(); + const mscclpp::IbMr* mr = ctx.registerMr(data, sizeof(int) * nelem); - struct mscclppIbQpInfo* qpInfo; - MSCCLPPCHECK(mscclppCalloc(&qpInfo, 2)); - qpInfo[is_send] = qp->info; + std::array qpInfo; + qpInfo[isSend] = qp->getInfo(); - struct mscclppIbMrInfo* mrInfo; - MSCCLPPCHECK(mscclppCalloc(&mrInfo, 2)); - mrInfo[is_send] = mr->info; + std::array mrInfo; + mrInfo[isSend] = mr->getInfo(); - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, qpInfo, sizeof(struct mscclppIbQpInfo))); - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, mrInfo, sizeof(struct mscclppIbMrInfo))); + bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo)); + bootstrap->allGather(mrInfo.data(), sizeof(mscclpp::IbMrInfo)); - for (int i = 0; i < 2; ++i) { - if (i == is_send) + for (int i = 0; i < bootstrap->getNranks(); ++i) { + if (i == isSend) continue; - qp->rtr(&qpInfo[i]); + qp->rtr(qpInfo[i]); qp->rts(); break; } printf("connection succeed\n"); - // A simple barrier - int* tmp; - MSCCLPPCHECK(mscclppCalloc(&tmp, 2)); - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); + bootstrap->barrier(); - if (is_send) { + if (isSend) { int maxIter = 100000; double start = getTime(); for (int iter = 0; iter < maxIter; ++iter) { - qp->stageSend(mr, &mrInfo[0], sizeof(int) * nelem, 0, 0, 0, true); - if (qp->postSend() != 0) { - WARN("postSend failed"); - return 1; - } + qp->stageSend(mr, mrInfo[0], sizeof(int) * nelem, 0, 0, 0, true); + qp->postSend(); bool waiting = true; while (waiting) { int wcNum = qp->pollCq(); @@ -88,7 +79,7 @@ int main(int argc, const char* argv[]) return 1; } for (int i = 0; i < wcNum; ++i) { - struct ibv_wc* wc = &qp->wcs[i]; + const struct ibv_wc* wc = reinterpret_cast(qp->getWc(i)); if (wc->status != IBV_WC_SUCCESS) { WARN("wc status %d", wc->status); return 1; @@ -103,10 +94,7 @@ int main(int argc, const char* argv[]) } // A simple barrier - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); - - MSCCLPPCHECK(mscclppIbContextDestroy(ctx)); - MSCCLPPCHECK(mscclppCommDestroy(comm)); + bootstrap->barrier(); return 0; }