mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-21 13:29:45 +00:00
WIP Connection in C++
This commit is contained in:
@@ -17,9 +17,16 @@ Communicator::Impl::~Impl() {
|
||||
|
||||
MSCCLPP_API_CPP Communicator::~Communicator() = default;
|
||||
|
||||
static mscclppTransport_t transportFlagsToCStyle(TransportFlags flags) {
|
||||
static mscclppTransport_t transportToCStyle(TransportFlags flags) {
|
||||
switch (flags) {
|
||||
case TransportIB:
|
||||
case TransportIB0:
|
||||
case TransportIB1:
|
||||
case TransportIB2:
|
||||
case TransportIB3:
|
||||
case TransportIB4:
|
||||
case TransportIB5:
|
||||
case TransportIB6:
|
||||
case TransportIB7:
|
||||
return mscclppTransportIB;
|
||||
case TransportCudaIpc:
|
||||
return mscclppTransportP2P;
|
||||
@@ -46,10 +53,23 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() {
|
||||
mscclppBootstrapBarrier(pimpl->comm);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag, TransportFlags transportFlags, const char* ibDev) {
|
||||
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportFlagsToCStyle(transportFlags), ibDev);
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connect(int remoteRank, int tag, TransportFlags transport) {
|
||||
std::string ibDev;
|
||||
switch (transport) {
|
||||
case TransportIB0:
|
||||
case TransportIB1:
|
||||
case TransportIB2:
|
||||
case TransportIB3:
|
||||
case TransportIB4:
|
||||
case TransportIB5:
|
||||
case TransportIB6:
|
||||
case TransportIB7:
|
||||
ibDev = getIBDeviceName(transport);
|
||||
break;
|
||||
}
|
||||
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportToCStyle(transport), ibDev.c_str());
|
||||
auto connIdx = pimpl->connections.size();
|
||||
auto conn = std::make_shared<HostConnection>(std::make_unique<HostConnection::Impl>(this, &pimpl->comm->conns[connIdx]));
|
||||
auto conn = std::make_shared<Connection>(std::make_unique<Connection::Impl>(this, &pimpl->comm->conns[connIdx]));
|
||||
pimpl->connections.push_back(conn);
|
||||
return conn;
|
||||
}
|
||||
@@ -58,14 +78,6 @@ MSCCLPP_API_CPP void Communicator::connectionSetup() {
|
||||
mscclppConnectionSetup(pimpl->comm);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::startProxying() {
|
||||
pimpl->proxy.start();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::stopProxying() {
|
||||
pimpl->proxy.stop();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Communicator::rank() {
|
||||
int result;
|
||||
mscclppCommRank(pimpl->comm, &result);
|
||||
|
||||
54
src/connection.cc
Normal file
54
src/connection.cc
Normal file
@@ -0,0 +1,54 @@
|
||||
#include "connection.hpp"
|
||||
#include "checks.hpp"
|
||||
#include "registered_memory.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
void validateTransport(RegisteredMemory mem, TransportFlags transport) {
|
||||
if (mem.transports() & transport == TransportNone) {
|
||||
throw std::runtime_error("mem does not support transport");
|
||||
}
|
||||
}
|
||||
|
||||
TransportFlags CudaIpcConnection::transport() {
|
||||
return TransportCudaIpc;
|
||||
}
|
||||
|
||||
TransportFlags CudaIpcConnection::remoteTransport() {
|
||||
return TransportCudaIpc;
|
||||
}
|
||||
|
||||
CudaIpcConnection::CudaIpcConnection() {
|
||||
cudaStreamCreate(&stream);
|
||||
}
|
||||
|
||||
CudaIpcConnection::~CudaIpcConnection() {
|
||||
cudaStreamDestroy(stream);
|
||||
}
|
||||
|
||||
void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) {
|
||||
validateTransport(dst, remoteTransport());
|
||||
validateTransport(src, transport());
|
||||
|
||||
auto dstPtr = dst.impl->getTransportData<void*>(remoteTransport());
|
||||
auto srcPtr = src.impl->getTransportData<void*>(transport());
|
||||
CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream));
|
||||
npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)dataSize);
|
||||
}
|
||||
|
||||
void CudaIpcConnection::flush() {
|
||||
CUDATHROW(cudaStreamSynchronize(stream));
|
||||
npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT);
|
||||
}
|
||||
|
||||
IBConnection::IBConnection(TransportFlags transport) : transport_(transport), remoteTransport_(TransportNone) {}
|
||||
|
||||
TransportFlags IBConnection::transport() {
|
||||
return transport_;
|
||||
}
|
||||
|
||||
TransportFlags IBConnection::remoteTransport() {
|
||||
return remoteTransport_;
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
@@ -9,15 +9,15 @@
|
||||
namespace mscclpp {
|
||||
|
||||
struct Communicator::Impl {
|
||||
mscclppComm_t comm;
|
||||
std::vector<std::shared_ptr<HostConnection>> connections;
|
||||
Proxy proxy;
|
||||
mscclppComm_t comm;
|
||||
std::vector<std::shared_ptr<Connection>> connections;
|
||||
Proxy proxy;
|
||||
|
||||
Impl();
|
||||
Impl();
|
||||
|
||||
~Impl();
|
||||
~Impl();
|
||||
|
||||
friend class HostConnection;
|
||||
friend class Connection;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
48
src/include/connection.hpp
Normal file
48
src/include/connection.hpp
Normal file
@@ -0,0 +1,48 @@
|
||||
#ifndef MSCCLPP_CONNECTION_HPP_
|
||||
#define MSCCLPP_CONNECTION_HPP_
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include <cuda_runtime.h>
|
||||
#include "ib.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
class CudaIpcConnection : public Connection {
|
||||
cudaStream_t stream;
|
||||
public:
|
||||
|
||||
CudaIpcConnection();
|
||||
|
||||
virtual ~CudaIpcConnection();
|
||||
|
||||
virtual TransportFlags transport();
|
||||
|
||||
virtual TransportFlags remoteTransport();
|
||||
|
||||
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size);
|
||||
|
||||
virtual void flush();
|
||||
};
|
||||
|
||||
class IBConnection : public Connection {
|
||||
TransportFlags transport_;
|
||||
TransportFlags remoteTransport_;
|
||||
mscclppIbQp qp;
|
||||
public:
|
||||
|
||||
IBConnection(TransportFlags transport);
|
||||
|
||||
virtual ~IBConnection();
|
||||
|
||||
virtual TransportFlags transport();
|
||||
|
||||
virtual TransportFlags remoteTransport();
|
||||
|
||||
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size);
|
||||
|
||||
virtual void flush();
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_CONNECTION_HPP_
|
||||
@@ -26,8 +26,9 @@ struct UniqueId {
|
||||
std::unique_ptr<UniqueId> getUniqueId();
|
||||
|
||||
using TransportFlags = uint32_t;
|
||||
const TransportFlags TransportNone = 0b0;
|
||||
const TransportFlags TransportCudaIpc = 0b1;
|
||||
const TransportFlags TransportIB = 0b10;
|
||||
const TransportFlags TransportIB0 = 0b10;
|
||||
const TransportFlags TransportIB1 = 0b100;
|
||||
const TransportFlags TransportIB2 = 0b1000;
|
||||
const TransportFlags TransportIB3 = 0b10000;
|
||||
@@ -37,7 +38,12 @@ const TransportFlags TransportIB6 = 0b10000000;
|
||||
const TransportFlags TransportIB7 = 0b100000000;
|
||||
const TransportFlags TransportAll = 0b111111111;
|
||||
|
||||
int getIBDeviceCount();
|
||||
std::string getIBDeviceName(TransportFlags ibTransport);
|
||||
TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName);
|
||||
|
||||
class Communicator;
|
||||
class Connection;
|
||||
|
||||
class RegisteredMemory {
|
||||
struct Impl;
|
||||
@@ -55,31 +61,20 @@ public:
|
||||
static RegisteredMemory deserialize(const std::vector<char>& data);
|
||||
|
||||
int rank();
|
||||
bool isLocal();
|
||||
bool isRemote();
|
||||
|
||||
friend class Connection;
|
||||
};
|
||||
|
||||
class Connection {
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
public:
|
||||
virtual ~Connection() = 0;
|
||||
|
||||
/* Connection can not be constructed from user code and must instead be created through Communicator::connect */
|
||||
Connection(std::unique_ptr<Impl>);
|
||||
~Connection();
|
||||
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) = 0;
|
||||
|
||||
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size);
|
||||
virtual void flush() = 0;
|
||||
|
||||
void flush();
|
||||
virtual TransportFlags transport() = 0;
|
||||
|
||||
TransportFlags transport();
|
||||
TransportFlags remoteTransport(); // Good to have because different IB transports can still connect to each other
|
||||
|
||||
// template<typename T> void write(RegisteredPtr<T> dst, RegisteredPtr<T> src, uint64_t size) {
|
||||
// write(dst.memory(), dst.offset() * sizeof(T), src.memory(), src.offset() * sizeof(T), size);
|
||||
// }
|
||||
|
||||
friend class Communicator;
|
||||
virtual TransportFlags remoteTransport() = 0;
|
||||
};
|
||||
|
||||
class Communicator {
|
||||
@@ -145,6 +140,11 @@ public:
|
||||
*/
|
||||
std::shared_ptr<Connection> connect(int remoteRank, int tag, TransportFlags transport);
|
||||
|
||||
/* Establish all connections declared by connect(). This function must be called after all connect()
|
||||
* calls are made. This function ensures that all remote ranks are ready to communicate when it returns.
|
||||
*/
|
||||
void connectionSetup();
|
||||
|
||||
/* Return the rank of the calling process.
|
||||
*
|
||||
* Outputs:
|
||||
|
||||
46
src/include/registered_memory.hpp
Normal file
46
src/include/registered_memory.hpp
Normal file
@@ -0,0 +1,46 @@
|
||||
#ifndef MSCCLPP_REGISTERED_MEMORY_HPP_
|
||||
#define MSCCLPP_REGISTERED_MEMORY_HPP_
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include "ib.h"
|
||||
#include <variant>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct IBTransportData {
|
||||
mscclppIbMr localIbMr;
|
||||
mscclppIbMrInfo remoteIbMrInfo;
|
||||
};
|
||||
|
||||
struct TransportData {
|
||||
TransportFlags transport;
|
||||
union {
|
||||
void* cudaIpcPtr;
|
||||
IBTransportData ibData;
|
||||
}
|
||||
};
|
||||
|
||||
struct RegisteredMemory::Impl {
|
||||
void* data;
|
||||
size_t size;
|
||||
TransportFlags transports;
|
||||
std::vector<TransportData> transportData;
|
||||
|
||||
Impl(void* data, size_t size, TransportFlags transports);
|
||||
|
||||
~Impl();
|
||||
|
||||
template<typename T> T& getTransportData(TransportFlags transport) {
|
||||
for (auto& data : transportData) {
|
||||
if (data.transport == transport) {
|
||||
return data;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Transport data not found");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_REGISTERED_MEMORY_HPP_
|
||||
7
src/registered_memory.cc
Normal file
7
src/registered_memory.cc
Normal file
@@ -0,0 +1,7 @@
|
||||
#include "registered_memory.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
|
||||
|
||||
} // namespace mscclpp
|
||||
Reference in New Issue
Block a user