mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
Remove free and most reinterpret_casts in IB code
This commit is contained in:
86
src/ib.cc
86
src/ib.cc
@@ -11,7 +11,6 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "alloc.h"
|
||||
#include "api.h"
|
||||
#include "checks.hpp"
|
||||
#include "debug.h"
|
||||
@@ -20,7 +19,7 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff) {
|
||||
IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
|
||||
if (size == 0) {
|
||||
throw std::invalid_argument("invalid size: " + std::to_string(size));
|
||||
}
|
||||
@@ -30,37 +29,32 @@ IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff) {
|
||||
}
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(buff) & -pageSize;
|
||||
std::size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
|
||||
struct ibv_pd* _pd = reinterpret_cast<struct ibv_pd*>(pd);
|
||||
struct ibv_mr* _mr = ibv_reg_mr(
|
||||
_pd, reinterpret_cast<void*>(addr), pages * pageSize,
|
||||
this->mr = ibv_reg_mr(
|
||||
pd, reinterpret_cast<void*>(addr), pages * pageSize,
|
||||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING);
|
||||
if (_mr == nullptr) {
|
||||
if (this->mr == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_reg_mr failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
this->mr = _mr;
|
||||
this->size = pages * pageSize;
|
||||
}
|
||||
|
||||
IbMr::~IbMr() { ibv_dereg_mr(reinterpret_cast<struct ibv_mr*>(this->mr)); }
|
||||
IbMr::~IbMr() { ibv_dereg_mr(this->mr); }
|
||||
|
||||
IbMrInfo IbMr::getInfo() const {
|
||||
IbMrInfo info;
|
||||
info.addr = reinterpret_cast<uint64_t>(this->buff);
|
||||
info.rkey = reinterpret_cast<struct ibv_mr*>(this->mr)->rkey;
|
||||
info.rkey = this->mr->rkey;
|
||||
return info;
|
||||
}
|
||||
|
||||
const void* IbMr::getBuff() const { return this->buff; }
|
||||
|
||||
uint32_t IbMr::getLkey() const { return reinterpret_cast<struct ibv_mr*>(this->mr)->lkey; }
|
||||
uint32_t IbMr::getLkey() const { return this->mr->lkey; }
|
||||
|
||||
IbQp::IbQp(void* ctx, void* pd, int port) {
|
||||
struct ibv_context* _ctx = reinterpret_cast<struct ibv_context*>(ctx);
|
||||
struct ibv_pd* _pd = reinterpret_cast<struct ibv_pd*>(pd);
|
||||
|
||||
this->cq = ibv_create_cq(_ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0);
|
||||
IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port) {
|
||||
this->cq = ibv_create_cq(ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0);
|
||||
if (this->cq == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_create_cq failed (errno " << errno << ")";
|
||||
@@ -70,8 +64,8 @@ IbQp::IbQp(void* ctx, void* pd, int port) {
|
||||
struct ibv_qp_init_attr qpInitAttr;
|
||||
std::memset(&qpInitAttr, 0, sizeof(qpInitAttr));
|
||||
qpInitAttr.sq_sig_all = 0;
|
||||
qpInitAttr.send_cq = reinterpret_cast<struct ibv_cq*>(this->cq);
|
||||
qpInitAttr.recv_cq = reinterpret_cast<struct ibv_cq*>(this->cq);
|
||||
qpInitAttr.send_cq = this->cq;
|
||||
qpInitAttr.recv_cq = this->cq;
|
||||
qpInitAttr.qp_type = IBV_QPT_RC;
|
||||
qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
|
||||
qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
|
||||
@@ -79,7 +73,7 @@ IbQp::IbQp(void* ctx, void* pd, int port) {
|
||||
qpInitAttr.cap.max_recv_sge = 1;
|
||||
qpInitAttr.cap.max_inline_data = 0;
|
||||
|
||||
struct ibv_qp* _qp = ibv_create_qp(_pd, &qpInitAttr);
|
||||
struct ibv_qp* _qp = ibv_create_qp(pd, &qpInitAttr);
|
||||
if (_qp == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_create_qp failed (errno " << errno << ")";
|
||||
@@ -87,7 +81,7 @@ IbQp::IbQp(void* ctx, void* pd, int port) {
|
||||
}
|
||||
|
||||
struct ibv_port_attr portAttr;
|
||||
if (ibv_query_port(_ctx, port, &portAttr) != 0) {
|
||||
if (ibv_query_port(ctx, port, &portAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_port failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
@@ -101,7 +95,7 @@ IbQp::IbQp(void* ctx, void* pd, int port) {
|
||||
|
||||
if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND || this->info.is_grh) {
|
||||
union ibv_gid gid;
|
||||
if (ibv_query_gid(_ctx, port, 0, &gid) != 0) {
|
||||
if (ibv_query_gid(ctx, port, 0, &gid) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_gid failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
@@ -123,17 +117,14 @@ IbQp::IbQp(void* ctx, void* pd, int port) {
|
||||
}
|
||||
this->qp = _qp;
|
||||
this->wrn = 0;
|
||||
MSCCLPPTHROW(mscclppCalloc(reinterpret_cast<struct ibv_send_wr**>(&this->wrs), MSCCLPP_IB_MAX_SENDS));
|
||||
MSCCLPPTHROW(mscclppCalloc(reinterpret_cast<struct ibv_sge**>(&this->sges), MSCCLPP_IB_MAX_SENDS));
|
||||
MSCCLPPTHROW(mscclppCalloc(reinterpret_cast<struct ibv_wc**>(&this->wcs), MSCCLPP_IB_CQ_POLL_NUM));
|
||||
this->wrs = std::make_unique<ibv_send_wr[]>(MSCCLPP_IB_MAX_SENDS);
|
||||
this->sges = std::make_unique<ibv_sge[]>(MSCCLPP_IB_MAX_SENDS);
|
||||
this->wcs = std::make_unique<ibv_wc[]>(MSCCLPP_IB_CQ_POLL_NUM);
|
||||
}
|
||||
|
||||
IbQp::~IbQp() {
|
||||
ibv_destroy_qp(reinterpret_cast<struct ibv_qp*>(this->qp));
|
||||
ibv_destroy_cq(reinterpret_cast<struct ibv_cq*>(this->cq));
|
||||
std::free(this->wrs);
|
||||
std::free(this->sges);
|
||||
std::free(this->wcs);
|
||||
ibv_destroy_qp(this->qp);
|
||||
ibv_destroy_cq(this->cq);
|
||||
}
|
||||
|
||||
void IbQp::rtr(const IbQpInfo& info) {
|
||||
@@ -160,7 +151,7 @@ void IbQp::rtr(const IbQpInfo& info) {
|
||||
qp_attr.ah_attr.sl = 0;
|
||||
qp_attr.ah_attr.src_path_bits = 0;
|
||||
qp_attr.ah_attr.port_num = info.port;
|
||||
int ret = ibv_modify_qp(reinterpret_cast<struct ibv_qp*>(this->qp), &qp_attr,
|
||||
int ret = ibv_modify_qp(this->qp, &qp_attr,
|
||||
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
|
||||
IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER);
|
||||
if (ret != 0) {
|
||||
@@ -180,7 +171,7 @@ void IbQp::rts() {
|
||||
qp_attr.sq_psn = 0;
|
||||
qp_attr.max_rd_atomic = 1;
|
||||
int ret = ibv_modify_qp(
|
||||
reinterpret_cast<struct ibv_qp*>(this->qp), &qp_attr,
|
||||
this->qp, &qp_attr,
|
||||
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC);
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
@@ -195,11 +186,9 @@ int IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_
|
||||
return -1;
|
||||
}
|
||||
int wrn = this->wrn;
|
||||
struct ibv_send_wr* wrs_ = reinterpret_cast<struct ibv_send_wr*>(this->wrs);
|
||||
struct ibv_sge* sges_ = reinterpret_cast<struct ibv_sge*>(this->sges);
|
||||
|
||||
struct ibv_send_wr* wr_ = &wrs_[wrn];
|
||||
struct ibv_sge* sge_ = &sges_[wrn];
|
||||
struct ibv_send_wr* wr_ = &this->wrs[wrn];
|
||||
struct ibv_sge* sge_ = &this->sges[wrn];
|
||||
wr_->wr_id = wrId;
|
||||
wr_->sg_list = sge_;
|
||||
wr_->num_sge = 1;
|
||||
@@ -212,7 +201,7 @@ int IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_
|
||||
sge_->length = size;
|
||||
sge_->lkey = mr->getLkey();
|
||||
if (wrn > 0) {
|
||||
wrs_[wrn - 1].next = wr_;
|
||||
this->wrs[wrn - 1].next = wr_;
|
||||
}
|
||||
this->wrn++;
|
||||
return this->wrn;
|
||||
@@ -221,9 +210,8 @@ int IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_
|
||||
int IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled, unsigned int immData) {
|
||||
int wrn = this->stageSend(mr, info, size, wrId, srcOffset, dstOffset, signaled);
|
||||
struct ibv_send_wr* wrs_ = reinterpret_cast<struct ibv_send_wr*>(this->wrs);
|
||||
wrs_[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
wrs_[wrn - 1].imm_data = immData;
|
||||
this->wrs[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
this->wrs[wrn - 1].imm_data = immData;
|
||||
return wrn;
|
||||
}
|
||||
|
||||
@@ -232,8 +220,7 @@ void IbQp::postSend() {
|
||||
return;
|
||||
}
|
||||
struct ibv_send_wr* bad_wr;
|
||||
int ret = ibv_post_send(reinterpret_cast<struct ibv_qp*>(this->qp), reinterpret_cast<struct ibv_send_wr*>(this->wrs),
|
||||
&bad_wr);
|
||||
int ret = ibv_post_send(this->qp, this->wrs.get(), &bad_wr);
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_post_send failed (errno " << errno << ")";
|
||||
@@ -248,7 +235,7 @@ void IbQp::postRecv(uint64_t wrId) {
|
||||
wr.sg_list = nullptr;
|
||||
wr.num_sge = 0;
|
||||
wr.next = nullptr;
|
||||
int ret = ibv_post_recv(reinterpret_cast<struct ibv_qp*>(this->qp), &wr, &bad_wr);
|
||||
int ret = ibv_post_recv(this->qp, &wr, &bad_wr);
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_post_recv failed (errno " << errno << ")";
|
||||
@@ -256,14 +243,11 @@ void IbQp::postRecv(uint64_t wrId) {
|
||||
}
|
||||
}
|
||||
|
||||
int IbQp::pollCq() {
|
||||
return ibv_poll_cq(reinterpret_cast<struct ibv_cq*>(this->cq), MSCCLPP_IB_CQ_POLL_NUM,
|
||||
reinterpret_cast<struct ibv_wc*>(this->wcs));
|
||||
}
|
||||
int IbQp::pollCq() { return ibv_poll_cq(this->cq, MSCCLPP_IB_CQ_POLL_NUM, this->wcs.get()); }
|
||||
|
||||
IbQpInfo& IbQp::getInfo() { return this->info; }
|
||||
|
||||
const void* IbQp::getWc(int idx) const { return &reinterpret_cast<struct ibv_wc*>(this->wcs)[idx]; }
|
||||
const void* IbQp::getWc(int idx) const { return &this->wcs[idx]; }
|
||||
|
||||
IbCtx::IbCtx(const std::string& devName) : devName(devName) {
|
||||
int num;
|
||||
@@ -280,7 +264,7 @@ IbCtx::IbCtx(const std::string& devName) : devName(devName) {
|
||||
err << "ibv_open_device failed (errno " << errno << ", device name << " << devName << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
this->pd = ibv_alloc_pd(reinterpret_cast<struct ibv_context*>(this->ctx));
|
||||
this->pd = ibv_alloc_pd(this->ctx);
|
||||
if (this->pd == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_alloc_pd failed (errno " << errno << ")";
|
||||
@@ -292,16 +276,16 @@ IbCtx::~IbCtx() {
|
||||
this->mrs.clear();
|
||||
this->qps.clear();
|
||||
if (this->pd != nullptr) {
|
||||
ibv_dealloc_pd(reinterpret_cast<struct ibv_pd*>(this->pd));
|
||||
ibv_dealloc_pd(this->pd);
|
||||
}
|
||||
if (this->ctx != nullptr) {
|
||||
ibv_close_device(reinterpret_cast<struct ibv_context*>(this->ctx));
|
||||
ibv_close_device(this->ctx);
|
||||
}
|
||||
}
|
||||
|
||||
bool IbCtx::isPortUsable(int port) const {
|
||||
struct ibv_port_attr portAttr;
|
||||
if (ibv_query_port(reinterpret_cast<struct ibv_context*>(this->ctx), port, &portAttr) != 0) {
|
||||
if (ibv_query_port(this->ctx, port, &portAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
@@ -312,7 +296,7 @@ bool IbCtx::isPortUsable(int port) const {
|
||||
|
||||
int IbCtx::getAnyActivePort() const {
|
||||
struct ibv_device_attr devAttr;
|
||||
if (ibv_query_device(reinterpret_cast<struct ibv_context*>(this->ctx), &devAttr) != 0) {
|
||||
if (ibv_query_device(this->ctx, &devAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_device failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
|
||||
@@ -10,6 +10,16 @@
|
||||
#define MSCCLPP_IB_MAX_SENDS 64
|
||||
#define MSCCLPP_IB_MAX_DEVS 8
|
||||
|
||||
// Forward declarations of IB structures
|
||||
struct ibv_context;
|
||||
struct ibv_pd;
|
||||
struct ibv_mr;
|
||||
struct ibv_qp;
|
||||
struct ibv_cq;
|
||||
struct ibv_wc;
|
||||
struct ibv_send_wr;
|
||||
struct ibv_sge;
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct IbMrInfo {
|
||||
@@ -26,9 +36,9 @@ class IbMr {
|
||||
uint32_t getLkey() const;
|
||||
|
||||
private:
|
||||
IbMr(void* pd, void* buff, std::size_t size);
|
||||
IbMr(ibv_pd* pd, void* buff, std::size_t size);
|
||||
|
||||
void* mr;
|
||||
ibv_mr* mr;
|
||||
void* buff;
|
||||
std::size_t size;
|
||||
|
||||
@@ -65,15 +75,15 @@ class IbQp {
|
||||
const void* getWc(int idx) const;
|
||||
|
||||
private:
|
||||
IbQp(void* ctx, void* pd, int port);
|
||||
IbQp(ibv_context* ctx, ibv_pd* pd, int port);
|
||||
|
||||
IbQpInfo info;
|
||||
|
||||
void* qp;
|
||||
void* cq;
|
||||
void* wcs;
|
||||
void* wrs;
|
||||
void* sges;
|
||||
ibv_qp* qp;
|
||||
ibv_cq* cq;
|
||||
std::unique_ptr<ibv_wc[]> wcs;
|
||||
std::unique_ptr<ibv_send_wr[]> wrs;
|
||||
std::unique_ptr<ibv_sge[]> sges;
|
||||
int wrn;
|
||||
|
||||
friend class IbCtx;
|
||||
@@ -94,8 +104,8 @@ class IbCtx {
|
||||
int getAnyActivePort() const;
|
||||
|
||||
const std::string devName;
|
||||
void* ctx;
|
||||
void* pd;
|
||||
ibv_context* ctx;
|
||||
ibv_pd* pd;
|
||||
std::list<std::unique_ptr<IbQp>> qps;
|
||||
std::list<std::unique_ptr<IbMr>> mrs;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user