diff --git a/python/requirements.txt b/python/requirements.txt index c6704cfd..e08aa01b 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -5,3 +5,5 @@ pytest PyHamcrest nanobind + +torch diff --git a/python/src/_py_mscclpp.cpp b/python/src/_py_mscclpp.cpp index 60306b16..2163ee70 100644 --- a/python/src/_py_mscclpp.cpp +++ b/python/src/_py_mscclpp.cpp @@ -136,6 +136,11 @@ NB_MODULE(_py_mscclpp, m) { mscclppSetLogHandler(mscclppDefaultLogHandler); }); + nb::enum_(m, "TransportType") + .value("P2P", mscclppTransport_t::mscclppTransportP2P) + .value("SHM", mscclppTransport_t::mscclppTransportSHM) + .value("IB", mscclppTransport_t::mscclppTransportIB); + nb::class_(m, "MscclppUniqueId") .def_ro_static("__doc__", &DOC_MscclppUniqueId) .def_static( @@ -215,6 +220,33 @@ NB_MODULE(_py_mscclpp, m) { "Is this comm object closed?") .def_ro("rank", &_Comm::_rank) .def_ro("world_size", &_Comm::_world_size) + .def( + "connect", + [](_Comm& self, + int remote_rank, + int tag, + uint64_t local_buff, + uint64_t buff_size, + mscclppTransport_t transport_type) -> void { + checkResult( + mscclppConnect( + self._handle, + remote_rank, + tag, + reinterpret_cast(local_buff), + buff_size, + transport_type, + 0 // ibDev + ), + "Connect failed"); + }, + "remote_rank"_a, + "tag"_a, + "local_buf"_a, + "buff_size"_a, + "transport_type"_a, + nb::call_guard(), + "Attach a local buffer to a remote connection.") .def( "connection_setup", [](_Comm& comm) { @@ -222,7 +254,7 @@ NB_MODULE(_py_mscclpp, m) { return maybe( mscclppConnectionSetup(comm._handle), true, - "Failed to settup MSCCLPP connection"); + "Failed to setup MSCCLPP connection"); }, nb::call_guard(), "Run connection setup for MSCCLPP.") diff --git a/python/src/mscclpp/__init__.py b/python/src/mscclpp/__init__.py index c44cc967..a677b4c0 100644 --- a/python/src/mscclpp/__init__.py +++ b/python/src/mscclpp/__init__.py @@ -11,11 +11,14 @@ logger = logging.getLogger(__file__) from . import _py_mscclpp __all__ = ( + "Comm", "MscclppUniqueId", "MSCCLPP_UNIQUE_ID_BYTES", + "TransportType", ) _Comm = _py_mscclpp._Comm +TransportType = _py_mscclpp.TransportType MscclppUniqueId = _py_mscclpp.MscclppUniqueId MSCCLPP_UNIQUE_ID_BYTES = _py_mscclpp.MSCCLPP_UNIQUE_ID_BYTES @@ -100,8 +103,9 @@ class Comm: def close(self) -> None: """Close the connection.""" - self._comm.close() - self._comm = None + if self._comm: + self._comm.close() + self._comm = None @property def rank(self) -> int: @@ -149,3 +153,17 @@ class Comm: :return: a list of de-pickled objects. Note, the ret[rank] item will be a new copy. """ return [pickle.loads(b) for b in self.all_gather_bytes(pickle.dumps(item))] + + def connect( + self, remote_rank: int, tag: int, data_ptr, data_size, transport: int + ) -> None: + self._comm.connect( + remote_rank, + tag, + data_ptr, + data_size, + transport, + ) + + def connection_setup(self) -> None: + self._comm.connection_setup() diff --git a/python/src/mscclpp/tests/bootstrap_test.py b/python/src/mscclpp/tests/bootstrap_test.py index 061ebdb4..242bb2eb 100644 --- a/python/src/mscclpp/tests/bootstrap_test.py +++ b/python/src/mscclpp/tests/bootstrap_test.py @@ -1,7 +1,9 @@ import argparse +import os from dataclasses import dataclass import hamcrest +import torch import mscclpp @@ -11,26 +13,7 @@ class Example: rank: int -def main(): - p = argparse.ArgumentParser() - p.add_argument("--rank", type=int, required=True) - p.add_argument("--world_size", type=int, required=True) - p.add_argument("--port", default=50000) - options = p.parse_args() - - comm_options = dict( - address=f"127.0.0.1:{options.port}", - rank=options.rank, - world_size=options.world_size, - ) - print(f"{comm_options=}", flush=True) - - comm = mscclpp.Comm.init_rank_from_address(**comm_options) - # comm.connection_setup() - - hamcrest.assert_that(comm.rank, hamcrest.equal_to(options.rank)) - hamcrest.assert_that(comm.world_size, hamcrest.equal_to(options.world_size)) - +def _test_allgather_int(options: argparse.Namespace, comm: mscclpp.Comm): hamcrest.assert_that( comm.bootstrap_all_gather_int(options.rank + 42), hamcrest.equal_to( @@ -41,6 +24,8 @@ def main(): ), ) + +def _test_allgather_bytes(options: argparse.Namespace, comm: mscclpp.Comm): hamcrest.assert_that( comm.all_gather_bytes(b"abc" * (1 + options.rank)), hamcrest.equal_to( @@ -51,6 +36,8 @@ def main(): ), ) + +def _test_allgather_json(options: argparse.Namespace, comm: mscclpp.Comm): hamcrest.assert_that( comm.all_gather_json({"rank": options.rank}), hamcrest.equal_to( @@ -71,6 +58,8 @@ def main(): ), ) + +def _test_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm): hamcrest.assert_that( comm.all_gather_pickle(Example(rank=options.rank)), hamcrest.equal_to( @@ -81,7 +70,57 @@ def main(): ), ) - comm.close() + comm.connection_setup() + + +def _test_allgather_torch(options: argparse.Namespace, comm: mscclpp.Comm): + buf = torch.zeros( + [options.world_size], dtype=torch.int64, device="cuda" + ).contiguous() + rank = options.rank + tag = 0 + remote_rank = (options.rank + 1) % options.world_size + comm.connect( + remote_rank, + tag, + buf.data_ptr(), + buf.element_size() * buf.numel(), + mscclpp._py_mscclpp.TransportType.P2P, + ) + + comm.connection_setup() + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--rank", type=int, required=True) + p.add_argument("--world_size", type=int, required=True) + p.add_argument("--port", default=50000) + options = p.parse_args() + + os.environ["CUDA_VISIBLE_DEVICES"] = str(options.rank) + + comm_options = dict( + address=f"127.0.0.1:{options.port}", + rank=options.rank, + world_size=options.world_size, + ) + print(f"{comm_options=}", flush=True) + + comm = mscclpp.Comm.init_rank_from_address(**comm_options) + # comm.connection_setup() + + hamcrest.assert_that(comm.rank, hamcrest.equal_to(options.rank)) + hamcrest.assert_that(comm.world_size, hamcrest.equal_to(options.world_size)) + + try: + _test_allgather_int(options, comm) + _test_allgather_bytes(options, comm) + _test_allgather_json(options, comm) + _test_allgather_pickle(options, comm) + _test_allgather_torch(options, comm) + finally: + comm.close() if __name__ == "__main__":