diff --git a/src/communicator.cc b/src/communicator.cc index 6f458fe5..726efbc8 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -26,7 +26,7 @@ Communicator::Impl::~Impl() { ibContexts.clear(); } -IbCtx* Communicator::Impl::getIbContext(TransportFlags ibTransport) { +IbCtx* Communicator::Impl::getIbContext(Transport ibTransport) { // Find IB context or create it auto it = ibContexts.find(ibTransport); if (it == ibContexts.end()) { @@ -40,24 +40,6 @@ IbCtx* Communicator::Impl::getIbContext(TransportFlags ibTransport) { MSCCLPP_API_CPP Communicator::~Communicator() = default; -static mscclppTransport_t transportToCStyle(TransportFlags flags) { - switch (flags) { - case TransportIB0: - case TransportIB1: - case TransportIB2: - case TransportIB3: - case TransportIB4: - case TransportIB5: - case TransportIB6: - case TransportIB7: - return mscclppTransportIB; - case TransportCudaIpc: - return mscclppTransportP2P; - default: - throw std::runtime_error("Unsupported conversion"); - } -} - MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr bootstrap) : pimpl(std::make_unique(bootstrap)) {} MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) { @@ -72,20 +54,19 @@ RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportF return RegisteredMemory(std::make_shared(ptr, size, pimpl->comm->rank, transports, *pimpl)); } -MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, TransportFlags transport) { +MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, Transport transport) { std::shared_ptr conn; - if (transport | TransportCudaIpc) { + if (transport == Transport::CudaIpc) { // sanity check: make sure the IPC connection is being made within a node if (pimpl->rankToHash_[remoteRank] != pimpl->rankToHash_[pimpl->bootstrap_->getRank()]) { std::stringstream ss; - ss << "Cuda IPC connection can only be made within a node: " << remoteRank << " != " << pimpl->bootstrap_->getRank(); + ss << "Cuda IPC connection can only be made within a node: " << remoteRank << "(" << std::hex << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")" << " != " + << pimpl->bootstrap_->getRank() << "(" << std::hex << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")"; throw std::runtime_error(ss.str()); - } + } auto cudaIpcConn = std::make_shared(); conn = cudaIpcConn; - INFO(MSCCLPP_INIT, "Cuda IPC connection between %d(%lx) and %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], - remoteRank, pimpl->rankToHash_[remoteRank]); - } else if (transport | TransportAllIB) { + } else if (AllIBTransports.has(transport)) { auto ibConn = std::make_shared(remoteRank, tag, transport, *pimpl); conn = ibConn; INFO(MSCCLPP_INIT, "IB connection between %d(%lx) via %s and %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], diff --git a/src/connection.cc b/src/connection.cc index fc653c2a..031f63ec 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -6,8 +6,8 @@ namespace mscclpp { -void validateTransport(RegisteredMemory mem, TransportFlags transport) { - if ((mem.transports() & transport) == TransportNone) { +void validateTransport(RegisteredMemory mem, Transport transport) { + if (!mem.transports().has(transport)) { throw std::runtime_error("mem does not support transport"); } } @@ -28,12 +28,12 @@ CudaIpcConnection::~CudaIpcConnection() { cudaStreamDestroy(stream); } -TransportFlags CudaIpcConnection::transport() { - return TransportCudaIpc; +Transport CudaIpcConnection::transport() { + return Transport::CudaIpc; } -TransportFlags CudaIpcConnection::remoteTransport() { - return TransportCudaIpc; +Transport CudaIpcConnection::remoteTransport() { + return Transport::CudaIpc; } void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { @@ -54,7 +54,7 @@ void CudaIpcConnection::flush() { // IBConnection -IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank_(remoteRank), tag_(tag), transport_(transport), remoteTransport_(TransportNone) { +IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl) : remoteRank_(remoteRank), tag_(tag), transport_(transport), remoteTransport_(Transport::Unknown) { qp = commImpl.getIbContext(transport)->createQp(); } @@ -62,11 +62,11 @@ IBConnection::~IBConnection() { // TODO: Destroy QP? } -TransportFlags IBConnection::transport() { +Transport IBConnection::transport() { return transport_; } -TransportFlags IBConnection::remoteTransport() { +Transport IBConnection::remoteTransport() { return remoteTransport_; } @@ -115,13 +115,11 @@ void IBConnection::flush() { } void IBConnection::startSetup(std::shared_ptr bootstrap) { - // TODO(chhwang): temporarily disabled to compile bootstrap->send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank_, tag_); } void IBConnection::endSetup(std::shared_ptr bootstrap) { IbQpInfo qpInfo; - // TODO(chhwang): temporarily disabled to compile bootstrap->recv(&qpInfo, sizeof(qpInfo), remoteRank_, tag_); qp->rtr(qpInfo); qp->rts(); diff --git a/src/ib.cc b/src/ib.cc index fe3334a3..88d14d8e 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -368,33 +368,33 @@ int getIBDeviceCount() { return num; } -std::string getIBDeviceName(TransportFlags ibTransport) { +std::string getIBDeviceName(Transport ibTransport) { int num; struct ibv_device** devices = ibv_get_device_list(&num); int ibTransportIndex; switch (ibTransport) { // TODO: get rid of this ugly switch - case TransportIB0: + case Transport::IB0: ibTransportIndex = 0; break; - case TransportIB1: + case Transport::IB1: ibTransportIndex = 1; break; - case TransportIB2: + case Transport::IB2: ibTransportIndex = 2; break; - case TransportIB3: + case Transport::IB3: ibTransportIndex = 3; break; - case TransportIB4: + case Transport::IB4: ibTransportIndex = 4; break; - case TransportIB5: + case Transport::IB5: ibTransportIndex = 5; break; - case TransportIB6: + case Transport::IB6: ibTransportIndex = 6; break; - case TransportIB7: + case Transport::IB7: ibTransportIndex = 7; break; default: @@ -406,28 +406,28 @@ std::string getIBDeviceName(TransportFlags ibTransport) { return devices[ibTransportIndex]->name; } -TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName) { +Transport getIBTransportByDeviceName(const std::string& ibDeviceName) { int num; struct ibv_device** devices = ibv_get_device_list(&num); for (int i = 0; i < num; ++i) { if (ibDeviceName == devices[i]->name) { switch (i) { // TODO: get rid of this ugly switch case 0: - return TransportIB0; + return Transport::IB0; case 1: - return TransportIB1; + return Transport::IB1; case 2: - return TransportIB2; + return Transport::IB2; case 3: - return TransportIB3; + return Transport::IB3; case 4: - return TransportIB4; + return Transport::IB4; case 5: - return TransportIB5; + return Transport::IB5; case 6: - return TransportIB6; + return Transport::IB6; case 7: - return TransportIB7; + return Transport::IB7; default: throw std::runtime_error("IB device index out of range"); } diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index 5be00a67..e8e274b9 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -16,7 +16,7 @@ class ConnectionBase; struct Communicator::Impl { mscclppComm_t comm; std::vector> connections; - std::unordered_map> ibContexts; + std::unordered_map> ibContexts; std::shared_ptr bootstrap_; std::vector rankToHash_; @@ -24,7 +24,7 @@ struct Communicator::Impl { ~Impl(); - IbCtx* getIbContext(TransportFlags ibTransport); + IbCtx* getIbContext(Transport ibTransport); }; } // namespace mscclpp diff --git a/src/include/connection.hpp b/src/include/connection.hpp index 132726f7..bd08802c 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -24,9 +24,9 @@ public: ~CudaIpcConnection(); - TransportFlags transport() override; + Transport transport() override; - TransportFlags remoteTransport() override; + Transport remoteTransport() override; void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override; @@ -36,18 +36,18 @@ public: class IBConnection : public ConnectionBase { int remoteRank_; int tag_; - TransportFlags transport_; - TransportFlags remoteTransport_; + Transport transport_; + Transport remoteTransport_; IbQp* qp; public: - IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl); + IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl); ~IBConnection(); - TransportFlags transport() override; + Transport transport() override; - TransportFlags remoteTransport() override; + Transport remoteTransport() override; void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override; diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index f14e19c1..3b9c6d8d 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -9,6 +9,7 @@ #include #include #include +#include namespace mscclpp { @@ -63,24 +64,129 @@ private: */ std::unique_ptr getUniqueId(); -using TransportFlags = uint32_t; -const TransportFlags TransportNone = 0b0; -const TransportFlags TransportCudaIpc = 0b1; -const TransportFlags TransportIB0 = 0b10; -const TransportFlags TransportIB1 = 0b100; -const TransportFlags TransportIB2 = 0b1000; -const TransportFlags TransportIB3 = 0b10000; -const TransportFlags TransportIB4 = 0b100000; -const TransportFlags TransportIB5 = 0b1000000; -const TransportFlags TransportIB6 = 0b10000000; -const TransportFlags TransportIB7 = 0b100000000; +enum class Transport { + Unknown, + CudaIpc, + IB0, + IB1, + IB2, + IB3, + IB4, + IB5, + IB6, + IB7, + NumTransports +}; -const TransportFlags TransportAll = 0b111111111; -const TransportFlags TransportAllIB = 0b111111110; +namespace detail { + const size_t TransportFlagsSize = 10; + static_assert(TransportFlagsSize == static_cast(Transport::NumTransports), "TransportFlagsSize must match the number of transports"); + using TransportFlagsBase = std::bitset; +} + +class TransportFlags : private detail::TransportFlagsBase { +public: + TransportFlags() = default; + TransportFlags(Transport transport) : detail::TransportFlagsBase(1 << static_cast(transport)) {} + + bool has(Transport transport) const { + return detail::TransportFlagsBase::test(static_cast(transport)); + } + + bool none() const { + return detail::TransportFlagsBase::none(); + } + + bool any() const { + return detail::TransportFlagsBase::any(); + } + + bool all() const { + return detail::TransportFlagsBase::all(); + } + + size_t count() const { + return detail::TransportFlagsBase::count(); + } + + TransportFlags& operator|=(TransportFlags other) { + detail::TransportFlagsBase::operator|=(other); + return *this; + } + + TransportFlags operator|(TransportFlags other) const { + return TransportFlags(*this) |= other; + } + + TransportFlags operator|(Transport transport) const { + return *this | TransportFlags(transport); + } + + TransportFlags& operator&=(TransportFlags other) { + detail::TransportFlagsBase::operator&=(other); + return *this; + } + + TransportFlags operator&(TransportFlags other) const { + return TransportFlags(*this) &= other; + } + + TransportFlags operator&(Transport transport) const { + return *this & TransportFlags(transport); + } + + TransportFlags& operator^=(TransportFlags other) { + detail::TransportFlagsBase::operator^=(other); + return *this; + } + + TransportFlags operator^(TransportFlags other) const { + return TransportFlags(*this) ^= other; + } + + TransportFlags operator^(Transport transport) const { + return *this ^ TransportFlags(transport); + } + + TransportFlags operator~() const { + return TransportFlags(*this).flip(); + } + + bool operator==(TransportFlags other) const { + return detail::TransportFlagsBase::operator==(other); + } + + bool operator!=(TransportFlags other) const { + return detail::TransportFlagsBase::operator!=(other); + } + + detail::TransportFlagsBase toBitset() const { + return *this; + } + +private: + TransportFlags(detail::TransportFlagsBase bitset) : detail::TransportFlagsBase(bitset) {} +}; + +inline TransportFlags operator|(Transport transport1, Transport transport2) { + return TransportFlags(transport1) | transport2; +} + +inline TransportFlags operator&(Transport transport1, Transport transport2) { + return TransportFlags(transport1) & transport2; +} + +inline TransportFlags operator^(Transport transport1, Transport transport2) { + return TransportFlags(transport1) ^ transport2; +} + +const TransportFlags NoTransports = TransportFlags(); +const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transport::IB2 | Transport::IB3 | Transport::IB4 | Transport::IB5 | Transport::IB6 | Transport::IB7; +const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc; int getIBDeviceCount(); -std::string getIBDeviceName(TransportFlags ibTransport); -TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName); +std::string getIBDeviceName(Transport ibTransport); +Transport getIBTransportByDeviceName(const std::string& ibDeviceName); class Communicator; class Connection; @@ -111,9 +217,9 @@ public: virtual void flush() = 0; - virtual TransportFlags transport() = 0; + virtual Transport transport() = 0; - virtual TransportFlags remoteTransport() = 0; + virtual Transport remoteTransport() = 0; protected: static std::shared_ptr getRegisteredMemoryImpl(RegisteredMemory&); @@ -166,7 +272,7 @@ public: * transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB) * ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P. */ - std::shared_ptr connect(int remoteRank, int tag, TransportFlags transport); + std::shared_ptr connect(int remoteRank, int tag, Transport transport); /* Establish all connections declared by connect(). This function must be called after all connect() * calls are made. This function ensures that all remote ranks are ready to communicate when it returns. @@ -180,4 +286,13 @@ private: } // namespace mscclpp +namespace std { + template <> + struct hash { + size_t operator()(const mscclpp::TransportFlags& flags) const { + return hash()(flags.toBitset()); + } + }; +} + #endif // MSCCLPP_H_ diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index d2270d46..afe42da4 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -10,7 +10,7 @@ namespace mscclpp { struct TransportInfo { - TransportFlags transport; + Transport transport; // TODO: rewrite this using std::variant or something bool ibLocal; @@ -31,7 +31,7 @@ struct RegisteredMemory::Impl { Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl); Impl(const std::vector& data); - TransportInfo& getTransportInfo(TransportFlags transport) { + TransportInfo& getTransportInfo(Transport transport) { for (auto& entry : transportInfos) { if (entry.transport == transport) { return entry; diff --git a/src/registered_memory.cc b/src/registered_memory.cc index f0db85ce..b26ea2d5 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -5,17 +5,17 @@ namespace mscclpp { RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) : data(data), size(size), rank(rank), transports(transports) { - if (transports & TransportCudaIpc) { + if (transports.has(Transport::CudaIpc)) { TransportInfo transportInfo; - transportInfo.transport = TransportCudaIpc; + transportInfo.transport = Transport::CudaIpc; cudaIpcMemHandle_t handle; // TODO: translate data to a base pointer CUDATHROW(cudaIpcGetMemHandle(&handle, data)); transportInfo.cudaIpcHandle = handle; this->transportInfos.push_back(transportInfo); } - if (transports & TransportAllIB) { - auto addIb = [&](TransportFlags ibTransport) { + if ((transports & AllIBTransports).any()) { + auto addIb = [&](Transport ibTransport) { TransportInfo transportInfo; transportInfo.transport = ibTransport; const IbMr* mr = commImpl.getIbContext(ibTransport)->registerMr(data, size); @@ -23,14 +23,14 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t transportInfo.ibLocal = true; this->transportInfos.push_back(transportInfo); }; - if (transports & TransportIB0) addIb(TransportIB0); - if (transports & TransportIB1) addIb(TransportIB1); - if (transports & TransportIB2) addIb(TransportIB2); - if (transports & TransportIB3) addIb(TransportIB3); - if (transports & TransportIB4) addIb(TransportIB4); - if (transports & TransportIB5) addIb(TransportIB5); - if (transports & TransportIB6) addIb(TransportIB6); - if (transports & TransportIB7) addIb(TransportIB7); + if (transports.has(Transport::IB0)) addIb(Transport::IB0); + if (transports.has(Transport::IB1)) addIb(Transport::IB1); + if (transports.has(Transport::IB2)) addIb(Transport::IB2); + if (transports.has(Transport::IB3)) addIb(Transport::IB3); + if (transports.has(Transport::IB4)) addIb(Transport::IB4); + if (transports.has(Transport::IB5)) addIb(Transport::IB5); + if (transports.has(Transport::IB6)) addIb(Transport::IB6); + if (transports.has(Transport::IB7)) addIb(Transport::IB7); } } @@ -66,9 +66,9 @@ std::vector RegisteredMemory::serialize() { std::copy_n(reinterpret_cast(&transportCount), sizeof(transportCount), std::back_inserter(result)); for (auto& entry : pimpl->transportInfos) { std::copy_n(reinterpret_cast(&entry.transport), sizeof(entry.transport), std::back_inserter(result)); - if (entry.transport == TransportCudaIpc) { + if (entry.transport == Transport::CudaIpc) { std::copy_n(reinterpret_cast(&entry.cudaIpcHandle), sizeof(entry.cudaIpcHandle), std::back_inserter(result)); - } else if (entry.transport & TransportAllIB) { + } else if (AllIBTransports.has(entry.transport)) { std::copy_n(reinterpret_cast(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result)); } else { throw std::runtime_error("Unknown transport"); @@ -96,12 +96,12 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { TransportInfo transportInfo; std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast(&transportInfo.transport)); it += sizeof(transportInfo.transport); - if (transportInfo.transport & TransportCudaIpc) { + if (transportInfo.transport == Transport::CudaIpc) { cudaIpcMemHandle_t handle; std::copy_n(it, sizeof(handle), reinterpret_cast(&handle)); it += sizeof(handle); transportInfo.cudaIpcHandle = handle; - } else if (transportInfo.transport & TransportAllIB) { + } else if (AllIBTransports.has(transportInfo.transport)) { IbMrInfo info; std::copy_n(it, sizeof(info), reinterpret_cast(&info)); it += sizeof(info); @@ -116,8 +116,8 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { throw std::runtime_error("Deserialization failed"); } - if (transports & TransportCudaIpc) { - auto entry = getTransportInfo(TransportCudaIpc); + if (transports.has(Transport::CudaIpc)) { + auto entry = getTransportInfo(Transport::CudaIpc); CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); } } diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index 05595313..1f14ca79 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -5,16 +5,16 @@ #include #include -mscclpp::TransportFlags findIb(int localRank){ - mscclpp::TransportFlags IBs[] = { - mscclpp::TransportIB0, - mscclpp::TransportIB1, - mscclpp::TransportIB2, - mscclpp::TransportIB3, - mscclpp::TransportIB4, - mscclpp::TransportIB5, - mscclpp::TransportIB6, - mscclpp::TransportIB7 +mscclpp::Transport findIb(int localRank){ + mscclpp::Transport IBs[] = { + mscclpp::Transport::IB0, + mscclpp::Transport::IB1, + mscclpp::Transport::IB2, + mscclpp::Transport::IB3, + mscclpp::Transport::IB4, + mscclpp::Transport::IB5, + mscclpp::Transport::IB6, + mscclpp::Transport::IB7 }; return IBs[localRank]; } @@ -31,8 +31,7 @@ void test_communicator(int rank, int worldSize, int nranksPerNode){ for (int i = 0; i < worldSize; i++){ if (i != rank){ if (i / nranksPerNode == rank / nranksPerNode){ - printf("i %d rank %d nranksPerNode %d\n", i, rank, nranksPerNode); - auto connect = communicator->connect(i, 0, mscclpp::TransportCudaIpc); + auto connect = communicator->connect(i, 0, mscclpp::Transport::CudaIpc); } else { auto connect = communicator->connect(i, 0, findIb(rank % nranksPerNode)); }