diff --git a/Makefile b/Makefile index e544aeee..9aaf34b8 100644 --- a/Makefile +++ b/Makefile @@ -120,7 +120,8 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc) LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc) -LIBSRCS += $(addprefix src/,communicator.cc fifo.cc host_connection.cc proxy_cpp.cc basic_proxy_handler.cc) +LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc) +#LIBSRCS += $(addprefix src/,fifo.cc host_connection.cc proxy_cpp.cc basic_proxy_handler.cc) ifneq ($(NPKIT), 0) LIBSRCS += $(addprefix src/misc/,npkit.cc) endif @@ -148,7 +149,7 @@ UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS)) TESTSDIR := tests -TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu allgather_test_cpp.cu) +TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu) # allgather_test_cpp.cu TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS)) TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS)) diff --git a/src/communicator.cc b/src/communicator.cc index a74923bb..316801de 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -4,17 +4,39 @@ #include "comm.h" #include "basic_proxy_handler.hpp" #include "api.h" +#include "utils.h" +#include "checks.hpp" +#include "debug.h" +#include "connection.hpp" namespace mscclpp { -Communicator::Impl::Impl() : comm(nullptr), proxy(makeBasicProxyHandler(*this)) {} +Communicator::Impl::Impl() : comm(nullptr) {} Communicator::Impl::~Impl() { + for (auto& entry : ibContexts) { + mscclppIbContextDestroy(entry.second); + } + ibContexts.clear(); if (comm) { mscclppCommDestroy(comm); } } +mscclppIbContext* Communicator::Impl::getIbContext(TransportFlags ibTransport) { + // Find IB context or create it + auto it = ibContexts.find(ibTransport); + if (it == ibContexts.end()) { + auto ibDev = getIBDeviceName(ibTransport); + mscclppIbContext* ibCtx; + MSCCLPPTHROW(mscclppIbContextCreate(&ibCtx, ibDev.c_str())); + ibContexts[ibTransport] = ibCtx; + return ibCtx; + } else { + return it->second; + } +} + MSCCLPP_API_CPP Communicator::~Communicator() = default; static mscclppTransport_t transportToCStyle(TransportFlags flags) { @@ -54,24 +76,16 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() { } MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, TransportFlags transport) { - std::string ibDev; - switch (transport) { - case TransportIB0: - case TransportIB1: - case TransportIB2: - case TransportIB3: - case TransportIB4: - case TransportIB5: - case TransportIB6: - case TransportIB7: - ibDev = getIBDeviceName(transport); - break; + std::shared_ptr conn; + if (transport | TransportCudaIpc) { + auto cudaIpcConn = std::make_shared(); + conn = cudaIpcConn; + } else if (transport | TransportAllIB) { + auto ibConn = std::make_shared(transport, *pimpl); + conn = ibConn; + } else { + throw std::runtime_error("Unsupported transport"); } - mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportToCStyle(transport), ibDev.c_str()); - auto connIdx = pimpl->connections.size(); - auto conn = std::make_shared(std::make_unique(this, &pimpl->comm->conns[connIdx])); - pimpl->connections.push_back(conn); - return conn; } MSCCLPP_API_CPP void Communicator::connectionSetup() { diff --git a/src/connection.cc b/src/connection.cc index 48b2d197..3e053cb3 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -1,26 +1,18 @@ #include "connection.hpp" #include "checks.hpp" #include "registered_memory.hpp" -#include "npkit.h" +#include "npkit/npkit.h" namespace mscclpp { void validateTransport(RegisteredMemory mem, TransportFlags transport) { - if (mem.transports() & transport == TransportNone) { + if ((mem.transports() & transport) == TransportNone) { throw std::runtime_error("mem does not support transport"); } } // CudaIpcConnection -TransportFlags CudaIpcConnection::transport() { - return TransportCudaIpc; -} - -TransportFlags CudaIpcConnection::remoteTransport() { - return TransportCudaIpc; -} - CudaIpcConnection::CudaIpcConnection() { cudaStreamCreate(&stream); } @@ -29,12 +21,20 @@ CudaIpcConnection::~CudaIpcConnection() { cudaStreamDestroy(stream); } +TransportFlags CudaIpcConnection::transport() { + return TransportCudaIpc; +} + +TransportFlags CudaIpcConnection::remoteTransport() { + return TransportCudaIpc; +} + void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { validateTransport(dst, remoteTransport()); validateTransport(src, transport()); - auto dstPtr = dst.impl->data; - auto srcPtr = src.impl->data; + auto dstPtr = dst.data(); + auto srcPtr = src.data(); CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream)); // npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size); @@ -47,7 +47,13 @@ void CudaIpcConnection::flush() { // IBConnection -IBConnection::IBConnection(TransportFlags transport) : transport_(transport), remoteTransport_(TransportNone) {} +IBConnection::IBConnection(TransportFlags transport, Communicator::Impl& commImpl) : transport_(transport), remoteTransport_(TransportNone) { + MSCCLPPTHROW(mscclppIbContextCreateQp(commImpl.getIbContext(transport), &qp)); +} + +IBConnection::~IBConnection() { + // TODO: Destroy QP? +} TransportFlags IBConnection::transport() { return transport_; @@ -57,20 +63,21 @@ TransportFlags IBConnection::remoteTransport() { return remoteTransport_; } -IBConnection::IBConnection(TransportFlags transport, Communicator::Impl& commImpl) : transport_(transport), remoteTransport_(TransportNone) { - MSCCLPPTHROW(mscclppIbContextCreateQp(commImpl.getIbContext(transport), &qp)); -} - -IBConnection::~IBConnection() { - // TODO: Destroy QP? -} - void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { validateTransport(dst, remoteTransport()); validateTransport(src, transport()); - auto dstMrInfo = dst.impl->getTransportInfo(remoteTransport()); - auto srcMr = src.impl->getTransportInfo(transport()); + auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport()); + if (dstTransportInfo.ibLocal) { + throw std::runtime_error("dst is local, which is not supported"); + } + auto srcTransportInfo = getRegisteredMemoryImpl(src)->getTransportInfo(remoteTransport()); + if (!srcTransportInfo.ibLocal) { + throw std::runtime_error("src is remote, which is not supported"); + } + + auto dstMrInfo = dstTransportInfo.ibMrInfo; + auto srcMr = srcTransportInfo.ibMr; qp->stageSend(srcMr, &dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false); diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index 827b0281..8eb0e202 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -5,19 +5,21 @@ #include "mscclpp.h" #include "channel.hpp" #include "proxy.hpp" +#include "ib.h" +#include namespace mscclpp { struct Communicator::Impl { mscclppComm_t comm; std::vector> connections; - Proxy proxy; + std::unordered_map ibContexts; Impl(); ~Impl(); - friend class Connection; + mscclppIbContext* getIbContext(TransportFlags ibTransport); }; } // namespace mscclpp diff --git a/src/include/connection.hpp b/src/include/connection.hpp index 72f0eb90..94d727e7 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -4,6 +4,7 @@ #include "mscclpp.hpp" #include #include "ib.h" +#include "communicator.hpp" namespace mscclpp { @@ -15,15 +16,15 @@ public: CudaIpcConnection(); - virtual ~CudaIpcConnection(); + ~CudaIpcConnection(); - virtual TransportFlags transport(); + TransportFlags transport() override; - virtual TransportFlags remoteTransport(); + TransportFlags remoteTransport() override; - virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size); + void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override; - virtual void flush(); + void flush() override; }; class IBConnection : public Connection { @@ -34,15 +35,15 @@ public: IBConnection(TransportFlags transport, Communicator::Impl& commImpl); - virtual ~IBConnection(); + ~IBConnection(); - virtual TransportFlags transport(); + TransportFlags transport() override; - virtual TransportFlags remoteTransport(); + TransportFlags remoteTransport() override; - virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size); + void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override; - virtual void flush(); + void flush() override; }; } // namespace mscclpp diff --git a/src/include/ib.hpp b/src/include/ib.hpp index 4c58cfdc..85c92af7 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -48,8 +48,8 @@ public: IbQp* createQp(int port = -1); private: - bool IbCtx::isPortUsable(int port) const; - int IbCtx::getAnyActivePort() const; + bool isPortUsable(int port) const; + int getAnyActivePort() const; void* ctx; void* pd; diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index 52b0511b..9c699efb 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -67,8 +67,7 @@ public: }; class Connection { - virtual ~Connection() = 0; - +public: virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) = 0; virtual void flush() = 0; @@ -76,13 +75,13 @@ class Connection { virtual TransportFlags transport() = 0; virtual TransportFlags remoteTransport() = 0; + +protected: + static std::shared_ptr getRegisteredMemoryImpl(RegisteredMemory&); }; class Communicator { - struct Impl; - std::unique_ptr pimpl; public: - /* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function. * * Inputs: @@ -159,6 +158,10 @@ public: * size: the number of ranks of the communicator */ int size(); + + struct Impl; +private: + std::unique_ptr pimpl; }; } // namespace mscclpp diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index 24eed981..7a0ab1d0 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -4,14 +4,21 @@ #include "mscclpp.hpp" #include "mscclpp.h" #include "ib.h" -#include +#include "communicator.hpp" #include namespace mscclpp { struct TransportInfo { TransportFlags transport; - std::variant data; + + // TODO: rewrite this using std::variant or something + bool ibLocal; + union { + cudaIpcMemHandle_t cudaIpcHandle; + mscclppIbMr* ibMr; + mscclppIbMrInfo ibMrInfo; + }; }; struct RegisteredMemory::Impl { @@ -21,13 +28,13 @@ struct RegisteredMemory::Impl { TransportFlags transports; std::vector transportInfos; - Impl(void* data, size_t size, int rank, TransportFlags transports); + Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl); Impl(const std::vector& data); - template T& getTransportInfo(TransportFlags transport) { + TransportInfo& getTransportInfo(TransportFlags transport) { for (auto& entry : transportInfos) { if (entry.transport == transport) { - return std::get(entry.data); + return entry; } } throw std::runtime_error("Transport data not found"); diff --git a/src/registered_memory.cc b/src/registered_memory.cc index eabb9e7d..7a5a0725 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -1,14 +1,16 @@ #include "registered_memory.hpp" +#include "checks.hpp" +#include namespace mscclpp { -RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator& comm) : data(data), size(size), rank(rank), transports(transports) { +RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) : data(data), size(size), rank(rank), transports(transports) { if (transports & TransportCudaIpc) { TransportInfo transportInfo; transportInfo.transport = TransportCudaIpc; cudaIpcMemHandle_t handle; CUDATHROW(cudaIpcGetMemHandle(&handle, data)); - transportInfo.data = handle; + transportInfo.cudaIpcHandle = handle; this->transportInfos.push_back(transportInfo); } if (transports & TransportAllIB) { @@ -16,8 +18,9 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t TransportInfo transportInfo; transportInfo.transport = ibTransport; mscclppIbMr* mr; - MSCCLPPTHROW(mscclppIbContextRegisterMr(comm.pimpl->getIbContext(ibTransport), data, size, &mr)); - transportInfo.data = mr; + MSCCLPPTHROW(mscclppIbContextRegisterMr(commImpl.getIbContext(ibTransport), data, size, &mr)); + transportInfo.ibMr = mr; + transportInfo.ibLocal = true; this->transportInfos.push_back(transportInfo); }; if (transports & TransportIB0) addIb(TransportIB0); @@ -31,62 +34,55 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t } } -RegisteredMemory::RegisteredMemory(std::shared_ptr pimpl) : impl(pimpl) {} +RegisteredMemory::RegisteredMemory(std::shared_ptr pimpl) : pimpl(pimpl) {} RegisteredMemory::~RegisteredMemory() = default; void* RegisteredMemory::data() { - return impl->data; + return pimpl->data; } size_t RegisteredMemory::size() { - return impl->size; + return pimpl->size; } int RegisteredMemory::rank() { - return impl->rank; + return pimpl->rank; } TransportFlags RegisteredMemory::transports() { - return impl->transports; + return pimpl->transports; } std::vector RegisteredMemory::serialize() { std::vector result; - std::copy_n(reinterpret_cast(&impl->size), sizeof(impl->size), std::back_inserter(result)); - std::copy_n(reinterpret_cast(&impl->rank), sizeof(impl->rank), std::back_inserter(result)); - std::copy_n(reinterpret_cast(&impl->transports), sizeof(impl->transports), std::back_inserter(result)); - if (impl->transportInfos.size() > std::numeric_limits::max()) { + std::copy_n(reinterpret_cast(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result)); + if (pimpl->transportInfos.size() > std::numeric_limits::max()) { throw std::runtime_error("Too many transport info entries"); } - int8_t transportCount = impl->transportInfos.size(); + int8_t transportCount = pimpl->transportInfos.size(); std::copy_n(reinterpret_cast(&transportCount), sizeof(transportCount), std::back_inserter(result)); - for (auto& entry : impl->transportInfos) { + for (auto& entry : pimpl->transportInfos) { std::copy_n(reinterpret_cast(&entry.transport), sizeof(entry.transport), std::back_inserter(result)); - std::visit(overloaded{ - [&](std::monostate&){ - throw std::runtime_error("Transport info not set"); - }, - [&](cudaIpcMemHandle_t handle){ - std::copy_n(reinterpret_cast(&handle), sizeof(handle), std::back_inserter(result)); - }, - [&](mscclppIbMr* mr){ - std::copy_n(reinterpret_cast(&mr->info), sizeof(mr->info), std::back_inserter(result)); - }, - [&](mscclppIbMrInfo info){ - std::copy_n(reinterpret_cast(&info), sizeof(info), std::back_inserter(result)); - } - }, entry.data); + if (entry.transport == TransportCudaIpc) { + std::copy_n(reinterpret_cast(&entry.cudaIpcHandle), sizeof(entry.cudaIpcHandle), std::back_inserter(result)); + } else if (entry.transport & TransportAllIB) { + std::copy_n(reinterpret_cast(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result)); + } else { + throw std::runtime_error("Unknown transport"); + } } return result; } -static RegisteredMemory RegisteredMemory::deserialize(const std::vector& data) { +RegisteredMemory RegisteredMemory::deserialize(const std::vector& data) { return RegisteredMemory(std::make_shared(data)); } -RegisteredMemory::Impl::Impl(const std::vector& data) { - auto it = data.begin(); +RegisteredMemory::Impl::Impl(const std::vector& serialization) { + auto it = serialization.begin(); std::copy_n(it, sizeof(this->size), reinterpret_cast(&this->size)); it += sizeof(this->size); std::copy_n(it, sizeof(this->rank), reinterpret_cast(&this->rank)); @@ -104,24 +100,25 @@ RegisteredMemory::Impl::Impl(const std::vector& data) { cudaIpcMemHandle_t handle; std::copy_n(it, sizeof(handle), reinterpret_cast(&handle)); it += sizeof(handle); - transportInfo.data = handle; + transportInfo.cudaIpcHandle = handle; } else if (transportInfo.transport & TransportAllIB) { mscclppIbMrInfo info; std::copy_n(it, sizeof(info), reinterpret_cast(&info)); it += sizeof(info); - transportInfo.data = info; + transportInfo.ibMrInfo = info; + transportInfo.ibLocal = false; } else { throw std::runtime_error("Unknown transport"); } this->transportInfos.push_back(transportInfo); } - if (it != data.end()) { + if (it != serialization.end()) { throw std::runtime_error("Deserialization failed"); } if (transports & TransportCudaIpc) { - auto cudaIpcHandle = getTransportInfo(TransportCudaIpc); - CUDATHROW(cudaIpcOpenMemHandle(&data, cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); + auto entry = getTransportInfo(TransportCudaIpc); + CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); } }