mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
basic init test
This commit is contained in:
@@ -247,5 +247,16 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
.def(
|
||||
"__del__",
|
||||
&MscclppComm::close,
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(
|
||||
"bootstrap_all_gather",
|
||||
[](MscclppComm &comm, void *data, int size) {
|
||||
comm.check_open();
|
||||
return maybe(
|
||||
mscclppBootstrapAllGather(comm._handle, data, size),
|
||||
true,
|
||||
"Failed to stop MSCCLPP proxy");
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>());
|
||||
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import concurrent.futures
|
||||
import unittest
|
||||
import hamcrest
|
||||
|
||||
@@ -38,13 +39,41 @@ class UniqueIdTest(unittest.TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def all_gather_task(rank: int, world_size: int) -> None:
|
||||
comm_options = dict(
|
||||
address="127.0.0.1:50000",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
print(f'{comm_options=}', flush=True)
|
||||
|
||||
comm = mscclpp.MscclppComm.init_rank_from_address(**comm_options)
|
||||
|
||||
buf = bytearray(world_size)
|
||||
buf[rank] = rank
|
||||
|
||||
if False:
|
||||
# crashes, bad call structure..
|
||||
comm.bootstrap_all_gather(memoryview(buf), world_size)
|
||||
hamcrest.assert_that(
|
||||
buf,
|
||||
hamcrest.equal_to(b'\000\002'),
|
||||
)
|
||||
|
||||
comm.close()
|
||||
|
||||
|
||||
class CommsTest(unittest.TestCase):
|
||||
def _test(self) -> None:
|
||||
# this hangs forever
|
||||
comm = mscclpp.MscclppComm.init_rank_from_address(
|
||||
address="127.0.0.1:50000",
|
||||
rank=0,
|
||||
world_size=2,
|
||||
)
|
||||
comm.close()
|
||||
def test_all_gather(self) -> None:
|
||||
world_size = 2
|
||||
|
||||
tasks: list[concurrent.futures.Future[None]] = []
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=world_size) as pool:
|
||||
for rank in range(world_size):
|
||||
tasks.append(pool.submit(all_gather_task, rank, world_size))
|
||||
|
||||
for f in concurrent.futures.as_completed(tasks):
|
||||
f.result()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user