WIP Connection in C++

This commit is contained in:
Olli Saarikivi
2023-04-25 00:41:45 +00:00
parent 35ade686ff
commit e4ee2eba25
7 changed files with 205 additions and 38 deletions

View File

@@ -17,9 +17,16 @@ Communicator::Impl::~Impl() {
MSCCLPP_API_CPP Communicator::~Communicator() = default;
static mscclppTransport_t transportFlagsToCStyle(TransportFlags flags) {
static mscclppTransport_t transportToCStyle(TransportFlags flags) {
switch (flags) {
case TransportIB:
case TransportIB0:
case TransportIB1:
case TransportIB2:
case TransportIB3:
case TransportIB4:
case TransportIB5:
case TransportIB6:
case TransportIB7:
return mscclppTransportIB;
case TransportCudaIpc:
return mscclppTransportP2P;
@@ -46,10 +53,23 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() {
mscclppBootstrapBarrier(pimpl->comm);
}
MSCCLPP_API_CPP std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag, TransportFlags transportFlags, const char* ibDev) {
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportFlagsToCStyle(transportFlags), ibDev);
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connect(int remoteRank, int tag, TransportFlags transport) {
std::string ibDev;
switch (transport) {
case TransportIB0:
case TransportIB1:
case TransportIB2:
case TransportIB3:
case TransportIB4:
case TransportIB5:
case TransportIB6:
case TransportIB7:
ibDev = getIBDeviceName(transport);
break;
}
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportToCStyle(transport), ibDev.c_str());
auto connIdx = pimpl->connections.size();
auto conn = std::make_shared<HostConnection>(std::make_unique<HostConnection::Impl>(this, &pimpl->comm->conns[connIdx]));
auto conn = std::make_shared<Connection>(std::make_unique<Connection::Impl>(this, &pimpl->comm->conns[connIdx]));
pimpl->connections.push_back(conn);
return conn;
}
@@ -58,14 +78,6 @@ MSCCLPP_API_CPP void Communicator::connectionSetup() {
mscclppConnectionSetup(pimpl->comm);
}
MSCCLPP_API_CPP void Communicator::startProxying() {
pimpl->proxy.start();
}
MSCCLPP_API_CPP void Communicator::stopProxying() {
pimpl->proxy.stop();
}
MSCCLPP_API_CPP int Communicator::rank() {
int result;
mscclppCommRank(pimpl->comm, &result);

54
src/connection.cc Normal file
View File

@@ -0,0 +1,54 @@
#include "connection.hpp"
#include "checks.hpp"
#include "registered_memory.hpp"
namespace mscclpp {
void validateTransport(RegisteredMemory mem, TransportFlags transport) {
if (mem.transports() & transport == TransportNone) {
throw std::runtime_error("mem does not support transport");
}
}
TransportFlags CudaIpcConnection::transport() {
return TransportCudaIpc;
}
TransportFlags CudaIpcConnection::remoteTransport() {
return TransportCudaIpc;
}
CudaIpcConnection::CudaIpcConnection() {
cudaStreamCreate(&stream);
}
CudaIpcConnection::~CudaIpcConnection() {
cudaStreamDestroy(stream);
}
void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) {
validateTransport(dst, remoteTransport());
validateTransport(src, transport());
auto dstPtr = dst.impl->getTransportData<void*>(remoteTransport());
auto srcPtr = src.impl->getTransportData<void*>(transport());
CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream));
npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)dataSize);
}
void CudaIpcConnection::flush() {
CUDATHROW(cudaStreamSynchronize(stream));
npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT);
}
IBConnection::IBConnection(TransportFlags transport) : transport_(transport), remoteTransport_(TransportNone) {}
TransportFlags IBConnection::transport() {
return transport_;
}
TransportFlags IBConnection::remoteTransport() {
return remoteTransport_;
}
} // namespace mscclpp

View File

@@ -9,15 +9,15 @@
namespace mscclpp {
struct Communicator::Impl {
mscclppComm_t comm;
std::vector<std::shared_ptr<HostConnection>> connections;
Proxy proxy;
mscclppComm_t comm;
std::vector<std::shared_ptr<Connection>> connections;
Proxy proxy;
Impl();
Impl();
~Impl();
~Impl();
friend class HostConnection;
friend class Connection;
};
} // namespace mscclpp

View File

@@ -0,0 +1,48 @@
#ifndef MSCCLPP_CONNECTION_HPP_
#define MSCCLPP_CONNECTION_HPP_
#include "mscclpp.hpp"
#include <cuda_runtime.h>
#include "ib.h"
namespace mscclpp {
class CudaIpcConnection : public Connection {
cudaStream_t stream;
public:
CudaIpcConnection();
virtual ~CudaIpcConnection();
virtual TransportFlags transport();
virtual TransportFlags remoteTransport();
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size);
virtual void flush();
};
class IBConnection : public Connection {
TransportFlags transport_;
TransportFlags remoteTransport_;
mscclppIbQp qp;
public:
IBConnection(TransportFlags transport);
virtual ~IBConnection();
virtual TransportFlags transport();
virtual TransportFlags remoteTransport();
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size);
virtual void flush();
};
} // namespace mscclpp
#endif // MSCCLPP_CONNECTION_HPP_

View File

@@ -26,8 +26,9 @@ struct UniqueId {
std::unique_ptr<UniqueId> getUniqueId();
using TransportFlags = uint32_t;
const TransportFlags TransportNone = 0b0;
const TransportFlags TransportCudaIpc = 0b1;
const TransportFlags TransportIB = 0b10;
const TransportFlags TransportIB0 = 0b10;
const TransportFlags TransportIB1 = 0b100;
const TransportFlags TransportIB2 = 0b1000;
const TransportFlags TransportIB3 = 0b10000;
@@ -37,7 +38,12 @@ const TransportFlags TransportIB6 = 0b10000000;
const TransportFlags TransportIB7 = 0b100000000;
const TransportFlags TransportAll = 0b111111111;
int getIBDeviceCount();
std::string getIBDeviceName(TransportFlags ibTransport);
TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName);
class Communicator;
class Connection;
class RegisteredMemory {
struct Impl;
@@ -55,31 +61,20 @@ public:
static RegisteredMemory deserialize(const std::vector<char>& data);
int rank();
bool isLocal();
bool isRemote();
friend class Connection;
};
class Connection {
struct Impl;
std::unique_ptr<Impl> pimpl;
public:
virtual ~Connection() = 0;
/* Connection can not be constructed from user code and must instead be created through Communicator::connect */
Connection(std::unique_ptr<Impl>);
~Connection();
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) = 0;
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size);
virtual void flush() = 0;
void flush();
virtual TransportFlags transport() = 0;
TransportFlags transport();
TransportFlags remoteTransport(); // Good to have because different IB transports can still connect to each other
// template<typename T> void write(RegisteredPtr<T> dst, RegisteredPtr<T> src, uint64_t size) {
// write(dst.memory(), dst.offset() * sizeof(T), src.memory(), src.offset() * sizeof(T), size);
// }
friend class Communicator;
virtual TransportFlags remoteTransport() = 0;
};
class Communicator {
@@ -145,6 +140,11 @@ public:
*/
std::shared_ptr<Connection> connect(int remoteRank, int tag, TransportFlags transport);
/* Establish all connections declared by connect(). This function must be called after all connect()
* calls are made. This function ensures that all remote ranks are ready to communicate when it returns.
*/
void connectionSetup();
/* Return the rank of the calling process.
*
* Outputs:

View File

@@ -0,0 +1,46 @@
#ifndef MSCCLPP_REGISTERED_MEMORY_HPP_
#define MSCCLPP_REGISTERED_MEMORY_HPP_
#include "mscclpp.hpp"
#include "ib.h"
#include <variant>
#include <cuda_runtime.h>
namespace mscclpp {
struct IBTransportData {
mscclppIbMr localIbMr;
mscclppIbMrInfo remoteIbMrInfo;
};
struct TransportData {
TransportFlags transport;
union {
void* cudaIpcPtr;
IBTransportData ibData;
}
};
struct RegisteredMemory::Impl {
void* data;
size_t size;
TransportFlags transports;
std::vector<TransportData> transportData;
Impl(void* data, size_t size, TransportFlags transports);
~Impl();
template<typename T> T& getTransportData(TransportFlags transport) {
for (auto& data : transportData) {
if (data.transport == transport) {
return data;
}
}
throw std::runtime_error("Transport data not found");
}
};
} // namespace mscclpp
#endif // MSCCLPP_REGISTERED_MEMORY_HPP_

7
src/registered_memory.cc Normal file
View File

@@ -0,0 +1,7 @@
#include "registered_memory.hpp"
namespace mscclpp {
} // namespace mscclpp