diff --git a/python/src/_py_mscclpp.cpp b/python/src/_py_mscclpp.cpp index 3171122d..a0df9f28 100644 --- a/python/src/_py_mscclpp.cpp +++ b/python/src/_py_mscclpp.cpp @@ -128,16 +128,24 @@ struct _Comm { }; struct _P2PHandle { - struct mscclppRegisteredMemoryP2P _rmP2P; - struct mscclppIbMr _ibmr; + struct mscclppRegisteredMemoryP2P _rm; + struct mscclppIbMr _local_ibmr; - _P2PHandle() : _rmP2P({0}), _ibmr({0}) {} + _P2PHandle() : _rm({0}), _local_ibmr({0}) {} - _P2PHandle(const mscclppRegisteredMemoryP2P& p2p) : _ibmr({0}) { - _rmP2P = p2p; - if (_rmP2P.IbMr != nullptr) { - _ibmr = *_rmP2P.IbMr; - _rmP2P.IbMr = &_ibmr; + _P2PHandle(const mscclppRegisteredMemoryP2P& p2p) : _local_ibmr({0}) { + _rm = p2p; + if (_rm.IbMr != nullptr) { + _local_ibmr = *_rm.IbMr; + _rm.IbMr = &_local_ibmr; + } + } + + mscclppTransport_t transport() const { + if (_rm.IbMr != nullptr) { + return mscclppTransport_t::mscclppTransportIB; + } else { + return mscclppTransport_t::mscclppTransportP2P; } } }; @@ -209,7 +217,60 @@ NB_MODULE(_py_mscclpp, m) { }); nb::class_<_P2PHandle>(m, "_P2PHandle") - .def_ro_static("__doc__", &DOC__P2PHandle); + .def_ro_static("__doc__", &DOC__P2PHandle) + .def( + "transport", + &_P2PHandle::transport, + "Get the transport type of the handle") + .def( + "data_ptr", + [](const _P2PHandle& self) -> uint64_t { + if (self.transport() == mscclppTransport_t::mscclppTransportP2P) { + return reinterpret_cast(self._rm.remoteBuff); + } + throw std::invalid_argument( + "IB transport does not have a local data ptr"); + }, + "Get the local data pointer, only for P2P handles"); + + nb::class_(m, "_RegisteredMemory") + .def( + "handles", + [](const mscclppRegisteredMemory& self) -> std::vector<_P2PHandle> { + std::vector<_P2PHandle> handles; + for (const auto& p2p : self.p2p) { + handles.push_back(_P2PHandle(p2p)); + } + return handles; + }, + "Get the P2P handle for this memory") + .def( + "write_all", + [](const mscclppRegisteredMemory& self, + const _Comm& comm, + uint64_t src_data, + size_t size, + uint32_t src_offset = 0, + uint32_t dst_offset = 0, + int64_t stream = 0) -> void { + checkResult( + mscclppRegisteredBufferWrite( + comm._handle, + const_cast(&self), + reinterpret_cast(src_data), + size, + src_offset, + dst_offset, + stream), + "Failed to write to registered memory"); + }, + "comm"_a, + "src_data"_a, + "size"_a, + "src_offset"_a = 0, + "dst_offset"_a = 0, + "stream"_a = 0, + "Write to all bound targets in the buffer"); nb::class_<_Comm>(m, "_Comm") .def_ro_static("__doc__", &DOC__Comm) @@ -248,58 +309,54 @@ NB_MODULE(_py_mscclpp, m) { "id"_a, "rank"_a, "world_size"_a, - "Initialize comms given u UniqueID, rank, and world_size") + "Initialize comms given UniqueID, rank, and world_size") .def( "opened", - [](_Comm& comm) -> bool { return comm._is_open; }, + [](_Comm& self) -> bool { return self._is_open; }, "Is this comm object opened?") .def( "closed", - [](_Comm& comm) -> bool { return !comm._is_open; }, + [](_Comm& self) -> bool { return !self._is_open; }, "Is this comm object closed?") .def_ro("rank", &_Comm::_rank) .def_ro("world_size", &_Comm::_world_size) .def( "register_buffer", - [](_Comm& comm, - uint64_t local_buff, - uint64_t buff_size) -> std::vector<_P2PHandle> { - comm.check_open(); + [](_Comm& self, + uint64_t data_ptr, + uint64_t size) -> mscclppRegisteredMemory { + self.check_open(); mscclppRegisteredMemory regMem; checkResult( mscclppRegisterBuffer( - comm._handle, - reinterpret_cast(local_buff), - buff_size, + self._handle, + reinterpret_cast(data_ptr), + size, ®Mem), "Registering buffer failed"); - - std::vector<_P2PHandle> handles; - for (const auto& p2p : regMem.p2p) { - handles.push_back(_P2PHandle(p2p)); - } - return handles; + return regMem; + ; }, - "local_buf"_a, - "buff_size"_a, + "data_ptr"_a, + "size"_a, nb::call_guard(), "Register a buffer for P2P transfers.") .def( "connect", - [](_Comm& comm, + [](_Comm& self, int remote_rank, int tag, - uint64_t local_buff, - uint64_t buff_size, + uint64_t data_ptr, + uint64_t size, mscclppTransport_t transport_type) -> void { - comm.check_open(); + self.check_open(); RETRY( mscclppConnect( - comm._handle, + self._handle, remote_rank, tag, - reinterpret_cast(local_buff), - buff_size, + reinterpret_cast(data_ptr), + size, transport_type, NULL // ibDev ), @@ -307,44 +364,44 @@ NB_MODULE(_py_mscclpp, m) { }, "remote_rank"_a, "tag"_a, - "local_buf"_a, - "buff_size"_a, + "data_ptr"_a, + "size"_a, "transport_type"_a, nb::call_guard(), "Attach a local buffer to a remote connection.") .def( "connection_setup", - [](_Comm& comm) -> void { - comm.check_open(); + [](_Comm& self) -> void { + self.check_open(); RETRY( - mscclppConnectionSetup(comm._handle), + mscclppConnectionSetup(self._handle), "Failed to setup MSCCLPP connection"); }, nb::call_guard(), "Run connection setup for MSCCLPP.") .def( "launch_proxies", - [](_Comm& comm) -> void { - comm.check_open(); - if (comm._proxies_running) { + [](_Comm& self) -> void { + self.check_open(); + if (self._proxies_running) { throw std::invalid_argument("Proxy Threads Already Running"); } checkResult( - mscclppProxyLaunch(comm._handle), + mscclppProxyLaunch(self._handle), "Failed to launch MSCCLPP proxy"); - comm._proxies_running = true; + self._proxies_running = true; }, nb::call_guard(), "Start the MSCCLPP proxy.") .def( "stop_proxies", - [](_Comm& comm) -> void { - comm.check_open(); - if (comm._proxies_running) { + [](_Comm& self) -> void { + self.check_open(); + if (self._proxies_running) { checkResult( - mscclppProxyStop(comm._handle), + mscclppProxyStop(self._handle), "Failed to stop MSCCLPP proxy"); - comm._proxies_running = false; + self._proxies_running = false; } }, nb::call_guard(), @@ -353,10 +410,10 @@ NB_MODULE(_py_mscclpp, m) { .def("__del__", &_Comm::close, nb::call_guard()) .def( "bootstrap_all_gather_int", - [](_Comm& comm, int val) -> std::vector { - std::vector buf(comm._world_size); - buf[comm._rank] = val; - mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int)); + [](_Comm& self, int val) -> std::vector { + std::vector buf(self._world_size); + buf[self._rank] = val; + mscclppBootstrapAllGather(self._handle, buf.data(), sizeof(int)); return buf; }, nb::call_guard(), @@ -364,13 +421,13 @@ NB_MODULE(_py_mscclpp, m) { "all-gather ints over the bootstrap connection.") .def( "all_gather_bytes", - [](_Comm& comm, nb::bytes& item) -> std::vector { + [](_Comm& self, nb::bytes& item) -> std::vector { // First, all-gather the sizes of all bytes. - std::vector sizes(comm._world_size); - sizes[comm._rank] = item.size(); + std::vector sizes(self._world_size); + sizes[self._rank] = item.size(); checkResult( mscclppBootstrapAllGather( - comm._handle, sizes.data(), sizeof(size_t)), + self._handle, sizes.data(), sizeof(size_t)), "bootstrapAllGather failed."); // Next, find the largest message to send. @@ -378,21 +435,21 @@ NB_MODULE(_py_mscclpp, m) { // Allocate an all-gather buffer large enough for max * world_size. std::shared_ptr data_buf( - new char[max_size * comm._world_size]); + new char[max_size * self._world_size]); // Copy the local item into the buffer. std::memcpy( - &data_buf[comm._rank * max_size], item.c_str(), item.size()); + &data_buf[self._rank * max_size], item.c_str(), item.size()); // all-gather the data buffer. checkResult( mscclppBootstrapAllGather( - comm._handle, data_buf.get(), max_size), + self._handle, data_buf.get(), max_size), "bootstrapAllGather failed."); // Build a response vector. std::vector ret; - for (int i = 0; i < comm._world_size; ++i) { + for (int i = 0; i < self._world_size; ++i) { // Copy out the relevant range of each item. ret.push_back(nb::bytes(&data_buf[i * max_size], sizes[i])); } diff --git a/python/src/mscclpp/__init__.py b/python/src/mscclpp/__init__.py index 51c564d8..defec50f 100644 --- a/python/src/mscclpp/__init__.py +++ b/python/src/mscclpp/__init__.py @@ -10,14 +10,8 @@ logger = logging.getLogger(__file__) from . import _py_mscclpp -__all__ = ( - "Comm", - "MscclppUniqueId", - "MSCCLPP_UNIQUE_ID_BYTES", - "TransportType", -) - _Comm = _py_mscclpp._Comm +_RegisteredMemory = _py_mscclpp._RegisteredMemory _P2PHandle = _py_mscclpp._P2PHandle TransportType = _py_mscclpp.TransportType @@ -67,7 +61,7 @@ _setup_logging() class Comm: """Comm object; represents a mscclpp connection.""" - _comm: _Comm + _c_comm: _Comm @staticmethod def init_rank_from_address( @@ -97,16 +91,16 @@ class Comm: def __init__(self, *, _comm: _Comm): """Construct a Comm object wrapping an internal _Comm handle.""" - self._comm = _comm + self._c_comm = _comm def __del__(self) -> None: self.close() def close(self) -> None: """Close the connection.""" - if self._comm: - self._comm.close() - self._comm = None + if self._c_comm: + self._c_comm.close() + self._c_comm = None @property def rank(self) -> int: @@ -114,7 +108,7 @@ class Comm: Assumes the Comm is open. """ - return self._comm.rank + return self._c_comm.rank @property def world_size(self) -> int: @@ -122,11 +116,11 @@ class Comm: Assumes the Comm is open. """ - return self._comm.world_size + return self._c_comm.world_size def bootstrap_all_gather_int(self, val: int) -> list[int]: """AllGather an int value through the bootstrap interface.""" - return self._comm.bootstrap_all_gather_int(val) + return self._c_comm.bootstrap_all_gather_int(val) def all_gather_bytes(self, item: bytes) -> list[bytes]: """AllGather bytes (of different sizes) through the bootstrap interface. @@ -134,7 +128,7 @@ class Comm: :param item: the bytes object for this rank. :return: a list of bytes objects; the ret[rank] object will be a new copy. """ - return self._comm.all_gather_bytes(item) + return self._c_comm.all_gather_bytes(item) def all_gather_json(self, item: Any) -> list[Any]: """AllGather JSON objects through the bootstrap interface. @@ -163,7 +157,7 @@ class Comm: data_size: int, transport: int, ) -> None: - self._comm.connect( + self._c_comm.connect( remote_rank, tag, data_ptr, @@ -172,32 +166,79 @@ class Comm: ) def connection_setup(self) -> None: - self._comm.connection_setup() + self._c_comm.connection_setup() def launch_proxies(self) -> None: - self._comm.launch_proxies() + self._c_comm.launch_proxies() def stop_proxies(self) -> None: - self._comm.stop_proxies() + self._c_comm.stop_proxies() def register_buffer( self, - data_ptr, - data_size: int, - ) -> list[_P2PHandle]: + data_ptr: int, + size: int, + ) -> "RegisteredMemory": + return RegisteredMemory( + comm=self._c_comm, + rm=self._c_comm.register_buffer( + data_ptr=data_ptr, + size=size, + ), + ) + + +class RegisteredMemory: + _comm: Comm + _c_rm: _RegisteredMemory + + def __init__( + self, + *, + comm: Comm, + rm: _RegisteredMemory, + ): + self._comm = comm + self._c_rm = rm + + def handles(self) -> list["P2PHandle"]: return [ - P2PHandle(self, h) - for h in self._comm.register_buffer( - data_ptr, - data_size, + P2PHandle( + comm=self._comm, + handle=h, ) + for h in self._c_rm.handles() ] + def _write( + self, + src_ptr: int, + size: int, + *, + src_offset: int = 0, + dst_offset: int = 0, + stream: int = 0, + ) -> None: + self._c_rm.write_all( + comm=self._comm._c_comm, + src_data=src_ptr, + size=size, + src_offset=src_offset, + dst_offset=dst_offset, + stream=stream, + ) + class P2PHandle: _comm: Comm - _handle: _P2PHandle + _c_handle: _P2PHandle - def __init__(self, comm: Comm, handle: _P2PHandle): + def __init__(self, *, comm: Comm, handle: _P2PHandle): self._comm = comm - self._handle = handle + self._c_handle = handle + + def transport(self) -> TransportType: + return self._c_handle.transport() + + def data_ptr(self) -> int: + return self._c_handle.data_ptr() diff --git a/python/src/mscclpp/tests/bootstrap_test.py b/python/src/mscclpp/tests/bootstrap_test.py index 0d5d8aa1..9915e7be 100644 --- a/python/src/mscclpp/tests/bootstrap_test.py +++ b/python/src/mscclpp/tests/bootstrap_test.py @@ -95,11 +95,22 @@ def _test_rm(options: argparse.Namespace, comm: mscclpp.Comm): mscclpp.TransportType.P2P, ) - handles = comm.register_buffer(buf.data_ptr(), buf.element_size() * buf.numel()) + rm = comm.register_buffer(buf.data_ptr(), buf.element_size() * buf.numel()) + handles = rm.handles() hamcrest.assert_that( handles, hamcrest.has_length(options.world_size - 1), ) + for handle in handles: + hamcrest.assert_that( + handle.transport(), + hamcrest.equal_to(mscclpp.TransportType.P2P), + ) + # assuming P2P ... + hamcrest.assert_that( + handle.data_ptr(), + hamcrest.greater_than(0), + ) torch.cuda.synchronize()