diff --git a/src/communicator.cc b/src/communicator.cc index 5a843c78..d12b20e4 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -1,3 +1,4 @@ +#include "mscclpp.hpp" #include "communicator.hpp" #include "host_connection.hpp" #include "comm.h" @@ -16,14 +17,14 @@ Communicator::Impl::~Impl() { MSCCLPP_API_CPP Communicator::~Communicator() = default; -mscclppTransport_t transportTypeToCStyle(TransportType type) { - switch (type) { - case TransportType::IB: +static mscclppTransport_t transportFlagsToCStyle(TransportFlags flags) { + switch (flags) { + case TransportIB: return mscclppTransportIB; - case TransportType::P2P: + case TransportCudaIpc: return mscclppTransportP2P; default: - throw std::runtime_error("Unknown transport type"); + throw std::runtime_error("Unsupported conversion"); } } @@ -45,9 +46,8 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() { mscclppBootstrapBarrier(pimpl->comm); } -MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, - TransportType transportType, const char* ibDev) { - mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev); +MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, TransportFlags transportFlags, const char* ibDev) { + mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportFlagsToCStyle(transportFlags), ibDev); auto connIdx = pimpl->connections.size(); auto conn = std::make_shared(std::make_unique(this, &pimpl->comm->conns[connIdx])); pimpl->connections.push_back(conn); diff --git a/src/ib.cc b/src/ib.cc index bb574e21..4a094761 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -9,48 +10,8 @@ #include "comm.h" #include "debug.h" #include "ib.h" - -static int getIbDevNumaNode(const char* ibDevPath) -{ - if (ibDevPath == NULL) { - WARN("ibDevPath is NULL"); - return -1; - } - const char* postfix = "/device/numa_node"; - FILE* fp = NULL; - char* filePath = NULL; - int node = -1; - int res; - if (mscclppCalloc(&filePath, strlen(ibDevPath) + strlen(postfix) + 1) != mscclppSuccess) { - WARN("mscclppCalloc failed"); - goto exit; - } - memcpy(filePath, ibDevPath, strlen(ibDevPath) * sizeof(char)); - filePath[strlen(ibDevPath)] = '\0'; - if (strncat(filePath, postfix, strlen(postfix)) == NULL) { - WARN("strncat failed"); - goto exit; - } - fp = fopen(filePath, "r"); - if (fp == NULL) { - WARN("fopen failed (errno %d, path %s)", errno, filePath); - goto exit; - } - res = fscanf(fp, "%d", &node); - if (res != 1) { - WARN("fscanf failed (errno %d, path %s)", errno, filePath); - node = -1; - goto exit; - } -exit: - if (filePath != NULL) { - free(filePath); - } - if (fp != NULL) { - fclose(fp); - } - return node; -} +#include "ib.hpp" +#include "checks.hpp" mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName) { @@ -400,3 +361,149 @@ int mscclppIbQp::pollCq() { return ibv_poll_cq(this->cq, MSCCLPP_IB_CQ_POLL_NUM, this->wcs); } + +namespace mscclpp { + +IbQp::IbQp(void* ctx, void* pd, int port) +{ + struct ibv_context* _ctx = static_cast(ctx); + struct ibv_pd* _pd = static_cast(pd); + + this->cq = ibv_create_cq(_ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0); + if (this->cq == nullptr) { + std::stringstream err; + err << "ibv_create_cq failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } + + struct ibv_qp_init_attr qpInitAttr; + std::memset(&qpInitAttr, 0, sizeof(qpInitAttr)); + qpInitAttr.sq_sig_all = 0; + qpInitAttr.send_cq = static_cast(this->cq); + qpInitAttr.recv_cq = static_cast(this->cq); + qpInitAttr.qp_type = IBV_QPT_RC; + qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; + qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; + qpInitAttr.cap.max_send_sge = 1; + qpInitAttr.cap.max_recv_sge = 1; + qpInitAttr.cap.max_inline_data = 0; + + struct ibv_qp* _qp = ibv_create_qp(_pd, &qpInitAttr); + if (_qp == nullptr) { + std::stringstream err; + err << "ibv_create_qp failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } + + struct ibv_port_attr portAttr; + if (ibv_query_port(_ctx, port, &portAttr) != 0) { + std::stringstream err; + err << "ibv_query_port failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } + this->info.lid = portAttr.lid; + this->info.port = port; + this->info.linkLayer = portAttr.link_layer; + this->info.qpn = _qp->qp_num; + this->info.mtu = portAttr.active_mtu; + if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND) { + union ibv_gid gid; + if (ibv_query_gid(_ctx, port, 0, &gid) != 0) { + std::stringstream err; + err << "ibv_query_gid failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } + this->info.spn = gid.global.subnet_prefix; + } + + struct ibv_qp_attr qpAttr; + memset(&qpAttr, 0, sizeof(qpAttr)); + qpAttr.qp_state = IBV_QPS_INIT; + qpAttr.pkey_index = 0; + qpAttr.port_num = port; + qpAttr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; + if (ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + std::stringstream err; + err << "ibv_modify_qp failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } + this->qp = _qp; +} + +IbCtx::IbCtx(const std::string& ibDevName) +{ + int num; + struct ibv_device** devices = ibv_get_device_list(&num); + for (int i = 0; i < num; ++i) { + if (std::string(devices[i]->name) == ibDevName) { + this->ctx = ibv_open_device(devices[i]); + break; + } + } + ibv_free_device_list(devices); + if (this->ctx == nullptr) { + std::stringstream err; + err << "ibv_open_device failed (errno " << errno << ", device name << " << ibDevName << ")"; + throw std::runtime_error(err.str()); + } + this->pd = ibv_alloc_pd(static_cast(this->ctx)); + if (this->pd == nullptr) { + std::stringstream err; + err << "ibv_alloc_pd failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } +} + +IbCtx::~IbCtx() +{ + if (this->pd != nullptr) { + ibv_dealloc_pd(static_cast(this->pd)); + } + if (this->ctx != nullptr) { + ibv_close_device(static_cast(this->ctx)); + } +} + +bool IbCtx::isPortUsable(int port) const +{ + struct ibv_port_attr portAttr; + if (ibv_query_port(static_cast(this->ctx), port, &portAttr) != 0) { + std::stringstream err; + err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")"; + throw std::runtime_error(err.str()); + } + return portAttr.state == IBV_PORT_ACTIVE && (portAttr.link_layer == IBV_LINK_LAYER_ETHERNET || + portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND); +} + +int IbCtx::getAnyActivePort() const +{ + struct ibv_device_attr devAttr; + if (ibv_query_device(static_cast(this->ctx), &devAttr) != 0) { + std::stringstream err; + err << "ibv_query_device failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } + for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) { + if (this->isPortUsable(port)) { + return port; + } + } + return -1; +} + +IbQp* IbCtx::createQp(int port /*=-1*/) +{ + if (port == -1) { + port = this->getAnyActivePort(); + if (port == -1) { + throw std::runtime_error("No active port found"); + } + } else if (!this->isPortUsable(port)) { + throw std::runtime_error("invalid IB port: " + std::to_string(port)); + } + qps.emplace_back(new IbQp(this->ctx, this->pd, port)); + return qps.back().get(); +} + +} // namespace mscclpp diff --git a/src/include/channel.hpp b/src/include/channel.hpp index cb1931b0..10a5f601 100644 --- a/src/include/channel.hpp +++ b/src/include/channel.hpp @@ -2,6 +2,7 @@ #define MSCCLPP_CHANNEL_HPP_ #include "mscclpp.hpp" +#include "epoch.hpp" #include "proxy.hpp" namespace mscclpp { @@ -88,7 +89,7 @@ public: ~HostConnection(); - void write() + void write(); int getId(); @@ -293,3 +294,6 @@ struct SimpleDeviceConnection { BufferHandle src; }; +} // namespace mscclpp + +#endif // MSCCLPP_CHANNEL_HPP_ diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index 8294eeb6..f2816c1a 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -3,6 +3,8 @@ #include "mscclpp.hpp" #include "mscclpp.h" +#include "channel.hpp" +#include "proxy.hpp" namespace mscclpp { @@ -20,4 +22,4 @@ struct Communicator::Impl { } // namespace mscclpp -#endif \ No newline at end of file +#endif // MSCCL_COMMUNICATOR_HPP_ diff --git a/src/include/ib.hpp b/src/include/ib.hpp new file mode 100644 index 00000000..4c58cfdc --- /dev/null +++ b/src/include/ib.hpp @@ -0,0 +1,61 @@ +#ifndef MSCCLPP_IB_HPP_ +#define MSCCLPP_IB_HPP_ + +#include +#include +#include + +namespace mscclpp { + +// QP info to be shared with the remote peer +struct IbQpInfo +{ + uint16_t lid; + uint8_t port; + uint8_t linkLayer; + uint32_t qpn; + uint64_t spn; + uint32_t mtu; +}; + +class IbQp +{ +public: + ~IbQp(); + + IbQpInfo info; + +private: + IbQp(void* ctx, void* pd, int port); + + void* qp; + void* cq; + void* wcs; + void* wrs; + void* sges; + int wrn; + + friend class IbCtx; +}; + + +class IbCtx +{ +public: + IbCtx(const std::string& ibDevName); + ~IbCtx(); + + IbQp* createQp(int port = -1); + +private: + bool IbCtx::isPortUsable(int port) const; + int IbCtx::getAnyActivePort() const; + + void* ctx; + void* pd; + std::list> qps; +}; + +} // namespace mscclpp + +#endif // MSCCLPP_IB_HPP_