#include #include #include #include #include #include #include #include #include #include #include #include namespace nb = nanobind; using namespace nb::literals; // This is a poorman's substitute for std::format, which is a C++20 feature. template std::string string_format(const std::string& format, Args... args) { // Shutup format warning. #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wformat-security" // Dry-run to the get the buffer size: // Extra space for '\0' int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + 1; if (size_s <= 0) { throw std::runtime_error("Error during formatting."); } // allocate buffer auto size = static_cast(size_s); std::unique_ptr buf(new char[size]); // actually format std::snprintf(buf.get(), size, format.c_str(), args...); // Bulid the return string. // We don't want the '\0' inside return std::string(buf.get(), buf.get() + size - 1); #pragma GCC diagnostic pop } // Maybe return the value, maybe throw an exception. template void checkResult( mscclppResult_t status, const std::string& format, Args... args) { switch (status) { case mscclppSuccess: return; case mscclppUnhandledCudaError: case mscclppSystemError: case mscclppInternalError: case mscclppRemoteError: case mscclppInProgress: case mscclppNumResults: throw std::runtime_error( string_format(format, args...) + " : " + std::string(mscclppGetErrorString(status))); case mscclppInvalidArgument: case mscclppInvalidUsage: default: throw std::invalid_argument( string_format(format, args...) + " : " + std::string(mscclppGetErrorString(status))); } } #define RETRY(C, ...) \ { \ mscclppResult_t res; \ do { \ res = (C); \ } while (res == mscclppInProgress); \ checkResult(res, __VA_ARGS__); \ } // Maybe return the value, maybe throw an exception. template Val maybe( mscclppResult_t status, Val val, const std::string& format, Args... args) { checkResult(status, format, args...); return val; } // Wrapper around connection state. struct _Comm { int _rank; int _world_size; mscclppComm_t _handle; bool _is_open; bool _proxies_running; public: _Comm(int rank, int world_size, mscclppComm_t handle) : _rank(rank), _world_size(world_size), _handle(handle), _is_open(true), _proxies_running(false) {} ~_Comm() { close(); } // Close should be safe to call on a closed handle. void close() { if (_is_open) { if (_proxies_running) { mscclppProxyStop(_handle); _proxies_running = false; } checkResult(mscclppCommDestroy(_handle), "Failed to close comm channel"); _handle = NULL; _is_open = false; _rank = -1; _world_size = -1; } } void check_open() { if (!_is_open) { throw std::invalid_argument("_Comm is not open"); } } }; struct _P2PHandle { struct mscclppRegisteredMemoryP2P _rmP2P; struct mscclppIbMr _ibmr; _P2PHandle() : _rmP2P({0}), _ibmr({0}) {} _P2PHandle(const mscclppRegisteredMemoryP2P& p2p) : _ibmr({0}) { _rmP2P = p2p; if (_rmP2P.IbMr != nullptr) { _ibmr = *_rmP2P.IbMr; _rmP2P.IbMr = &_ibmr; } } }; nb::callable _log_callback; void _LogHandler(const char* msg) { if (_log_callback) { nb::gil_scoped_acquire guard; _log_callback(msg); } } static const std::string DOC_MscclppUniqueId = "MSCCLPP Unique Id; used by the MPI Interface"; static const std::string DOC__Comm = "MSCCLPP Communications Handle"; static const std::string DOC__P2PHandle = "MSCCLPP P2P MR Handle"; NB_MODULE(_py_mscclpp, m) { m.doc() = "Python bindings for MSCCLPP: which is not NCCL"; m.attr("MSCCLPP_UNIQUE_ID_BYTES") = MSCCLPP_UNIQUE_ID_BYTES; m.def("_bind_log_handler", [](nb::callable cb) -> void { _log_callback = nb::borrow(cb); mscclppSetLogHandler(_LogHandler); }); m.def("_release_log_handler", []() -> void { _log_callback.reset(); mscclppSetLogHandler(mscclppDefaultLogHandler); }); nb::enum_(m, "TransportType") .value("P2P", mscclppTransport_t::mscclppTransportP2P) .value("SHM", mscclppTransport_t::mscclppTransportSHM) .value("IB", mscclppTransport_t::mscclppTransportIB); nb::class_(m, "MscclppUniqueId") .def_ro_static("__doc__", &DOC_MscclppUniqueId) .def_static( "from_context", []() -> mscclppUniqueId { mscclppUniqueId uniqueId; return maybe( mscclppGetUniqueId(&uniqueId), uniqueId, "Failed to get MSCCLP Unique Id."); }, nb::call_guard()) .def_static( "from_bytes", [](nb::bytes source) -> mscclppUniqueId { if (source.size() != MSCCLPP_UNIQUE_ID_BYTES) { throw std::invalid_argument(string_format( "Requires exactly %d bytes; found %d", MSCCLPP_UNIQUE_ID_BYTES, source.size())); } mscclppUniqueId uniqueId; std::memcpy( uniqueId.internal, source.c_str(), sizeof(uniqueId.internal)); return uniqueId; }) .def("bytes", [](mscclppUniqueId id) { return nb::bytes(id.internal, sizeof(id.internal)); }); nb::class_<_P2PHandle>(m, "_P2PHandle") .def_ro_static("__doc__", &DOC__P2PHandle); nb::class_<_Comm>(m, "_Comm") .def_ro_static("__doc__", &DOC__Comm) .def_static( "init_rank_from_address", [](const std::string& address, int rank, int world_size) -> _Comm* { mscclppComm_t handle; checkResult( mscclppCommInitRank(&handle, world_size, address.c_str(), rank), "Failed to initialize comms: %s rank=%d world_size=%d", address, rank, world_size); return new _Comm(rank, world_size, handle); }, nb::rv_policy::take_ownership, nb::call_guard(), "address"_a, "rank"_a, "world_size"_a, "Initialize comms given an IP address, rank, and world_size") .def_static( "init_rank_from_id", [](const mscclppUniqueId& id, int rank, int world_size) -> _Comm* { mscclppComm_t handle; checkResult( mscclppCommInitRankFromId(&handle, world_size, id, rank), "Failed to initialize comms: %02X%s rank=%d world_size=%d", id.internal, rank, world_size); return new _Comm(rank, world_size, handle); }, nb::rv_policy::take_ownership, nb::call_guard(), "id"_a, "rank"_a, "world_size"_a, "Initialize comms given u UniqueID, rank, and world_size") .def( "opened", [](_Comm& comm) -> bool { return comm._is_open; }, "Is this comm object opened?") .def( "closed", [](_Comm& comm) -> bool { return !comm._is_open; }, "Is this comm object closed?") .def_ro("rank", &_Comm::_rank) .def_ro("world_size", &_Comm::_world_size) .def( "register_buffer", [](_Comm& comm, uint64_t local_buff, uint64_t buff_size) -> std::vector<_P2PHandle> { comm.check_open(); mscclppRegisteredMemory regMem; checkResult( mscclppRegisterBuffer( comm._handle, reinterpret_cast(local_buff), buff_size, ®Mem), "Registering buffer failed"); std::vector<_P2PHandle> handles; for (const auto& p2p : regMem.p2p) { handles.push_back(_P2PHandle(p2p)); } return handles; }, "local_buf"_a, "buff_size"_a, nb::call_guard(), "Register a buffer for P2P transfers.") .def( "connect", [](_Comm& comm, int remote_rank, int tag, uint64_t local_buff, uint64_t buff_size, mscclppTransport_t transport_type) -> void { comm.check_open(); RETRY( mscclppConnect( comm._handle, remote_rank, tag, reinterpret_cast(local_buff), buff_size, transport_type, NULL // ibDev ), "Connect failed"); }, "remote_rank"_a, "tag"_a, "local_buf"_a, "buff_size"_a, "transport_type"_a, nb::call_guard(), "Attach a local buffer to a remote connection.") .def( "connection_setup", [](_Comm& comm) -> void { comm.check_open(); RETRY( mscclppConnectionSetup(comm._handle), "Failed to setup MSCCLPP connection"); }, nb::call_guard(), "Run connection setup for MSCCLPP.") .def( "launch_proxies", [](_Comm& comm) -> void { comm.check_open(); if (comm._proxies_running) { throw std::invalid_argument("Proxy Threads Already Running"); } checkResult( mscclppProxyLaunch(comm._handle), "Failed to launch MSCCLPP proxy"); comm._proxies_running = true; }, nb::call_guard(), "Start the MSCCLPP proxy.") .def( "stop_proxies", [](_Comm& comm) -> void { comm.check_open(); if (comm._proxies_running) { checkResult( mscclppProxyStop(comm._handle), "Failed to stop MSCCLPP proxy"); comm._proxies_running = false; } }, nb::call_guard(), "Start the MSCCLPP proxy.") .def("close", &_Comm::close, nb::call_guard()) .def("__del__", &_Comm::close, nb::call_guard()) .def( "bootstrap_all_gather_int", [](_Comm& comm, int val) -> std::vector { std::vector buf(comm._world_size); buf[comm._rank] = val; mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int)); return buf; }, nb::call_guard(), "val"_a, "all-gather ints over the bootstrap connection.") .def( "all_gather_bytes", [](_Comm& comm, nb::bytes& item) -> std::vector { // First, all-gather the sizes of all bytes. std::vector sizes(comm._world_size); sizes[comm._rank] = item.size(); checkResult( mscclppBootstrapAllGather( comm._handle, sizes.data(), sizeof(size_t)), "bootstrapAllGather failed."); // Next, find the largest message to send. size_t max_size = *std::max_element(sizes.begin(), sizes.end()); // Allocate an all-gather buffer large enough for max * world_size. std::shared_ptr data_buf( new char[max_size * comm._world_size]); // Copy the local item into the buffer. std::memcpy( &data_buf[comm._rank * max_size], item.c_str(), item.size()); // all-gather the data buffer. checkResult( mscclppBootstrapAllGather( comm._handle, data_buf.get(), max_size), "bootstrapAllGather failed."); // Build a response vector. std::vector ret; for (int i = 0; i < comm._world_size; ++i) { // Copy out the relevant range of each item. ret.push_back(nb::bytes(&data_buf[i * max_size], sizes[i])); } return ret; }, nb::call_guard(), "item"_a, "all-gather bytes over the bootstrap connection; sizes do not need " "to match."); }