diff --git a/python/src/_py_mscclpp.cpp b/python/src/_py_mscclpp.cpp index 2c3727e0..bf8cc77d 100644 --- a/python/src/_py_mscclpp.cpp +++ b/python/src/_py_mscclpp.cpp @@ -1,7 +1,9 @@ #include #include #include +#include +#include #include #include #include @@ -52,12 +54,13 @@ template void checkResult(mscclppResult_t status, const std:: case mscclppRemoteError: case mscclppInProgress: case mscclppNumResults: + // throw std::runtime_error(string_format(format, args...) + " : " + std::string(mscclppGetErrorString(status))); throw std::runtime_error(string_format(format, args...)); case mscclppInvalidArgument: case mscclppInvalidUsage: default: - throw std::invalid_argument(string_format(format, args...)); + throw std::invalid_argument(string_format(format, args...)); } } @@ -72,6 +75,8 @@ Val maybe(mscclppResult_t status, Val val, const std::string& format, Args... ar // Wrapper around connection state. struct MscclppComm { + int _rank; + int _world_size; mscclppComm_t _handle; bool _is_open = false; @@ -137,6 +142,8 @@ NB_MODULE(_py_mscclpp, m) "init_rank_from_address", [](const std::string& address, int rank, int world_size) { MscclppComm comm = {0}; + comm._rank = rank; + comm._world_size = world_size; comm._is_open = true; return maybe(mscclppCommInitRank(&comm._handle, world_size, address.c_str(), rank), comm, "Failed to initialize comms: %s rank=%d world_size=%d", address, rank, world_size); @@ -147,6 +154,8 @@ NB_MODULE(_py_mscclpp, m) "init_rank_from_id", [](const mscclppUniqueId& id, int rank, int world_size) { MscclppComm comm = {0}; + comm._rank = rank; + comm._world_size = world_size; comm._is_open = true; return maybe(mscclppCommInitRankFromId(&comm._handle, world_size, id, rank), comm, "Failed to initialize comms: %02X%s rank=%d world_size=%d", id.internal, rank, world_size); @@ -157,22 +166,8 @@ NB_MODULE(_py_mscclpp, m) "opened", [](MscclppComm& comm) { return comm._is_open; }, "Is this comm object opened?") .def( "closed", [](MscclppComm& comm) { return !comm._is_open; }, "Is this comm object closed?") - .def( - "rank", - [](MscclppComm& comm) { - comm.check_open(); - int rank; - return maybe(mscclppCommRank(comm._handle, &rank), rank, "Failed to retrieve MSCCLPP rank"); - }, - nb::call_guard(), "The rank of this node.") - .def( - "size", - [](MscclppComm& comm) { - comm.check_open(); - int size; - return maybe(mscclppCommSize(comm._handle, &size), size, "Failed to retrieve MSCCLPP world size"); - }, - nb::call_guard(), "The world size of this node.") + .def_ro( "rank", &MscclppComm::_rank) + .def_ro( "world_size", &MscclppComm::_world_size) .def( "connection_setup", [](MscclppComm& comm) { @@ -196,6 +191,22 @@ NB_MODULE(_py_mscclpp, m) nb::call_guard(), "Start the MSCCLPP proxy.") .def("close", &MscclppComm::close, nb::call_guard()) .def("__del__", &MscclppComm::close, nb::call_guard()) + .def("connection_setup", + [](MscclppComm& comm) -> void { + comm.check_open(); + checkResult(mscclppConnectionSetup(comm._handle), "Connection Setup Failed"); + }, + nb::call_guard()) + .def( + "bootstrap_all_gather_int", + [](MscclppComm& comm, int val) -> std::vector { + std::vector buf(comm._world_size); + buf[comm._rank] = val; + // this call segfaults; disabling this call does not. + checkResult(mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int)), "All Gather Failed"); + return buf; + }, + nb::call_guard()) .def( "bootstrap_all_gather", [](MscclppComm& comm, void* data, int size) { diff --git a/python/src/mscclpp/test_mscclpp.py b/python/src/mscclpp/test_mscclpp.py index a77707d8..d7b23dee 100644 --- a/python/src/mscclpp/test_mscclpp.py +++ b/python/src/mscclpp/test_mscclpp.py @@ -1,9 +1,14 @@ +import os +import sys import concurrent.futures import unittest import hamcrest +import subprocess import mscclpp +MOD_DIR = os.path.dirname(__file__) +TESTS_DIR = os.path.join(MOD_DIR, "tests") class UniqueIdTest(unittest.TestCase): def test_no_constructor(self) -> None: @@ -39,41 +44,34 @@ class UniqueIdTest(unittest.TestCase): ), ) -def all_gather_task(rank: int, world_size: int) -> None: - comm_options = dict( - address="127.0.0.1:50000", - rank=rank, - world_size=world_size, - ) - print(f'{comm_options=}', flush=True) - - comm = mscclpp.MscclppComm.init_rank_from_address(**comm_options) - - buf = bytearray(world_size) - buf[rank] = rank - - if False: - # crashes, bad call structure.. - comm.bootstrap_all_gather(memoryview(buf), world_size) - hamcrest.assert_that( - buf, - hamcrest.equal_to(b'\000\002'), - ) - - comm.close() - - class CommsTest(unittest.TestCase): def test_all_gather(self) -> None: world_size = 2 tasks: list[concurrent.futures.Future[None]] = [] - with concurrent.futures.ProcessPoolExecutor(max_workers=world_size) as pool: + with concurrent.futures.ThreadPoolExecutor(max_workers=world_size) as pool: for rank in range(world_size): - tasks.append(pool.submit(all_gather_task, rank, world_size)) + tasks.append(pool.submit( + subprocess.check_output, + [ + "python", + "-m", + "mscclpp.tests.bootstrap_test", + f"--rank={rank}", + f"--world_size={world_size}", + ], + stderr=subprocess.STDOUT, + )) - for f in concurrent.futures.as_completed(tasks): - f.result() + errors = [] + for rank, f in enumerate(tasks): + try: + f.result() + except subprocess.CalledProcessError as e: + errors.append(f"{rank=}: " + e.output.decode('utf-8')) + + if errors: + raise AssertionError("\n\n".join(errors)) diff --git a/python/src/mscclpp/tests/__init__.py b/python/src/mscclpp/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/src/mscclpp/tests/bootstrap_test.py b/python/src/mscclpp/tests/bootstrap_test.py new file mode 100644 index 00000000..ac7d2be5 --- /dev/null +++ b/python/src/mscclpp/tests/bootstrap_test.py @@ -0,0 +1,39 @@ +import argparse +import hamcrest +import mscclpp + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--rank", type=int, required=True) + p.add_argument("--world_size", type=int, required=True) + p.add_argument("--port", default=50000) + options = p.parse_args() + + comm_options = dict( + address=f"127.0.0.1:{options.port}", + rank=options.rank, + world_size=options.world_size, + ) + print(f'{comm_options=}', flush=True) + + comm = mscclpp.MscclppComm.init_rank_from_address(**comm_options) + # comm.connection_setup() + + hamcrest.assert_that(comm.rank, hamcrest.equal_to(options.rank)) + hamcrest.assert_that(comm.world_size, hamcrest.equal_to(options.world_size)) + + hamcrest.assert_that( + comm.bootstrap_all_gather_int(options.rank + 42), + hamcrest.equal_to([ + 42, + 43, + ]), + ) + + buf = bytearray(world_size) + buf[rank] = rank + + comm.close() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/python/test.sh b/python/test.sh index 32b19bda..87fc5a8c 100755 --- a/python/test.sh +++ b/python/test.sh @@ -8,4 +8,5 @@ fi cmake --build build -pytest build/mscclpp +cd build +MSCCLPP_DEBUG=INFO pytest -s mscclpp