mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
WIP builds, but doesn't link
This commit is contained in:
5
Makefile
5
Makefile
@@ -120,7 +120,8 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma
|
||||
|
||||
LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc)
|
||||
LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc)
|
||||
LIBSRCS += $(addprefix src/,communicator.cc fifo.cc host_connection.cc proxy_cpp.cc basic_proxy_handler.cc)
|
||||
LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc)
|
||||
#LIBSRCS += $(addprefix src/,fifo.cc host_connection.cc proxy_cpp.cc basic_proxy_handler.cc)
|
||||
ifneq ($(NPKIT), 0)
|
||||
LIBSRCS += $(addprefix src/misc/,npkit.cc)
|
||||
endif
|
||||
@@ -148,7 +149,7 @@ UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
|
||||
UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS))
|
||||
|
||||
TESTSDIR := tests
|
||||
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu allgather_test_cpp.cu)
|
||||
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu) # allgather_test_cpp.cu
|
||||
TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS))
|
||||
TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
|
||||
TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS))
|
||||
|
||||
@@ -4,17 +4,39 @@
|
||||
#include "comm.h"
|
||||
#include "basic_proxy_handler.hpp"
|
||||
#include "api.h"
|
||||
#include "utils.h"
|
||||
#include "checks.hpp"
|
||||
#include "debug.h"
|
||||
#include "connection.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
Communicator::Impl::Impl() : comm(nullptr), proxy(makeBasicProxyHandler(*this)) {}
|
||||
Communicator::Impl::Impl() : comm(nullptr) {}
|
||||
|
||||
Communicator::Impl::~Impl() {
|
||||
for (auto& entry : ibContexts) {
|
||||
mscclppIbContextDestroy(entry.second);
|
||||
}
|
||||
ibContexts.clear();
|
||||
if (comm) {
|
||||
mscclppCommDestroy(comm);
|
||||
}
|
||||
}
|
||||
|
||||
mscclppIbContext* Communicator::Impl::getIbContext(TransportFlags ibTransport) {
|
||||
// Find IB context or create it
|
||||
auto it = ibContexts.find(ibTransport);
|
||||
if (it == ibContexts.end()) {
|
||||
auto ibDev = getIBDeviceName(ibTransport);
|
||||
mscclppIbContext* ibCtx;
|
||||
MSCCLPPTHROW(mscclppIbContextCreate(&ibCtx, ibDev.c_str()));
|
||||
ibContexts[ibTransport] = ibCtx;
|
||||
return ibCtx;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Communicator::~Communicator() = default;
|
||||
|
||||
static mscclppTransport_t transportToCStyle(TransportFlags flags) {
|
||||
@@ -54,24 +76,16 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() {
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connect(int remoteRank, int tag, TransportFlags transport) {
|
||||
std::string ibDev;
|
||||
switch (transport) {
|
||||
case TransportIB0:
|
||||
case TransportIB1:
|
||||
case TransportIB2:
|
||||
case TransportIB3:
|
||||
case TransportIB4:
|
||||
case TransportIB5:
|
||||
case TransportIB6:
|
||||
case TransportIB7:
|
||||
ibDev = getIBDeviceName(transport);
|
||||
break;
|
||||
std::shared_ptr<Connection> conn;
|
||||
if (transport | TransportCudaIpc) {
|
||||
auto cudaIpcConn = std::make_shared<CudaIpcConnection>();
|
||||
conn = cudaIpcConn;
|
||||
} else if (transport | TransportAllIB) {
|
||||
auto ibConn = std::make_shared<IBConnection>(transport, *pimpl);
|
||||
conn = ibConn;
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported transport");
|
||||
}
|
||||
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportToCStyle(transport), ibDev.c_str());
|
||||
auto connIdx = pimpl->connections.size();
|
||||
auto conn = std::make_shared<Connection>(std::make_unique<Connection::Impl>(this, &pimpl->comm->conns[connIdx]));
|
||||
pimpl->connections.push_back(conn);
|
||||
return conn;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::connectionSetup() {
|
||||
|
||||
@@ -1,26 +1,18 @@
|
||||
#include "connection.hpp"
|
||||
#include "checks.hpp"
|
||||
#include "registered_memory.hpp"
|
||||
#include "npkit.h"
|
||||
#include "npkit/npkit.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
void validateTransport(RegisteredMemory mem, TransportFlags transport) {
|
||||
if (mem.transports() & transport == TransportNone) {
|
||||
if ((mem.transports() & transport) == TransportNone) {
|
||||
throw std::runtime_error("mem does not support transport");
|
||||
}
|
||||
}
|
||||
|
||||
// CudaIpcConnection
|
||||
|
||||
TransportFlags CudaIpcConnection::transport() {
|
||||
return TransportCudaIpc;
|
||||
}
|
||||
|
||||
TransportFlags CudaIpcConnection::remoteTransport() {
|
||||
return TransportCudaIpc;
|
||||
}
|
||||
|
||||
CudaIpcConnection::CudaIpcConnection() {
|
||||
cudaStreamCreate(&stream);
|
||||
}
|
||||
@@ -29,12 +21,20 @@ CudaIpcConnection::~CudaIpcConnection() {
|
||||
cudaStreamDestroy(stream);
|
||||
}
|
||||
|
||||
TransportFlags CudaIpcConnection::transport() {
|
||||
return TransportCudaIpc;
|
||||
}
|
||||
|
||||
TransportFlags CudaIpcConnection::remoteTransport() {
|
||||
return TransportCudaIpc;
|
||||
}
|
||||
|
||||
void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) {
|
||||
validateTransport(dst, remoteTransport());
|
||||
validateTransport(src, transport());
|
||||
|
||||
auto dstPtr = dst.impl->data;
|
||||
auto srcPtr = src.impl->data;
|
||||
auto dstPtr = dst.data();
|
||||
auto srcPtr = src.data();
|
||||
|
||||
CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream));
|
||||
// npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size);
|
||||
@@ -47,7 +47,13 @@ void CudaIpcConnection::flush() {
|
||||
|
||||
// IBConnection
|
||||
|
||||
IBConnection::IBConnection(TransportFlags transport) : transport_(transport), remoteTransport_(TransportNone) {}
|
||||
IBConnection::IBConnection(TransportFlags transport, Communicator::Impl& commImpl) : transport_(transport), remoteTransport_(TransportNone) {
|
||||
MSCCLPPTHROW(mscclppIbContextCreateQp(commImpl.getIbContext(transport), &qp));
|
||||
}
|
||||
|
||||
IBConnection::~IBConnection() {
|
||||
// TODO: Destroy QP?
|
||||
}
|
||||
|
||||
TransportFlags IBConnection::transport() {
|
||||
return transport_;
|
||||
@@ -57,20 +63,21 @@ TransportFlags IBConnection::remoteTransport() {
|
||||
return remoteTransport_;
|
||||
}
|
||||
|
||||
IBConnection::IBConnection(TransportFlags transport, Communicator::Impl& commImpl) : transport_(transport), remoteTransport_(TransportNone) {
|
||||
MSCCLPPTHROW(mscclppIbContextCreateQp(commImpl.getIbContext(transport), &qp));
|
||||
}
|
||||
|
||||
IBConnection::~IBConnection() {
|
||||
// TODO: Destroy QP?
|
||||
}
|
||||
|
||||
void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) {
|
||||
validateTransport(dst, remoteTransport());
|
||||
validateTransport(src, transport());
|
||||
|
||||
auto dstMrInfo = dst.impl->getTransportInfo<mscclppIbMrInfo>(remoteTransport());
|
||||
auto srcMr = src.impl->getTransportInfo<mscclppIbMr*>(transport());
|
||||
auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport());
|
||||
if (dstTransportInfo.ibLocal) {
|
||||
throw std::runtime_error("dst is local, which is not supported");
|
||||
}
|
||||
auto srcTransportInfo = getRegisteredMemoryImpl(src)->getTransportInfo(remoteTransport());
|
||||
if (!srcTransportInfo.ibLocal) {
|
||||
throw std::runtime_error("src is remote, which is not supported");
|
||||
}
|
||||
|
||||
auto dstMrInfo = dstTransportInfo.ibMrInfo;
|
||||
auto srcMr = srcTransportInfo.ibMr;
|
||||
|
||||
qp->stageSend(srcMr, &dstMrInfo, (uint32_t)size,
|
||||
/*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false);
|
||||
|
||||
@@ -5,19 +5,21 @@
|
||||
#include "mscclpp.h"
|
||||
#include "channel.hpp"
|
||||
#include "proxy.hpp"
|
||||
#include "ib.h"
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct Communicator::Impl {
|
||||
mscclppComm_t comm;
|
||||
std::vector<std::shared_ptr<Connection>> connections;
|
||||
Proxy proxy;
|
||||
std::unordered_map<TransportFlags, mscclppIbContext*> ibContexts;
|
||||
|
||||
Impl();
|
||||
|
||||
~Impl();
|
||||
|
||||
friend class Connection;
|
||||
mscclppIbContext* getIbContext(TransportFlags ibTransport);
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "mscclpp.hpp"
|
||||
#include <cuda_runtime.h>
|
||||
#include "ib.h"
|
||||
#include "communicator.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -15,15 +16,15 @@ public:
|
||||
|
||||
CudaIpcConnection();
|
||||
|
||||
virtual ~CudaIpcConnection();
|
||||
~CudaIpcConnection();
|
||||
|
||||
virtual TransportFlags transport();
|
||||
TransportFlags transport() override;
|
||||
|
||||
virtual TransportFlags remoteTransport();
|
||||
TransportFlags remoteTransport() override;
|
||||
|
||||
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size);
|
||||
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override;
|
||||
|
||||
virtual void flush();
|
||||
void flush() override;
|
||||
};
|
||||
|
||||
class IBConnection : public Connection {
|
||||
@@ -34,15 +35,15 @@ public:
|
||||
|
||||
IBConnection(TransportFlags transport, Communicator::Impl& commImpl);
|
||||
|
||||
virtual ~IBConnection();
|
||||
~IBConnection();
|
||||
|
||||
virtual TransportFlags transport();
|
||||
TransportFlags transport() override;
|
||||
|
||||
virtual TransportFlags remoteTransport();
|
||||
TransportFlags remoteTransport() override;
|
||||
|
||||
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size);
|
||||
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override;
|
||||
|
||||
virtual void flush();
|
||||
void flush() override;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -48,8 +48,8 @@ public:
|
||||
IbQp* createQp(int port = -1);
|
||||
|
||||
private:
|
||||
bool IbCtx::isPortUsable(int port) const;
|
||||
int IbCtx::getAnyActivePort() const;
|
||||
bool isPortUsable(int port) const;
|
||||
int getAnyActivePort() const;
|
||||
|
||||
void* ctx;
|
||||
void* pd;
|
||||
|
||||
@@ -67,8 +67,7 @@ public:
|
||||
};
|
||||
|
||||
class Connection {
|
||||
virtual ~Connection() = 0;
|
||||
|
||||
public:
|
||||
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) = 0;
|
||||
|
||||
virtual void flush() = 0;
|
||||
@@ -76,13 +75,13 @@ class Connection {
|
||||
virtual TransportFlags transport() = 0;
|
||||
|
||||
virtual TransportFlags remoteTransport() = 0;
|
||||
|
||||
protected:
|
||||
static std::shared_ptr<RegisteredMemory::Impl> getRegisteredMemoryImpl(RegisteredMemory&);
|
||||
};
|
||||
|
||||
class Communicator {
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
public:
|
||||
|
||||
/* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function.
|
||||
*
|
||||
* Inputs:
|
||||
@@ -159,6 +158,10 @@ public:
|
||||
* size: the number of ranks of the communicator
|
||||
*/
|
||||
int size();
|
||||
|
||||
struct Impl;
|
||||
private:
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -4,14 +4,21 @@
|
||||
#include "mscclpp.hpp"
|
||||
#include "mscclpp.h"
|
||||
#include "ib.h"
|
||||
#include <variant>
|
||||
#include "communicator.hpp"
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct TransportInfo {
|
||||
TransportFlags transport;
|
||||
std::variant<std::monostate, cudaIpcMemHandle_t, mscclppIbMr*, mscclppIbMrInfo> data;
|
||||
|
||||
// TODO: rewrite this using std::variant or something
|
||||
bool ibLocal;
|
||||
union {
|
||||
cudaIpcMemHandle_t cudaIpcHandle;
|
||||
mscclppIbMr* ibMr;
|
||||
mscclppIbMrInfo ibMrInfo;
|
||||
};
|
||||
};
|
||||
|
||||
struct RegisteredMemory::Impl {
|
||||
@@ -21,13 +28,13 @@ struct RegisteredMemory::Impl {
|
||||
TransportFlags transports;
|
||||
std::vector<TransportInfo> transportInfos;
|
||||
|
||||
Impl(void* data, size_t size, int rank, TransportFlags transports);
|
||||
Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl);
|
||||
Impl(const std::vector<char>& data);
|
||||
|
||||
template<class T> T& getTransportInfo(TransportFlags transport) {
|
||||
TransportInfo& getTransportInfo(TransportFlags transport) {
|
||||
for (auto& entry : transportInfos) {
|
||||
if (entry.transport == transport) {
|
||||
return std::get<T>(entry.data);
|
||||
return entry;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Transport data not found");
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
#include "registered_memory.hpp"
|
||||
#include "checks.hpp"
|
||||
#include <algorithm>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator& comm) : data(data), size(size), rank(rank), transports(transports) {
|
||||
RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) : data(data), size(size), rank(rank), transports(transports) {
|
||||
if (transports & TransportCudaIpc) {
|
||||
TransportInfo transportInfo;
|
||||
transportInfo.transport = TransportCudaIpc;
|
||||
cudaIpcMemHandle_t handle;
|
||||
CUDATHROW(cudaIpcGetMemHandle(&handle, data));
|
||||
transportInfo.data = handle;
|
||||
transportInfo.cudaIpcHandle = handle;
|
||||
this->transportInfos.push_back(transportInfo);
|
||||
}
|
||||
if (transports & TransportAllIB) {
|
||||
@@ -16,8 +18,9 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t
|
||||
TransportInfo transportInfo;
|
||||
transportInfo.transport = ibTransport;
|
||||
mscclppIbMr* mr;
|
||||
MSCCLPPTHROW(mscclppIbContextRegisterMr(comm.pimpl->getIbContext(ibTransport), data, size, &mr));
|
||||
transportInfo.data = mr;
|
||||
MSCCLPPTHROW(mscclppIbContextRegisterMr(commImpl.getIbContext(ibTransport), data, size, &mr));
|
||||
transportInfo.ibMr = mr;
|
||||
transportInfo.ibLocal = true;
|
||||
this->transportInfos.push_back(transportInfo);
|
||||
};
|
||||
if (transports & TransportIB0) addIb(TransportIB0);
|
||||
@@ -31,62 +34,55 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t
|
||||
}
|
||||
}
|
||||
|
||||
RegisteredMemory::RegisteredMemory(std::shared_ptr<Impl> pimpl) : impl(pimpl) {}
|
||||
RegisteredMemory::RegisteredMemory(std::shared_ptr<Impl> pimpl) : pimpl(pimpl) {}
|
||||
|
||||
RegisteredMemory::~RegisteredMemory() = default;
|
||||
|
||||
void* RegisteredMemory::data() {
|
||||
return impl->data;
|
||||
return pimpl->data;
|
||||
}
|
||||
|
||||
size_t RegisteredMemory::size() {
|
||||
return impl->size;
|
||||
return pimpl->size;
|
||||
}
|
||||
|
||||
int RegisteredMemory::rank() {
|
||||
return impl->rank;
|
||||
return pimpl->rank;
|
||||
}
|
||||
|
||||
TransportFlags RegisteredMemory::transports() {
|
||||
return impl->transports;
|
||||
return pimpl->transports;
|
||||
}
|
||||
|
||||
std::vector<char> RegisteredMemory::serialize() {
|
||||
std::vector<char> result;
|
||||
std::copy_n(reinterpret_cast<char*>(&impl->size), sizeof(impl->size), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&impl->rank), sizeof(impl->rank), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&impl->transports), sizeof(impl->transports), std::back_inserter(result));
|
||||
if (impl->transportInfos.size() > std::numeric_limits<int8_t>::max()) {
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result));
|
||||
if (pimpl->transportInfos.size() > std::numeric_limits<int8_t>::max()) {
|
||||
throw std::runtime_error("Too many transport info entries");
|
||||
}
|
||||
int8_t transportCount = impl->transportInfos.size();
|
||||
int8_t transportCount = pimpl->transportInfos.size();
|
||||
std::copy_n(reinterpret_cast<char*>(&transportCount), sizeof(transportCount), std::back_inserter(result));
|
||||
for (auto& entry : impl->transportInfos) {
|
||||
for (auto& entry : pimpl->transportInfos) {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.transport), sizeof(entry.transport), std::back_inserter(result));
|
||||
std::visit(overloaded{
|
||||
[&](std::monostate&){
|
||||
throw std::runtime_error("Transport info not set");
|
||||
},
|
||||
[&](cudaIpcMemHandle_t handle){
|
||||
std::copy_n(reinterpret_cast<char*>(&handle), sizeof(handle), std::back_inserter(result));
|
||||
},
|
||||
[&](mscclppIbMr* mr){
|
||||
std::copy_n(reinterpret_cast<char*>(&mr->info), sizeof(mr->info), std::back_inserter(result));
|
||||
},
|
||||
[&](mscclppIbMrInfo info){
|
||||
std::copy_n(reinterpret_cast<char*>(&info), sizeof(info), std::back_inserter(result));
|
||||
}
|
||||
}, entry.data);
|
||||
if (entry.transport == TransportCudaIpc) {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcHandle), sizeof(entry.cudaIpcHandle), std::back_inserter(result));
|
||||
} else if (entry.transport & TransportAllIB) {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result));
|
||||
} else {
|
||||
throw std::runtime_error("Unknown transport");
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static RegisteredMemory RegisteredMemory::deserialize(const std::vector<char>& data) {
|
||||
RegisteredMemory RegisteredMemory::deserialize(const std::vector<char>& data) {
|
||||
return RegisteredMemory(std::make_shared<Impl>(data));
|
||||
}
|
||||
|
||||
RegisteredMemory::Impl::Impl(const std::vector<char>& data) {
|
||||
auto it = data.begin();
|
||||
RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
|
||||
auto it = serialization.begin();
|
||||
std::copy_n(it, sizeof(this->size), reinterpret_cast<char*>(&this->size));
|
||||
it += sizeof(this->size);
|
||||
std::copy_n(it, sizeof(this->rank), reinterpret_cast<char*>(&this->rank));
|
||||
@@ -104,24 +100,25 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& data) {
|
||||
cudaIpcMemHandle_t handle;
|
||||
std::copy_n(it, sizeof(handle), reinterpret_cast<char*>(&handle));
|
||||
it += sizeof(handle);
|
||||
transportInfo.data = handle;
|
||||
transportInfo.cudaIpcHandle = handle;
|
||||
} else if (transportInfo.transport & TransportAllIB) {
|
||||
mscclppIbMrInfo info;
|
||||
std::copy_n(it, sizeof(info), reinterpret_cast<char*>(&info));
|
||||
it += sizeof(info);
|
||||
transportInfo.data = info;
|
||||
transportInfo.ibMrInfo = info;
|
||||
transportInfo.ibLocal = false;
|
||||
} else {
|
||||
throw std::runtime_error("Unknown transport");
|
||||
}
|
||||
this->transportInfos.push_back(transportInfo);
|
||||
}
|
||||
if (it != data.end()) {
|
||||
if (it != serialization.end()) {
|
||||
throw std::runtime_error("Deserialization failed");
|
||||
}
|
||||
|
||||
if (transports & TransportCudaIpc) {
|
||||
auto cudaIpcHandle = getTransportInfo<cudaIpcMemHandle_t>(TransportCudaIpc);
|
||||
CUDATHROW(cudaIpcOpenMemHandle(&data, cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess));
|
||||
auto entry = getTransportInfo(TransportCudaIpc);
|
||||
CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user