diff --git a/include/mscclpp/algorithm.hpp b/include/mscclpp/algorithm.hpp index 6acf3ea9..1f6a5708 100644 --- a/include/mscclpp/algorithm.hpp +++ b/include/mscclpp/algorithm.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include namespace mscclpp { @@ -86,22 +87,15 @@ class Algorithm { namespace std { -// Refer https://www.boost.org/doc/libs/1_86_0/libs/container_hash/doc/html/hash.html#combine -template -inline void hash_combine(std::size_t& seed, const T& value) { - std::hash hasher; - seed ^= hasher(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); -} - template <> struct hash { std::size_t operator()(const mscclpp::AlgorithmCtxKey& key) const { std::size_t seed = 42; - hash_combine(seed, key.baseSendBuff); - hash_combine(seed, key.baseRecvBuff); - hash_combine(seed, key.baseSendSize); - hash_combine(seed, key.baseRecvSize); - hash_combine(seed, key.tag); + mscclpp::detail::hashCombine(seed, key.baseSendBuff); + mscclpp::detail::hashCombine(seed, key.baseRecvBuff); + mscclpp::detail::hashCombine(seed, key.baseSendSize); + mscclpp::detail::hashCombine(seed, key.baseRecvSize); + mscclpp::detail::hashCombine(seed, key.tag); return seed; } }; diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index bc0c1328..a4273f58 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -947,18 +947,24 @@ DeviceHandle> deviceHandle(T&& t) { template using PacketPayload = typename T::Payload; +/// Convert Transport to string and output to stream. +/// @param os Output stream. +/// @param transport Input transport. +/// @return Output stream. +std::ostream& operator<<(std::ostream& os, const Transport& transport); + +/// Convert DeviceType to string and output to stream. +/// @param os Output stream. +/// @param deviceType Input device type. +/// @return Output stream. +std::ostream& operator<<(std::ostream& os, const DeviceType& deviceType); + +/// Convert Device to string and output to stream. +/// @param os Output stream. +/// @param device Input device. +/// @return Output stream. +std::ostream& operator<<(std::ostream& os, const Device& device); + } // namespace mscclpp -namespace std { - -std::string to_string(const mscclpp::Transport& transport); - -std::string to_string(const mscclpp::Device& device); - -/// Specialization of the std::hash template for mscclpp::TransportFlags. -template <> -struct hash; - -} // namespace std - #endif // MSCCLPP_CORE_HPP_ diff --git a/include/mscclpp/utils.hpp b/include/mscclpp/utils.hpp index 882622fc..ffe269da 100644 --- a/include/mscclpp/utils.hpp +++ b/include/mscclpp/utils.hpp @@ -4,10 +4,21 @@ #ifndef MSCCLPP_UTILS_HPP_ #define MSCCLPP_UTILS_HPP_ +#include #include #include namespace mscclpp { +namespace detail { + +// Refer https://www.boost.org/doc/libs/1_86_0/libs/container_hash/doc/html/hash.html#combine +template +inline void hashCombine(std::size_t& seed, const T& value) { + std::hash hasher; + seed ^= hasher(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +} // namespace detail /// Get the host name of the system. /// @param maxlen The maximum length of the returned string. diff --git a/python/csrc/core_py.cpp b/python/csrc/core_py.cpp index b95bd982..3ea591e7 100644 --- a/python/csrc/core_py.cpp +++ b/python/csrc/core_py.cpp @@ -9,6 +9,7 @@ #include #include +#include namespace nb = nanobind; using namespace mscclpp; @@ -122,7 +123,11 @@ void register_core(nb::module_& m) { .def(nb::init(), nb::arg("type"), nb::arg("id") = -1) .def_rw("type", &Device::type) .def_rw("id", &Device::id) - .def("__str__", [](const Device& self) { return std::to_string(self); }); + .def("__str__", [](const Device& self) { + std::stringstream ss; + ss << self; + return ss.str(); + }); nb::class_(m, "EndpointConfigIb") .def(nb::init<>()) diff --git a/src/context.cc b/src/context.cc index d725486d..9bf299d3 100644 --- a/src/context.cc +++ b/src/context.cc @@ -90,8 +90,8 @@ MSCCLPP_API_CPP Connection Context::connect(const Endpoint &localEndpoint, const if (localTransport != remoteTransport && !(AllIBTransports.has(localTransport) && AllIBTransports.has(remoteTransport))) { std::stringstream ss; - ss << "Transport mismatch between local (" << std::to_string(localTransport) << ") and remote (" - << std::to_string(remoteEndpoint.transport()) << ") endpoints"; + ss << "Transport mismatch between local (" << localTransport << ") and remote (" << remoteEndpoint.transport() + << ") endpoints"; throw Error(ss.str(), ErrorCode::InvalidUsage); } std::shared_ptr conn; diff --git a/src/core.cc b/src/core.cc index e6005815..2d67b988 100644 --- a/src/core.cc +++ b/src/core.cc @@ -89,27 +89,22 @@ const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transpo const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc | Transport::Ethernet; +std::ostream& operator<<(std::ostream& os, const Transport& transport) { + static const std::string TransportNames[] = {"UNK", "IPC", "IB0", "IB1", "IB2", "IB3", + "IB4", "IB5", "IB6", "IB7", "ETH", "NUM"}; + os << TransportNames[static_cast(transport)]; + return os; +} + +std::ostream& operator<<(std::ostream& os, const DeviceType& deviceType) { + static const std::string DeviceTypeNames[] = {"Unknown", "CPU", "GPU"}; + os << DeviceTypeNames[static_cast(deviceType)]; + return os; +} + +std::ostream& operator<<(std::ostream& os, const Device& device) { + os << "Device(type=" << device.type << ", id=" << device.id << ")"; + return os; +} + } // namespace mscclpp - -namespace std { - -std::string to_string(const mscclpp::Transport& transport) { - static const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2", "IB3", - "IB4", "IB5", "IB6", "IB7", "ETH", "NUM"}; - return TransportNames[static_cast(transport)]; -} - -std::string to_string(const mscclpp::Device& device) { - std::stringstream ss; - ss << "Device(type=" << to_string(device.type) << ", id=" << device.id << ")"; - return ss.str(); -} - -template <> -struct hash { - size_t operator()(const mscclpp::TransportFlags& flags) const { - return hash()(flags.toBitset()); - } -}; - -} // namespace std diff --git a/src/executor/executor.cc b/src/executor/executor.cc index 0669d1ea..bf2caf97 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include "debug.h" #include "execution_kernel.hpp" @@ -55,19 +56,12 @@ struct DeviceExecutionPlanKey { namespace std { -// Refer https://www.boost.org/doc/libs/1_86_0/libs/container_hash/doc/html/hash.html#combine -template -inline void hash_combine(std::size_t& seed, const T& value) { - std::hash hasher; - seed ^= hasher(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); -} - template <> struct hash> { std::size_t operator()(const std::pair& key) const { std::size_t seed = 42; - hash_combine(seed, static_cast(key.first)); - hash_combine(seed, key.second); + mscclpp::detail::hashCombine(seed, static_cast(key.first)); + mscclpp::detail::hashCombine(seed, key.second); return seed; } }; @@ -76,11 +70,11 @@ template <> struct hash { std::size_t operator()(const mscclpp::ExecutionContextKey& key) const { size_t seed = 42; - hash_combine(seed, key.sendBuff); - hash_combine(seed, key.recvBuff); - hash_combine(seed, key.sendBuffSize); - hash_combine(seed, key.recvBuffSize); - hash_combine(seed, key.plan); + mscclpp::detail::hashCombine(seed, key.sendBuff); + mscclpp::detail::hashCombine(seed, key.recvBuff); + mscclpp::detail::hashCombine(seed, key.sendBuffSize); + mscclpp::detail::hashCombine(seed, key.recvBuffSize); + mscclpp::detail::hashCombine(seed, key.plan); return seed; } }; @@ -89,10 +83,10 @@ template <> struct hash { std::size_t operator()(const mscclpp::DeviceExecutionPlanKey& key) const { std::size_t seed = 42; - hash_combine(seed, key.inputMessageSize); - hash_combine(seed, key.outputMessageSize); - hash_combine(seed, key.constSrcOffset); - hash_combine(seed, key.constDstOffset); + mscclpp::detail::hashCombine(seed, key.inputMessageSize); + mscclpp::detail::hashCombine(seed, key.outputMessageSize); + mscclpp::detail::hashCombine(seed, key.constSrcOffset); + mscclpp::detail::hashCombine(seed, key.constDstOffset); return seed; } }; diff --git a/src/include/execution_plan.hpp b/src/include/execution_plan.hpp index 53edb8bd..fccded02 100644 --- a/src/include/execution_plan.hpp +++ b/src/include/execution_plan.hpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -33,17 +34,20 @@ namespace std { template <> struct hash { std::size_t operator()(const mscclpp::ChannelKey& key) const { - return std::hash()(static_cast(key.bufferType)) ^ std::hash()(static_cast(key.channelType)); + std::size_t seed = 0; + mscclpp::detail::hashCombine(seed, static_cast(key.bufferType)); + mscclpp::detail::hashCombine(seed, static_cast(key.channelType)); + return seed; } }; template <> struct hash> { std::size_t operator()(const std::pair& key) const { - std::size_t h1 = std::hash()(key.first); - std::size_t h2 = std::hash()(static_cast(key.second)); - // Refer hash_combine from boost - return h1 ^ (h2 + 0x9e3779b9 + (h1 << 6) + (h1 >> 2)); + std::size_t seed = 0; + mscclpp::detail::hashCombine(seed, key.first); + mscclpp::detail::hashCombine(seed, static_cast(key.second)); + return seed; } }; } // namespace std