diff --git a/src/ib.cc b/src/ib.cc index 7eed6b5e..ccccbfc3 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -11,7 +11,6 @@ #include #include -#include "alloc.h" #include "api.h" #include "checks.hpp" #include "debug.h" @@ -20,7 +19,7 @@ namespace mscclpp { -IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff) { +IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) { if (size == 0) { throw std::invalid_argument("invalid size: " + std::to_string(size)); } @@ -30,37 +29,32 @@ IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff) { } uintptr_t addr = reinterpret_cast(buff) & -pageSize; std::size_t pages = (size + (reinterpret_cast(buff) - addr) + pageSize - 1) / pageSize; - struct ibv_pd* _pd = reinterpret_cast(pd); - struct ibv_mr* _mr = ibv_reg_mr( - _pd, reinterpret_cast(addr), pages * pageSize, + this->mr = ibv_reg_mr( + pd, reinterpret_cast(addr), pages * pageSize, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING); - if (_mr == nullptr) { + if (this->mr == nullptr) { std::stringstream err; err << "ibv_reg_mr failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); } - this->mr = _mr; this->size = pages * pageSize; } -IbMr::~IbMr() { ibv_dereg_mr(reinterpret_cast(this->mr)); } +IbMr::~IbMr() { ibv_dereg_mr(this->mr); } IbMrInfo IbMr::getInfo() const { IbMrInfo info; info.addr = reinterpret_cast(this->buff); - info.rkey = reinterpret_cast(this->mr)->rkey; + info.rkey = this->mr->rkey; return info; } const void* IbMr::getBuff() const { return this->buff; } -uint32_t IbMr::getLkey() const { return reinterpret_cast(this->mr)->lkey; } +uint32_t IbMr::getLkey() const { return this->mr->lkey; } -IbQp::IbQp(void* ctx, void* pd, int port) { - struct ibv_context* _ctx = reinterpret_cast(ctx); - struct ibv_pd* _pd = reinterpret_cast(pd); - - this->cq = ibv_create_cq(_ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0); +IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port) { + this->cq = ibv_create_cq(ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0); if (this->cq == nullptr) { std::stringstream err; err << "ibv_create_cq failed (errno " << errno << ")"; @@ -70,8 +64,8 @@ IbQp::IbQp(void* ctx, void* pd, int port) { struct ibv_qp_init_attr qpInitAttr; std::memset(&qpInitAttr, 0, sizeof(qpInitAttr)); qpInitAttr.sq_sig_all = 0; - qpInitAttr.send_cq = reinterpret_cast(this->cq); - qpInitAttr.recv_cq = reinterpret_cast(this->cq); + qpInitAttr.send_cq = this->cq; + qpInitAttr.recv_cq = this->cq; qpInitAttr.qp_type = IBV_QPT_RC; qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; @@ -79,7 +73,7 @@ IbQp::IbQp(void* ctx, void* pd, int port) { qpInitAttr.cap.max_recv_sge = 1; qpInitAttr.cap.max_inline_data = 0; - struct ibv_qp* _qp = ibv_create_qp(_pd, &qpInitAttr); + struct ibv_qp* _qp = ibv_create_qp(pd, &qpInitAttr); if (_qp == nullptr) { std::stringstream err; err << "ibv_create_qp failed (errno " << errno << ")"; @@ -87,7 +81,7 @@ IbQp::IbQp(void* ctx, void* pd, int port) { } struct ibv_port_attr portAttr; - if (ibv_query_port(_ctx, port, &portAttr) != 0) { + if (ibv_query_port(ctx, port, &portAttr) != 0) { std::stringstream err; err << "ibv_query_port failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); @@ -101,7 +95,7 @@ IbQp::IbQp(void* ctx, void* pd, int port) { if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND || this->info.is_grh) { union ibv_gid gid; - if (ibv_query_gid(_ctx, port, 0, &gid) != 0) { + if (ibv_query_gid(ctx, port, 0, &gid) != 0) { std::stringstream err; err << "ibv_query_gid failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); @@ -123,17 +117,14 @@ IbQp::IbQp(void* ctx, void* pd, int port) { } this->qp = _qp; this->wrn = 0; - MSCCLPPTHROW(mscclppCalloc(reinterpret_cast(&this->wrs), MSCCLPP_IB_MAX_SENDS)); - MSCCLPPTHROW(mscclppCalloc(reinterpret_cast(&this->sges), MSCCLPP_IB_MAX_SENDS)); - MSCCLPPTHROW(mscclppCalloc(reinterpret_cast(&this->wcs), MSCCLPP_IB_CQ_POLL_NUM)); + this->wrs = std::make_unique(MSCCLPP_IB_MAX_SENDS); + this->sges = std::make_unique(MSCCLPP_IB_MAX_SENDS); + this->wcs = std::make_unique(MSCCLPP_IB_CQ_POLL_NUM); } IbQp::~IbQp() { - ibv_destroy_qp(reinterpret_cast(this->qp)); - ibv_destroy_cq(reinterpret_cast(this->cq)); - std::free(this->wrs); - std::free(this->sges); - std::free(this->wcs); + ibv_destroy_qp(this->qp); + ibv_destroy_cq(this->cq); } void IbQp::rtr(const IbQpInfo& info) { @@ -160,7 +151,7 @@ void IbQp::rtr(const IbQpInfo& info) { qp_attr.ah_attr.sl = 0; qp_attr.ah_attr.src_path_bits = 0; qp_attr.ah_attr.port_num = info.port; - int ret = ibv_modify_qp(reinterpret_cast(this->qp), &qp_attr, + int ret = ibv_modify_qp(this->qp, &qp_attr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); if (ret != 0) { @@ -180,7 +171,7 @@ void IbQp::rts() { qp_attr.sq_psn = 0; qp_attr.max_rd_atomic = 1; int ret = ibv_modify_qp( - reinterpret_cast(this->qp), &qp_attr, + this->qp, &qp_attr, IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC); if (ret != 0) { std::stringstream err; @@ -195,11 +186,9 @@ int IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_ return -1; } int wrn = this->wrn; - struct ibv_send_wr* wrs_ = reinterpret_cast(this->wrs); - struct ibv_sge* sges_ = reinterpret_cast(this->sges); - struct ibv_send_wr* wr_ = &wrs_[wrn]; - struct ibv_sge* sge_ = &sges_[wrn]; + struct ibv_send_wr* wr_ = &this->wrs[wrn]; + struct ibv_sge* sge_ = &this->sges[wrn]; wr_->wr_id = wrId; wr_->sg_list = sge_; wr_->num_sge = 1; @@ -212,7 +201,7 @@ int IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_ sge_->length = size; sge_->lkey = mr->getLkey(); if (wrn > 0) { - wrs_[wrn - 1].next = wr_; + this->wrs[wrn - 1].next = wr_; } this->wrn++; return this->wrn; @@ -221,9 +210,8 @@ int IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_ int IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData) { int wrn = this->stageSend(mr, info, size, wrId, srcOffset, dstOffset, signaled); - struct ibv_send_wr* wrs_ = reinterpret_cast(this->wrs); - wrs_[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wrs_[wrn - 1].imm_data = immData; + this->wrs[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + this->wrs[wrn - 1].imm_data = immData; return wrn; } @@ -232,8 +220,7 @@ void IbQp::postSend() { return; } struct ibv_send_wr* bad_wr; - int ret = ibv_post_send(reinterpret_cast(this->qp), reinterpret_cast(this->wrs), - &bad_wr); + int ret = ibv_post_send(this->qp, this->wrs.get(), &bad_wr); if (ret != 0) { std::stringstream err; err << "ibv_post_send failed (errno " << errno << ")"; @@ -248,7 +235,7 @@ void IbQp::postRecv(uint64_t wrId) { wr.sg_list = nullptr; wr.num_sge = 0; wr.next = nullptr; - int ret = ibv_post_recv(reinterpret_cast(this->qp), &wr, &bad_wr); + int ret = ibv_post_recv(this->qp, &wr, &bad_wr); if (ret != 0) { std::stringstream err; err << "ibv_post_recv failed (errno " << errno << ")"; @@ -256,14 +243,11 @@ void IbQp::postRecv(uint64_t wrId) { } } -int IbQp::pollCq() { - return ibv_poll_cq(reinterpret_cast(this->cq), MSCCLPP_IB_CQ_POLL_NUM, - reinterpret_cast(this->wcs)); -} +int IbQp::pollCq() { return ibv_poll_cq(this->cq, MSCCLPP_IB_CQ_POLL_NUM, this->wcs.get()); } IbQpInfo& IbQp::getInfo() { return this->info; } -const void* IbQp::getWc(int idx) const { return &reinterpret_cast(this->wcs)[idx]; } +const void* IbQp::getWc(int idx) const { return &this->wcs[idx]; } IbCtx::IbCtx(const std::string& devName) : devName(devName) { int num; @@ -280,7 +264,7 @@ IbCtx::IbCtx(const std::string& devName) : devName(devName) { err << "ibv_open_device failed (errno " << errno << ", device name << " << devName << ")"; throw mscclpp::IbError(err.str(), errno); } - this->pd = ibv_alloc_pd(reinterpret_cast(this->ctx)); + this->pd = ibv_alloc_pd(this->ctx); if (this->pd == nullptr) { std::stringstream err; err << "ibv_alloc_pd failed (errno " << errno << ")"; @@ -292,16 +276,16 @@ IbCtx::~IbCtx() { this->mrs.clear(); this->qps.clear(); if (this->pd != nullptr) { - ibv_dealloc_pd(reinterpret_cast(this->pd)); + ibv_dealloc_pd(this->pd); } if (this->ctx != nullptr) { - ibv_close_device(reinterpret_cast(this->ctx)); + ibv_close_device(this->ctx); } } bool IbCtx::isPortUsable(int port) const { struct ibv_port_attr portAttr; - if (ibv_query_port(reinterpret_cast(this->ctx), port, &portAttr) != 0) { + if (ibv_query_port(this->ctx, port, &portAttr) != 0) { std::stringstream err; err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")"; throw mscclpp::IbError(err.str(), errno); @@ -312,7 +296,7 @@ bool IbCtx::isPortUsable(int port) const { int IbCtx::getAnyActivePort() const { struct ibv_device_attr devAttr; - if (ibv_query_device(reinterpret_cast(this->ctx), &devAttr) != 0) { + if (ibv_query_device(this->ctx, &devAttr) != 0) { std::stringstream err; err << "ibv_query_device failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); diff --git a/src/include/ib.hpp b/src/include/ib.hpp index 2fe9a447..6bf86218 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -10,6 +10,16 @@ #define MSCCLPP_IB_MAX_SENDS 64 #define MSCCLPP_IB_MAX_DEVS 8 +// Forward declarations of IB structures +struct ibv_context; +struct ibv_pd; +struct ibv_mr; +struct ibv_qp; +struct ibv_cq; +struct ibv_wc; +struct ibv_send_wr; +struct ibv_sge; + namespace mscclpp { struct IbMrInfo { @@ -26,9 +36,9 @@ class IbMr { uint32_t getLkey() const; private: - IbMr(void* pd, void* buff, std::size_t size); + IbMr(ibv_pd* pd, void* buff, std::size_t size); - void* mr; + ibv_mr* mr; void* buff; std::size_t size; @@ -65,15 +75,15 @@ class IbQp { const void* getWc(int idx) const; private: - IbQp(void* ctx, void* pd, int port); + IbQp(ibv_context* ctx, ibv_pd* pd, int port); IbQpInfo info; - void* qp; - void* cq; - void* wcs; - void* wrs; - void* sges; + ibv_qp* qp; + ibv_cq* cq; + std::unique_ptr wcs; + std::unique_ptr wrs; + std::unique_ptr sges; int wrn; friend class IbCtx; @@ -94,8 +104,8 @@ class IbCtx { int getAnyActivePort() const; const std::string devName; - void* ctx; - void* pd; + ibv_context* ctx; + ibv_pd* pd; std::list> qps; std::list> mrs; };