mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
IB in cpp style WIP
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
#include "mscclpp.hpp"
|
||||
#include "communicator.hpp"
|
||||
#include "host_connection.hpp"
|
||||
#include "comm.h"
|
||||
@@ -16,14 +17,14 @@ Communicator::Impl::~Impl() {
|
||||
|
||||
MSCCLPP_API_CPP Communicator::~Communicator() = default;
|
||||
|
||||
mscclppTransport_t transportTypeToCStyle(TransportType type) {
|
||||
switch (type) {
|
||||
case TransportType::IB:
|
||||
static mscclppTransport_t transportFlagsToCStyle(TransportFlags flags) {
|
||||
switch (flags) {
|
||||
case TransportIB:
|
||||
return mscclppTransportIB;
|
||||
case TransportType::P2P:
|
||||
case TransportCudaIpc:
|
||||
return mscclppTransportP2P;
|
||||
default:
|
||||
throw std::runtime_error("Unknown transport type");
|
||||
throw std::runtime_error("Unsupported conversion");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,9 +46,8 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() {
|
||||
mscclppBootstrapBarrier(pimpl->comm);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag,
|
||||
TransportType transportType, const char* ibDev) {
|
||||
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev);
|
||||
MSCCLPP_API_CPP std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag, TransportFlags transportFlags, const char* ibDev) {
|
||||
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportFlagsToCStyle(transportFlags), ibDev);
|
||||
auto connIdx = pimpl->connections.size();
|
||||
auto conn = std::make_shared<HostConnection>(std::make_unique<HostConnection::Impl>(this, &pimpl->comm->conns[connIdx]));
|
||||
pimpl->connections.push_back(conn);
|
||||
|
||||
191
src/ib.cc
191
src/ib.cc
@@ -1,6 +1,7 @@
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <malloc.h>
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
@@ -9,48 +10,8 @@
|
||||
#include "comm.h"
|
||||
#include "debug.h"
|
||||
#include "ib.h"
|
||||
|
||||
static int getIbDevNumaNode(const char* ibDevPath)
|
||||
{
|
||||
if (ibDevPath == NULL) {
|
||||
WARN("ibDevPath is NULL");
|
||||
return -1;
|
||||
}
|
||||
const char* postfix = "/device/numa_node";
|
||||
FILE* fp = NULL;
|
||||
char* filePath = NULL;
|
||||
int node = -1;
|
||||
int res;
|
||||
if (mscclppCalloc(&filePath, strlen(ibDevPath) + strlen(postfix) + 1) != mscclppSuccess) {
|
||||
WARN("mscclppCalloc failed");
|
||||
goto exit;
|
||||
}
|
||||
memcpy(filePath, ibDevPath, strlen(ibDevPath) * sizeof(char));
|
||||
filePath[strlen(ibDevPath)] = '\0';
|
||||
if (strncat(filePath, postfix, strlen(postfix)) == NULL) {
|
||||
WARN("strncat failed");
|
||||
goto exit;
|
||||
}
|
||||
fp = fopen(filePath, "r");
|
||||
if (fp == NULL) {
|
||||
WARN("fopen failed (errno %d, path %s)", errno, filePath);
|
||||
goto exit;
|
||||
}
|
||||
res = fscanf(fp, "%d", &node);
|
||||
if (res != 1) {
|
||||
WARN("fscanf failed (errno %d, path %s)", errno, filePath);
|
||||
node = -1;
|
||||
goto exit;
|
||||
}
|
||||
exit:
|
||||
if (filePath != NULL) {
|
||||
free(filePath);
|
||||
}
|
||||
if (fp != NULL) {
|
||||
fclose(fp);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
#include "ib.hpp"
|
||||
#include "checks.hpp"
|
||||
|
||||
mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName)
|
||||
{
|
||||
@@ -400,3 +361,149 @@ int mscclppIbQp::pollCq()
|
||||
{
|
||||
return ibv_poll_cq(this->cq, MSCCLPP_IB_CQ_POLL_NUM, this->wcs);
|
||||
}
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
IbQp::IbQp(void* ctx, void* pd, int port)
|
||||
{
|
||||
struct ibv_context* _ctx = static_cast<struct ibv_context*>(ctx);
|
||||
struct ibv_pd* _pd = static_cast<struct ibv_pd*>(pd);
|
||||
|
||||
this->cq = ibv_create_cq(_ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0);
|
||||
if (this->cq == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_create_cq failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
struct ibv_qp_init_attr qpInitAttr;
|
||||
std::memset(&qpInitAttr, 0, sizeof(qpInitAttr));
|
||||
qpInitAttr.sq_sig_all = 0;
|
||||
qpInitAttr.send_cq = static_cast<struct ibv_cq*>(this->cq);
|
||||
qpInitAttr.recv_cq = static_cast<struct ibv_cq*>(this->cq);
|
||||
qpInitAttr.qp_type = IBV_QPT_RC;
|
||||
qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
|
||||
qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
|
||||
qpInitAttr.cap.max_send_sge = 1;
|
||||
qpInitAttr.cap.max_recv_sge = 1;
|
||||
qpInitAttr.cap.max_inline_data = 0;
|
||||
|
||||
struct ibv_qp* _qp = ibv_create_qp(_pd, &qpInitAttr);
|
||||
if (_qp == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_create_qp failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
struct ibv_port_attr portAttr;
|
||||
if (ibv_query_port(_ctx, port, &portAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_port failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
this->info.lid = portAttr.lid;
|
||||
this->info.port = port;
|
||||
this->info.linkLayer = portAttr.link_layer;
|
||||
this->info.qpn = _qp->qp_num;
|
||||
this->info.mtu = portAttr.active_mtu;
|
||||
if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND) {
|
||||
union ibv_gid gid;
|
||||
if (ibv_query_gid(_ctx, port, 0, &gid) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_gid failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
this->info.spn = gid.global.subnet_prefix;
|
||||
}
|
||||
|
||||
struct ibv_qp_attr qpAttr;
|
||||
memset(&qpAttr, 0, sizeof(qpAttr));
|
||||
qpAttr.qp_state = IBV_QPS_INIT;
|
||||
qpAttr.pkey_index = 0;
|
||||
qpAttr.port_num = port;
|
||||
qpAttr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
|
||||
if (ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_modify_qp failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
this->qp = _qp;
|
||||
}
|
||||
|
||||
IbCtx::IbCtx(const std::string& ibDevName)
|
||||
{
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
if (std::string(devices[i]->name) == ibDevName) {
|
||||
this->ctx = ibv_open_device(devices[i]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
ibv_free_device_list(devices);
|
||||
if (this->ctx == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_open_device failed (errno " << errno << ", device name << " << ibDevName << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
this->pd = ibv_alloc_pd(static_cast<struct ibv_context*>(this->ctx));
|
||||
if (this->pd == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_alloc_pd failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
IbCtx::~IbCtx()
|
||||
{
|
||||
if (this->pd != nullptr) {
|
||||
ibv_dealloc_pd(static_cast<struct ibv_pd*>(this->pd));
|
||||
}
|
||||
if (this->ctx != nullptr) {
|
||||
ibv_close_device(static_cast<struct ibv_context*>(this->ctx));
|
||||
}
|
||||
}
|
||||
|
||||
bool IbCtx::isPortUsable(int port) const
|
||||
{
|
||||
struct ibv_port_attr portAttr;
|
||||
if (ibv_query_port(static_cast<struct ibv_context*>(this->ctx), port, &portAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
return portAttr.state == IBV_PORT_ACTIVE && (portAttr.link_layer == IBV_LINK_LAYER_ETHERNET ||
|
||||
portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND);
|
||||
}
|
||||
|
||||
int IbCtx::getAnyActivePort() const
|
||||
{
|
||||
struct ibv_device_attr devAttr;
|
||||
if (ibv_query_device(static_cast<struct ibv_context*>(this->ctx), &devAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_device failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) {
|
||||
if (this->isPortUsable(port)) {
|
||||
return port;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
IbQp* IbCtx::createQp(int port /*=-1*/)
|
||||
{
|
||||
if (port == -1) {
|
||||
port = this->getAnyActivePort();
|
||||
if (port == -1) {
|
||||
throw std::runtime_error("No active port found");
|
||||
}
|
||||
} else if (!this->isPortUsable(port)) {
|
||||
throw std::runtime_error("invalid IB port: " + std::to_string(port));
|
||||
}
|
||||
qps.emplace_back(new IbQp(this->ctx, this->pd, port));
|
||||
return qps.back().get();
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define MSCCLPP_CHANNEL_HPP_
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include "epoch.hpp"
|
||||
#include "proxy.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
@@ -88,7 +89,7 @@ public:
|
||||
|
||||
~HostConnection();
|
||||
|
||||
void write()
|
||||
void write();
|
||||
|
||||
int getId();
|
||||
|
||||
@@ -293,3 +294,6 @@ struct SimpleDeviceConnection {
|
||||
BufferHandle src;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_CHANNEL_HPP_
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include "mscclpp.h"
|
||||
#include "channel.hpp"
|
||||
#include "proxy.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -20,4 +22,4 @@ struct Communicator::Impl {
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif
|
||||
#endif // MSCCL_COMMUNICATOR_HPP_
|
||||
|
||||
61
src/include/ib.hpp
Normal file
61
src/include/ib.hpp
Normal file
@@ -0,0 +1,61 @@
|
||||
#ifndef MSCCLPP_IB_HPP_
|
||||
#define MSCCLPP_IB_HPP_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <list>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
// QP info to be shared with the remote peer
|
||||
struct IbQpInfo
|
||||
{
|
||||
uint16_t lid;
|
||||
uint8_t port;
|
||||
uint8_t linkLayer;
|
||||
uint32_t qpn;
|
||||
uint64_t spn;
|
||||
uint32_t mtu;
|
||||
};
|
||||
|
||||
class IbQp
|
||||
{
|
||||
public:
|
||||
~IbQp();
|
||||
|
||||
IbQpInfo info;
|
||||
|
||||
private:
|
||||
IbQp(void* ctx, void* pd, int port);
|
||||
|
||||
void* qp;
|
||||
void* cq;
|
||||
void* wcs;
|
||||
void* wrs;
|
||||
void* sges;
|
||||
int wrn;
|
||||
|
||||
friend class IbCtx;
|
||||
};
|
||||
|
||||
|
||||
class IbCtx
|
||||
{
|
||||
public:
|
||||
IbCtx(const std::string& ibDevName);
|
||||
~IbCtx();
|
||||
|
||||
IbQp* createQp(int port = -1);
|
||||
|
||||
private:
|
||||
bool IbCtx::isPortUsable(int port) const;
|
||||
int IbCtx::getAnyActivePort() const;
|
||||
|
||||
void* ctx;
|
||||
void* pd;
|
||||
std::list<std::unique_ptr<IbQp>> qps;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_IB_HPP_
|
||||
Reference in New Issue
Block a user