mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
IB: completely replaced with C++ interfaces
This commit is contained in:
@@ -16,7 +16,7 @@ Communicator::Impl::Impl() : comm(nullptr) {}
|
||||
|
||||
Communicator::Impl::~Impl() {
|
||||
for (auto& entry : ibContexts) {
|
||||
mscclppIbContextDestroy(entry.second);
|
||||
delete entry.second;
|
||||
}
|
||||
ibContexts.clear();
|
||||
if (comm) {
|
||||
@@ -24,13 +24,12 @@ Communicator::Impl::~Impl() {
|
||||
}
|
||||
}
|
||||
|
||||
mscclppIbContext* Communicator::Impl::getIbContext(TransportFlags ibTransport) {
|
||||
IbCtx* Communicator::Impl::getIbContext(TransportFlags ibTransport) {
|
||||
// Find IB context or create it
|
||||
auto it = ibContexts.find(ibTransport);
|
||||
if (it == ibContexts.end()) {
|
||||
auto ibDev = getIBDeviceName(ibTransport);
|
||||
mscclppIbContext* ibCtx;
|
||||
MSCCLPPTHROW(mscclppIbContextCreate(&ibCtx, ibDev.c_str()));
|
||||
IbCtx* ibCtx = new IbCtx(ibDev);
|
||||
ibContexts[ibTransport] = ibCtx;
|
||||
return ibCtx;
|
||||
} else {
|
||||
@@ -92,6 +91,7 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connect(int remoteRank
|
||||
throw std::runtime_error("Unsupported transport");
|
||||
}
|
||||
pimpl->connections.push_back(conn);
|
||||
return conn;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::connectionSetup() {
|
||||
@@ -115,81 +115,4 @@ MSCCLPP_API_CPP int Communicator::size() {
|
||||
return result;
|
||||
}
|
||||
|
||||
// TODO: move these elsewhere
|
||||
|
||||
int getIBDeviceCount() {
|
||||
int num;
|
||||
ibv_get_device_list(&num);
|
||||
return num;
|
||||
}
|
||||
|
||||
std::string getIBDeviceName(TransportFlags ibTransport) {
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
int ibTransportIndex;
|
||||
switch (ibTransport) { // TODO: get rid of this ugly switch
|
||||
case TransportIB0:
|
||||
ibTransportIndex = 0;
|
||||
break;
|
||||
case TransportIB1:
|
||||
ibTransportIndex = 1;
|
||||
break;
|
||||
case TransportIB2:
|
||||
ibTransportIndex = 2;
|
||||
break;
|
||||
case TransportIB3:
|
||||
ibTransportIndex = 3;
|
||||
break;
|
||||
case TransportIB4:
|
||||
ibTransportIndex = 4;
|
||||
break;
|
||||
case TransportIB5:
|
||||
ibTransportIndex = 5;
|
||||
break;
|
||||
case TransportIB6:
|
||||
ibTransportIndex = 6;
|
||||
break;
|
||||
case TransportIB7:
|
||||
ibTransportIndex = 7;
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Not an IB transport");
|
||||
}
|
||||
if (ibTransportIndex >= num) {
|
||||
throw std::runtime_error("IB transport out of range");
|
||||
}
|
||||
return devices[ibTransportIndex]->name;
|
||||
}
|
||||
|
||||
TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName) {
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
if (ibDeviceName == devices[i]->name) {
|
||||
switch (i) { // TODO: get rid of this ugly switch
|
||||
case 0:
|
||||
return TransportIB0;
|
||||
case 1:
|
||||
return TransportIB1;
|
||||
case 2:
|
||||
return TransportIB2;
|
||||
case 3:
|
||||
return TransportIB3;
|
||||
case 4:
|
||||
return TransportIB4;
|
||||
case 5:
|
||||
return TransportIB5;
|
||||
case 6:
|
||||
return TransportIB6;
|
||||
case 7:
|
||||
return TransportIB7;
|
||||
default:
|
||||
throw std::runtime_error("IB device index out of range");
|
||||
}
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("IB device not found");
|
||||
}
|
||||
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "checks.hpp"
|
||||
#include "registered_memory.hpp"
|
||||
#include "npkit/npkit.h"
|
||||
#include "infiniband/verbs.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -54,7 +55,7 @@ void CudaIpcConnection::flush() {
|
||||
// IBConnection
|
||||
|
||||
IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank(remoteRank), tag(tag), transport_(transport), remoteTransport_(TransportNone) {
|
||||
MSCCLPPTHROW(mscclppIbContextCreateQp(commImpl.getIbContext(transport), &qp));
|
||||
qp = commImpl.getIbContext(transport)->createQp();
|
||||
}
|
||||
|
||||
IBConnection::~IBConnection() {
|
||||
@@ -85,13 +86,8 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem
|
||||
auto dstMrInfo = dstTransportInfo.ibMrInfo;
|
||||
auto srcMr = srcTransportInfo.ibMr;
|
||||
|
||||
qp->stageSend(srcMr, &dstMrInfo, (uint32_t)size,
|
||||
/*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false);
|
||||
int ret = qp->postSend();
|
||||
if (ret != 0) {
|
||||
// Return value is errno.
|
||||
WARN("data postSend failed: errno %d", ret);
|
||||
}
|
||||
qp->stageSend(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false);
|
||||
qp->postSend();
|
||||
// npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size);
|
||||
}
|
||||
|
||||
@@ -104,15 +100,11 @@ void IBConnection::flush() {
|
||||
continue;
|
||||
}
|
||||
for (int i = 0; i < wcNum; ++i) {
|
||||
struct ibv_wc* wc = &qp->wcs[i];
|
||||
const struct ibv_wc* wc = reinterpret_cast<const struct ibv_wc*>(qp->getWc(i));
|
||||
if (wc->status != IBV_WC_SUCCESS) {
|
||||
WARN("wc status %d", wc->status);
|
||||
continue;
|
||||
}
|
||||
if (wc->qp_num != qp->qp->qp_num) {
|
||||
WARN("got wc of unknown qp_num %d", wc->qp_num);
|
||||
continue;
|
||||
}
|
||||
if (wc->opcode == IBV_WC_RDMA_WRITE) {
|
||||
isWaiting = false;
|
||||
break;
|
||||
@@ -123,18 +115,16 @@ void IBConnection::flush() {
|
||||
}
|
||||
|
||||
void IBConnection::startSetup(Communicator& comm) {
|
||||
comm.bootstrap().send(&qp->info, sizeof(qp->info), remoteRank, tag);
|
||||
// TODO(chhwang): temporarily disabled to compile
|
||||
// comm.bootstrap().send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank, tag);
|
||||
}
|
||||
|
||||
void IBConnection::endSetup(Communicator& comm) {
|
||||
mscclppIbQpInfo qpInfo;
|
||||
comm.bootstrap().recv(&qpInfo, sizeof(qpInfo), remoteRank, tag);
|
||||
if (qp->rtr(&qpInfo) != 0) {
|
||||
throw std::runtime_error("Failed to transition QP to RTR");
|
||||
}
|
||||
if (qp->rts() != 0) {
|
||||
throw std::runtime_error("Failed to transition QP to RTS");
|
||||
}
|
||||
IbQpInfo qpInfo;
|
||||
// TODO(chhwang): temporarily disabled to compile
|
||||
// comm.bootstrap().recv(&qpInfo, sizeof(qpInfo), remoteRank, tag);
|
||||
qp->rtr(qpInfo);
|
||||
qp->rts();
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
612
src/ib.cc
612
src/ib.cc
@@ -4,370 +4,67 @@
|
||||
#include <sstream>
|
||||
#include <malloc.h>
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include "alloc.h"
|
||||
#include "comm.h"
|
||||
#include "debug.h"
|
||||
#include "ib.h"
|
||||
#include "ib.hpp"
|
||||
#include "checks.hpp"
|
||||
#include <infiniband/verbs.h>
|
||||
#include <string>
|
||||
|
||||
mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName)
|
||||
{
|
||||
struct mscclppIbContext* _ctx;
|
||||
MSCCLPPCHECK(mscclppCalloc(&_ctx, 1));
|
||||
namespace mscclpp {
|
||||
|
||||
std::vector<int> ports;
|
||||
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
if (strncmp(devices[i]->name, ibDevName, IBV_SYSFS_NAME_MAX) == 0) {
|
||||
_ctx->ctx = ibv_open_device(devices[i]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
ibv_free_device_list(devices);
|
||||
if (_ctx->ctx == nullptr) {
|
||||
WARN("ibv_open_device failed (errno %d, device name %s)", errno, ibDevName);
|
||||
goto fail;
|
||||
}
|
||||
|
||||
// Check available ports
|
||||
struct ibv_device_attr devAttr;
|
||||
if (ibv_query_device(_ctx->ctx, &devAttr) != 0) {
|
||||
WARN("ibv_query_device failed (errno %d, device name %s)", errno, ibDevName);
|
||||
goto fail;
|
||||
}
|
||||
|
||||
for (uint8_t i = 1; i <= devAttr.phys_port_cnt; ++i) {
|
||||
struct ibv_port_attr portAttr;
|
||||
if (ibv_query_port(_ctx->ctx, i, &portAttr) != 0) {
|
||||
WARN("ibv_query_port failed (errno %d, port %d)", errno, i);
|
||||
goto fail;
|
||||
}
|
||||
if (portAttr.state != IBV_PORT_ACTIVE) {
|
||||
continue;
|
||||
}
|
||||
if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND && portAttr.link_layer != IBV_LINK_LAYER_ETHERNET) {
|
||||
continue;
|
||||
}
|
||||
ports.push_back((int)i);
|
||||
}
|
||||
if (ports.size() == 0) {
|
||||
WARN("no active IB port found");
|
||||
goto fail;
|
||||
}
|
||||
MSCCLPPCHECK(mscclppCalloc(&_ctx->ports, ports.size()));
|
||||
_ctx->nPorts = (int)ports.size();
|
||||
for (int i = 0; i < _ctx->nPorts; ++i) {
|
||||
_ctx->ports[i] = ports[i];
|
||||
}
|
||||
|
||||
_ctx->pd = ibv_alloc_pd(_ctx->ctx);
|
||||
if (_ctx->pd == NULL) {
|
||||
WARN("ibv_alloc_pd failed (errno %d)", errno);
|
||||
goto fail;
|
||||
}
|
||||
|
||||
*ctx = _ctx;
|
||||
return mscclppSuccess;
|
||||
fail:
|
||||
*ctx = NULL;
|
||||
if (_ctx->ports != NULL) {
|
||||
free(_ctx->ports);
|
||||
}
|
||||
free(_ctx);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppIbContextDestroy(struct mscclppIbContext* ctx)
|
||||
{
|
||||
for (int i = 0; i < ctx->nMrs; ++i) {
|
||||
if (ctx->mrs[i].mr) {
|
||||
ibv_dereg_mr(ctx->mrs[i].mr);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < ctx->nQps; ++i) {
|
||||
if (ctx->qps[i].qp) {
|
||||
ibv_destroy_qp(ctx->qps[i].qp);
|
||||
}
|
||||
ibv_destroy_cq(ctx->qps[i].cq);
|
||||
free(ctx->qps[i].wcs);
|
||||
free(ctx->qps[i].sges);
|
||||
free(ctx->qps[i].wrs);
|
||||
}
|
||||
if (ctx->pd != NULL) {
|
||||
ibv_dealloc_pd(ctx->pd);
|
||||
}
|
||||
if (ctx->ctx != NULL) {
|
||||
ibv_close_device(ctx->ctx);
|
||||
}
|
||||
free(ctx->mrs);
|
||||
free(ctx->qps);
|
||||
free(ctx->ports);
|
||||
free(ctx);
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppIbContextCreateQp(struct mscclppIbContext* ctx, struct mscclppIbQp** ibQp, int port /*=-1*/)
|
||||
{
|
||||
if (port < 0) {
|
||||
port = ctx->ports[0];
|
||||
} else {
|
||||
bool found = false;
|
||||
for (int i = 0; i < ctx->nPorts; ++i) {
|
||||
if (ctx->ports[i] == port) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
WARN("invalid IB port: %d", port);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
}
|
||||
|
||||
struct ibv_cq* cq = ibv_create_cq(ctx->ctx, MSCCLPP_IB_CQ_SIZE, NULL, NULL, 0);
|
||||
if (cq == NULL) {
|
||||
WARN("ibv_create_cq failed (errno %d)", errno);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
struct ibv_qp_init_attr qp_init_attr;
|
||||
std::memset(&qp_init_attr, 0, sizeof(struct ibv_qp_init_attr));
|
||||
qp_init_attr.sq_sig_all = 0;
|
||||
qp_init_attr.send_cq = cq;
|
||||
qp_init_attr.recv_cq = cq;
|
||||
qp_init_attr.qp_type = IBV_QPT_RC;
|
||||
qp_init_attr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
|
||||
qp_init_attr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
|
||||
qp_init_attr.cap.max_send_sge = 1;
|
||||
qp_init_attr.cap.max_recv_sge = 1;
|
||||
qp_init_attr.cap.max_inline_data = 0;
|
||||
struct ibv_qp* qp = ibv_create_qp(ctx->pd, &qp_init_attr);
|
||||
if (qp == nullptr) {
|
||||
WARN("ibv_create_qp failed (errno %d)", errno);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
struct ibv_port_attr port_attr;
|
||||
if (ibv_query_port(ctx->ctx, port, &port_attr) != 0) {
|
||||
WARN("ibv_query_port failed (errno %d, port %d)", errno, port);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
// Register QP to this ctx
|
||||
qp->context = ctx->ctx;
|
||||
if (qp->context == NULL) {
|
||||
WARN("IB context is NULL");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
ctx->nQps++;
|
||||
if (ctx->qps == NULL) {
|
||||
MSCCLPPCHECK(mscclppCalloc(&ctx->qps, MAXCONNECTIONS));
|
||||
ctx->maxQps = MAXCONNECTIONS;
|
||||
}
|
||||
if (ctx->maxQps < ctx->nQps) {
|
||||
WARN("too many QPs");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
struct mscclppIbQp* _ibQp = &ctx->qps[ctx->nQps - 1];
|
||||
_ibQp->qp = qp;
|
||||
_ibQp->info.lid = port_attr.lid;
|
||||
_ibQp->info.port = port;
|
||||
_ibQp->info.linkLayer = port_attr.link_layer;
|
||||
_ibQp->info.qpn = qp->qp_num;
|
||||
_ibQp->info.mtu = port_attr.active_mtu;
|
||||
if (port_attr.link_layer != IBV_LINK_LAYER_INFINIBAND) {
|
||||
union ibv_gid gid;
|
||||
if (ibv_query_gid(ctx->ctx, port, 0, &gid) != 0) {
|
||||
WARN("ibv_query_gid failed (errno %d)", errno);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
_ibQp->info.spn = gid.global.subnet_prefix;
|
||||
}
|
||||
|
||||
struct ibv_qp_attr qp_attr;
|
||||
std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr));
|
||||
qp_attr.qp_state = IBV_QPS_INIT;
|
||||
qp_attr.pkey_index = 0;
|
||||
qp_attr.port_num = port;
|
||||
qp_attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
|
||||
if (ibv_modify_qp(qp, &qp_attr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
|
||||
WARN("ibv_modify_qp failed (errno %d)", errno);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
MSCCLPPCHECK(mscclppCalloc(&_ibQp->wrs, MSCCLPP_IB_MAX_SENDS));
|
||||
MSCCLPPCHECK(mscclppCalloc(&_ibQp->sges, MSCCLPP_IB_MAX_SENDS));
|
||||
MSCCLPPCHECK(mscclppCalloc(&_ibQp->wcs, MSCCLPP_IB_CQ_POLL_NUM));
|
||||
_ibQp->cq = cq;
|
||||
|
||||
*ibQp = _ibQp;
|
||||
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppIbContextRegisterMr(struct mscclppIbContext* ctx, void* buff, size_t size,
|
||||
struct mscclppIbMr** ibMr)
|
||||
IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff)
|
||||
{
|
||||
if (size == 0) {
|
||||
WARN("invalid size: %zu", size);
|
||||
return mscclppInvalidArgument;
|
||||
throw std::runtime_error("invalid size: " + std::to_string(size));
|
||||
}
|
||||
static __thread uintptr_t pageSize = 0;
|
||||
if (pageSize == 0) {
|
||||
pageSize = sysconf(_SC_PAGESIZE);
|
||||
}
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(buff) & -pageSize;
|
||||
size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
|
||||
struct ibv_mr* mr =
|
||||
ibv_reg_mr(ctx->pd, reinterpret_cast<void*>(addr), pages * pageSize,
|
||||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING);
|
||||
if (mr == nullptr) {
|
||||
WARN("ibv_reg_mr failed (errno %d)", errno);
|
||||
return mscclppInternalError;
|
||||
std::size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
|
||||
struct ibv_pd* _pd = reinterpret_cast<struct ibv_pd*>(pd);
|
||||
struct ibv_mr* _mr = ibv_reg_mr(_pd, reinterpret_cast<void*>(addr), pages * pageSize, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING);
|
||||
if (_mr == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_reg_mr failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
ctx->nMrs++;
|
||||
if (ctx->mrs == NULL) {
|
||||
MSCCLPPCHECK(mscclppCalloc(&ctx->mrs, MAXCONNECTIONS));
|
||||
ctx->maxMrs = MAXCONNECTIONS;
|
||||
}
|
||||
if (ctx->maxMrs < ctx->nMrs) {
|
||||
WARN("too many MRs");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
struct mscclppIbMr* _ibMr = &ctx->mrs[ctx->nMrs - 1];
|
||||
_ibMr->mr = mr;
|
||||
_ibMr->buff = buff;
|
||||
_ibMr->info.addr = (uint64_t)buff;
|
||||
_ibMr->info.rkey = mr->rkey;
|
||||
*ibMr = _ibMr;
|
||||
return mscclppSuccess;
|
||||
this->mr = _mr;
|
||||
this->size = pages * pageSize;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int mscclppIbQp::rtr(const mscclppIbQpInfo* info)
|
||||
IbMr::~IbMr()
|
||||
{
|
||||
struct ibv_qp_attr qp_attr;
|
||||
std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr));
|
||||
qp_attr.qp_state = IBV_QPS_RTR;
|
||||
qp_attr.path_mtu = info->mtu;
|
||||
qp_attr.dest_qp_num = info->qpn;
|
||||
qp_attr.rq_psn = 0;
|
||||
qp_attr.max_dest_rd_atomic = 1;
|
||||
qp_attr.min_rnr_timer = 0x12;
|
||||
if (info->linkLayer == IBV_LINK_LAYER_ETHERNET) {
|
||||
qp_attr.ah_attr.is_global = 1;
|
||||
qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info->spn;
|
||||
qp_attr.ah_attr.grh.dgid.global.interface_id = info->lid;
|
||||
qp_attr.ah_attr.grh.flow_label = 0;
|
||||
qp_attr.ah_attr.grh.sgid_index = 0;
|
||||
qp_attr.ah_attr.grh.hop_limit = 255;
|
||||
qp_attr.ah_attr.grh.traffic_class = 0;
|
||||
} else {
|
||||
qp_attr.ah_attr.is_global = 0;
|
||||
qp_attr.ah_attr.dlid = info->lid;
|
||||
}
|
||||
qp_attr.ah_attr.sl = 0;
|
||||
qp_attr.ah_attr.src_path_bits = 0;
|
||||
qp_attr.ah_attr.port_num = info->port;
|
||||
return ibv_modify_qp(this->qp, &qp_attr,
|
||||
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
|
||||
IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER);
|
||||
ibv_dereg_mr(reinterpret_cast<struct ibv_mr*>(this->mr));
|
||||
}
|
||||
|
||||
int mscclppIbQp::rts()
|
||||
IbMrInfo IbMr::getInfo() const
|
||||
{
|
||||
struct ibv_qp_attr qp_attr;
|
||||
std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr));
|
||||
qp_attr.qp_state = IBV_QPS_RTS;
|
||||
qp_attr.timeout = 18;
|
||||
qp_attr.retry_cnt = 7;
|
||||
qp_attr.rnr_retry = 7;
|
||||
qp_attr.sq_psn = 0;
|
||||
qp_attr.max_rd_atomic = 1;
|
||||
return ibv_modify_qp(this->qp, &qp_attr,
|
||||
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
|
||||
IBV_QP_MAX_QP_RD_ATOMIC);
|
||||
IbMrInfo info;
|
||||
info.addr = reinterpret_cast<uint64_t>(this->buff);
|
||||
info.rkey = reinterpret_cast<struct ibv_mr*>(this->mr)->rkey;
|
||||
return info;
|
||||
}
|
||||
|
||||
int mscclppIbQp::stageSend(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId,
|
||||
uint64_t srcOffset, uint64_t dstOffset, bool signaled)
|
||||
const void* IbMr::getBuff() const
|
||||
{
|
||||
if (this->wrn >= MSCCLPP_IB_MAX_SENDS) {
|
||||
return -1;
|
||||
}
|
||||
int wrn = this->wrn;
|
||||
struct ibv_send_wr* wr_ = &this->wrs[wrn];
|
||||
struct ibv_sge* sge_ = &this->sges[wrn];
|
||||
// std::memset(wr_, 0, sizeof(struct ibv_send_wr));
|
||||
// std::memset(sge_, 0, sizeof(struct ibv_sge));
|
||||
wr_->wr_id = wrId;
|
||||
wr_->sg_list = sge_;
|
||||
wr_->num_sge = 1;
|
||||
wr_->opcode = IBV_WR_RDMA_WRITE;
|
||||
wr_->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
|
||||
wr_->wr.rdma.remote_addr = (uint64_t)(info->addr) + dstOffset;
|
||||
wr_->wr.rdma.rkey = info->rkey;
|
||||
wr_->next = nullptr;
|
||||
sge_->addr = (uint64_t)(ibMr->buff) + srcOffset;
|
||||
sge_->length = size;
|
||||
sge_->lkey = ibMr->mr->lkey;
|
||||
if (wrn > 0) {
|
||||
this->wrs[wrn - 1].next = wr_;
|
||||
}
|
||||
this->wrn++;
|
||||
return this->wrn;
|
||||
return this->buff;
|
||||
}
|
||||
|
||||
int mscclppIbQp::stageSendWithImm(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId,
|
||||
uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData)
|
||||
uint32_t IbMr::getLkey() const
|
||||
{
|
||||
int wrn = this->stageSend(ibMr, info, size, wrId, srcOffset, dstOffset, signaled);
|
||||
this->wrs[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
this->wrs[wrn - 1].imm_data = immData;
|
||||
return wrn;
|
||||
return reinterpret_cast<struct ibv_mr*>(this->mr)->lkey;
|
||||
}
|
||||
|
||||
int mscclppIbQp::postSend()
|
||||
{
|
||||
if (this->wrn == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct ibv_send_wr* bad_wr;
|
||||
int ret = ibv_post_send(this->qp, this->wrs, &bad_wr);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
this->wrn = 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int mscclppIbQp::postRecv(uint64_t wrId)
|
||||
{
|
||||
struct ibv_recv_wr wr, *bad_wr;
|
||||
wr.wr_id = wrId;
|
||||
wr.sg_list = nullptr;
|
||||
wr.num_sge = 0;
|
||||
wr.next = nullptr;
|
||||
return ibv_post_recv(this->qp, &wr, &bad_wr);
|
||||
}
|
||||
|
||||
int mscclppIbQp::pollCq()
|
||||
{
|
||||
return ibv_poll_cq(this->cq, MSCCLPP_IB_CQ_POLL_NUM, this->wcs);
|
||||
}
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
IbQp::IbQp(void* ctx, void* pd, int port)
|
||||
{
|
||||
struct ibv_context* _ctx = static_cast<struct ibv_context*>(ctx);
|
||||
struct ibv_pd* _pd = static_cast<struct ibv_pd*>(pd);
|
||||
struct ibv_context* _ctx = reinterpret_cast<struct ibv_context*>(ctx);
|
||||
struct ibv_pd* _pd = reinterpret_cast<struct ibv_pd*>(pd);
|
||||
|
||||
this->cq = ibv_create_cq(_ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0);
|
||||
if (this->cq == nullptr) {
|
||||
@@ -379,8 +76,8 @@ IbQp::IbQp(void* ctx, void* pd, int port)
|
||||
struct ibv_qp_init_attr qpInitAttr;
|
||||
std::memset(&qpInitAttr, 0, sizeof(qpInitAttr));
|
||||
qpInitAttr.sq_sig_all = 0;
|
||||
qpInitAttr.send_cq = static_cast<struct ibv_cq*>(this->cq);
|
||||
qpInitAttr.recv_cq = static_cast<struct ibv_cq*>(this->cq);
|
||||
qpInitAttr.send_cq = reinterpret_cast<struct ibv_cq*>(this->cq);
|
||||
qpInitAttr.recv_cq = reinterpret_cast<struct ibv_cq*>(this->cq);
|
||||
qpInitAttr.qp_type = IBV_QPT_RC;
|
||||
qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
|
||||
qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
|
||||
@@ -428,14 +125,160 @@ IbQp::IbQp(void* ctx, void* pd, int port)
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
this->qp = _qp;
|
||||
MSCCLPPTHROW(mscclppCalloc(reinterpret_cast<struct ibv_send_wr**>(&this->wrs), MSCCLPP_IB_MAX_SENDS));
|
||||
MSCCLPPTHROW(mscclppCalloc(reinterpret_cast<struct ibv_sge**>(&this->sges), MSCCLPP_IB_MAX_SENDS));
|
||||
MSCCLPPTHROW(mscclppCalloc(reinterpret_cast<struct ibv_wc**>(&this->wcs), MSCCLPP_IB_CQ_POLL_NUM));
|
||||
}
|
||||
|
||||
IbCtx::IbCtx(const std::string& ibDevName)
|
||||
IbQp::~IbQp()
|
||||
{
|
||||
ibv_destroy_qp(reinterpret_cast<struct ibv_qp*>(this->qp));
|
||||
ibv_destroy_cq(reinterpret_cast<struct ibv_cq*>(this->cq));
|
||||
std::free(this->wrs);
|
||||
std::free(this->sges);
|
||||
std::free(this->wcs);
|
||||
}
|
||||
|
||||
void IbQp::rtr(const IbQpInfo& info)
|
||||
{
|
||||
struct ibv_qp_attr qp_attr;
|
||||
std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr));
|
||||
qp_attr.qp_state = IBV_QPS_RTR;
|
||||
qp_attr.path_mtu = static_cast<ibv_mtu>(info.mtu);
|
||||
qp_attr.dest_qp_num = info.qpn;
|
||||
qp_attr.rq_psn = 0;
|
||||
qp_attr.max_dest_rd_atomic = 1;
|
||||
qp_attr.min_rnr_timer = 0x12;
|
||||
if (info.linkLayer == IBV_LINK_LAYER_ETHERNET) {
|
||||
qp_attr.ah_attr.is_global = 1;
|
||||
qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info.spn;
|
||||
qp_attr.ah_attr.grh.dgid.global.interface_id = info.lid;
|
||||
qp_attr.ah_attr.grh.flow_label = 0;
|
||||
qp_attr.ah_attr.grh.sgid_index = 0;
|
||||
qp_attr.ah_attr.grh.hop_limit = 255;
|
||||
qp_attr.ah_attr.grh.traffic_class = 0;
|
||||
} else {
|
||||
qp_attr.ah_attr.is_global = 0;
|
||||
qp_attr.ah_attr.dlid = info.lid;
|
||||
}
|
||||
qp_attr.ah_attr.sl = 0;
|
||||
qp_attr.ah_attr.src_path_bits = 0;
|
||||
qp_attr.ah_attr.port_num = info.port;
|
||||
int ret = ibv_modify_qp(reinterpret_cast<struct ibv_qp*>(this->qp), &qp_attr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER);
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_modify_qp failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
void IbQp::rts()
|
||||
{
|
||||
struct ibv_qp_attr qp_attr;
|
||||
std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr));
|
||||
qp_attr.qp_state = IBV_QPS_RTS;
|
||||
qp_attr.timeout = 18;
|
||||
qp_attr.retry_cnt = 7;
|
||||
qp_attr.rnr_retry = 7;
|
||||
qp_attr.sq_psn = 0;
|
||||
qp_attr.max_rd_atomic = 1;
|
||||
int ret = ibv_modify_qp(reinterpret_cast<struct ibv_qp*>(this->qp), &qp_attr, IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC);
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_modify_qp failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
int IbQp::stageSend(const IbMr *mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled)
|
||||
{
|
||||
if (this->wrn >= MSCCLPP_IB_MAX_SENDS) {
|
||||
return -1;
|
||||
}
|
||||
int wrn = this->wrn;
|
||||
struct ibv_send_wr* wrs_ = reinterpret_cast<struct ibv_send_wr*>(this->wrs);
|
||||
struct ibv_sge* sges_ = reinterpret_cast<struct ibv_sge*>(this->sges);
|
||||
|
||||
struct ibv_send_wr* wr_ = &wrs_[wrn];
|
||||
struct ibv_sge* sge_ = &sges_[wrn];
|
||||
wr_->wr_id = wrId;
|
||||
wr_->sg_list = sge_;
|
||||
wr_->num_sge = 1;
|
||||
wr_->opcode = IBV_WR_RDMA_WRITE;
|
||||
wr_->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
|
||||
wr_->wr.rdma.remote_addr = (uint64_t)(info.addr) + dstOffset;
|
||||
wr_->wr.rdma.rkey = info.rkey;
|
||||
wr_->next = nullptr;
|
||||
sge_->addr = (uint64_t)(mr->getBuff()) + srcOffset;
|
||||
sge_->length = size;
|
||||
sge_->lkey = mr->getLkey();
|
||||
if (wrn > 0) {
|
||||
wrs_[wrn - 1].next = wr_;
|
||||
}
|
||||
this->wrn++;
|
||||
return this->wrn;
|
||||
}
|
||||
|
||||
int IbQp::stageSendWithImm(const IbMr *mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData)
|
||||
{
|
||||
int wrn = this->stageSend(mr, info, size, wrId, srcOffset, dstOffset, signaled);
|
||||
struct ibv_send_wr* wrs_ = reinterpret_cast<struct ibv_send_wr*>(this->wrs);
|
||||
wrs_[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
wrs_[wrn - 1].imm_data = immData;
|
||||
return wrn;
|
||||
}
|
||||
|
||||
void IbQp::postSend()
|
||||
{
|
||||
if (this->wrn == 0) {
|
||||
return;
|
||||
}
|
||||
struct ibv_send_wr* bad_wr;
|
||||
int ret = ibv_post_send(reinterpret_cast<struct ibv_qp*>(this->qp), reinterpret_cast<struct ibv_send_wr*>(this->wrs), &bad_wr);
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_post_send failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
this->wrn = 0;
|
||||
}
|
||||
|
||||
void IbQp::postRecv(uint64_t wrId)
|
||||
{
|
||||
struct ibv_recv_wr wr, *bad_wr;
|
||||
wr.wr_id = wrId;
|
||||
wr.sg_list = nullptr;
|
||||
wr.num_sge = 0;
|
||||
wr.next = nullptr;
|
||||
int ret = ibv_post_recv(reinterpret_cast<struct ibv_qp*>(this->qp), &wr, &bad_wr);
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_post_recv failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
int IbQp::pollCq()
|
||||
{
|
||||
return ibv_poll_cq(reinterpret_cast<struct ibv_cq*>(this->cq), MSCCLPP_IB_CQ_POLL_NUM, reinterpret_cast<struct ibv_wc*>(this->wcs));
|
||||
}
|
||||
|
||||
const IbQpInfo& IbQp::getInfo() const
|
||||
{
|
||||
return this->info;
|
||||
}
|
||||
|
||||
const void* IbQp::getWc(int idx) const
|
||||
{
|
||||
return &reinterpret_cast<struct ibv_wc*>(this->wcs)[idx];
|
||||
}
|
||||
|
||||
IbCtx::IbCtx(const std::string& devName) : devName(devName)
|
||||
{
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
if (std::string(devices[i]->name) == ibDevName) {
|
||||
if (std::string(devices[i]->name) == devName) {
|
||||
this->ctx = ibv_open_device(devices[i]);
|
||||
break;
|
||||
}
|
||||
@@ -443,10 +286,10 @@ IbCtx::IbCtx(const std::string& ibDevName)
|
||||
ibv_free_device_list(devices);
|
||||
if (this->ctx == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_open_device failed (errno " << errno << ", device name << " << ibDevName << ")";
|
||||
err << "ibv_open_device failed (errno " << errno << ", device name << " << devName << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
this->pd = ibv_alloc_pd(static_cast<struct ibv_context*>(this->ctx));
|
||||
this->pd = ibv_alloc_pd(reinterpret_cast<struct ibv_context*>(this->ctx));
|
||||
if (this->pd == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_alloc_pd failed (errno " << errno << ")";
|
||||
@@ -456,18 +299,20 @@ IbCtx::IbCtx(const std::string& ibDevName)
|
||||
|
||||
IbCtx::~IbCtx()
|
||||
{
|
||||
this->mrs.clear();
|
||||
this->qps.clear();
|
||||
if (this->pd != nullptr) {
|
||||
ibv_dealloc_pd(static_cast<struct ibv_pd*>(this->pd));
|
||||
ibv_dealloc_pd(reinterpret_cast<struct ibv_pd*>(this->pd));
|
||||
}
|
||||
if (this->ctx != nullptr) {
|
||||
ibv_close_device(static_cast<struct ibv_context*>(this->ctx));
|
||||
ibv_close_device(reinterpret_cast<struct ibv_context*>(this->ctx));
|
||||
}
|
||||
}
|
||||
|
||||
bool IbCtx::isPortUsable(int port) const
|
||||
{
|
||||
struct ibv_port_attr portAttr;
|
||||
if (ibv_query_port(static_cast<struct ibv_context*>(this->ctx), port, &portAttr) != 0) {
|
||||
if (ibv_query_port(reinterpret_cast<struct ibv_context*>(this->ctx), port, &portAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
@@ -479,7 +324,7 @@ bool IbCtx::isPortUsable(int port) const
|
||||
int IbCtx::getAnyActivePort() const
|
||||
{
|
||||
struct ibv_device_attr devAttr;
|
||||
if (ibv_query_device(static_cast<struct ibv_context*>(this->ctx), &devAttr) != 0) {
|
||||
if (ibv_query_device(reinterpret_cast<struct ibv_context*>(this->ctx), &devAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_device failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
@@ -506,4 +351,89 @@ IbQp* IbCtx::createQp(int port /*=-1*/)
|
||||
return qps.back().get();
|
||||
}
|
||||
|
||||
const IbMr* IbCtx::registerMr(void* buff, std::size_t size)
|
||||
{
|
||||
mrs.emplace_back(new IbMr(this->pd, buff, size));
|
||||
return mrs.back().get();
|
||||
}
|
||||
|
||||
const std::string& IbCtx::getDevName() const
|
||||
{
|
||||
return this->devName;
|
||||
}
|
||||
|
||||
int getIBDeviceCount() {
|
||||
int num;
|
||||
ibv_get_device_list(&num);
|
||||
return num;
|
||||
}
|
||||
|
||||
std::string getIBDeviceName(TransportFlags ibTransport) {
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
int ibTransportIndex;
|
||||
switch (ibTransport) { // TODO: get rid of this ugly switch
|
||||
case TransportIB0:
|
||||
ibTransportIndex = 0;
|
||||
break;
|
||||
case TransportIB1:
|
||||
ibTransportIndex = 1;
|
||||
break;
|
||||
case TransportIB2:
|
||||
ibTransportIndex = 2;
|
||||
break;
|
||||
case TransportIB3:
|
||||
ibTransportIndex = 3;
|
||||
break;
|
||||
case TransportIB4:
|
||||
ibTransportIndex = 4;
|
||||
break;
|
||||
case TransportIB5:
|
||||
ibTransportIndex = 5;
|
||||
break;
|
||||
case TransportIB6:
|
||||
ibTransportIndex = 6;
|
||||
break;
|
||||
case TransportIB7:
|
||||
ibTransportIndex = 7;
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Not an IB transport");
|
||||
}
|
||||
if (ibTransportIndex >= num) {
|
||||
throw std::runtime_error("IB transport out of range");
|
||||
}
|
||||
return devices[ibTransportIndex]->name;
|
||||
}
|
||||
|
||||
TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName) {
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
if (ibDeviceName == devices[i]->name) {
|
||||
switch (i) { // TODO: get rid of this ugly switch
|
||||
case 0:
|
||||
return TransportIB0;
|
||||
case 1:
|
||||
return TransportIB1;
|
||||
case 2:
|
||||
return TransportIB2;
|
||||
case 3:
|
||||
return TransportIB3;
|
||||
case 4:
|
||||
return TransportIB4;
|
||||
case 5:
|
||||
return TransportIB5;
|
||||
case 6:
|
||||
return TransportIB6;
|
||||
case 7:
|
||||
return TransportIB7;
|
||||
default:
|
||||
throw std::runtime_error("IB device index out of range");
|
||||
}
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("IB device not found");
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -7,9 +7,10 @@
|
||||
#ifndef MSCCLPP_COMM_H_
|
||||
#define MSCCLPP_COMM_H_
|
||||
|
||||
#include "ib.h"
|
||||
#include "ib.hpp"
|
||||
#include "proxy.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#define MAXCONNECTIONS 64
|
||||
|
||||
@@ -31,7 +32,7 @@ struct mscclppConn
|
||||
std::vector<mscclppBufferRegistration> bufferRegistrations;
|
||||
std::vector<mscclppBufferRegistration> remoteBufferRegistrations;
|
||||
|
||||
struct mscclppIbContext* ibCtx;
|
||||
mscclpp::IbCtx* ibCtx;
|
||||
#if defined(ENABLE_NPKIT)
|
||||
std::vector<uint64_t> npkitUsedReqIds;
|
||||
std::vector<uint64_t> npkitFreeReqIds;
|
||||
@@ -57,7 +58,7 @@ struct mscclppComm
|
||||
// Flag to ask MSCCLPP kernels to abort
|
||||
volatile uint32_t* abortFlag;
|
||||
|
||||
struct mscclppIbContext* ibContext[MSCCLPP_IB_MAX_DEVS];
|
||||
std::unique_ptr<mscclpp::IbCtx> ibContext[MSCCLPP_IB_MAX_DEVS];
|
||||
struct mscclppProxyState* proxyState[MSCCLPP_PROXY_MAX_NUM];
|
||||
};
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "mscclpp.h"
|
||||
#include "channel.hpp"
|
||||
#include "proxy.hpp"
|
||||
#include "ib.h"
|
||||
#include "ib.hpp"
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mscclpp {
|
||||
@@ -15,13 +15,13 @@ class ConnectionBase;
|
||||
struct Communicator::Impl {
|
||||
mscclppComm_t comm;
|
||||
std::vector<std::shared_ptr<ConnectionBase>> connections;
|
||||
std::unordered_map<TransportFlags, mscclppIbContext*> ibContexts;
|
||||
std::unordered_map<TransportFlags, IbCtx*> ibContexts;
|
||||
|
||||
Impl();
|
||||
|
||||
~Impl();
|
||||
|
||||
mscclppIbContext* getIbContext(TransportFlags ibTransport);
|
||||
IbCtx* getIbContext(TransportFlags ibTransport);
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include <cuda_runtime.h>
|
||||
#include "ib.h"
|
||||
#include "ib.hpp"
|
||||
#include "communicator.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
@@ -38,7 +38,7 @@ class IBConnection : public ConnectionBase {
|
||||
int tag;
|
||||
TransportFlags transport_;
|
||||
TransportFlags remoteTransport_;
|
||||
mscclppIbQp* qp;
|
||||
IbQp* qp;
|
||||
public:
|
||||
|
||||
IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl);
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
#ifndef MSCCLPP_IB_H_
|
||||
#define MSCCLPP_IB_H_
|
||||
|
||||
#include "mscclpp.h"
|
||||
#include <infiniband/verbs.h>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#define MSCCLPP_IB_CQ_SIZE 1024
|
||||
#define MSCCLPP_IB_CQ_POLL_NUM 4
|
||||
#define MSCCLPP_IB_MAX_SENDS 64
|
||||
#define MSCCLPP_IB_MAX_DEVS 8
|
||||
|
||||
// QP info to be shared with the remote peer
|
||||
struct mscclppIbQpInfo
|
||||
{
|
||||
uint16_t lid;
|
||||
uint8_t port;
|
||||
uint8_t linkLayer;
|
||||
uint32_t qpn;
|
||||
uint64_t spn;
|
||||
ibv_mtu mtu;
|
||||
};
|
||||
|
||||
// IB queue pair
|
||||
struct mscclppIbQp
|
||||
{
|
||||
struct ibv_qp* qp;
|
||||
struct mscclppIbQpInfo info;
|
||||
struct ibv_send_wr* wrs;
|
||||
struct ibv_sge* sges;
|
||||
struct ibv_cq* cq;
|
||||
struct ibv_wc* wcs;
|
||||
int wrn;
|
||||
|
||||
int rtr(const mscclppIbQpInfo* info);
|
||||
int rts();
|
||||
int stageSend(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled);
|
||||
int stageSendWithImm(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId,
|
||||
uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData);
|
||||
int postSend();
|
||||
int postRecv(uint64_t wrId);
|
||||
int pollCq();
|
||||
};
|
||||
|
||||
// Holds resources of a single IB device.
|
||||
struct mscclppIbContext
|
||||
{
|
||||
struct ibv_context* ctx;
|
||||
struct ibv_pd* pd;
|
||||
int* ports;
|
||||
int nPorts;
|
||||
struct mscclppIbQp* qps;
|
||||
int nQps;
|
||||
int maxQps;
|
||||
struct mscclppIbMr* mrs;
|
||||
int nMrs;
|
||||
int maxMrs;
|
||||
};
|
||||
|
||||
mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName);
|
||||
mscclppResult_t mscclppIbContextDestroy(struct mscclppIbContext* ctx);
|
||||
mscclppResult_t mscclppIbContextCreateQp(struct mscclppIbContext* ctx, struct mscclppIbQp** ibQp, int port = -1);
|
||||
mscclppResult_t mscclppIbContextRegisterMr(struct mscclppIbContext* ctx, void* buff, size_t size,
|
||||
struct mscclppIbMr** ibMr);
|
||||
|
||||
#endif
|
||||
@@ -5,8 +5,38 @@
|
||||
#include <string>
|
||||
#include <list>
|
||||
|
||||
#define MSCCLPP_IB_CQ_SIZE 1024
|
||||
#define MSCCLPP_IB_CQ_POLL_NUM 1
|
||||
#define MSCCLPP_IB_MAX_SENDS 64
|
||||
#define MSCCLPP_IB_MAX_DEVS 8
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct IbMrInfo
|
||||
{
|
||||
uint64_t addr;
|
||||
uint32_t rkey;
|
||||
};
|
||||
|
||||
class IbMr
|
||||
{
|
||||
public:
|
||||
~IbMr();
|
||||
|
||||
IbMrInfo getInfo() const;
|
||||
const void* getBuff() const;
|
||||
uint32_t getLkey() const;
|
||||
|
||||
private:
|
||||
IbMr(void* pd, void* buff, std::size_t size);
|
||||
|
||||
void* mr;
|
||||
void* buff;
|
||||
std::size_t size;
|
||||
|
||||
friend class IbCtx;
|
||||
};
|
||||
|
||||
// QP info to be shared with the remote peer
|
||||
struct IbQpInfo
|
||||
{
|
||||
@@ -15,7 +45,7 @@ struct IbQpInfo
|
||||
uint8_t linkLayer;
|
||||
uint32_t qpn;
|
||||
uint64_t spn;
|
||||
uint32_t mtu;
|
||||
int mtu;
|
||||
};
|
||||
|
||||
class IbQp
|
||||
@@ -23,11 +53,22 @@ class IbQp
|
||||
public:
|
||||
~IbQp();
|
||||
|
||||
IbQpInfo info;
|
||||
void rtr(const IbQpInfo& info);
|
||||
void rts();
|
||||
int stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled);
|
||||
int stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData);
|
||||
void postSend();
|
||||
void postRecv(uint64_t wrId);
|
||||
int pollCq();
|
||||
|
||||
const IbQpInfo& getInfo() const;
|
||||
const void* getWc(int idx) const;
|
||||
|
||||
private:
|
||||
IbQp(void* ctx, void* pd, int port);
|
||||
|
||||
IbQpInfo info;
|
||||
|
||||
void* qp;
|
||||
void* cq;
|
||||
void* wcs;
|
||||
@@ -38,22 +79,26 @@ private:
|
||||
friend class IbCtx;
|
||||
};
|
||||
|
||||
|
||||
class IbCtx
|
||||
{
|
||||
public:
|
||||
IbCtx(const std::string& ibDevName);
|
||||
IbCtx(const std::string& devName);
|
||||
~IbCtx();
|
||||
|
||||
IbQp* createQp(int port = -1);
|
||||
const IbMr* registerMr(void* buff, std::size_t size);
|
||||
|
||||
const std::string& getDevName() const;
|
||||
|
||||
private:
|
||||
bool isPortUsable(int port) const;
|
||||
int getAnyActivePort() const;
|
||||
|
||||
const std::string devName;
|
||||
void* ctx;
|
||||
void* pd;
|
||||
std::list<std::unique_ptr<IbQp>> qps;
|
||||
std::list<std::unique_ptr<IbMr>> mrs;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -207,25 +207,10 @@ typedef struct
|
||||
char internal[MSCCLPP_UNIQUE_ID_BYTES];
|
||||
} mscclppUniqueId;
|
||||
|
||||
// MR info to be shared with the remote peer
|
||||
struct mscclppIbMrInfo
|
||||
{
|
||||
uint64_t addr;
|
||||
uint32_t rkey;
|
||||
};
|
||||
|
||||
// IB memory region
|
||||
struct mscclppIbMr
|
||||
{
|
||||
struct ibv_mr* mr;
|
||||
void* buff;
|
||||
struct mscclppIbMrInfo info;
|
||||
};
|
||||
|
||||
struct mscclppRegisteredMemoryP2P
|
||||
{
|
||||
void* remoteBuff;
|
||||
mscclppIbMr* IbMr;
|
||||
const void* IbMr;
|
||||
};
|
||||
|
||||
struct mscclppRegisteredMemory
|
||||
|
||||
@@ -59,7 +59,7 @@ struct mscclppProxyState
|
||||
mscclppProxyRunState_t run;
|
||||
|
||||
int numaNodeToBind;
|
||||
struct mscclppIbContext* ibContext; // For IB connection only
|
||||
mscclpp::IbCtx* ibContext; // For IB connection only
|
||||
cudaStream_t p2pStream; // for P2P DMA engine only
|
||||
|
||||
struct mscclppProxyFifo fifo;
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include "mscclpp.h"
|
||||
#include "ib.h"
|
||||
#include "ib.hpp"
|
||||
#include "communicator.hpp"
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
@@ -16,8 +16,8 @@ struct TransportInfo {
|
||||
bool ibLocal;
|
||||
union {
|
||||
cudaIpcMemHandle_t cudaIpcHandle;
|
||||
mscclppIbMr* ibMr;
|
||||
mscclppIbMrInfo ibMrInfo;
|
||||
const IbMr* ibMr;
|
||||
IbMrInfo ibMrInfo;
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
79
src/init.cc
79
src/init.cc
@@ -7,6 +7,7 @@
|
||||
#include "gdr.h"
|
||||
#endif
|
||||
#include "mscclpp.h"
|
||||
#include "infiniband/verbs.h"
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
@@ -191,7 +192,7 @@ MSCCLPP_API mscclppResult_t mscclppCommDestroy(mscclppComm_t comm)
|
||||
|
||||
for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) {
|
||||
if (comm->ibContext[i]) {
|
||||
MSCCLPPCHECK(mscclppIbContextDestroy(comm->ibContext[i]));
|
||||
comm->ibContext[i].reset(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -366,24 +367,17 @@ struct mscclppHostIBConn : mscclppHostConn
|
||||
}
|
||||
void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize)
|
||||
{
|
||||
this->ibQp->stageSend(this->ibMrs[src], &this->remoteIbMrInfos[dst], (uint32_t)dataSize,
|
||||
this->ibQp->stageSend(this->ibMrs[src], this->remoteIbMrInfos[dst], (uint32_t)dataSize,
|
||||
/*wrId=*/0, /*srcOffset=*/srcDataOffset, /*dstOffset=*/dstDataOffset, /*signaled=*/false);
|
||||
int ret = this->ibQp->postSend();
|
||||
if (ret != 0) {
|
||||
// Return value is errno.
|
||||
WARN("data postSend failed: errno %d", ret);
|
||||
}
|
||||
this->ibQp->postSend();
|
||||
npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)dataSize);
|
||||
}
|
||||
void signal()
|
||||
{
|
||||
// My local device flag is copied to the remote's proxy flag
|
||||
this->ibQp->stageSend(this->ibMrs[0], &this->remoteIbMrInfos[0], sizeof(uint64_t),
|
||||
this->ibQp->stageSend(this->ibMrs[0], this->remoteIbMrInfos[0], sizeof(uint64_t),
|
||||
/*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/sizeof(uint64_t), /*signaled=*/true);
|
||||
int ret = this->ibQp->postSend();
|
||||
if (ret != 0) {
|
||||
WARN("flag postSend failed: errno %d", ret);
|
||||
}
|
||||
this->ibQp->postSend();
|
||||
npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_FLAG_ENTRY, (uint32_t)sizeof(uint64_t));
|
||||
}
|
||||
void wait()
|
||||
@@ -399,15 +393,11 @@ struct mscclppHostIBConn : mscclppHostConn
|
||||
continue;
|
||||
}
|
||||
for (int i = 0; i < wcNum; ++i) {
|
||||
struct ibv_wc* wc = &this->ibQp->wcs[i];
|
||||
struct ibv_wc* wc = (struct ibv_wc*)this->ibQp->getWc(i);
|
||||
if (wc->status != IBV_WC_SUCCESS) {
|
||||
WARN("wc status %d", wc->status);
|
||||
continue;
|
||||
}
|
||||
if (wc->qp_num != this->ibQp->qp->qp_num) {
|
||||
WARN("got wc of unknown qp_num %d", wc->qp_num);
|
||||
continue;
|
||||
}
|
||||
if (wc->opcode == IBV_WC_RDMA_WRITE) {
|
||||
isWaiting = false;
|
||||
break;
|
||||
@@ -418,9 +408,9 @@ struct mscclppHostIBConn : mscclppHostConn
|
||||
}
|
||||
|
||||
mscclppConn* conn;
|
||||
struct mscclppIbQp* ibQp;
|
||||
std::vector<mscclppIbMr*> ibMrs;
|
||||
std::vector<mscclppIbMrInfo> remoteIbMrInfos;
|
||||
mscclpp::IbQp* ibQp;
|
||||
std::vector<const mscclpp::IbMr*> ibMrs;
|
||||
std::vector<mscclpp::IbMrInfo> remoteIbMrInfos;
|
||||
};
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag, mscclppTransport_t transportType, const char* ibDev)
|
||||
@@ -458,7 +448,7 @@ MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int
|
||||
if (firstNullIdx == -1) {
|
||||
firstNullIdx = i;
|
||||
}
|
||||
} else if (strncmp(comm->ibContext[i]->ctx->device->name, ibDev, IBV_SYSFS_NAME_MAX) == 0) {
|
||||
} else if (strncmp(comm->ibContext[i]->getDevName().c_str(), ibDev, IBV_SYSFS_NAME_MAX) == 0) {
|
||||
ibDevIdx = i;
|
||||
break;
|
||||
}
|
||||
@@ -468,13 +458,10 @@ MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int
|
||||
if (ibDevIdx == -1) {
|
||||
// Create a new context.
|
||||
ibDevIdx = firstNullIdx;
|
||||
if (mscclppIbContextCreate(&comm->ibContext[ibDevIdx], ibDev) != mscclppSuccess) {
|
||||
WARN("Failed to create IB context");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
comm->ibContext[ibDevIdx].reset(new mscclpp::IbCtx(std::string(ibDev)));
|
||||
}
|
||||
// Set the ib context for this conn
|
||||
conn->ibCtx = comm->ibContext[ibDevIdx];
|
||||
conn->ibCtx = comm->ibContext[ibDevIdx].get();
|
||||
|
||||
} else if (transportType == mscclppTransportP2P) {
|
||||
// do the rest of the initialization later
|
||||
@@ -609,17 +596,17 @@ MSCCLPP_API mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t com
|
||||
struct mscclppBufferRegistrationInfo
|
||||
{
|
||||
cudaIpcMemHandle_t cudaHandle;
|
||||
mscclppIbMrInfo ibMrInfo;
|
||||
mscclpp::IbMrInfo ibMrInfo;
|
||||
uint64_t size;
|
||||
};
|
||||
|
||||
struct connInfo
|
||||
{
|
||||
mscclppIbQpInfo infoQp;
|
||||
mscclpp::IbQpInfo infoQp;
|
||||
std::vector<mscclppBufferRegistrationInfo> bufferInfos;
|
||||
|
||||
struct header {
|
||||
mscclppIbQpInfo infoQp;
|
||||
mscclpp::IbQpInfo infoQp;
|
||||
int numBufferInfos;
|
||||
};
|
||||
|
||||
@@ -702,22 +689,20 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output
|
||||
devConn->remoteBuff = NULL;
|
||||
devConn->remoteSignalEpochId = NULL;
|
||||
|
||||
struct mscclppIbContext* ibCtx = conn->ibCtx;
|
||||
mscclpp::IbCtx* ibCtx = conn->ibCtx;
|
||||
if (hostConn->ibQp == NULL) {
|
||||
MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &hostConn->ibQp));
|
||||
hostConn->ibQp = ibCtx->createQp();
|
||||
}
|
||||
|
||||
// Add all registered buffers
|
||||
for (const auto &bufReg : conn->bufferRegistrations) {
|
||||
hostConn->ibMrs.emplace_back();
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, bufReg.data,
|
||||
sizeof(struct mscclppDevConnSignalEpochId), &hostConn->ibMrs.back()));
|
||||
hostConn->ibMrs.emplace_back(ibCtx->registerMr(bufReg.data, sizeof(struct mscclppDevConnSignalEpochId)));
|
||||
connInfo->bufferInfos.emplace_back();
|
||||
connInfo->bufferInfos.back().ibMrInfo = hostConn->ibMrs.back()->info;
|
||||
connInfo->bufferInfos.back().ibMrInfo = hostConn->ibMrs.back()->getInfo();
|
||||
connInfo->bufferInfos.back().size = bufReg.size;
|
||||
}
|
||||
|
||||
connInfo->infoQp = hostConn->ibQp->info;
|
||||
connInfo->infoQp = hostConn->ibQp->getInfo();
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
@@ -728,14 +713,8 @@ mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/,
|
||||
return mscclppInternalError;
|
||||
}
|
||||
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
|
||||
if (hostConn->ibQp->rtr(&connInfo->infoQp) != 0) {
|
||||
WARN("Failed to transition QP to RTR");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
if (hostConn->ibQp->rts() != 0) {
|
||||
WARN("Failed to transition QP to RTS");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
hostConn->ibQp->rtr(connInfo->infoQp);
|
||||
hostConn->ibQp->rts();
|
||||
|
||||
// No remote pointers to set with IB, so we just set the Mrs
|
||||
|
||||
@@ -788,25 +767,25 @@ MSCCLPP_API mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm)
|
||||
struct bufferInfo
|
||||
{
|
||||
cudaIpcMemHandle_t handleBuff;
|
||||
mscclppIbMrInfo infoBuffMr;
|
||||
mscclpp::IbMrInfo infoBuffMr;
|
||||
};
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, void* local_memory, size_t size,
|
||||
mscclppRegisteredMemory* regMem)
|
||||
{
|
||||
std::vector<struct mscclppIbMr*> ibMrs;
|
||||
std::vector<const mscclpp::IbMr*> ibMrs;
|
||||
for (int i = 0; i < comm->nConns; ++i) {
|
||||
struct mscclppConn* conn = &comm->conns[i];
|
||||
struct bufferInfo bInfo;
|
||||
struct mscclppIbMr* ibBuffMr;
|
||||
const mscclpp::IbMr* ibBuffMr;
|
||||
|
||||
// TODO: (conn->transport & mscclppTransportP2P) to support both P2P and IB
|
||||
if (conn->transport == mscclppTransportP2P) {
|
||||
CUDACHECK(cudaIpcGetMemHandle(&bInfo.handleBuff, local_memory));
|
||||
} else if (conn->transport == mscclppTransportIB) {
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(conn->ibCtx, local_memory, size, &ibBuffMr));
|
||||
bInfo.infoBuffMr = ibBuffMr->info;
|
||||
ibMrs.push_back(ibBuffMr);
|
||||
ibBuffMr = conn->ibCtx->registerMr(local_memory, size);
|
||||
bInfo.infoBuffMr = ibBuffMr->getInfo();
|
||||
ibMrs.emplace_back(ibBuffMr);
|
||||
}
|
||||
|
||||
MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &bInfo, sizeof(bInfo)));
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "checks.h"
|
||||
#include "comm.h"
|
||||
#include "debug.h"
|
||||
#include "ib.h"
|
||||
#include "ib.hpp"
|
||||
#include "socket.h"
|
||||
|
||||
#include <emmintrin.h>
|
||||
|
||||
@@ -18,8 +18,7 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t
|
||||
auto addIb = [&](TransportFlags ibTransport) {
|
||||
TransportInfo transportInfo;
|
||||
transportInfo.transport = ibTransport;
|
||||
mscclppIbMr* mr;
|
||||
MSCCLPPTHROW(mscclppIbContextRegisterMr(commImpl.getIbContext(ibTransport), data, size, &mr));
|
||||
const IbMr* mr = commImpl.getIbContext(ibTransport)->registerMr(data, size);
|
||||
transportInfo.ibMr = mr;
|
||||
transportInfo.ibLocal = true;
|
||||
this->transportInfos.push_back(transportInfo);
|
||||
@@ -103,7 +102,7 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
|
||||
it += sizeof(handle);
|
||||
transportInfo.cudaIpcHandle = handle;
|
||||
} else if (transportInfo.transport & TransportAllIB) {
|
||||
mscclppIbMrInfo info;
|
||||
IbMrInfo info;
|
||||
std::copy_n(it, sizeof(info), reinterpret_cast<char*>(&info));
|
||||
it += sizeof(info);
|
||||
transportInfo.ibMrInfo = info;
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
#include "alloc.h"
|
||||
#include "checks.h"
|
||||
#include "ib.h"
|
||||
#include <set>
|
||||
#include "ib.hpp"
|
||||
#include "infiniband/verbs.h"
|
||||
#include "mscclpp.hpp"
|
||||
#include <string>
|
||||
#include <array>
|
||||
|
||||
// Measure current time in second.
|
||||
static double getTime(void)
|
||||
@@ -24,8 +26,8 @@ int main(int argc, const char* argv[])
|
||||
printf("Usage: %s <ip:port> <0(recv)/1(send)> <gpu id> <ib id>\n", argv[0]);
|
||||
return 1;
|
||||
}
|
||||
const char* ip_port = argv[1];
|
||||
int is_send = atoi(argv[2]);
|
||||
const char* ipPortPair = argv[1];
|
||||
int isSend = atoi(argv[2]);
|
||||
int cudaDevId = atoi(argv[3]);
|
||||
std::string ibDevName = "mlx5_ib" + std::string(argv[4]);
|
||||
|
||||
@@ -35,51 +37,40 @@ int main(int argc, const char* argv[])
|
||||
int nelem = 1;
|
||||
MSCCLPPCHECK(mscclppCudaCalloc(&data, nelem));
|
||||
|
||||
mscclppComm_t comm;
|
||||
MSCCLPPCHECK(mscclppCommInitRank(&comm, 2, ip_port, is_send));
|
||||
std::shared_ptr<mscclpp::Bootstrap> bootstrap(new mscclpp::Bootstrap(isSend, 2));
|
||||
bootstrap->initialize(ipPortPair);
|
||||
|
||||
struct mscclppIbContext* ctx;
|
||||
struct mscclppIbQp* qp;
|
||||
struct mscclppIbMr* mr;
|
||||
MSCCLPPCHECK(mscclppIbContextCreate(&ctx, ibDevName.c_str()));
|
||||
MSCCLPPCHECK(mscclppIbContextCreateQp(ctx, &qp));
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ctx, data, sizeof(int) * nelem, &mr));
|
||||
mscclpp::IbCtx ctx(ibDevName);
|
||||
mscclpp::IbQp* qp = ctx.createQp();
|
||||
const mscclpp::IbMr* mr = ctx.registerMr(data, sizeof(int) * nelem);
|
||||
|
||||
struct mscclppIbQpInfo* qpInfo;
|
||||
MSCCLPPCHECK(mscclppCalloc(&qpInfo, 2));
|
||||
qpInfo[is_send] = qp->info;
|
||||
std::array<mscclpp::IbQpInfo, 2> qpInfo;
|
||||
qpInfo[isSend] = qp->getInfo();
|
||||
|
||||
struct mscclppIbMrInfo* mrInfo;
|
||||
MSCCLPPCHECK(mscclppCalloc(&mrInfo, 2));
|
||||
mrInfo[is_send] = mr->info;
|
||||
std::array<mscclpp::IbMrInfo, 2> mrInfo;
|
||||
mrInfo[isSend] = mr->getInfo();
|
||||
|
||||
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, qpInfo, sizeof(struct mscclppIbQpInfo)));
|
||||
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, mrInfo, sizeof(struct mscclppIbMrInfo)));
|
||||
bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo));
|
||||
bootstrap->allGather(mrInfo.data(), sizeof(mscclpp::IbMrInfo));
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
if (i == is_send)
|
||||
for (int i = 0; i < bootstrap->getNranks(); ++i) {
|
||||
if (i == isSend)
|
||||
continue;
|
||||
qp->rtr(&qpInfo[i]);
|
||||
qp->rtr(qpInfo[i]);
|
||||
qp->rts();
|
||||
break;
|
||||
}
|
||||
|
||||
printf("connection succeed\n");
|
||||
|
||||
// A simple barrier
|
||||
int* tmp;
|
||||
MSCCLPPCHECK(mscclppCalloc(&tmp, 2));
|
||||
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
|
||||
bootstrap->barrier();
|
||||
|
||||
if (is_send) {
|
||||
if (isSend) {
|
||||
int maxIter = 100000;
|
||||
double start = getTime();
|
||||
for (int iter = 0; iter < maxIter; ++iter) {
|
||||
qp->stageSend(mr, &mrInfo[0], sizeof(int) * nelem, 0, 0, 0, true);
|
||||
if (qp->postSend() != 0) {
|
||||
WARN("postSend failed");
|
||||
return 1;
|
||||
}
|
||||
qp->stageSend(mr, mrInfo[0], sizeof(int) * nelem, 0, 0, 0, true);
|
||||
qp->postSend();
|
||||
bool waiting = true;
|
||||
while (waiting) {
|
||||
int wcNum = qp->pollCq();
|
||||
@@ -88,7 +79,7 @@ int main(int argc, const char* argv[])
|
||||
return 1;
|
||||
}
|
||||
for (int i = 0; i < wcNum; ++i) {
|
||||
struct ibv_wc* wc = &qp->wcs[i];
|
||||
const struct ibv_wc* wc = reinterpret_cast<const struct ibv_wc*>(qp->getWc(i));
|
||||
if (wc->status != IBV_WC_SUCCESS) {
|
||||
WARN("wc status %d", wc->status);
|
||||
return 1;
|
||||
@@ -103,10 +94,7 @@ int main(int argc, const char* argv[])
|
||||
}
|
||||
|
||||
// A simple barrier
|
||||
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
|
||||
|
||||
MSCCLPPCHECK(mscclppIbContextDestroy(ctx));
|
||||
MSCCLPPCHECK(mscclppCommDestroy(comm));
|
||||
bootstrap->barrier();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user