IB in cpp style WIP

This commit is contained in:
Changho Hwang
2023-04-23 14:47:07 +00:00
parent 0bc3c3e574
commit 35ade686ff
5 changed files with 226 additions and 52 deletions

View File

@@ -1,3 +1,4 @@
#include "mscclpp.hpp"
#include "communicator.hpp"
#include "host_connection.hpp"
#include "comm.h"
@@ -16,14 +17,14 @@ Communicator::Impl::~Impl() {
MSCCLPP_API_CPP Communicator::~Communicator() = default;
mscclppTransport_t transportTypeToCStyle(TransportType type) {
switch (type) {
case TransportType::IB:
static mscclppTransport_t transportFlagsToCStyle(TransportFlags flags) {
switch (flags) {
case TransportIB:
return mscclppTransportIB;
case TransportType::P2P:
case TransportCudaIpc:
return mscclppTransportP2P;
default:
throw std::runtime_error("Unknown transport type");
throw std::runtime_error("Unsupported conversion");
}
}
@@ -45,9 +46,8 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() {
mscclppBootstrapBarrier(pimpl->comm);
}
MSCCLPP_API_CPP std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag,
TransportType transportType, const char* ibDev) {
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev);
MSCCLPP_API_CPP std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag, TransportFlags transportFlags, const char* ibDev) {
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportFlagsToCStyle(transportFlags), ibDev);
auto connIdx = pimpl->connections.size();
auto conn = std::make_shared<HostConnection>(std::make_unique<HostConnection::Impl>(this, &pimpl->comm->conns[connIdx]));
pimpl->connections.push_back(conn);

191
src/ib.cc
View File

@@ -1,6 +1,7 @@
#include <cassert>
#include <cstdlib>
#include <cstring>
#include <sstream>
#include <malloc.h>
#include <unistd.h>
#include <vector>
@@ -9,48 +10,8 @@
#include "comm.h"
#include "debug.h"
#include "ib.h"
static int getIbDevNumaNode(const char* ibDevPath)
{
if (ibDevPath == NULL) {
WARN("ibDevPath is NULL");
return -1;
}
const char* postfix = "/device/numa_node";
FILE* fp = NULL;
char* filePath = NULL;
int node = -1;
int res;
if (mscclppCalloc(&filePath, strlen(ibDevPath) + strlen(postfix) + 1) != mscclppSuccess) {
WARN("mscclppCalloc failed");
goto exit;
}
memcpy(filePath, ibDevPath, strlen(ibDevPath) * sizeof(char));
filePath[strlen(ibDevPath)] = '\0';
if (strncat(filePath, postfix, strlen(postfix)) == NULL) {
WARN("strncat failed");
goto exit;
}
fp = fopen(filePath, "r");
if (fp == NULL) {
WARN("fopen failed (errno %d, path %s)", errno, filePath);
goto exit;
}
res = fscanf(fp, "%d", &node);
if (res != 1) {
WARN("fscanf failed (errno %d, path %s)", errno, filePath);
node = -1;
goto exit;
}
exit:
if (filePath != NULL) {
free(filePath);
}
if (fp != NULL) {
fclose(fp);
}
return node;
}
#include "ib.hpp"
#include "checks.hpp"
mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName)
{
@@ -400,3 +361,149 @@ int mscclppIbQp::pollCq()
{
return ibv_poll_cq(this->cq, MSCCLPP_IB_CQ_POLL_NUM, this->wcs);
}
namespace mscclpp {
IbQp::IbQp(void* ctx, void* pd, int port)
{
struct ibv_context* _ctx = static_cast<struct ibv_context*>(ctx);
struct ibv_pd* _pd = static_cast<struct ibv_pd*>(pd);
this->cq = ibv_create_cq(_ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0);
if (this->cq == nullptr) {
std::stringstream err;
err << "ibv_create_cq failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
struct ibv_qp_init_attr qpInitAttr;
std::memset(&qpInitAttr, 0, sizeof(qpInitAttr));
qpInitAttr.sq_sig_all = 0;
qpInitAttr.send_cq = static_cast<struct ibv_cq*>(this->cq);
qpInitAttr.recv_cq = static_cast<struct ibv_cq*>(this->cq);
qpInitAttr.qp_type = IBV_QPT_RC;
qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
qpInitAttr.cap.max_send_sge = 1;
qpInitAttr.cap.max_recv_sge = 1;
qpInitAttr.cap.max_inline_data = 0;
struct ibv_qp* _qp = ibv_create_qp(_pd, &qpInitAttr);
if (_qp == nullptr) {
std::stringstream err;
err << "ibv_create_qp failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
struct ibv_port_attr portAttr;
if (ibv_query_port(_ctx, port, &portAttr) != 0) {
std::stringstream err;
err << "ibv_query_port failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
this->info.lid = portAttr.lid;
this->info.port = port;
this->info.linkLayer = portAttr.link_layer;
this->info.qpn = _qp->qp_num;
this->info.mtu = portAttr.active_mtu;
if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND) {
union ibv_gid gid;
if (ibv_query_gid(_ctx, port, 0, &gid) != 0) {
std::stringstream err;
err << "ibv_query_gid failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
this->info.spn = gid.global.subnet_prefix;
}
struct ibv_qp_attr qpAttr;
memset(&qpAttr, 0, sizeof(qpAttr));
qpAttr.qp_state = IBV_QPS_INIT;
qpAttr.pkey_index = 0;
qpAttr.port_num = port;
qpAttr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
if (ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
std::stringstream err;
err << "ibv_modify_qp failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
this->qp = _qp;
}
IbCtx::IbCtx(const std::string& ibDevName)
{
int num;
struct ibv_device** devices = ibv_get_device_list(&num);
for (int i = 0; i < num; ++i) {
if (std::string(devices[i]->name) == ibDevName) {
this->ctx = ibv_open_device(devices[i]);
break;
}
}
ibv_free_device_list(devices);
if (this->ctx == nullptr) {
std::stringstream err;
err << "ibv_open_device failed (errno " << errno << ", device name << " << ibDevName << ")";
throw std::runtime_error(err.str());
}
this->pd = ibv_alloc_pd(static_cast<struct ibv_context*>(this->ctx));
if (this->pd == nullptr) {
std::stringstream err;
err << "ibv_alloc_pd failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
}
IbCtx::~IbCtx()
{
if (this->pd != nullptr) {
ibv_dealloc_pd(static_cast<struct ibv_pd*>(this->pd));
}
if (this->ctx != nullptr) {
ibv_close_device(static_cast<struct ibv_context*>(this->ctx));
}
}
bool IbCtx::isPortUsable(int port) const
{
struct ibv_port_attr portAttr;
if (ibv_query_port(static_cast<struct ibv_context*>(this->ctx), port, &portAttr) != 0) {
std::stringstream err;
err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")";
throw std::runtime_error(err.str());
}
return portAttr.state == IBV_PORT_ACTIVE && (portAttr.link_layer == IBV_LINK_LAYER_ETHERNET ||
portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND);
}
int IbCtx::getAnyActivePort() const
{
struct ibv_device_attr devAttr;
if (ibv_query_device(static_cast<struct ibv_context*>(this->ctx), &devAttr) != 0) {
std::stringstream err;
err << "ibv_query_device failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) {
if (this->isPortUsable(port)) {
return port;
}
}
return -1;
}
IbQp* IbCtx::createQp(int port /*=-1*/)
{
if (port == -1) {
port = this->getAnyActivePort();
if (port == -1) {
throw std::runtime_error("No active port found");
}
} else if (!this->isPortUsable(port)) {
throw std::runtime_error("invalid IB port: " + std::to_string(port));
}
qps.emplace_back(new IbQp(this->ctx, this->pd, port));
return qps.back().get();
}
} // namespace mscclpp

View File

@@ -2,6 +2,7 @@
#define MSCCLPP_CHANNEL_HPP_
#include "mscclpp.hpp"
#include "epoch.hpp"
#include "proxy.hpp"
namespace mscclpp {
@@ -88,7 +89,7 @@ public:
~HostConnection();
void write()
void write();
int getId();
@@ -293,3 +294,6 @@ struct SimpleDeviceConnection {
BufferHandle src;
};
} // namespace mscclpp
#endif // MSCCLPP_CHANNEL_HPP_

View File

@@ -3,6 +3,8 @@
#include "mscclpp.hpp"
#include "mscclpp.h"
#include "channel.hpp"
#include "proxy.hpp"
namespace mscclpp {
@@ -20,4 +22,4 @@ struct Communicator::Impl {
} // namespace mscclpp
#endif
#endif // MSCCL_COMMUNICATOR_HPP_

61
src/include/ib.hpp Normal file
View File

@@ -0,0 +1,61 @@
#ifndef MSCCLPP_IB_HPP_
#define MSCCLPP_IB_HPP_
#include <memory>
#include <string>
#include <list>
namespace mscclpp {
// QP info to be shared with the remote peer
struct IbQpInfo
{
uint16_t lid;
uint8_t port;
uint8_t linkLayer;
uint32_t qpn;
uint64_t spn;
uint32_t mtu;
};
class IbQp
{
public:
~IbQp();
IbQpInfo info;
private:
IbQp(void* ctx, void* pd, int port);
void* qp;
void* cq;
void* wcs;
void* wrs;
void* sges;
int wrn;
friend class IbCtx;
};
class IbCtx
{
public:
IbCtx(const std::string& ibDevName);
~IbCtx();
IbQp* createQp(int port = -1);
private:
bool IbCtx::isPortUsable(int port) const;
int IbCtx::getAnyActivePort() const;
void* ctx;
void* pd;
std::list<std::unique_ptr<IbQp>> qps;
};
} // namespace mscclpp
#endif // MSCCLPP_IB_HPP_