tests for host hash

This commit is contained in:
Saeed Maleki
2023-04-27 20:09:47 +00:00
9 changed files with 207 additions and 114 deletions

View File

@@ -26,7 +26,7 @@ Communicator::Impl::~Impl() {
ibContexts.clear();
}
IbCtx* Communicator::Impl::getIbContext(TransportFlags ibTransport) {
IbCtx* Communicator::Impl::getIbContext(Transport ibTransport) {
// Find IB context or create it
auto it = ibContexts.find(ibTransport);
if (it == ibContexts.end()) {
@@ -40,24 +40,6 @@ IbCtx* Communicator::Impl::getIbContext(TransportFlags ibTransport) {
MSCCLPP_API_CPP Communicator::~Communicator() = default;
static mscclppTransport_t transportToCStyle(TransportFlags flags) {
switch (flags) {
case TransportIB0:
case TransportIB1:
case TransportIB2:
case TransportIB3:
case TransportIB4:
case TransportIB5:
case TransportIB6:
case TransportIB7:
return mscclppTransportIB;
case TransportCudaIpc:
return mscclppTransportP2P;
default:
throw std::runtime_error("Unsupported conversion");
}
}
MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<BaseBootstrap> bootstrap) : pimpl(std::make_unique<Impl>(bootstrap)) {}
MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) {
@@ -72,20 +54,19 @@ RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportF
return RegisteredMemory(std::make_shared<RegisteredMemory::Impl>(ptr, size, pimpl->comm->rank, transports, *pimpl));
}
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connect(int remoteRank, int tag, TransportFlags transport) {
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connect(int remoteRank, int tag, Transport transport) {
std::shared_ptr<ConnectionBase> conn;
if (transport | TransportCudaIpc) {
if (transport == Transport::CudaIpc) {
// sanity check: make sure the IPC connection is being made within a node
if (pimpl->rankToHash_[remoteRank] != pimpl->rankToHash_[pimpl->bootstrap_->getRank()]) {
std::stringstream ss;
ss << "Cuda IPC connection can only be made within a node: " << remoteRank << " != " << pimpl->bootstrap_->getRank();
ss << "Cuda IPC connection can only be made within a node: " << remoteRank << "(" << std::hex << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")" << " != "
<< pimpl->bootstrap_->getRank() << "(" << std::hex << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")";
throw std::runtime_error(ss.str());
}
}
auto cudaIpcConn = std::make_shared<CudaIpcConnection>();
conn = cudaIpcConn;
INFO(MSCCLPP_INIT, "Cuda IPC connection between %d(%lx) and %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],
remoteRank, pimpl->rankToHash_[remoteRank]);
} else if (transport | TransportAllIB) {
} else if (AllIBTransports.has(transport)) {
auto ibConn = std::make_shared<IBConnection>(remoteRank, tag, transport, *pimpl);
conn = ibConn;
INFO(MSCCLPP_INIT, "IB connection between %d(%lx) via %s and %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],

View File

@@ -6,8 +6,8 @@
namespace mscclpp {
void validateTransport(RegisteredMemory mem, TransportFlags transport) {
if ((mem.transports() & transport) == TransportNone) {
void validateTransport(RegisteredMemory mem, Transport transport) {
if (!mem.transports().has(transport)) {
throw std::runtime_error("mem does not support transport");
}
}
@@ -28,12 +28,12 @@ CudaIpcConnection::~CudaIpcConnection() {
cudaStreamDestroy(stream);
}
TransportFlags CudaIpcConnection::transport() {
return TransportCudaIpc;
Transport CudaIpcConnection::transport() {
return Transport::CudaIpc;
}
TransportFlags CudaIpcConnection::remoteTransport() {
return TransportCudaIpc;
Transport CudaIpcConnection::remoteTransport() {
return Transport::CudaIpc;
}
void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) {
@@ -54,7 +54,7 @@ void CudaIpcConnection::flush() {
// IBConnection
IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank_(remoteRank), tag_(tag), transport_(transport), remoteTransport_(TransportNone) {
IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl) : remoteRank_(remoteRank), tag_(tag), transport_(transport), remoteTransport_(Transport::Unknown) {
qp = commImpl.getIbContext(transport)->createQp();
}
@@ -62,11 +62,11 @@ IBConnection::~IBConnection() {
// TODO: Destroy QP?
}
TransportFlags IBConnection::transport() {
Transport IBConnection::transport() {
return transport_;
}
TransportFlags IBConnection::remoteTransport() {
Transport IBConnection::remoteTransport() {
return remoteTransport_;
}
@@ -115,13 +115,11 @@ void IBConnection::flush() {
}
void IBConnection::startSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
// TODO(chhwang): temporarily disabled to compile
bootstrap->send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank_, tag_);
}
void IBConnection::endSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
IbQpInfo qpInfo;
// TODO(chhwang): temporarily disabled to compile
bootstrap->recv(&qpInfo, sizeof(qpInfo), remoteRank_, tag_);
qp->rtr(qpInfo);
qp->rts();

View File

@@ -368,33 +368,33 @@ int getIBDeviceCount() {
return num;
}
std::string getIBDeviceName(TransportFlags ibTransport) {
std::string getIBDeviceName(Transport ibTransport) {
int num;
struct ibv_device** devices = ibv_get_device_list(&num);
int ibTransportIndex;
switch (ibTransport) { // TODO: get rid of this ugly switch
case TransportIB0:
case Transport::IB0:
ibTransportIndex = 0;
break;
case TransportIB1:
case Transport::IB1:
ibTransportIndex = 1;
break;
case TransportIB2:
case Transport::IB2:
ibTransportIndex = 2;
break;
case TransportIB3:
case Transport::IB3:
ibTransportIndex = 3;
break;
case TransportIB4:
case Transport::IB4:
ibTransportIndex = 4;
break;
case TransportIB5:
case Transport::IB5:
ibTransportIndex = 5;
break;
case TransportIB6:
case Transport::IB6:
ibTransportIndex = 6;
break;
case TransportIB7:
case Transport::IB7:
ibTransportIndex = 7;
break;
default:
@@ -406,28 +406,28 @@ std::string getIBDeviceName(TransportFlags ibTransport) {
return devices[ibTransportIndex]->name;
}
TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName) {
Transport getIBTransportByDeviceName(const std::string& ibDeviceName) {
int num;
struct ibv_device** devices = ibv_get_device_list(&num);
for (int i = 0; i < num; ++i) {
if (ibDeviceName == devices[i]->name) {
switch (i) { // TODO: get rid of this ugly switch
case 0:
return TransportIB0;
return Transport::IB0;
case 1:
return TransportIB1;
return Transport::IB1;
case 2:
return TransportIB2;
return Transport::IB2;
case 3:
return TransportIB3;
return Transport::IB3;
case 4:
return TransportIB4;
return Transport::IB4;
case 5:
return TransportIB5;
return Transport::IB5;
case 6:
return TransportIB6;
return Transport::IB6;
case 7:
return TransportIB7;
return Transport::IB7;
default:
throw std::runtime_error("IB device index out of range");
}

View File

@@ -16,7 +16,7 @@ class ConnectionBase;
struct Communicator::Impl {
mscclppComm_t comm;
std::vector<std::shared_ptr<ConnectionBase>> connections;
std::unordered_map<TransportFlags, std::unique_ptr<IbCtx>> ibContexts;
std::unordered_map<Transport, std::unique_ptr<IbCtx>> ibContexts;
std::shared_ptr<BaseBootstrap> bootstrap_;
std::vector<uint64_t> rankToHash_;
@@ -24,7 +24,7 @@ struct Communicator::Impl {
~Impl();
IbCtx* getIbContext(TransportFlags ibTransport);
IbCtx* getIbContext(Transport ibTransport);
};
} // namespace mscclpp

View File

@@ -24,9 +24,9 @@ public:
~CudaIpcConnection();
TransportFlags transport() override;
Transport transport() override;
TransportFlags remoteTransport() override;
Transport remoteTransport() override;
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override;
@@ -36,18 +36,18 @@ public:
class IBConnection : public ConnectionBase {
int remoteRank_;
int tag_;
TransportFlags transport_;
TransportFlags remoteTransport_;
Transport transport_;
Transport remoteTransport_;
IbQp* qp;
public:
IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl);
IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl);
~IBConnection();
TransportFlags transport() override;
Transport transport() override;
TransportFlags remoteTransport() override;
Transport remoteTransport() override;
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override;

View File

@@ -9,6 +9,7 @@
#include <vector>
#include <memory>
#include <string>
#include <bitset>
namespace mscclpp {
@@ -63,24 +64,129 @@ private:
*/
std::unique_ptr<UniqueId> getUniqueId();
using TransportFlags = uint32_t;
const TransportFlags TransportNone = 0b0;
const TransportFlags TransportCudaIpc = 0b1;
const TransportFlags TransportIB0 = 0b10;
const TransportFlags TransportIB1 = 0b100;
const TransportFlags TransportIB2 = 0b1000;
const TransportFlags TransportIB3 = 0b10000;
const TransportFlags TransportIB4 = 0b100000;
const TransportFlags TransportIB5 = 0b1000000;
const TransportFlags TransportIB6 = 0b10000000;
const TransportFlags TransportIB7 = 0b100000000;
enum class Transport {
Unknown,
CudaIpc,
IB0,
IB1,
IB2,
IB3,
IB4,
IB5,
IB6,
IB7,
NumTransports
};
const TransportFlags TransportAll = 0b111111111;
const TransportFlags TransportAllIB = 0b111111110;
namespace detail {
const size_t TransportFlagsSize = 10;
static_assert(TransportFlagsSize == static_cast<size_t>(Transport::NumTransports), "TransportFlagsSize must match the number of transports");
using TransportFlagsBase = std::bitset<TransportFlagsSize>;
}
class TransportFlags : private detail::TransportFlagsBase {
public:
TransportFlags() = default;
TransportFlags(Transport transport) : detail::TransportFlagsBase(1 << static_cast<size_t>(transport)) {}
bool has(Transport transport) const {
return detail::TransportFlagsBase::test(static_cast<size_t>(transport));
}
bool none() const {
return detail::TransportFlagsBase::none();
}
bool any() const {
return detail::TransportFlagsBase::any();
}
bool all() const {
return detail::TransportFlagsBase::all();
}
size_t count() const {
return detail::TransportFlagsBase::count();
}
TransportFlags& operator|=(TransportFlags other) {
detail::TransportFlagsBase::operator|=(other);
return *this;
}
TransportFlags operator|(TransportFlags other) const {
return TransportFlags(*this) |= other;
}
TransportFlags operator|(Transport transport) const {
return *this | TransportFlags(transport);
}
TransportFlags& operator&=(TransportFlags other) {
detail::TransportFlagsBase::operator&=(other);
return *this;
}
TransportFlags operator&(TransportFlags other) const {
return TransportFlags(*this) &= other;
}
TransportFlags operator&(Transport transport) const {
return *this & TransportFlags(transport);
}
TransportFlags& operator^=(TransportFlags other) {
detail::TransportFlagsBase::operator^=(other);
return *this;
}
TransportFlags operator^(TransportFlags other) const {
return TransportFlags(*this) ^= other;
}
TransportFlags operator^(Transport transport) const {
return *this ^ TransportFlags(transport);
}
TransportFlags operator~() const {
return TransportFlags(*this).flip();
}
bool operator==(TransportFlags other) const {
return detail::TransportFlagsBase::operator==(other);
}
bool operator!=(TransportFlags other) const {
return detail::TransportFlagsBase::operator!=(other);
}
detail::TransportFlagsBase toBitset() const {
return *this;
}
private:
TransportFlags(detail::TransportFlagsBase bitset) : detail::TransportFlagsBase(bitset) {}
};
inline TransportFlags operator|(Transport transport1, Transport transport2) {
return TransportFlags(transport1) | transport2;
}
inline TransportFlags operator&(Transport transport1, Transport transport2) {
return TransportFlags(transport1) & transport2;
}
inline TransportFlags operator^(Transport transport1, Transport transport2) {
return TransportFlags(transport1) ^ transport2;
}
const TransportFlags NoTransports = TransportFlags();
const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transport::IB2 | Transport::IB3 | Transport::IB4 | Transport::IB5 | Transport::IB6 | Transport::IB7;
const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc;
int getIBDeviceCount();
std::string getIBDeviceName(TransportFlags ibTransport);
TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName);
std::string getIBDeviceName(Transport ibTransport);
Transport getIBTransportByDeviceName(const std::string& ibDeviceName);
class Communicator;
class Connection;
@@ -111,9 +217,9 @@ public:
virtual void flush() = 0;
virtual TransportFlags transport() = 0;
virtual Transport transport() = 0;
virtual TransportFlags remoteTransport() = 0;
virtual Transport remoteTransport() = 0;
protected:
static std::shared_ptr<RegisteredMemory::Impl> getRegisteredMemoryImpl(RegisteredMemory&);
@@ -166,7 +272,7 @@ public:
* transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB)
* ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P.
*/
std::shared_ptr<Connection> connect(int remoteRank, int tag, TransportFlags transport);
std::shared_ptr<Connection> connect(int remoteRank, int tag, Transport transport);
/* Establish all connections declared by connect(). This function must be called after all connect()
* calls are made. This function ensures that all remote ranks are ready to communicate when it returns.
@@ -180,4 +286,13 @@ private:
} // namespace mscclpp
namespace std {
template <>
struct hash<mscclpp::TransportFlags> {
size_t operator()(const mscclpp::TransportFlags& flags) const {
return hash<mscclpp::detail::TransportFlagsBase>()(flags.toBitset());
}
};
}
#endif // MSCCLPP_H_

View File

@@ -10,7 +10,7 @@
namespace mscclpp {
struct TransportInfo {
TransportFlags transport;
Transport transport;
// TODO: rewrite this using std::variant or something
bool ibLocal;
@@ -31,7 +31,7 @@ struct RegisteredMemory::Impl {
Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl);
Impl(const std::vector<char>& data);
TransportInfo& getTransportInfo(TransportFlags transport) {
TransportInfo& getTransportInfo(Transport transport) {
for (auto& entry : transportInfos) {
if (entry.transport == transport) {
return entry;

View File

@@ -5,17 +5,17 @@
namespace mscclpp {
RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) : data(data), size(size), rank(rank), transports(transports) {
if (transports & TransportCudaIpc) {
if (transports.has(Transport::CudaIpc)) {
TransportInfo transportInfo;
transportInfo.transport = TransportCudaIpc;
transportInfo.transport = Transport::CudaIpc;
cudaIpcMemHandle_t handle;
// TODO: translate data to a base pointer
CUDATHROW(cudaIpcGetMemHandle(&handle, data));
transportInfo.cudaIpcHandle = handle;
this->transportInfos.push_back(transportInfo);
}
if (transports & TransportAllIB) {
auto addIb = [&](TransportFlags ibTransport) {
if ((transports & AllIBTransports).any()) {
auto addIb = [&](Transport ibTransport) {
TransportInfo transportInfo;
transportInfo.transport = ibTransport;
const IbMr* mr = commImpl.getIbContext(ibTransport)->registerMr(data, size);
@@ -23,14 +23,14 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t
transportInfo.ibLocal = true;
this->transportInfos.push_back(transportInfo);
};
if (transports & TransportIB0) addIb(TransportIB0);
if (transports & TransportIB1) addIb(TransportIB1);
if (transports & TransportIB2) addIb(TransportIB2);
if (transports & TransportIB3) addIb(TransportIB3);
if (transports & TransportIB4) addIb(TransportIB4);
if (transports & TransportIB5) addIb(TransportIB5);
if (transports & TransportIB6) addIb(TransportIB6);
if (transports & TransportIB7) addIb(TransportIB7);
if (transports.has(Transport::IB0)) addIb(Transport::IB0);
if (transports.has(Transport::IB1)) addIb(Transport::IB1);
if (transports.has(Transport::IB2)) addIb(Transport::IB2);
if (transports.has(Transport::IB3)) addIb(Transport::IB3);
if (transports.has(Transport::IB4)) addIb(Transport::IB4);
if (transports.has(Transport::IB5)) addIb(Transport::IB5);
if (transports.has(Transport::IB6)) addIb(Transport::IB6);
if (transports.has(Transport::IB7)) addIb(Transport::IB7);
}
}
@@ -66,9 +66,9 @@ std::vector<char> RegisteredMemory::serialize() {
std::copy_n(reinterpret_cast<char*>(&transportCount), sizeof(transportCount), std::back_inserter(result));
for (auto& entry : pimpl->transportInfos) {
std::copy_n(reinterpret_cast<char*>(&entry.transport), sizeof(entry.transport), std::back_inserter(result));
if (entry.transport == TransportCudaIpc) {
if (entry.transport == Transport::CudaIpc) {
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcHandle), sizeof(entry.cudaIpcHandle), std::back_inserter(result));
} else if (entry.transport & TransportAllIB) {
} else if (AllIBTransports.has(entry.transport)) {
std::copy_n(reinterpret_cast<char*>(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result));
} else {
throw std::runtime_error("Unknown transport");
@@ -96,12 +96,12 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
TransportInfo transportInfo;
std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast<char*>(&transportInfo.transport));
it += sizeof(transportInfo.transport);
if (transportInfo.transport & TransportCudaIpc) {
if (transportInfo.transport == Transport::CudaIpc) {
cudaIpcMemHandle_t handle;
std::copy_n(it, sizeof(handle), reinterpret_cast<char*>(&handle));
it += sizeof(handle);
transportInfo.cudaIpcHandle = handle;
} else if (transportInfo.transport & TransportAllIB) {
} else if (AllIBTransports.has(transportInfo.transport)) {
IbMrInfo info;
std::copy_n(it, sizeof(info), reinterpret_cast<char*>(&info));
it += sizeof(info);
@@ -116,8 +116,8 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
throw std::runtime_error("Deserialization failed");
}
if (transports & TransportCudaIpc) {
auto entry = getTransportInfo(TransportCudaIpc);
if (transports.has(Transport::CudaIpc)) {
auto entry = getTransportInfo(Transport::CudaIpc);
CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess));
}
}

View File

@@ -5,16 +5,16 @@
#include <iostream>
#include <mpi.h>
mscclpp::TransportFlags findIb(int localRank){
mscclpp::TransportFlags IBs[] = {
mscclpp::TransportIB0,
mscclpp::TransportIB1,
mscclpp::TransportIB2,
mscclpp::TransportIB3,
mscclpp::TransportIB4,
mscclpp::TransportIB5,
mscclpp::TransportIB6,
mscclpp::TransportIB7
mscclpp::Transport findIb(int localRank){
mscclpp::Transport IBs[] = {
mscclpp::Transport::IB0,
mscclpp::Transport::IB1,
mscclpp::Transport::IB2,
mscclpp::Transport::IB3,
mscclpp::Transport::IB4,
mscclpp::Transport::IB5,
mscclpp::Transport::IB6,
mscclpp::Transport::IB7
};
return IBs[localRank];
}
@@ -31,8 +31,7 @@ void test_communicator(int rank, int worldSize, int nranksPerNode){
for (int i = 0; i < worldSize; i++){
if (i != rank){
if (i / nranksPerNode == rank / nranksPerNode){
printf("i %d rank %d nranksPerNode %d\n", i, rank, nranksPerNode);
auto connect = communicator->connect(i, 0, mscclpp::TransportCudaIpc);
auto connect = communicator->connect(i, 0, mscclpp::Transport::CudaIpc);
} else {
auto connect = communicator->connect(i, 0, findIb(rank % nranksPerNode));
}