mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
tests for host hash
This commit is contained in:
@@ -26,7 +26,7 @@ Communicator::Impl::~Impl() {
|
||||
ibContexts.clear();
|
||||
}
|
||||
|
||||
IbCtx* Communicator::Impl::getIbContext(TransportFlags ibTransport) {
|
||||
IbCtx* Communicator::Impl::getIbContext(Transport ibTransport) {
|
||||
// Find IB context or create it
|
||||
auto it = ibContexts.find(ibTransport);
|
||||
if (it == ibContexts.end()) {
|
||||
@@ -40,24 +40,6 @@ IbCtx* Communicator::Impl::getIbContext(TransportFlags ibTransport) {
|
||||
|
||||
MSCCLPP_API_CPP Communicator::~Communicator() = default;
|
||||
|
||||
static mscclppTransport_t transportToCStyle(TransportFlags flags) {
|
||||
switch (flags) {
|
||||
case TransportIB0:
|
||||
case TransportIB1:
|
||||
case TransportIB2:
|
||||
case TransportIB3:
|
||||
case TransportIB4:
|
||||
case TransportIB5:
|
||||
case TransportIB6:
|
||||
case TransportIB7:
|
||||
return mscclppTransportIB;
|
||||
case TransportCudaIpc:
|
||||
return mscclppTransportP2P;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported conversion");
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<BaseBootstrap> bootstrap) : pimpl(std::make_unique<Impl>(bootstrap)) {}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) {
|
||||
@@ -72,20 +54,19 @@ RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportF
|
||||
return RegisteredMemory(std::make_shared<RegisteredMemory::Impl>(ptr, size, pimpl->comm->rank, transports, *pimpl));
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connect(int remoteRank, int tag, TransportFlags transport) {
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connect(int remoteRank, int tag, Transport transport) {
|
||||
std::shared_ptr<ConnectionBase> conn;
|
||||
if (transport | TransportCudaIpc) {
|
||||
if (transport == Transport::CudaIpc) {
|
||||
// sanity check: make sure the IPC connection is being made within a node
|
||||
if (pimpl->rankToHash_[remoteRank] != pimpl->rankToHash_[pimpl->bootstrap_->getRank()]) {
|
||||
std::stringstream ss;
|
||||
ss << "Cuda IPC connection can only be made within a node: " << remoteRank << " != " << pimpl->bootstrap_->getRank();
|
||||
ss << "Cuda IPC connection can only be made within a node: " << remoteRank << "(" << std::hex << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")" << " != "
|
||||
<< pimpl->bootstrap_->getRank() << "(" << std::hex << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
auto cudaIpcConn = std::make_shared<CudaIpcConnection>();
|
||||
conn = cudaIpcConn;
|
||||
INFO(MSCCLPP_INIT, "Cuda IPC connection between %d(%lx) and %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],
|
||||
remoteRank, pimpl->rankToHash_[remoteRank]);
|
||||
} else if (transport | TransportAllIB) {
|
||||
} else if (AllIBTransports.has(transport)) {
|
||||
auto ibConn = std::make_shared<IBConnection>(remoteRank, tag, transport, *pimpl);
|
||||
conn = ibConn;
|
||||
INFO(MSCCLPP_INIT, "IB connection between %d(%lx) via %s and %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
void validateTransport(RegisteredMemory mem, TransportFlags transport) {
|
||||
if ((mem.transports() & transport) == TransportNone) {
|
||||
void validateTransport(RegisteredMemory mem, Transport transport) {
|
||||
if (!mem.transports().has(transport)) {
|
||||
throw std::runtime_error("mem does not support transport");
|
||||
}
|
||||
}
|
||||
@@ -28,12 +28,12 @@ CudaIpcConnection::~CudaIpcConnection() {
|
||||
cudaStreamDestroy(stream);
|
||||
}
|
||||
|
||||
TransportFlags CudaIpcConnection::transport() {
|
||||
return TransportCudaIpc;
|
||||
Transport CudaIpcConnection::transport() {
|
||||
return Transport::CudaIpc;
|
||||
}
|
||||
|
||||
TransportFlags CudaIpcConnection::remoteTransport() {
|
||||
return TransportCudaIpc;
|
||||
Transport CudaIpcConnection::remoteTransport() {
|
||||
return Transport::CudaIpc;
|
||||
}
|
||||
|
||||
void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) {
|
||||
@@ -54,7 +54,7 @@ void CudaIpcConnection::flush() {
|
||||
|
||||
// IBConnection
|
||||
|
||||
IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank_(remoteRank), tag_(tag), transport_(transport), remoteTransport_(TransportNone) {
|
||||
IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl) : remoteRank_(remoteRank), tag_(tag), transport_(transport), remoteTransport_(Transport::Unknown) {
|
||||
qp = commImpl.getIbContext(transport)->createQp();
|
||||
}
|
||||
|
||||
@@ -62,11 +62,11 @@ IBConnection::~IBConnection() {
|
||||
// TODO: Destroy QP?
|
||||
}
|
||||
|
||||
TransportFlags IBConnection::transport() {
|
||||
Transport IBConnection::transport() {
|
||||
return transport_;
|
||||
}
|
||||
|
||||
TransportFlags IBConnection::remoteTransport() {
|
||||
Transport IBConnection::remoteTransport() {
|
||||
return remoteTransport_;
|
||||
}
|
||||
|
||||
@@ -115,13 +115,11 @@ void IBConnection::flush() {
|
||||
}
|
||||
|
||||
void IBConnection::startSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
|
||||
// TODO(chhwang): temporarily disabled to compile
|
||||
bootstrap->send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank_, tag_);
|
||||
}
|
||||
|
||||
void IBConnection::endSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
|
||||
IbQpInfo qpInfo;
|
||||
// TODO(chhwang): temporarily disabled to compile
|
||||
bootstrap->recv(&qpInfo, sizeof(qpInfo), remoteRank_, tag_);
|
||||
qp->rtr(qpInfo);
|
||||
qp->rts();
|
||||
|
||||
36
src/ib.cc
36
src/ib.cc
@@ -368,33 +368,33 @@ int getIBDeviceCount() {
|
||||
return num;
|
||||
}
|
||||
|
||||
std::string getIBDeviceName(TransportFlags ibTransport) {
|
||||
std::string getIBDeviceName(Transport ibTransport) {
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
int ibTransportIndex;
|
||||
switch (ibTransport) { // TODO: get rid of this ugly switch
|
||||
case TransportIB0:
|
||||
case Transport::IB0:
|
||||
ibTransportIndex = 0;
|
||||
break;
|
||||
case TransportIB1:
|
||||
case Transport::IB1:
|
||||
ibTransportIndex = 1;
|
||||
break;
|
||||
case TransportIB2:
|
||||
case Transport::IB2:
|
||||
ibTransportIndex = 2;
|
||||
break;
|
||||
case TransportIB3:
|
||||
case Transport::IB3:
|
||||
ibTransportIndex = 3;
|
||||
break;
|
||||
case TransportIB4:
|
||||
case Transport::IB4:
|
||||
ibTransportIndex = 4;
|
||||
break;
|
||||
case TransportIB5:
|
||||
case Transport::IB5:
|
||||
ibTransportIndex = 5;
|
||||
break;
|
||||
case TransportIB6:
|
||||
case Transport::IB6:
|
||||
ibTransportIndex = 6;
|
||||
break;
|
||||
case TransportIB7:
|
||||
case Transport::IB7:
|
||||
ibTransportIndex = 7;
|
||||
break;
|
||||
default:
|
||||
@@ -406,28 +406,28 @@ std::string getIBDeviceName(TransportFlags ibTransport) {
|
||||
return devices[ibTransportIndex]->name;
|
||||
}
|
||||
|
||||
TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName) {
|
||||
Transport getIBTransportByDeviceName(const std::string& ibDeviceName) {
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
if (ibDeviceName == devices[i]->name) {
|
||||
switch (i) { // TODO: get rid of this ugly switch
|
||||
case 0:
|
||||
return TransportIB0;
|
||||
return Transport::IB0;
|
||||
case 1:
|
||||
return TransportIB1;
|
||||
return Transport::IB1;
|
||||
case 2:
|
||||
return TransportIB2;
|
||||
return Transport::IB2;
|
||||
case 3:
|
||||
return TransportIB3;
|
||||
return Transport::IB3;
|
||||
case 4:
|
||||
return TransportIB4;
|
||||
return Transport::IB4;
|
||||
case 5:
|
||||
return TransportIB5;
|
||||
return Transport::IB5;
|
||||
case 6:
|
||||
return TransportIB6;
|
||||
return Transport::IB6;
|
||||
case 7:
|
||||
return TransportIB7;
|
||||
return Transport::IB7;
|
||||
default:
|
||||
throw std::runtime_error("IB device index out of range");
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ class ConnectionBase;
|
||||
struct Communicator::Impl {
|
||||
mscclppComm_t comm;
|
||||
std::vector<std::shared_ptr<ConnectionBase>> connections;
|
||||
std::unordered_map<TransportFlags, std::unique_ptr<IbCtx>> ibContexts;
|
||||
std::unordered_map<Transport, std::unique_ptr<IbCtx>> ibContexts;
|
||||
std::shared_ptr<BaseBootstrap> bootstrap_;
|
||||
std::vector<uint64_t> rankToHash_;
|
||||
|
||||
@@ -24,7 +24,7 @@ struct Communicator::Impl {
|
||||
|
||||
~Impl();
|
||||
|
||||
IbCtx* getIbContext(TransportFlags ibTransport);
|
||||
IbCtx* getIbContext(Transport ibTransport);
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -24,9 +24,9 @@ public:
|
||||
|
||||
~CudaIpcConnection();
|
||||
|
||||
TransportFlags transport() override;
|
||||
Transport transport() override;
|
||||
|
||||
TransportFlags remoteTransport() override;
|
||||
Transport remoteTransport() override;
|
||||
|
||||
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override;
|
||||
|
||||
@@ -36,18 +36,18 @@ public:
|
||||
class IBConnection : public ConnectionBase {
|
||||
int remoteRank_;
|
||||
int tag_;
|
||||
TransportFlags transport_;
|
||||
TransportFlags remoteTransport_;
|
||||
Transport transport_;
|
||||
Transport remoteTransport_;
|
||||
IbQp* qp;
|
||||
public:
|
||||
|
||||
IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl);
|
||||
IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl);
|
||||
|
||||
~IBConnection();
|
||||
|
||||
TransportFlags transport() override;
|
||||
Transport transport() override;
|
||||
|
||||
TransportFlags remoteTransport() override;
|
||||
Transport remoteTransport() override;
|
||||
|
||||
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override;
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <bitset>
|
||||
|
||||
|
||||
namespace mscclpp {
|
||||
@@ -63,24 +64,129 @@ private:
|
||||
*/
|
||||
std::unique_ptr<UniqueId> getUniqueId();
|
||||
|
||||
using TransportFlags = uint32_t;
|
||||
const TransportFlags TransportNone = 0b0;
|
||||
const TransportFlags TransportCudaIpc = 0b1;
|
||||
const TransportFlags TransportIB0 = 0b10;
|
||||
const TransportFlags TransportIB1 = 0b100;
|
||||
const TransportFlags TransportIB2 = 0b1000;
|
||||
const TransportFlags TransportIB3 = 0b10000;
|
||||
const TransportFlags TransportIB4 = 0b100000;
|
||||
const TransportFlags TransportIB5 = 0b1000000;
|
||||
const TransportFlags TransportIB6 = 0b10000000;
|
||||
const TransportFlags TransportIB7 = 0b100000000;
|
||||
enum class Transport {
|
||||
Unknown,
|
||||
CudaIpc,
|
||||
IB0,
|
||||
IB1,
|
||||
IB2,
|
||||
IB3,
|
||||
IB4,
|
||||
IB5,
|
||||
IB6,
|
||||
IB7,
|
||||
NumTransports
|
||||
};
|
||||
|
||||
const TransportFlags TransportAll = 0b111111111;
|
||||
const TransportFlags TransportAllIB = 0b111111110;
|
||||
namespace detail {
|
||||
const size_t TransportFlagsSize = 10;
|
||||
static_assert(TransportFlagsSize == static_cast<size_t>(Transport::NumTransports), "TransportFlagsSize must match the number of transports");
|
||||
using TransportFlagsBase = std::bitset<TransportFlagsSize>;
|
||||
}
|
||||
|
||||
class TransportFlags : private detail::TransportFlagsBase {
|
||||
public:
|
||||
TransportFlags() = default;
|
||||
TransportFlags(Transport transport) : detail::TransportFlagsBase(1 << static_cast<size_t>(transport)) {}
|
||||
|
||||
bool has(Transport transport) const {
|
||||
return detail::TransportFlagsBase::test(static_cast<size_t>(transport));
|
||||
}
|
||||
|
||||
bool none() const {
|
||||
return detail::TransportFlagsBase::none();
|
||||
}
|
||||
|
||||
bool any() const {
|
||||
return detail::TransportFlagsBase::any();
|
||||
}
|
||||
|
||||
bool all() const {
|
||||
return detail::TransportFlagsBase::all();
|
||||
}
|
||||
|
||||
size_t count() const {
|
||||
return detail::TransportFlagsBase::count();
|
||||
}
|
||||
|
||||
TransportFlags& operator|=(TransportFlags other) {
|
||||
detail::TransportFlagsBase::operator|=(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
TransportFlags operator|(TransportFlags other) const {
|
||||
return TransportFlags(*this) |= other;
|
||||
}
|
||||
|
||||
TransportFlags operator|(Transport transport) const {
|
||||
return *this | TransportFlags(transport);
|
||||
}
|
||||
|
||||
TransportFlags& operator&=(TransportFlags other) {
|
||||
detail::TransportFlagsBase::operator&=(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
TransportFlags operator&(TransportFlags other) const {
|
||||
return TransportFlags(*this) &= other;
|
||||
}
|
||||
|
||||
TransportFlags operator&(Transport transport) const {
|
||||
return *this & TransportFlags(transport);
|
||||
}
|
||||
|
||||
TransportFlags& operator^=(TransportFlags other) {
|
||||
detail::TransportFlagsBase::operator^=(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
TransportFlags operator^(TransportFlags other) const {
|
||||
return TransportFlags(*this) ^= other;
|
||||
}
|
||||
|
||||
TransportFlags operator^(Transport transport) const {
|
||||
return *this ^ TransportFlags(transport);
|
||||
}
|
||||
|
||||
TransportFlags operator~() const {
|
||||
return TransportFlags(*this).flip();
|
||||
}
|
||||
|
||||
bool operator==(TransportFlags other) const {
|
||||
return detail::TransportFlagsBase::operator==(other);
|
||||
}
|
||||
|
||||
bool operator!=(TransportFlags other) const {
|
||||
return detail::TransportFlagsBase::operator!=(other);
|
||||
}
|
||||
|
||||
detail::TransportFlagsBase toBitset() const {
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
TransportFlags(detail::TransportFlagsBase bitset) : detail::TransportFlagsBase(bitset) {}
|
||||
};
|
||||
|
||||
inline TransportFlags operator|(Transport transport1, Transport transport2) {
|
||||
return TransportFlags(transport1) | transport2;
|
||||
}
|
||||
|
||||
inline TransportFlags operator&(Transport transport1, Transport transport2) {
|
||||
return TransportFlags(transport1) & transport2;
|
||||
}
|
||||
|
||||
inline TransportFlags operator^(Transport transport1, Transport transport2) {
|
||||
return TransportFlags(transport1) ^ transport2;
|
||||
}
|
||||
|
||||
const TransportFlags NoTransports = TransportFlags();
|
||||
const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transport::IB2 | Transport::IB3 | Transport::IB4 | Transport::IB5 | Transport::IB6 | Transport::IB7;
|
||||
const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc;
|
||||
|
||||
int getIBDeviceCount();
|
||||
std::string getIBDeviceName(TransportFlags ibTransport);
|
||||
TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName);
|
||||
std::string getIBDeviceName(Transport ibTransport);
|
||||
Transport getIBTransportByDeviceName(const std::string& ibDeviceName);
|
||||
|
||||
class Communicator;
|
||||
class Connection;
|
||||
@@ -111,9 +217,9 @@ public:
|
||||
|
||||
virtual void flush() = 0;
|
||||
|
||||
virtual TransportFlags transport() = 0;
|
||||
virtual Transport transport() = 0;
|
||||
|
||||
virtual TransportFlags remoteTransport() = 0;
|
||||
virtual Transport remoteTransport() = 0;
|
||||
|
||||
protected:
|
||||
static std::shared_ptr<RegisteredMemory::Impl> getRegisteredMemoryImpl(RegisteredMemory&);
|
||||
@@ -166,7 +272,7 @@ public:
|
||||
* transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB)
|
||||
* ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P.
|
||||
*/
|
||||
std::shared_ptr<Connection> connect(int remoteRank, int tag, TransportFlags transport);
|
||||
std::shared_ptr<Connection> connect(int remoteRank, int tag, Transport transport);
|
||||
|
||||
/* Establish all connections declared by connect(). This function must be called after all connect()
|
||||
* calls are made. This function ensures that all remote ranks are ready to communicate when it returns.
|
||||
@@ -180,4 +286,13 @@ private:
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<mscclpp::TransportFlags> {
|
||||
size_t operator()(const mscclpp::TransportFlags& flags) const {
|
||||
return hash<mscclpp::detail::TransportFlagsBase>()(flags.toBitset());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#endif // MSCCLPP_H_
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
namespace mscclpp {
|
||||
|
||||
struct TransportInfo {
|
||||
TransportFlags transport;
|
||||
Transport transport;
|
||||
|
||||
// TODO: rewrite this using std::variant or something
|
||||
bool ibLocal;
|
||||
@@ -31,7 +31,7 @@ struct RegisteredMemory::Impl {
|
||||
Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl);
|
||||
Impl(const std::vector<char>& data);
|
||||
|
||||
TransportInfo& getTransportInfo(TransportFlags transport) {
|
||||
TransportInfo& getTransportInfo(Transport transport) {
|
||||
for (auto& entry : transportInfos) {
|
||||
if (entry.transport == transport) {
|
||||
return entry;
|
||||
|
||||
@@ -5,17 +5,17 @@
|
||||
namespace mscclpp {
|
||||
|
||||
RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) : data(data), size(size), rank(rank), transports(transports) {
|
||||
if (transports & TransportCudaIpc) {
|
||||
if (transports.has(Transport::CudaIpc)) {
|
||||
TransportInfo transportInfo;
|
||||
transportInfo.transport = TransportCudaIpc;
|
||||
transportInfo.transport = Transport::CudaIpc;
|
||||
cudaIpcMemHandle_t handle;
|
||||
// TODO: translate data to a base pointer
|
||||
CUDATHROW(cudaIpcGetMemHandle(&handle, data));
|
||||
transportInfo.cudaIpcHandle = handle;
|
||||
this->transportInfos.push_back(transportInfo);
|
||||
}
|
||||
if (transports & TransportAllIB) {
|
||||
auto addIb = [&](TransportFlags ibTransport) {
|
||||
if ((transports & AllIBTransports).any()) {
|
||||
auto addIb = [&](Transport ibTransport) {
|
||||
TransportInfo transportInfo;
|
||||
transportInfo.transport = ibTransport;
|
||||
const IbMr* mr = commImpl.getIbContext(ibTransport)->registerMr(data, size);
|
||||
@@ -23,14 +23,14 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t
|
||||
transportInfo.ibLocal = true;
|
||||
this->transportInfos.push_back(transportInfo);
|
||||
};
|
||||
if (transports & TransportIB0) addIb(TransportIB0);
|
||||
if (transports & TransportIB1) addIb(TransportIB1);
|
||||
if (transports & TransportIB2) addIb(TransportIB2);
|
||||
if (transports & TransportIB3) addIb(TransportIB3);
|
||||
if (transports & TransportIB4) addIb(TransportIB4);
|
||||
if (transports & TransportIB5) addIb(TransportIB5);
|
||||
if (transports & TransportIB6) addIb(TransportIB6);
|
||||
if (transports & TransportIB7) addIb(TransportIB7);
|
||||
if (transports.has(Transport::IB0)) addIb(Transport::IB0);
|
||||
if (transports.has(Transport::IB1)) addIb(Transport::IB1);
|
||||
if (transports.has(Transport::IB2)) addIb(Transport::IB2);
|
||||
if (transports.has(Transport::IB3)) addIb(Transport::IB3);
|
||||
if (transports.has(Transport::IB4)) addIb(Transport::IB4);
|
||||
if (transports.has(Transport::IB5)) addIb(Transport::IB5);
|
||||
if (transports.has(Transport::IB6)) addIb(Transport::IB6);
|
||||
if (transports.has(Transport::IB7)) addIb(Transport::IB7);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,9 +66,9 @@ std::vector<char> RegisteredMemory::serialize() {
|
||||
std::copy_n(reinterpret_cast<char*>(&transportCount), sizeof(transportCount), std::back_inserter(result));
|
||||
for (auto& entry : pimpl->transportInfos) {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.transport), sizeof(entry.transport), std::back_inserter(result));
|
||||
if (entry.transport == TransportCudaIpc) {
|
||||
if (entry.transport == Transport::CudaIpc) {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcHandle), sizeof(entry.cudaIpcHandle), std::back_inserter(result));
|
||||
} else if (entry.transport & TransportAllIB) {
|
||||
} else if (AllIBTransports.has(entry.transport)) {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result));
|
||||
} else {
|
||||
throw std::runtime_error("Unknown transport");
|
||||
@@ -96,12 +96,12 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
|
||||
TransportInfo transportInfo;
|
||||
std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast<char*>(&transportInfo.transport));
|
||||
it += sizeof(transportInfo.transport);
|
||||
if (transportInfo.transport & TransportCudaIpc) {
|
||||
if (transportInfo.transport == Transport::CudaIpc) {
|
||||
cudaIpcMemHandle_t handle;
|
||||
std::copy_n(it, sizeof(handle), reinterpret_cast<char*>(&handle));
|
||||
it += sizeof(handle);
|
||||
transportInfo.cudaIpcHandle = handle;
|
||||
} else if (transportInfo.transport & TransportAllIB) {
|
||||
} else if (AllIBTransports.has(transportInfo.transport)) {
|
||||
IbMrInfo info;
|
||||
std::copy_n(it, sizeof(info), reinterpret_cast<char*>(&info));
|
||||
it += sizeof(info);
|
||||
@@ -116,8 +116,8 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
|
||||
throw std::runtime_error("Deserialization failed");
|
||||
}
|
||||
|
||||
if (transports & TransportCudaIpc) {
|
||||
auto entry = getTransportInfo(TransportCudaIpc);
|
||||
if (transports.has(Transport::CudaIpc)) {
|
||||
auto entry = getTransportInfo(Transport::CudaIpc);
|
||||
CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,16 +5,16 @@
|
||||
#include <iostream>
|
||||
#include <mpi.h>
|
||||
|
||||
mscclpp::TransportFlags findIb(int localRank){
|
||||
mscclpp::TransportFlags IBs[] = {
|
||||
mscclpp::TransportIB0,
|
||||
mscclpp::TransportIB1,
|
||||
mscclpp::TransportIB2,
|
||||
mscclpp::TransportIB3,
|
||||
mscclpp::TransportIB4,
|
||||
mscclpp::TransportIB5,
|
||||
mscclpp::TransportIB6,
|
||||
mscclpp::TransportIB7
|
||||
mscclpp::Transport findIb(int localRank){
|
||||
mscclpp::Transport IBs[] = {
|
||||
mscclpp::Transport::IB0,
|
||||
mscclpp::Transport::IB1,
|
||||
mscclpp::Transport::IB2,
|
||||
mscclpp::Transport::IB3,
|
||||
mscclpp::Transport::IB4,
|
||||
mscclpp::Transport::IB5,
|
||||
mscclpp::Transport::IB6,
|
||||
mscclpp::Transport::IB7
|
||||
};
|
||||
return IBs[localRank];
|
||||
}
|
||||
@@ -31,8 +31,7 @@ void test_communicator(int rank, int worldSize, int nranksPerNode){
|
||||
for (int i = 0; i < worldSize; i++){
|
||||
if (i != rank){
|
||||
if (i / nranksPerNode == rank / nranksPerNode){
|
||||
printf("i %d rank %d nranksPerNode %d\n", i, rank, nranksPerNode);
|
||||
auto connect = communicator->connect(i, 0, mscclpp::TransportCudaIpc);
|
||||
auto connect = communicator->connect(i, 0, mscclpp::Transport::CudaIpc);
|
||||
} else {
|
||||
auto connect = communicator->connect(i, 0, findIb(rank % nranksPerNode));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user