diff --git a/python/src/_py_mscclpp.cpp b/python/src/_py_mscclpp.cpp index ad4e183f..39bc52e3 100644 --- a/python/src/_py_mscclpp.cpp +++ b/python/src/_py_mscclpp.cpp @@ -94,16 +94,21 @@ struct _Comm { int _world_size; mscclppComm_t _handle; bool _is_open; + bool _proxies_running; public: _Comm(int rank, int world_size, mscclppComm_t handle) - : _rank(rank), _world_size(world_size), _handle(handle), _is_open(true) {} + : _rank(rank), _world_size(world_size), _handle(handle), _is_open(true), _proxies_running(false) {} ~_Comm() { close(); } // Close should be safe to call on a closed handle. void close() { if (_is_open) { + if (_proxies_running) { + mscclppProxyStop(_handle); + _proxies_running = false; + } checkResult(mscclppCommDestroy(_handle), "Failed to close comm channel"); _handle = NULL; _is_open = false; @@ -138,20 +143,15 @@ NB_MODULE(_py_mscclpp, m) { m.attr("MSCCLPP_UNIQUE_ID_BYTES") = MSCCLPP_UNIQUE_ID_BYTES; - m.def("_bind_log_handler", [](nb::callable cb) { + m.def("_bind_log_handler", [](nb::callable cb) -> void { _log_callback = nb::borrow(cb); mscclppSetLogHandler(_LogHandler); }); - m.def("_release_log_handler", []() { + m.def("_release_log_handler", []() -> void { _log_callback.reset(); mscclppSetLogHandler(mscclppDefaultLogHandler); }); - m.def("_setup", []() { - int device; - cudaGetDevice(&device); - }); - nb::enum_(m, "TransportType") .value("P2P", mscclppTransport_t::mscclppTransportP2P) .value("SHM", mscclppTransport_t::mscclppTransportSHM) @@ -161,7 +161,7 @@ NB_MODULE(_py_mscclpp, m) { .def_ro_static("__doc__", &DOC_MscclppUniqueId) .def_static( "from_context", - []() { + []() -> mscclppUniqueId { mscclppUniqueId uniqueId; return maybe( mscclppGetUniqueId(&uniqueId), @@ -171,7 +171,7 @@ NB_MODULE(_py_mscclpp, m) { nb::call_guard()) .def_static( "from_bytes", - [](nb::bytes source) { + [](nb::bytes source) -> mscclppUniqueId { if (source.size() != MSCCLPP_UNIQUE_ID_BYTES) { throw std::invalid_argument(string_format( "Requires exactly %d bytes; found %d", @@ -192,7 +192,7 @@ NB_MODULE(_py_mscclpp, m) { .def_ro_static("__doc__", &DOC__Comm) .def_static( "init_rank_from_address", - [](const std::string& address, int rank, int world_size) { + [](const std::string& address, int rank, int world_size) -> _Comm* { mscclppComm_t handle; checkResult( mscclppCommInitRank(&handle, world_size, address.c_str(), rank), @@ -210,7 +210,7 @@ NB_MODULE(_py_mscclpp, m) { "Initialize comms given an IP address, rank, and world_size") .def_static( "init_rank_from_id", - [](const mscclppUniqueId& id, int rank, int world_size) { + [](const mscclppUniqueId& id, int rank, int world_size) -> _Comm* { mscclppComm_t handle; checkResult( mscclppCommInitRankFromId(&handle, world_size, id, rank), @@ -228,25 +228,28 @@ NB_MODULE(_py_mscclpp, m) { "Initialize comms given u UniqueID, rank, and world_size") .def( "opened", - [](_Comm& comm) { return comm._is_open; }, + [](_Comm& comm) -> bool { return comm._is_open; }, "Is this comm object opened?") .def( "closed", - [](_Comm& comm) { return !comm._is_open; }, + [](_Comm& comm) -> bool { return !comm._is_open; }, "Is this comm object closed?") .def_ro("rank", &_Comm::_rank) .def_ro("world_size", &_Comm::_world_size) .def( "connect", - [](_Comm& self, + [](_Comm& comm, int remote_rank, int tag, uint64_t local_buff, uint64_t buff_size, mscclppTransport_t transport_type) -> void { + if (comm._proxies_running) { + throw std::invalid_argument("Proxy Threads Already Running"); + } RETRY( mscclppConnect( - self._handle, + comm._handle, remote_rank, tag, reinterpret_cast(local_buff), @@ -273,24 +276,29 @@ NB_MODULE(_py_mscclpp, m) { nb::call_guard(), "Run connection setup for MSCCLPP.") .def( - "launch_proxy", - [](_Comm& comm) { + "launch_proxies", + [](_Comm& comm) -> void { comm.check_open(); - return maybe( + if (comm._proxies_running) { + throw std::invalid_argument("Proxy Threads Already Running"); + } + checkResult( mscclppProxyLaunch(comm._handle), - true, "Failed to launch MSCCLPP proxy"); + comm._proxies_running = true; }, nb::call_guard(), "Start the MSCCLPP proxy.") .def( - "stop_proxy", - [](_Comm& comm) { + "stop_proxies", + [](_Comm& comm) -> void { comm.check_open(); - return maybe( - mscclppProxyStop(comm._handle), - true, - "Failed to stop MSCCLPP proxy"); + if (comm._proxies_running) { + checkResult( + mscclppProxyStop(comm._handle), + "Failed to stop MSCCLPP proxy"); + comm._proxies_running = false; + } }, nb::call_guard(), "Start the MSCCLPP proxy.") @@ -318,7 +326,7 @@ NB_MODULE(_py_mscclpp, m) { "all-gather ints over the bootstrap connection.") .def( "all_gather_bytes", - [](_Comm& comm, nb::bytes& item) { + [](_Comm& comm, nb::bytes& item) -> std::vector { // First, all-gather the sizes of all bytes. std::vector sizes(comm._world_size); sizes[comm._rank] = item.size(); diff --git a/python/src/mscclpp/__init__.py b/python/src/mscclpp/__init__.py index 4943fa5e..cbb84c2c 100644 --- a/python/src/mscclpp/__init__.py +++ b/python/src/mscclpp/__init__.py @@ -10,8 +10,6 @@ logger = logging.getLogger(__file__) from . import _py_mscclpp -_py_mscclpp._setup() - __all__ = ( "Comm", "MscclppUniqueId", @@ -156,7 +154,12 @@ class Comm: return [pickle.loads(b) for b in self.all_gather_bytes(pickle.dumps(item))] def connect( - self, remote_rank: int, tag: int, data_ptr, data_size, transport: int + self, + remote_rank: int, + tag: int, + data_ptr, + data_size: int, + transport: int, ) -> None: self._comm.connect( remote_rank, @@ -168,3 +171,9 @@ class Comm: def connection_setup(self) -> None: self._comm.connection_setup() + + def launch_proxies(self) -> None: + self._comm.launch_proxies() + + def stop_proxies(self) -> None: + self._comm.stop_proxies() diff --git a/python/src/mscclpp/tests/bootstrap_test.py b/python/src/mscclpp/tests/bootstrap_test.py index 0b14f001..6f5c5ec7 100644 --- a/python/src/mscclpp/tests/bootstrap_test.py +++ b/python/src/mscclpp/tests/bootstrap_test.py @@ -99,6 +99,10 @@ def _test_p2p_connect(options: argparse.Namespace, comm: mscclpp.Comm): comm.connection_setup() + comm.launch_proxies() + comm.stop_proxies() + + def main(): p = argparse.ArgumentParser()