diff --git a/python/src/_py_mscclpp.cpp b/python/src/_py_mscclpp.cpp index 55c5848f..095ff2cf 100644 --- a/python/src/_py_mscclpp.cpp +++ b/python/src/_py_mscclpp.cpp @@ -247,5 +247,16 @@ NB_MODULE(_py_mscclpp, m) { .def( "__del__", &MscclppComm::close, + nb::call_guard()) + .def( + "bootstrap_all_gather", + [](MscclppComm &comm, void *data, int size) { + comm.check_open(); + return maybe( + mscclppBootstrapAllGather(comm._handle, data, size), + true, + "Failed to stop MSCCLPP proxy"); + }, nb::call_guard()); + } diff --git a/python/src/mscclpp/test_mscclpp.py b/python/src/mscclpp/test_mscclpp.py index e67f2770..a77707d8 100644 --- a/python/src/mscclpp/test_mscclpp.py +++ b/python/src/mscclpp/test_mscclpp.py @@ -1,3 +1,4 @@ +import concurrent.futures import unittest import hamcrest @@ -38,13 +39,41 @@ class UniqueIdTest(unittest.TestCase): ), ) +def all_gather_task(rank: int, world_size: int) -> None: + comm_options = dict( + address="127.0.0.1:50000", + rank=rank, + world_size=world_size, + ) + print(f'{comm_options=}', flush=True) + + comm = mscclpp.MscclppComm.init_rank_from_address(**comm_options) + + buf = bytearray(world_size) + buf[rank] = rank + + if False: + # crashes, bad call structure.. + comm.bootstrap_all_gather(memoryview(buf), world_size) + hamcrest.assert_that( + buf, + hamcrest.equal_to(b'\000\002'), + ) + + comm.close() + class CommsTest(unittest.TestCase): - def _test(self) -> None: - # this hangs forever - comm = mscclpp.MscclppComm.init_rank_from_address( - address="127.0.0.1:50000", - rank=0, - world_size=2, - ) - comm.close() + def test_all_gather(self) -> None: + world_size = 2 + + tasks: list[concurrent.futures.Future[None]] = [] + + with concurrent.futures.ProcessPoolExecutor(max_workers=world_size) as pool: + for rank in range(world_size): + tasks.append(pool.submit(all_gather_task, rank, world_size)) + + for f in concurrent.futures.as_completed(tasks): + f.result() + +