From e4ee2eba25de399e4242b5ee9fd9f607b1b40e88 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Tue, 25 Apr 2023 00:41:45 +0000 Subject: [PATCH] WIP Connection in C++ --- src/communicator.cc | 38 ++++++++++++++-------- src/connection.cc | 54 +++++++++++++++++++++++++++++++ src/include/communicator.hpp | 12 +++---- src/include/connection.hpp | 48 +++++++++++++++++++++++++++ src/include/mscclpp.hpp | 38 +++++++++++----------- src/include/registered_memory.hpp | 46 ++++++++++++++++++++++++++ src/registered_memory.cc | 7 ++++ 7 files changed, 205 insertions(+), 38 deletions(-) create mode 100644 src/connection.cc create mode 100644 src/include/connection.hpp create mode 100644 src/include/registered_memory.hpp create mode 100644 src/registered_memory.cc diff --git a/src/communicator.cc b/src/communicator.cc index d12b20e4..a74923bb 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -17,9 +17,16 @@ Communicator::Impl::~Impl() { MSCCLPP_API_CPP Communicator::~Communicator() = default; -static mscclppTransport_t transportFlagsToCStyle(TransportFlags flags) { +static mscclppTransport_t transportToCStyle(TransportFlags flags) { switch (flags) { - case TransportIB: + case TransportIB0: + case TransportIB1: + case TransportIB2: + case TransportIB3: + case TransportIB4: + case TransportIB5: + case TransportIB6: + case TransportIB7: return mscclppTransportIB; case TransportCudaIpc: return mscclppTransportP2P; @@ -46,10 +53,23 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() { mscclppBootstrapBarrier(pimpl->comm); } -MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, TransportFlags transportFlags, const char* ibDev) { - mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportFlagsToCStyle(transportFlags), ibDev); +MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, TransportFlags transport) { + std::string ibDev; + switch (transport) { + case TransportIB0: + case TransportIB1: + case TransportIB2: + case TransportIB3: + case TransportIB4: + case TransportIB5: + case TransportIB6: + case TransportIB7: + ibDev = getIBDeviceName(transport); + break; + } + mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportToCStyle(transport), ibDev.c_str()); auto connIdx = pimpl->connections.size(); - auto conn = std::make_shared(std::make_unique(this, &pimpl->comm->conns[connIdx])); + auto conn = std::make_shared(std::make_unique(this, &pimpl->comm->conns[connIdx])); pimpl->connections.push_back(conn); return conn; } @@ -58,14 +78,6 @@ MSCCLPP_API_CPP void Communicator::connectionSetup() { mscclppConnectionSetup(pimpl->comm); } -MSCCLPP_API_CPP void Communicator::startProxying() { - pimpl->proxy.start(); -} - -MSCCLPP_API_CPP void Communicator::stopProxying() { - pimpl->proxy.stop(); -} - MSCCLPP_API_CPP int Communicator::rank() { int result; mscclppCommRank(pimpl->comm, &result); diff --git a/src/connection.cc b/src/connection.cc new file mode 100644 index 00000000..12ebee02 --- /dev/null +++ b/src/connection.cc @@ -0,0 +1,54 @@ +#include "connection.hpp" +#include "checks.hpp" +#include "registered_memory.hpp" + +namespace mscclpp { + +void validateTransport(RegisteredMemory mem, TransportFlags transport) { + if (mem.transports() & transport == TransportNone) { + throw std::runtime_error("mem does not support transport"); + } +} + +TransportFlags CudaIpcConnection::transport() { + return TransportCudaIpc; +} + +TransportFlags CudaIpcConnection::remoteTransport() { + return TransportCudaIpc; +} + +CudaIpcConnection::CudaIpcConnection() { + cudaStreamCreate(&stream); +} + +CudaIpcConnection::~CudaIpcConnection() { + cudaStreamDestroy(stream); +} + +void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { + validateTransport(dst, remoteTransport()); + validateTransport(src, transport()); + + auto dstPtr = dst.impl->getTransportData(remoteTransport()); + auto srcPtr = src.impl->getTransportData(transport()); + CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream)); + npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)dataSize); +} + +void CudaIpcConnection::flush() { + CUDATHROW(cudaStreamSynchronize(stream)); + npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT); +} + +IBConnection::IBConnection(TransportFlags transport) : transport_(transport), remoteTransport_(TransportNone) {} + +TransportFlags IBConnection::transport() { + return transport_; +} + +TransportFlags IBConnection::remoteTransport() { + return remoteTransport_; +} + +} // namespace mscclpp diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index f2816c1a..827b0281 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -9,15 +9,15 @@ namespace mscclpp { struct Communicator::Impl { - mscclppComm_t comm; - std::vector> connections; - Proxy proxy; + mscclppComm_t comm; + std::vector> connections; + Proxy proxy; - Impl(); + Impl(); - ~Impl(); + ~Impl(); - friend class HostConnection; + friend class Connection; }; } // namespace mscclpp diff --git a/src/include/connection.hpp b/src/include/connection.hpp new file mode 100644 index 00000000..048e2c6a --- /dev/null +++ b/src/include/connection.hpp @@ -0,0 +1,48 @@ +#ifndef MSCCLPP_CONNECTION_HPP_ +#define MSCCLPP_CONNECTION_HPP_ + +#include "mscclpp.hpp" +#include +#include "ib.h" + +namespace mscclpp { + +class CudaIpcConnection : public Connection { + cudaStream_t stream; +public: + + CudaIpcConnection(); + + virtual ~CudaIpcConnection(); + + virtual TransportFlags transport(); + + virtual TransportFlags remoteTransport(); + + virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size); + + virtual void flush(); +}; + +class IBConnection : public Connection { + TransportFlags transport_; + TransportFlags remoteTransport_; + mscclppIbQp qp; +public: + + IBConnection(TransportFlags transport); + + virtual ~IBConnection(); + + virtual TransportFlags transport(); + + virtual TransportFlags remoteTransport(); + + virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size); + + virtual void flush(); +}; + +} // namespace mscclpp + +#endif // MSCCLPP_CONNECTION_HPP_ diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index 67d40050..f4d73ab4 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -26,8 +26,9 @@ struct UniqueId { std::unique_ptr getUniqueId(); using TransportFlags = uint32_t; +const TransportFlags TransportNone = 0b0; const TransportFlags TransportCudaIpc = 0b1; -const TransportFlags TransportIB = 0b10; +const TransportFlags TransportIB0 = 0b10; const TransportFlags TransportIB1 = 0b100; const TransportFlags TransportIB2 = 0b1000; const TransportFlags TransportIB3 = 0b10000; @@ -37,7 +38,12 @@ const TransportFlags TransportIB6 = 0b10000000; const TransportFlags TransportIB7 = 0b100000000; const TransportFlags TransportAll = 0b111111111; +int getIBDeviceCount(); +std::string getIBDeviceName(TransportFlags ibTransport); +TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName); + class Communicator; +class Connection; class RegisteredMemory { struct Impl; @@ -55,31 +61,20 @@ public: static RegisteredMemory deserialize(const std::vector& data); int rank(); - bool isLocal(); - bool isRemote(); + + friend class Connection; }; class Connection { - struct Impl; - std::unique_ptr pimpl; -public: + virtual ~Connection() = 0; - /* Connection can not be constructed from user code and must instead be created through Communicator::connect */ - Connection(std::unique_ptr); - ~Connection(); + virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) = 0; - void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size); + virtual void flush() = 0; - void flush(); + virtual TransportFlags transport() = 0; - TransportFlags transport(); - TransportFlags remoteTransport(); // Good to have because different IB transports can still connect to each other - - // template void write(RegisteredPtr dst, RegisteredPtr src, uint64_t size) { - // write(dst.memory(), dst.offset() * sizeof(T), src.memory(), src.offset() * sizeof(T), size); - // } - - friend class Communicator; + virtual TransportFlags remoteTransport() = 0; }; class Communicator { @@ -145,6 +140,11 @@ public: */ std::shared_ptr connect(int remoteRank, int tag, TransportFlags transport); + /* Establish all connections declared by connect(). This function must be called after all connect() + * calls are made. This function ensures that all remote ranks are ready to communicate when it returns. + */ + void connectionSetup(); + /* Return the rank of the calling process. * * Outputs: diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp new file mode 100644 index 00000000..82fe942e --- /dev/null +++ b/src/include/registered_memory.hpp @@ -0,0 +1,46 @@ +#ifndef MSCCLPP_REGISTERED_MEMORY_HPP_ +#define MSCCLPP_REGISTERED_MEMORY_HPP_ + +#include "mscclpp.hpp" +#include "ib.h" +#include +#include + +namespace mscclpp { + +struct IBTransportData { + mscclppIbMr localIbMr; + mscclppIbMrInfo remoteIbMrInfo; +}; + +struct TransportData { + TransportFlags transport; + union { + void* cudaIpcPtr; + IBTransportData ibData; + } +}; + +struct RegisteredMemory::Impl { + void* data; + size_t size; + TransportFlags transports; + std::vector transportData; + + Impl(void* data, size_t size, TransportFlags transports); + + ~Impl(); + + template T& getTransportData(TransportFlags transport) { + for (auto& data : transportData) { + if (data.transport == transport) { + return data; + } + } + throw std::runtime_error("Transport data not found"); + } +}; + +} // namespace mscclpp + +#endif // MSCCLPP_REGISTERED_MEMORY_HPP_ diff --git a/src/registered_memory.cc b/src/registered_memory.cc new file mode 100644 index 00000000..d491e72f --- /dev/null +++ b/src/registered_memory.cc @@ -0,0 +1,7 @@ +#include "registered_memory.hpp" + +namespace mscclpp { + + + +} // namespace mscclpp