basic init test

This commit is contained in:
Crutcher Dunnavant
2023-03-24 23:45:29 +00:00
parent 8b6e35d5e0
commit 3b1abaaad1
2 changed files with 48 additions and 8 deletions

View File

@@ -247,5 +247,16 @@ NB_MODULE(_py_mscclpp, m) {
.def(
"__del__",
&MscclppComm::close,
nb::call_guard<nb::gil_scoped_release>())
.def(
"bootstrap_all_gather",
[](MscclppComm &comm, void *data, int size) {
comm.check_open();
return maybe(
mscclppBootstrapAllGather(comm._handle, data, size),
true,
"Failed to stop MSCCLPP proxy");
},
nb::call_guard<nb::gil_scoped_release>());
}

View File

@@ -1,3 +1,4 @@
import concurrent.futures
import unittest
import hamcrest
@@ -38,13 +39,41 @@ class UniqueIdTest(unittest.TestCase):
),
)
def all_gather_task(rank: int, world_size: int) -> None:
comm_options = dict(
address="127.0.0.1:50000",
rank=rank,
world_size=world_size,
)
print(f'{comm_options=}', flush=True)
comm = mscclpp.MscclppComm.init_rank_from_address(**comm_options)
buf = bytearray(world_size)
buf[rank] = rank
if False:
# crashes, bad call structure..
comm.bootstrap_all_gather(memoryview(buf), world_size)
hamcrest.assert_that(
buf,
hamcrest.equal_to(b'\000\002'),
)
comm.close()
class CommsTest(unittest.TestCase):
def _test(self) -> None:
# this hangs forever
comm = mscclpp.MscclppComm.init_rank_from_address(
address="127.0.0.1:50000",
rank=0,
world_size=2,
)
comm.close()
def test_all_gather(self) -> None:
world_size = 2
tasks: list[concurrent.futures.Future[None]] = []
with concurrent.futures.ProcessPoolExecutor(max_workers=world_size) as pool:
for rank in range(world_size):
tasks.append(pool.submit(all_gather_task, rank, world_size))
for f in concurrent.futures.as_completed(tasks):
f.result()