WIP builds, but doesn't link

This commit is contained in:
Olli Saarikivi
2023-04-26 17:46:47 +00:00
parent 90a8860bcc
commit d746201287
9 changed files with 136 additions and 104 deletions

View File

@@ -120,7 +120,8 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma
LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc)
LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc)
LIBSRCS += $(addprefix src/,communicator.cc fifo.cc host_connection.cc proxy_cpp.cc basic_proxy_handler.cc)
LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc)
#LIBSRCS += $(addprefix src/,fifo.cc host_connection.cc proxy_cpp.cc basic_proxy_handler.cc)
ifneq ($(NPKIT), 0)
LIBSRCS += $(addprefix src/misc/,npkit.cc)
endif
@@ -148,7 +149,7 @@ UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS))
TESTSDIR := tests
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu allgather_test_cpp.cu)
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu) # allgather_test_cpp.cu
TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS))
TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS))

View File

@@ -4,17 +4,39 @@
#include "comm.h"
#include "basic_proxy_handler.hpp"
#include "api.h"
#include "utils.h"
#include "checks.hpp"
#include "debug.h"
#include "connection.hpp"
namespace mscclpp {
Communicator::Impl::Impl() : comm(nullptr), proxy(makeBasicProxyHandler(*this)) {}
Communicator::Impl::Impl() : comm(nullptr) {}
Communicator::Impl::~Impl() {
for (auto& entry : ibContexts) {
mscclppIbContextDestroy(entry.second);
}
ibContexts.clear();
if (comm) {
mscclppCommDestroy(comm);
}
}
mscclppIbContext* Communicator::Impl::getIbContext(TransportFlags ibTransport) {
// Find IB context or create it
auto it = ibContexts.find(ibTransport);
if (it == ibContexts.end()) {
auto ibDev = getIBDeviceName(ibTransport);
mscclppIbContext* ibCtx;
MSCCLPPTHROW(mscclppIbContextCreate(&ibCtx, ibDev.c_str()));
ibContexts[ibTransport] = ibCtx;
return ibCtx;
} else {
return it->second;
}
}
MSCCLPP_API_CPP Communicator::~Communicator() = default;
static mscclppTransport_t transportToCStyle(TransportFlags flags) {
@@ -54,24 +76,16 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() {
}
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connect(int remoteRank, int tag, TransportFlags transport) {
std::string ibDev;
switch (transport) {
case TransportIB0:
case TransportIB1:
case TransportIB2:
case TransportIB3:
case TransportIB4:
case TransportIB5:
case TransportIB6:
case TransportIB7:
ibDev = getIBDeviceName(transport);
break;
std::shared_ptr<Connection> conn;
if (transport | TransportCudaIpc) {
auto cudaIpcConn = std::make_shared<CudaIpcConnection>();
conn = cudaIpcConn;
} else if (transport | TransportAllIB) {
auto ibConn = std::make_shared<IBConnection>(transport, *pimpl);
conn = ibConn;
} else {
throw std::runtime_error("Unsupported transport");
}
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportToCStyle(transport), ibDev.c_str());
auto connIdx = pimpl->connections.size();
auto conn = std::make_shared<Connection>(std::make_unique<Connection::Impl>(this, &pimpl->comm->conns[connIdx]));
pimpl->connections.push_back(conn);
return conn;
}
MSCCLPP_API_CPP void Communicator::connectionSetup() {

View File

@@ -1,26 +1,18 @@
#include "connection.hpp"
#include "checks.hpp"
#include "registered_memory.hpp"
#include "npkit.h"
#include "npkit/npkit.h"
namespace mscclpp {
void validateTransport(RegisteredMemory mem, TransportFlags transport) {
if (mem.transports() & transport == TransportNone) {
if ((mem.transports() & transport) == TransportNone) {
throw std::runtime_error("mem does not support transport");
}
}
// CudaIpcConnection
TransportFlags CudaIpcConnection::transport() {
return TransportCudaIpc;
}
TransportFlags CudaIpcConnection::remoteTransport() {
return TransportCudaIpc;
}
CudaIpcConnection::CudaIpcConnection() {
cudaStreamCreate(&stream);
}
@@ -29,12 +21,20 @@ CudaIpcConnection::~CudaIpcConnection() {
cudaStreamDestroy(stream);
}
TransportFlags CudaIpcConnection::transport() {
return TransportCudaIpc;
}
TransportFlags CudaIpcConnection::remoteTransport() {
return TransportCudaIpc;
}
void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) {
validateTransport(dst, remoteTransport());
validateTransport(src, transport());
auto dstPtr = dst.impl->data;
auto srcPtr = src.impl->data;
auto dstPtr = dst.data();
auto srcPtr = src.data();
CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream));
// npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size);
@@ -47,7 +47,13 @@ void CudaIpcConnection::flush() {
// IBConnection
IBConnection::IBConnection(TransportFlags transport) : transport_(transport), remoteTransport_(TransportNone) {}
IBConnection::IBConnection(TransportFlags transport, Communicator::Impl& commImpl) : transport_(transport), remoteTransport_(TransportNone) {
MSCCLPPTHROW(mscclppIbContextCreateQp(commImpl.getIbContext(transport), &qp));
}
IBConnection::~IBConnection() {
// TODO: Destroy QP?
}
TransportFlags IBConnection::transport() {
return transport_;
@@ -57,20 +63,21 @@ TransportFlags IBConnection::remoteTransport() {
return remoteTransport_;
}
IBConnection::IBConnection(TransportFlags transport, Communicator::Impl& commImpl) : transport_(transport), remoteTransport_(TransportNone) {
MSCCLPPTHROW(mscclppIbContextCreateQp(commImpl.getIbContext(transport), &qp));
}
IBConnection::~IBConnection() {
// TODO: Destroy QP?
}
void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) {
validateTransport(dst, remoteTransport());
validateTransport(src, transport());
auto dstMrInfo = dst.impl->getTransportInfo<mscclppIbMrInfo>(remoteTransport());
auto srcMr = src.impl->getTransportInfo<mscclppIbMr*>(transport());
auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport());
if (dstTransportInfo.ibLocal) {
throw std::runtime_error("dst is local, which is not supported");
}
auto srcTransportInfo = getRegisteredMemoryImpl(src)->getTransportInfo(remoteTransport());
if (!srcTransportInfo.ibLocal) {
throw std::runtime_error("src is remote, which is not supported");
}
auto dstMrInfo = dstTransportInfo.ibMrInfo;
auto srcMr = srcTransportInfo.ibMr;
qp->stageSend(srcMr, &dstMrInfo, (uint32_t)size,
/*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false);

View File

@@ -5,19 +5,21 @@
#include "mscclpp.h"
#include "channel.hpp"
#include "proxy.hpp"
#include "ib.h"
#include <unordered_map>
namespace mscclpp {
struct Communicator::Impl {
mscclppComm_t comm;
std::vector<std::shared_ptr<Connection>> connections;
Proxy proxy;
std::unordered_map<TransportFlags, mscclppIbContext*> ibContexts;
Impl();
~Impl();
friend class Connection;
mscclppIbContext* getIbContext(TransportFlags ibTransport);
};
} // namespace mscclpp

View File

@@ -4,6 +4,7 @@
#include "mscclpp.hpp"
#include <cuda_runtime.h>
#include "ib.h"
#include "communicator.hpp"
namespace mscclpp {
@@ -15,15 +16,15 @@ public:
CudaIpcConnection();
virtual ~CudaIpcConnection();
~CudaIpcConnection();
virtual TransportFlags transport();
TransportFlags transport() override;
virtual TransportFlags remoteTransport();
TransportFlags remoteTransport() override;
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size);
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override;
virtual void flush();
void flush() override;
};
class IBConnection : public Connection {
@@ -34,15 +35,15 @@ public:
IBConnection(TransportFlags transport, Communicator::Impl& commImpl);
virtual ~IBConnection();
~IBConnection();
virtual TransportFlags transport();
TransportFlags transport() override;
virtual TransportFlags remoteTransport();
TransportFlags remoteTransport() override;
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size);
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override;
virtual void flush();
void flush() override;
};
} // namespace mscclpp

View File

@@ -48,8 +48,8 @@ public:
IbQp* createQp(int port = -1);
private:
bool IbCtx::isPortUsable(int port) const;
int IbCtx::getAnyActivePort() const;
bool isPortUsable(int port) const;
int getAnyActivePort() const;
void* ctx;
void* pd;

View File

@@ -67,8 +67,7 @@ public:
};
class Connection {
virtual ~Connection() = 0;
public:
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) = 0;
virtual void flush() = 0;
@@ -76,13 +75,13 @@ class Connection {
virtual TransportFlags transport() = 0;
virtual TransportFlags remoteTransport() = 0;
protected:
static std::shared_ptr<RegisteredMemory::Impl> getRegisteredMemoryImpl(RegisteredMemory&);
};
class Communicator {
struct Impl;
std::unique_ptr<Impl> pimpl;
public:
/* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function.
*
* Inputs:
@@ -159,6 +158,10 @@ public:
* size: the number of ranks of the communicator
*/
int size();
struct Impl;
private:
std::unique_ptr<Impl> pimpl;
};
} // namespace mscclpp

View File

@@ -4,14 +4,21 @@
#include "mscclpp.hpp"
#include "mscclpp.h"
#include "ib.h"
#include <variant>
#include "communicator.hpp"
#include <cuda_runtime.h>
namespace mscclpp {
struct TransportInfo {
TransportFlags transport;
std::variant<std::monostate, cudaIpcMemHandle_t, mscclppIbMr*, mscclppIbMrInfo> data;
// TODO: rewrite this using std::variant or something
bool ibLocal;
union {
cudaIpcMemHandle_t cudaIpcHandle;
mscclppIbMr* ibMr;
mscclppIbMrInfo ibMrInfo;
};
};
struct RegisteredMemory::Impl {
@@ -21,13 +28,13 @@ struct RegisteredMemory::Impl {
TransportFlags transports;
std::vector<TransportInfo> transportInfos;
Impl(void* data, size_t size, int rank, TransportFlags transports);
Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl);
Impl(const std::vector<char>& data);
template<class T> T& getTransportInfo(TransportFlags transport) {
TransportInfo& getTransportInfo(TransportFlags transport) {
for (auto& entry : transportInfos) {
if (entry.transport == transport) {
return std::get<T>(entry.data);
return entry;
}
}
throw std::runtime_error("Transport data not found");

View File

@@ -1,14 +1,16 @@
#include "registered_memory.hpp"
#include "checks.hpp"
#include <algorithm>
namespace mscclpp {
RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator& comm) : data(data), size(size), rank(rank), transports(transports) {
RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) : data(data), size(size), rank(rank), transports(transports) {
if (transports & TransportCudaIpc) {
TransportInfo transportInfo;
transportInfo.transport = TransportCudaIpc;
cudaIpcMemHandle_t handle;
CUDATHROW(cudaIpcGetMemHandle(&handle, data));
transportInfo.data = handle;
transportInfo.cudaIpcHandle = handle;
this->transportInfos.push_back(transportInfo);
}
if (transports & TransportAllIB) {
@@ -16,8 +18,9 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t
TransportInfo transportInfo;
transportInfo.transport = ibTransport;
mscclppIbMr* mr;
MSCCLPPTHROW(mscclppIbContextRegisterMr(comm.pimpl->getIbContext(ibTransport), data, size, &mr));
transportInfo.data = mr;
MSCCLPPTHROW(mscclppIbContextRegisterMr(commImpl.getIbContext(ibTransport), data, size, &mr));
transportInfo.ibMr = mr;
transportInfo.ibLocal = true;
this->transportInfos.push_back(transportInfo);
};
if (transports & TransportIB0) addIb(TransportIB0);
@@ -31,62 +34,55 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t
}
}
RegisteredMemory::RegisteredMemory(std::shared_ptr<Impl> pimpl) : impl(pimpl) {}
RegisteredMemory::RegisteredMemory(std::shared_ptr<Impl> pimpl) : pimpl(pimpl) {}
RegisteredMemory::~RegisteredMemory() = default;
void* RegisteredMemory::data() {
return impl->data;
return pimpl->data;
}
size_t RegisteredMemory::size() {
return impl->size;
return pimpl->size;
}
int RegisteredMemory::rank() {
return impl->rank;
return pimpl->rank;
}
TransportFlags RegisteredMemory::transports() {
return impl->transports;
return pimpl->transports;
}
std::vector<char> RegisteredMemory::serialize() {
std::vector<char> result;
std::copy_n(reinterpret_cast<char*>(&impl->size), sizeof(impl->size), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&impl->rank), sizeof(impl->rank), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&impl->transports), sizeof(impl->transports), std::back_inserter(result));
if (impl->transportInfos.size() > std::numeric_limits<int8_t>::max()) {
std::copy_n(reinterpret_cast<char*>(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result));
if (pimpl->transportInfos.size() > std::numeric_limits<int8_t>::max()) {
throw std::runtime_error("Too many transport info entries");
}
int8_t transportCount = impl->transportInfos.size();
int8_t transportCount = pimpl->transportInfos.size();
std::copy_n(reinterpret_cast<char*>(&transportCount), sizeof(transportCount), std::back_inserter(result));
for (auto& entry : impl->transportInfos) {
for (auto& entry : pimpl->transportInfos) {
std::copy_n(reinterpret_cast<char*>(&entry.transport), sizeof(entry.transport), std::back_inserter(result));
std::visit(overloaded{
[&](std::monostate&){
throw std::runtime_error("Transport info not set");
},
[&](cudaIpcMemHandle_t handle){
std::copy_n(reinterpret_cast<char*>(&handle), sizeof(handle), std::back_inserter(result));
},
[&](mscclppIbMr* mr){
std::copy_n(reinterpret_cast<char*>(&mr->info), sizeof(mr->info), std::back_inserter(result));
},
[&](mscclppIbMrInfo info){
std::copy_n(reinterpret_cast<char*>(&info), sizeof(info), std::back_inserter(result));
}
}, entry.data);
if (entry.transport == TransportCudaIpc) {
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcHandle), sizeof(entry.cudaIpcHandle), std::back_inserter(result));
} else if (entry.transport & TransportAllIB) {
std::copy_n(reinterpret_cast<char*>(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result));
} else {
throw std::runtime_error("Unknown transport");
}
}
return result;
}
static RegisteredMemory RegisteredMemory::deserialize(const std::vector<char>& data) {
RegisteredMemory RegisteredMemory::deserialize(const std::vector<char>& data) {
return RegisteredMemory(std::make_shared<Impl>(data));
}
RegisteredMemory::Impl::Impl(const std::vector<char>& data) {
auto it = data.begin();
RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
auto it = serialization.begin();
std::copy_n(it, sizeof(this->size), reinterpret_cast<char*>(&this->size));
it += sizeof(this->size);
std::copy_n(it, sizeof(this->rank), reinterpret_cast<char*>(&this->rank));
@@ -104,24 +100,25 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& data) {
cudaIpcMemHandle_t handle;
std::copy_n(it, sizeof(handle), reinterpret_cast<char*>(&handle));
it += sizeof(handle);
transportInfo.data = handle;
transportInfo.cudaIpcHandle = handle;
} else if (transportInfo.transport & TransportAllIB) {
mscclppIbMrInfo info;
std::copy_n(it, sizeof(info), reinterpret_cast<char*>(&info));
it += sizeof(info);
transportInfo.data = info;
transportInfo.ibMrInfo = info;
transportInfo.ibLocal = false;
} else {
throw std::runtime_error("Unknown transport");
}
this->transportInfos.push_back(transportInfo);
}
if (it != data.end()) {
if (it != serialization.end()) {
throw std::runtime_error("Deserialization failed");
}
if (transports & TransportCudaIpc) {
auto cudaIpcHandle = getTransportInfo<cudaIpcMemHandle_t>(TransportCudaIpc);
CUDATHROW(cudaIpcOpenMemHandle(&data, cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess));
auto entry = getTransportInfo(TransportCudaIpc);
CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess));
}
}