diff --git a/python/src/mscclpp/tests/bootstrap_test.py b/python/src/mscclpp/tests/bootstrap_test.py index 4073287d..0b14f001 100644 --- a/python/src/mscclpp/tests/bootstrap_test.py +++ b/python/src/mscclpp/tests/bootstrap_test.py @@ -6,7 +6,6 @@ import hamcrest import torch import mscclpp -import time @dataclass @@ -14,7 +13,7 @@ class Example: rank: int -def _test_allgather_int(options: argparse.Namespace, comm: mscclpp.Comm): +def _test_bootstrap_allgather_int(options: argparse.Namespace, comm: mscclpp.Comm): hamcrest.assert_that( comm.bootstrap_all_gather_int(options.rank + 42), hamcrest.equal_to( @@ -26,7 +25,7 @@ def _test_allgather_int(options: argparse.Namespace, comm: mscclpp.Comm): ) -def _test_allgather_bytes(options: argparse.Namespace, comm: mscclpp.Comm): +def _test_bootstrap_allgather_bytes(options: argparse.Namespace, comm: mscclpp.Comm): hamcrest.assert_that( comm.all_gather_bytes(b"abc" * (1 + options.rank)), hamcrest.equal_to( @@ -38,7 +37,7 @@ def _test_allgather_bytes(options: argparse.Namespace, comm: mscclpp.Comm): ) -def _test_allgather_json(options: argparse.Namespace, comm: mscclpp.Comm): +def _test_bootstrap_allgather_json(options: argparse.Namespace, comm: mscclpp.Comm): hamcrest.assert_that( comm.all_gather_json({"rank": options.rank}), hamcrest.equal_to( @@ -60,7 +59,7 @@ def _test_allgather_json(options: argparse.Namespace, comm: mscclpp.Comm): ) -def _test_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm): +def _test_bootstrap_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm): hamcrest.assert_that( comm.all_gather_pickle(Example(rank=options.rank)), hamcrest.equal_to( @@ -74,7 +73,7 @@ def _test_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm): comm.connection_setup() -def _test_allgather_torch(options: argparse.Namespace, comm: mscclpp.Comm): +def _test_p2p_connect(options: argparse.Namespace, comm: mscclpp.Comm): rank = options.rank buf = torch.zeros([options.world_size], dtype=torch.int64) @@ -97,7 +96,6 @@ def _test_allgather_torch(options: argparse.Namespace, comm: mscclpp.Comm): ) torch.cuda.synchronize() - # time.sleep(3) comm.connection_setup() @@ -125,11 +123,11 @@ def main(): hamcrest.assert_that(comm.world_size, hamcrest.equal_to(options.world_size)) try: - _test_allgather_int(options, comm) - _test_allgather_bytes(options, comm) - _test_allgather_json(options, comm) - _test_allgather_pickle(options, comm) - _test_allgather_torch(options, comm) + _test_bootstrap_allgather_int(options, comm) + _test_bootstrap_allgather_bytes(options, comm) + _test_bootstrap_allgather_json(options, comm) + _test_bootstrap_allgather_pickle(options, comm) + _test_p2p_connect(options, comm) finally: comm.close()